In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import os
from PIL import Image
import numpy as np

## Define segmentation dataclass

In [2]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted([f for f in os.listdir(image_dir) if not f.startswith('.')])
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if not f.startswith('.')])

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.transform(image)
        mask = self.transform(mask)

        return image, mask

## Define the UNet model

In [3]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        def double_conv(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )

        self.dconv_down1 = double_conv(in_channels, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        # Modified maxpool with ceil_mode
        self.maxpool = nn.MaxPool2d(2, ceil_mode=True)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(64 + 128, 64)

        self.conv_last = nn.Conv2d(64, out_channels, 1)

    def crop_tensor(self, target_tensor, tensor_to_crop):
        _, _, H, W = tensor_to_crop.size()
        return target_tensor[:, :, :H, :W]

    def forward(self, x):
        # Encoder path
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)

        # Bottleneck
        x = self.dconv_down4(x)

        # Decoder path with cropping
        x = self.upsample(x)
        conv3 = self.crop_tensor(conv3, x)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)

        x = self.upsample(x)
        conv2 = self.crop_tensor(conv2, x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)

        x = self.upsample(x)
        conv1 = self.crop_tensor(conv1, x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)

        out = self.conv_last(x)
        return out

## Define a collate function with padding

In [4]:
def collate_fn(batch):
    images, masks = zip(*batch)

    max_height = max(img.shape[-2] for img in images)
    max_width = max(img.shape[-1] for img in images)

    def pad_to_size(tensor, target_height, target_width, mode='reflect'):
        if tensor.ndim == 3:
            tensor = tensor.unsqueeze(0)  # Add batch dimension
        _, _, h, w = tensor.shape
        pad_h = target_height - h
        pad_w = target_width - w
        padded_tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode=mode)
        return padded_tensor.squeeze(0)  # Remove batch dimension

    padded_images = [pad_to_size(img, max_height, max_width, mode='reflect') for img in images]
    padded_masks = [pad_to_size(msk.unsqueeze(0), max_height, max_width, mode='constant') for msk in masks]

    # Check if all padded images have the same shape
    assert all(img.shape == padded_images[0].shape for img in padded_images), "Padded images have inconsistent shapes"

    # Check if all padded masks have the same shape 
    assert all(msk.shape == padded_masks[0].shape for msk in padded_masks), "Padded masks have inconsistent shapes"

    # Print shapes for debugging
    print("Padded images shapes:", [img.shape for img in padded_images])
    print("Padded masks shapes:", [msk.shape for msk in padded_masks])

    return torch.stack(padded_images), torch.stack(padded_masks)


