In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import os
from PIL import Image
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt

JPEG Images are the base images and Segmentation Class has masked images

In [2]:
#check with neil for pylance errors

class UNet(nn.Module):
    """
    UNet architecture implementation from the original paper:
    "U-Net: Convolutional Networks for Biomedical Image Segmentation"
    input_channels:
    num_classes: number of classes in pixels
    """

    def __init__(self, input_channels, num_classes=2):
        super(UNet, self).__init__()

        # Encoder (Contracting Path)
        # Each encoder block: Conv -> ReLU -> Conv -> ReLU
        #For example 572
        self.enc1_conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, padding=1)

        self.enc1_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        self.enc2_conv1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.enc2_conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        self.enc3_conv1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.enc3_conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        self.enc4_conv1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.enc4_conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        # Bottleneck
        self.bottleneck_conv1 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
        self.bottleneck_conv2 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)

        # Decoder (Expansive Path)
        # Each decoder block: UpConv -> Concat -> Conv -> ReLU -> Conv -> ReLU
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4_conv1 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.dec4_conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3_conv1 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.dec3_conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2_conv1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.dec2_conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1_conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.dec1_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # Final output layer
        self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1)
        # Max pooling layer (shared across encoder blocks)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder Block 1
        enc1 = F.relu(self.enc1_conv1(x))
        enc1 = F.relu(self.enc1_conv2(enc1))
        enc1_pool = self.pool(enc1)

        # Encoder Block 2
        enc2 = F.relu(self.enc2_conv1(enc1_pool))
        enc2 = F.relu(self.enc2_conv2(enc2))
        enc2_pool = self.pool(enc2)
        
        # Encoder Block 3
        enc3 = F.relu(self.enc3_conv1(enc2_pool))
        enc3 = F.relu(self.enc3_conv2(enc3))
        enc3_pool = self.pool(enc3)

        # Encoder Block 4
        enc4 = F.relu(self.enc4_conv1(enc3_pool))
        enc4 = F.relu(self.enc4_conv2(enc4))
        enc4_pool = self.pool(enc4)


        # Bottleneck
        bottleneck = F.relu(self.bottleneck_conv1(enc4_pool))
        bottleneck = F.relu(self.bottleneck_conv2(bottleneck))


        # Decoder Block 4
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)  # Skip connection
        dec4 = F.relu(self.dec4_conv1(dec4))
        dec4 = F.relu(self.dec4_conv2(dec4))

        # Decoder Block 3
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)  # Skip connection
        dec3 = F.relu(self.dec3_conv1(dec3))
        dec3 = F.relu(self.dec3_conv2(dec3))

        # Decoder Block 2
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)  # Skip connection
        dec2 = F.relu(self.dec2_conv1(dec2))
        dec2 = F.relu(self.dec2_conv2(dec2))

        # Decoder Block 1
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)  # Skip connection
        dec1 = F.relu(self.dec1_conv1(dec1))
        dec1 = F.relu(self.dec1_conv2(dec1))
        out = self.out_conv(dec1)

        return out

In [None]:
#need to convert image data into tensors?
device = 'cuda'
class PascalVocDataset(Dataset):
  def __init__(self, root_dir, annotation_dir, image_transform=None, mask_transform=None):
    self.root_dir = root_dir
    self.annotation_dir = annotation_dir
    self.image_transform = image_transform
    self.mask_transform =  mask_transform

    img_files = {os.path.splitext(f)[0]: f for f in os.listdir(root_dir) if not f.startswith('.')}
    mask_files = {os.path.splitext(f)[0]: f for f in os.listdir(annotation_dir) if not f.startswith('.')}

    self.filenames = sorted(list(set(img_files.keys()) & set(mask_files.keys())))
    self.img_map = img_files
    self.mask_map = mask_files

    if len(self.filenames) == 0:
        print(f"Error: 0 pairs found")
        print(f"Sample images found: {list(img_files.keys())[:5]}")
        print(f"Sample masks found: {list(mask_files.keys())[:5]}")
    else:
        print(f"Success: Found {len(self.filenames)} matching image-mask pairs.")

    


  def __len__(self):
    return len(self.filenames)

  def __getitem__(self,idx):
    fname = self.filenames[idx]
   
       
                
    img_path = os.path.join(self.root_dir, self.img_map[fname])
    mask_path = os.path.join(self.annotation_dir, self.mask_map[fname])

    

    image = Image.open(img_path).convert("RGB")
    mask = Image.open(mask_path).convert("L")

    if self.image_transform:
        image = self.image_transform(image)
    if self.mask_transform:
        mask = self.mask_transform(mask)

    return image, mask

   
    


