In [53]:
#installations and imports

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

In [54]:
# Define the UNet architecture
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Define the contracting path
        self.contracting_path = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Define the expansive path
        self.expansive_path = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # Forward pass through contracting and expansive paths
        x = self.contracting_path(x)
        x = self.expansive_path(x)
        return x


In [55]:
# Define custom dataset class
class CustomDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image and mask
        image = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('RGB')
        # Apply transformations
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        return image, mask


In [56]:
# Define training parameters
batch_size = 3
lr = 0.001
num_epochs = 10

In [57]:
# Prepare dataset and dataloaders
# Replace the placeholders with paths to your dataset
image_paths = ['data/images/image0.jpg']
mask_paths = ['data/masks/image0.jpg']
transform = transforms.Compose([
    transforms.Resize((1632,1224)),
    transforms.ToTensor()
])
dataset = CustomDataset(image_paths, mask_paths, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [58]:
# Initialize the model, loss function, and optimizer
model = UNet()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


In [59]:
# Train the model
model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, masks in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')


Epoch [1/10], Loss: 0.6873
Epoch [2/10], Loss: 0.6832
Epoch [3/10], Loss: 0.6791
Epoch [4/10], Loss: 0.6718
Epoch [5/10], Loss: 0.6621
Epoch [6/10], Loss: 0.6491
Epoch [7/10], Loss: 0.6344
Epoch [8/10], Loss: 0.6173
Epoch [9/10], Loss: 0.5977
Epoch [10/10], Loss: 0.5775


In [60]:
# Save the trained model
torch.save(model.state_dict(), 'trained_unet_model.pth')