In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import YourDatasetClass
from U_NET import UNet
from model import Encoder

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((572, 572)),
    transforms.ToTensor()])

dataset = YourDatasetClass(transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

unet_model = UNet(in_channels=1, out_channels=1)
encoder_model = Encoder()

# loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(unet_model.parameters()) + list(encoder_model.parameters()), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in dataloader:
        optimizer.zero_grad()

        # Forward pass through U-Net
        unet_outputs = unet_model(images)
        
        # Forward pass through Encoder
        flattened_outputs = unet_outputs.view(unet_outputs.size(0), -1)
        encoder_outputs = encoder_model(flattened_outputs)
        
        loss = criterion(encoder_outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}")

# Save your trained models if needed
torch.save(unet_model.state_dict(), 'unet_model.pth')
torch.save(encoder_model.state_dict(), 'encoder_model.pth')