In [32]:
def collate_fn(batch):
    images, masks = zip(*batch)

    max_height = max(img.shape[-2] for img in images)
    max_width = max(img.shape[-1] for img in images)

    # Ensure the dimensions are divisible by 8
    target_height = ((max_height + 7) // 8) * 8
    target_width = ((max_width + 7) // 8) * 8

    def pad_to_size(tensor, target_height, target_width, mode='reflect'):
        if tensor.ndim == 3:
            tensor = tensor.unsqueeze(0)  # Add batch dimension
        _, _, h, w = tensor.shape
        pad_h = target_height - h
        pad_w = target_width - w
        padded_tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode=mode)
        return padded_tensor.squeeze(0)  # Remove batch dimension

    padded_images = [pad_to_size(img, target_height, target_width, mode='reflect') for img in images]
    padded_masks = [pad_to_size(msk.unsqueeze(0), target_height, target_width, mode='constant') for msk in masks]

    # Check if all padded images have the same shape
    assert all(img.shape == padded_images[0].shape for img in padded_images), "Padded images have inconsistent shapes"

    # Check if all padded masks have the same shape 
    assert all(msk.shape == padded_masks[0].shape for msk in padded_masks), "Padded masks have inconsistent shapes"

    # Print shapes for debugging
    print("Padded images shapes:", [img.shape for img in padded_images])
    print("Padded masks shapes:", [msk.shape for msk in padded_masks])

    return torch.stack(padded_images), torch.stack(padded_masks)

### 0 reduction update

In [22]:
def collate_fn(batch):
    images, masks = zip(*batch)

    max_height = max(img.shape[-2] for img in images)
    max_width = max(img.shape[-1] for img in images)

    # Ensure the dimensions are divisible by 8
    target_height = ((max_height + 7) // 8) * 8
    target_width = ((max_width + 7) // 8) * 8

    def pad_to_multiple(tensor, target_height, target_width):
        _, _, h, w = tensor.shape
        pad_h = max(target_height - h, 0)
        pad_w = max(target_width - w, 0)
        return F.pad(tensor, (0, pad_w, 0, pad_h), mode='reflect')

    padded_images = [pad_to_multiple(img, target_height, target_width) for img in images]
    padded_masks = [pad_to_multiple(msk.unsqueeze(0), target_height, target_width) for msk in masks]

    # Check if all padded images have the same shape
    assert all(img.shape == padded_images[0].shape for img in padded_images), "Padded images have inconsistent shapes"

    # Check if all padded masks have the same shape 
    assert all(msk.shape == padded_masks[0].shape for msk in padded_masks), "Padded masks have inconsistent shapes"

    # Print shapes for debugging
    print("Padded images shapes:", [img.shape for img in padded_images])
    print("Padded masks shapes:", [msk.shape for msk in padded_masks])

    return torch.stack(padded_images), torch.stack(padded_masks)


## upload train/val sets

In [33]:
# Initialize datasets
train_dataset = SegmentationDataset(
    image_dir="/Users/aja294/Documents/Hemp_local/leaf_morphometrics/semantic_seg_template/data/train/images",
    mask_dir="/Users/aja294/Documents/Hemp_local/leaf_morphometrics/semantic_seg_template/data/train/masks"
)

In [34]:
val_dataset = SegmentationDataset(
    image_dir="/Users/aja294/Documents/Hemp_local/leaf_morphometrics/semantic_seg_template/data/val/images",
    mask_dir="/Users/aja294/Documents/Hemp_local/leaf_morphometrics/semantic_seg_template/data/val/masks"
)

## Data loader setup


### With padding

In [35]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

### With Metal/Cuda

In [36]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
device

Using device: mps


device(type='mps')

In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss(reduction='none')
device

Using device: mps


device(type='mps')

### With CPU

In [38]:
device = torch.device('cpu')
print(f"Using device: {device}")
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
device

Using device: cpu


device(type='cpu')

### Run Training

###  Updated model train

In [39]:
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Starting training epoch {epoch+1}")  # Debug: Ensure entering the training loop
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        print(f"Training batch shapes - Images: {images.shape}, Masks: {masks.shape}")  # Debug: Check input shapes
        images = images.to(device)
        masks = masks.to(device)
        print(f"Training batch shapes (on device) - Images: {images.shape}, Masks: {masks.shape}")  # Debug: Check shapes after moving to device

        outputs = model(images)
        print("Model output shape:", outputs.shape)  # Debug: Check model output shape
        print("Target masks shape:", masks.shape)  # Debug: Check target masks shape

        # Temporarily remove loss computation and backpropagation for debugging
        # _, _, h, w = masks.shape
        # loss_mask = torch.zeros_like(outputs)
        # loss_mask[:, :, :h, :w] = 1
        # loss = F.binary_cross_entropy_with_logits(outputs, masks, reduction='none')
        # masked_loss = (loss * loss_mask).sum() / loss_mask.sum()
        # optimizer.zero_grad()
        # masked_loss.backward()
        # optimizer.step()
        # train_loss += masked_loss.item()

    # Validation
    print(f"Starting validation epoch {epoch+1}")  # Debug: Ensure entering the validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            print(f"Validation batch shapes - Images: {images.shape}, Masks: {masks.shape}")  # Debug: Check input shapes
            images = images.to(device)
            masks = masks.to(device)
            print(f"Validation batch shapes (on device) - Images: {images.shape}, Masks: {masks.shape}")  # Debug: Check shapes after moving to device

            outputs = model(images)
            print("Model output shape:", outputs.shape)  # Debug: Check model output shape
            print("Target masks shape:", masks.shape)  # Debug: Check target masks shape

            # Temporarily remove loss computation for debugging
            # _, _, h, w = masks.shape
            # loss_mask = torch.zeros_like(outputs)
            # loss_mask[:, :, :h, :w] = 1
            # loss = F.binary_cross_entropy_with_logits(outputs, masks, reduction='none')
            # masked_loss = (loss * loss_mask).sum() / loss_mask.sum()
            # val_loss += masked_loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")


Starting training epoch 1
Padded images shapes: [torch.Size([3, 2528, 2728]), torch.Size([3, 2528, 2728]), torch.Size([3, 2528, 2728]), torch.Size([3, 2528, 2728])]
Padded masks shapes: [torch.Size([1, 2528, 2728]), torch.Size([1, 2528, 2728]), torch.Size([1, 2528, 2728]), torch.Size([1, 2528, 2728])]
Training batch shapes - Images: torch.Size([4, 3, 2528, 2728]), Masks: torch.Size([4, 1, 2528, 2728])
Training batch shapes (on device) - Images: torch.Size([4, 3, 2528, 2728]), Masks: torch.Size([4, 1, 2528, 2728])


: 

### Built to only train on originals while padding degrades

In [28]:
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Starting training epoch {epoch+1}")  # Debug: Ensure entering the training loop
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        print(f"Training batch shapes - Images: {images.shape}, Masks: {masks.shape}")  # Debug: Check input shapes
        images = images.to(device)
        masks = masks.to(device)
        print(f"Training batch shapes (on device) - Images: {images.shape}, Masks: {masks.shape}")  # Debug: Check shapes after moving to device

        outputs = model(images)
        print("Model output shape:", outputs.shape)  # Debug: Check model output shape
        print("Target masks shape:", masks.shape)  # Debug: Check target masks shape

        # Create mask for loss computation
        _, _, h, w = masks.shape
        loss_mask = torch.zeros_like(outputs)
        loss_mask[:, :, :h, :w] = 1

        loss = criterion(outputs, masks)
        masked_loss = (loss * loss_mask).sum() / loss_mask.sum()

        optimizer.zero_grad()
        masked_loss.backward()
        optimizer.step()
        train_loss += masked_loss.item()

    # Validation
    print(f"Starting validation epoch {epoch+1}")  # Debug: Ensure entering the validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            print(f"Validation batch shapes - Images: {images.shape}, Masks: {masks.shape}")  # Debug: Check input shapes
            images = images.to(device)
            masks = masks.to(device)
            print(f"Validation batch shapes (on device) - Images: {images.shape}, Masks: {masks.shape}")  # Debug: Check shapes after moving to device

            outputs = model(images)
            print("Model output shape:", outputs.shape)  # Debug: Check model output shape
            print("Target masks shape:", masks.shape)  # Debug: Check target masks shape

            # Create mask for loss computation  
            _, _, h, w = masks.shape
            loss_mask = torch.zeros_like(outputs)
            loss_mask[:, :, :h, :w] = 1

            loss = criterion(outputs, masks)
            masked_loss = (loss * loss_mask).sum() / loss_mask.sum()
            val_loss += masked_loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")


Starting training epoch 1


ValueError: not enough values to unpack (expected 4, got 3)

In [19]:
# Save model
torch.save(model.state_dict(), "unet_model.pth")