root_dir = '/scratch/st-sielmann-1/agrobot/Image_Segmentation/pascal-voc-2012-dataset/VOC2012_train_val/VOC2012_train_val/JPEGImages'
annotation_dir = '/scratch/st-sielmann-1/agrobot/Image_Segmentation/pascal-voc-2012-dataset/VOC2012_train_val/VOC2012_train_val/SegmentationObject'

transform_image = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
])
transform_annotation = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: ((x * 255) > 0).long().squeeze(0)) #CHANGE #1: Added a squeeze to mask ans binary
])


customDataset = PascalVocDataset(root_dir, annotation_dir, transform_image, transform_annotation)
image_loader = DataLoader(customDataset, batch_size = 16, num_workers=8, shuffle=True,  pin_memory=True, persistent_workers=True)
# CHANGE #2: added workers, shuffling, and other params for quicker training

model = UNet(input_channels=3, num_classes=2).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001 )
scheduler = ReduceLROnPlateau(optimizer, "min", patience = 2, factor=0.5)
print(len(image_loader))

Success: Found 2913 matching image-mask pairs.
183


In [None]:
# Training loop
epoch_num = 500
best_loss = float('inf')

for epoch in range(epoch_num):
    model.train()
    running_loss = 0.0
    num_correct = 0
    total_pixels = 0
    
    for i, (image, mask) in enumerate(image_loader):
        image = image.to(device, non_blocking=True)
        annotation = mask.to(device, non_blocking=True)

        #CHANGE 3: Any time annotations are present for more than one class I squeeze, consistenly of actual use case
        if annotation.dim() == 4:
            annotation = annotation.squeeze(1)
        annotation = annotation.long()
        
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, annotation)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        running_loss += loss.item()
        
        with torch.no_grad():
            predicted = torch.argmax(output, dim=1)
            num_correct += (predicted == annotation).sum().item()
            total_pixels += annotation.numel()
        

        if i == 5 and epoch % 10 == 0:
            with torch.no_grad():
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                
                img_show = (image[0].cpu() * 0.5 + 0.5).permute(1, 2, 0).numpy()
                axes[0].imshow(img_show)
                axes[0].set_title('Input Image')
                
                axes[1].imshow(annotation[0].cpu(), cmap='gray')
                axes[1].set_title('Ground Truth')
                #CHANGE 4: Added an actual model output to the plt, instead save it so training is not interrupted (apparently plt.show() interrupts execution)
                axes[2].imshow(predicted[0].cpu(), cmap='gray')
                axes[2].set_title(f'Prediction (Epoch {epoch+1})')
                
                plt.tight_layout()
                plt.savefig(f'prediction_epoch_{epoch+1}.png', dpi=100)
                plt.close()
    
    accuracy = 100 * (num_correct / total_pixels)
    epoch_loss = running_loss / len(image_loader)
    scheduler.step(epoch_loss)
    
    print(f'Epoch {epoch+1:2d}/{epoch_num} | Loss: {epoch_loss:.4f} | Acc: {accuracy:.2f}%')
    #CHANGE 5: saves best weights automatically in case SOCKEYE TIME RUNS OUT OR SOME BUG LIKE KERNEL RESTARTING
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_unet.pth')

print(f'\nTraining complete! Best loss: {best_loss:.4f}')

Epoch  1/50 | Loss: 0.5885 | Acc: 68.73%
Epoch  2/50 | Loss: 0.5887 | Acc: 69.33%
Epoch  3/50 | Loss: 0.5269 | Acc: 71.94%
Epoch  4/50 | Loss: 0.5014 | Acc: 74.73%
Epoch  5/50 | Loss: 0.4891 | Acc: 75.67%
Epoch  6/50 | Loss: 0.4824 | Acc: 76.20%
Epoch  7/50 | Loss: 0.4682 | Acc: 77.03%
Epoch  8/50 | Loss: 0.4562 | Acc: 77.83%
Epoch  9/50 | Loss: 0.4576 | Acc: 77.99%
Epoch 10/50 | Loss: 0.4498 | Acc: 78.20%
Epoch 11/50 | Loss: 0.4402 | Acc: 78.88%
Epoch 12/50 | Loss: 0.4344 | Acc: 79.19%
Epoch 13/50 | Loss: 0.4307 | Acc: 79.27%
Epoch 14/50 | Loss: 0.4279 | Acc: 79.62%
Epoch 15/50 | Loss: 0.4227 | Acc: 79.82%
Epoch 16/50 | Loss: 0.4153 | Acc: 80.16%
Epoch 17/50 | Loss: 0.4187 | Acc: 79.84%
Epoch 18/50 | Loss: 0.4071 | Acc: 80.59%
Epoch 19/50 | Loss: 0.4017 | Acc: 80.93%
