In [1]:
model_name = 'UNet_lr_dynamic_batch_2_standard_scale_[1.0, 1.1]_colour_0.25_blur_0.03_linearup_weightedloss_wd_1e-7'

In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from tqdm import tqdm
import json  # To save and read mean and std values as JSON

# Directory containing your training images
image_directory = 'train/train_image'

# Transform to convert the images to a tensor
transform = transforms.ToTensor()

# Path to store the mean and std
stats_file_path = 'mean_std.json'

# Function to calculate mean and std
def calculate_mean_std():
    mean = torch.zeros(3)
    std = torch.zeros(3)
    num_pixels = 0

    # Iterate through each image in the directory
    for image_name in tqdm(os.listdir(image_directory)):
        image_path = os.path.join(image_directory, image_name)
        image = Image.open(image_path).convert('RGB')  # Ensure the image is in RGB format
        image_tensor = transform(image)
        
        # Accumulate sum and sum of squares for mean and std calculation
        mean += image_tensor.mean(dim=(1, 2))
        std += image_tensor.std(dim=(1, 2))
        num_pixels += 1

    # Compute the mean and std for each channel
    mean /= num_pixels
    std /= num_pixels

    # Convert to standard Python float types for saving to JSON
    mean_list = mean.tolist()
    std_list = std.tolist()

    # Save to a file
    stats = {'mean': mean_list, 'std': std_list}
    with open(stats_file_path, 'w') as f:
        json.dump(stats, f)
    print(f'Saved mean and std to {stats_file_path}')

    return mean, std

# Check if mean and std are already saved in the file
if os.path.exists(stats_file_path):
    print(f'{stats_file_path} found, loading mean and std...')
    with open(stats_file_path, 'r') as f:
        stats = json.load(f)
        mean = torch.tensor(stats['mean'])
        std = torch.tensor(stats['std'])
    print(f'Mean: {mean}')
    print(f'Std: {std}')
else:
    print(f'{stats_file_path} not found, calculating mean and std...')
    mean, std = calculate_mean_std()

print(f'Final Mean: {mean}')
print(f'Final Std: {std}')

In [3]:
import random

# Set the random seed for reproducibility
seed = 42 # This is the answer to the ultimate question of life, the universe and everything
np.random.seed(seed)
random.seed(seed)

In [4]:
import torchvision.transforms.functional as F1
from torchvision import transforms
import random
import numpy as np
import torch

class JointTransform:
    def __init__( 
        self, resize=(512, 512), rotation_degree=15, scale=(0.9, 1.1), 
        brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, gaussian_noise_std=0.01,
        apply_gaussian_blur=False, blur_kernel_size=3, mean=(0.4914, 0.4822, 0.4465), 
        std=(0.2023, 0.1994, 0.2010)
    ):
        self.resize = resize
        self.rotation_degree = rotation_degree
        self.scale = scale
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue
        self.gaussian_noise_std = gaussian_noise_std
        self.apply_gaussian_blur = apply_gaussian_blur
        self.blur_kernel_size = blur_kernel_size
        self.mean = mean
        self.std = std

    def __call__(self, image, mask):
        # Resize
        image = F1.resize(image, self.resize)
        mask = F1.resize(mask, self.resize, interpolation=Image.NEAREST)

    
        # # Random Rotation (Only if rotation_degree > 0)
        # if self.rotation_degree > 0:
        #     angle = random.uniform(-self.rotation_degree, self.rotation_degree)
        #     image = F1.rotate(image, angle)
        #     mask = F1.rotate(mask, angle)
        
        # Random Scaling
        scale_factor = random.uniform(self.scale[0], self.scale[1])
        new_size = [int(self.resize[0] * scale_factor), int(self.resize[1] * scale_factor)]
        image = F1.resize(image, new_size)
        mask = F1.resize(mask, new_size, interpolation=Image.NEAREST)
        
        # Center Crop to Original Size
        image = F1.center_crop(image, self.resize)
        mask = F1.center_crop(mask, self.resize)
        
        # Color Jitter
        color_jitter = transforms.ColorJitter(
            brightness=self.brightness, 
            contrast=self.contrast, 
            saturation=self.saturation, 
            hue=self.hue
        )
        image = color_jitter(image)
        
        # Gaussian Blur
        if self.apply_gaussian_blur and random.random() < 0.5:
            image = F1.gaussian_blur(image, kernel_size=self.blur_kernel_size)
        
        # Convert to tensor
        image = F1.to_tensor(image)
        mask = F1.pil_to_tensor(mask).squeeze(0).long()
        
        # Add Gaussian Noise
        if self.gaussian_noise_std > 0:
            noise = torch.randn(image.size()) * self.gaussian_noise_std
            image = image + noise
            
        # Clamp image values to [0, 1] after adding noise
        image = torch.clamp(image, 0.0, 1.0)
        
        # Normalize the image
        image = F1.normalize(image, mean=self.mean, std=self.std)
        
        return image, mask

In [5]:
# Define transforms
train_transform = JointTransform(
    resize=(512, 512), rotation_degree=0, scale=(1.0, 1.1),
    brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1,
    gaussian_noise_std=0.03, apply_gaussian_blur=True, blur_kernel_size=3,
    mean=mean, std=std
)
val_transform = JointTransform(
    resize=(512, 512), rotation_degree=0, scale=(1.0, 1.0),
    brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0,
    gaussian_noise_std=0.0, apply_gaussian_blur=False, blur_kernel_size=0,
    mean=mean, std=std
)

In [6]:
import torch
import torchvision.transforms.functional as F1
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class CelebAMaskHQDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_ext='jpg', mask_ext='png', transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = [file for file in os.listdir(image_dir) if file.endswith(f'.{image_ext}')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '.png'))

        # Load image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        # Apply transforms
        if self.transform:
            image, mask = self.transform(image, mask)
        else:
            image = F1.to_tensor(image)
            mask = F1.pil_to_tensor(mask).squeeze(0).long()
        
        return image, mask

# Load training dataset with augmentation
train_dataset = CelebAMaskHQDataset(
    image_dir='train/train_image', 
    mask_dir='train/train_mask', 
    image_ext='jpg', mask_ext='png', 
    transform=train_transform
)

# # Load training dataset with augmentation TINY
# train_dataset = CelebAMaskHQDataset(
#     image_dir='train_tiny/train_image', 
#     mask_dir='train_tiny/train_mask', 
#     image_ext='jpg', mask_ext='png', 
#     transform=train_transform
# )

# Load validation dataset without augmentation
val_dataset = CelebAMaskHQDataset(
    image_dir='val/val_image', 
    mask_dir='val/val_mask', 
    image_ext='jpg', mask_ext='png', 
    transform=val_transform
)

# Create DataLoaders for train and validation sets
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, pin_memory=True)

In [7]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_dataset(loader, num_examples=5):
    # Get one batch of data
    dataiter = iter(loader)
    images, masks = next(dataiter)
    
    # Ensure the number of examples does not exceed the batch size
    num_examples = min(num_examples, images.size(0))
    
    # Convert tensors to NumPy arrays for plotting
    images = images.numpy()
    masks = masks.numpy()
    
    for idx in range(num_examples):
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        
        # Image
        img = images[idx].transpose(1, 2, 0)  # (C, H, W) to (H, W, C)
        img = np.clip(img, 0, 1)  # Ensure values are between 0 and 1
        ax[0].imshow(img)
        ax[0].set_title('Image')
        ax[0].axis('off')
        
        # Mask
        mask = masks[idx]
        ax[1].imshow(mask, cmap='jet', interpolation='nearest')
        ax[1].set_title('Mask')
        ax[1].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
print("Training examples:")
visualize_dataset(train_loader, num_examples=10)

In [None]:
print("Validation examples:")
visualize_dataset(val_loader, num_examples=10)

In [10]:
import torch
from torch import nn
from torchsummary import summary

class SmallerUNet(nn.Module):
    def __init__(self):
        super(SmallerUNet, self).__init__()

        # Encoder: Convolution + ReLU + MaxPool (downsampling)
        self.encoder1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 512x512 -> 256x256
        
        self.encoder2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 256x256 -> 128x128
        
        self.encoder3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 128x128 -> 64x64
        
        self.encoder4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x64 -> 32x32

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )

        # Decoder: Transpose Convolution + ReLU (upsampling)
        self.upconv4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)  # 32x32 -> 64x64
        self.decoder4 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)  # 64x64 -> 128x128
        self.decoder3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        self.upconv2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)  # 128x128 -> 256x256
        self.decoder2 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        self.upconv1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)  # 256x256 -> 512x512
        self.decoder1 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU()
        )

        # Output layer
        self.output = nn.Conv2d(16, 19, kernel_size=1)  # 1x1 Conv to map to the number of classes (19 in this case)

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        x = self.pool1(enc1)
        
        enc2 = self.encoder2(x)
        x = self.pool2(enc2)
        
        enc3 = self.encoder3(x)
        x = self.pool3(enc3)
        
        enc4 = self.encoder4(x)
        x = self.pool4(enc4)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path with skip connections
        x = self.upconv4(x)
        x = torch.cat((enc4, x), dim=1)  # Concatenate encoder output (skip connection)
        x = self.decoder4(x)

        x = self.upconv3(x)
        x = torch.cat((enc3, x), dim=1)
        x = self.decoder3(x)

        x = self.upconv2(x)
        x = torch.cat((enc2, x), dim=1)
        x = self.decoder2(x)

        x = self.upconv1(x)
        x = torch.cat((enc1, x), dim=1)
        x = self.decoder1(x)

        # Output layer
        x = self.output(x)
        
        return x

In [None]:
# Check if CUDA is available
if torch.cuda.is_available():

    # Get the number of GPUs
    num_gpus = torch.cuda.device_count()

    # Print each GPU's index and name
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available.")

In [None]:
# Instantiate the model
num_classes = 19  # Replace with the number of classes in your dataset

# Specify a specific GPU (e.g., GPU 2)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate the smaller model
model = SmallerUNet()

model.to(device)

# Print model summary using torchsummary (input size: 3x512x512 for RGB images)
summary(model, input_size=(3, 512, 512))

In [None]:
# Count the total number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params}")

In [None]:
import torch
import torch.nn as nn
import os
import json
from tqdm import tqdm

# File to save/load class distribution
class_dist_file_path = 'class_distribution.json'

# Function to compute class distribution
def calculate_class_distribution(train_loader, num_classes=19):
    # Initialize an empty tensor to hold the count of pixels for each of the classes
    class_distribution = torch.zeros(num_classes, dtype=torch.long)

    # Iterate through all batches in the train loader
    for images, masks in tqdm(train_loader):  # Assuming masks are (batch_size, H, W) with integer class labels
        # Flatten masks to count pixels across all batches and images
        flattened_masks = masks.view(-1)  # (batch_size * H * W,)
        
        # Update class distribution by counting occurrences of each class (0 to num_classes-1)
        for class_idx in range(num_classes):
            class_distribution[class_idx] += (flattened_masks == class_idx).sum().item()

    # Convert to a list for saving
    class_distribution_list = class_distribution.tolist()

    # Save the class distribution to a JSON file
    with open(class_dist_file_path, 'w') as f:
        json.dump({'class_distribution': class_distribution_list}, f)

    print(f'Saved class distribution to {class_dist_file_path}')

    return class_distribution

# Check if class distribution file already exists
if os.path.exists(class_dist_file_path):
    print(f'{class_dist_file_path} found, loading class distribution...')
    with open(class_dist_file_path, 'r') as f:
        class_dist_data = json.load(f)
        class_distribution = torch.tensor(class_dist_data['class_distribution'], dtype=torch.long)
    print(f'Loaded class distribution: {class_distribution}')
else:
    print(f'{class_dist_file_path} not found, calculating class distribution...')
    class_distribution = calculate_class_distribution(train_loader, num_classes=19)

# Print out the pixel distribution across the 19 classes
for class_idx in range(19):
    print(f"Class {class_idx}: {class_distribution[class_idx]} pixels")

# Now we compute the class weights dynamically based on class_distribution
total_pixels = class_distribution.sum().item()  # Total number of pixels across all classes

# Calculate class frequencies
class_frequencies = class_distribution.float()

# Add a small epsilon to avoid taking the log of zero
epsilon = 1e-6
class_weights = 1.0 / torch.log(class_frequencies + epsilon)

# Normalise the weights to sum to 1 (optional)
class_weights = class_weights / class_weights.sum()

# Print the computed class weights
for class_idx in range(19):
    print(f"Weight for Class {class_idx}: {class_weights[class_idx].item()}")

# Assuming you are using CUDA, move the class weights to the device
class_weights = class_weights.to(device)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Convert class distribution to frequencies (assuming class_distribution is defined)
class_frequencies = class_distribution.float()

# Create a plot
plt.figure(figsize=(15, 7.5))

# Plot using seaborn's barplot for better aesthetics
sns.barplot(x=np.arange(19), y=class_frequencies)

# Set y-axis to log10 scale
plt.yscale('log')

# Beautify the plot
plt.xlabel('Class Index', fontsize=30)
plt.ylabel('Number of Pixels in Training Set', fontsize=30)

# Add grid and style
plt.grid(True, which="both", linestyle='--', linewidth=0.5)
plt.xticks(np.arange(19))
plt.tight_layout()

# Save the plot with ppi=300
plt.savefig('class_frequency_distribution.png', dpi=300)

# Show the plot
plt.show()

In [None]:
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import logging
import numpy as np

# Set up the log file path
log_file = model_name + 'training_output.log'

# If the file exists, delete it
if os.path.exists(log_file):
    os.remove(log_file)

# Set up logging to save the output to a new .txt file
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(message)s')

def log_message(message):
    print(message)  # Print the message to console
    logging.info(message)  # Save the message to the log file

def calculate_intersect_union(pred_mask, gt_mask, num_classes=19):
    area_intersect_all = np.zeros(num_classes)
    area_union_all = np.zeros(num_classes)
    
    for cls_idx in range(num_classes):
        area_intersect = np.sum((pred_mask == cls_idx) & (gt_mask == cls_idx))
        area_pred_label = np.sum(pred_mask == cls_idx)
        area_gt_label = np.sum(gt_mask == cls_idx)
        area_union = area_pred_label + area_gt_label - area_intersect

        area_intersect_all[cls_idx] += area_intersect
        area_union_all[cls_idx] += area_union

    return area_intersect_all, area_union_all

# Combined training function with dynamic learning rate switching, patience handling, and mIoU calculation
def train_dynamic(model, device, train_loader, val_loader, class_weights, initial_lr=0.001, min_sgd_lr=0.001, weight_decay=1e-6, scheduler_type='cosine', epochs=100, patience=10, checkpoint_path='best_model.pth', csv_file_name='losses.csv', load_weight=False, weight_path=None):
    # Define the loss function
    criterion = nn.CrossEntropyLoss(weight=class_weights)  # Multi-class segmentation loss

    # Start with Adam optimiser
    optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay)
    best_val_loss = float('inf')
    patience_counter = 0
    sgd_lr = 0.02  # Initial learning rate for SGD phase

        # Initialize scheduler
    if scheduler_type == 'cosine':
        scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0.5 * initial_lr)
    elif scheduler_type == 'exponential':
        scheduler = ExponentialLR(optimizer, gamma=0.9)
    else:
        scheduler = None

    # Load pre-trained weights if specified
    if load_weight and weight_path:
        model.load_state_dict(torch.load(weight_path))
        log_message(f"Loaded weights from {weight_path}")

    # Lists to store train and validation losses and mIoU for each epoch
    train_losses, val_losses = [], []
    train_miou_list, val_miou_list = [], []



    for epoch in range(epochs):
        model.train()
        train_loss = 0

        
        # Initialize total intersection and union areas for the epoch
        total_area_intersect_train = np.zeros(19)
        total_area_union_train = np.zeros(19)

        # Training loop
        for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update train loss
            train_loss += loss.item()

            # Compute intersection and union for mIoU
            pred_masks = torch.argmax(outputs, dim=1).cpu().numpy()
            gt_masks = masks.cpu().numpy()
            
            for batch_idx in range(images.shape[0]):
                area_intersect, area_union = calculate_intersect_union(
                    pred_masks[batch_idx], gt_masks[batch_idx]
                )
                total_area_intersect_train += area_intersect
                total_area_union_train += area_union

        # Compute IoU for each class
        iou_per_class_train = total_area_intersect_train / (total_area_union_train + 1e-6)

        # Identify valid classes (those with at least one pixel in the union)
        valid_classes_train = total_area_union_train > 0

        # Compute mean IoU over valid classes
        train_miou = iou_per_class_train[valid_classes_train].mean() * 100

        # Validation loop
        model.eval()
        val_loss = 0
        # Initialize total intersection and union areas for validation
        total_area_intersect_val = np.zeros(19)
        total_area_union_val = np.zeros(19)

        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()

                # Compute mIoU for validation
                pred_masks = torch.argmax(outputs, dim=1).cpu().numpy()
                gt_masks = masks.cpu().numpy()
                
                for batch_idx in range(images.shape[0]):
                    area_intersect, area_union = calculate_intersect_union(
                        pred_masks[batch_idx], gt_masks[batch_idx]
                    )
                    total_area_intersect_val += area_intersect
                    total_area_union_val += area_union

        val_loss /= len(val_loader)
        train_loss /= len(train_loader)

        # Save losses for plotting
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        # Compute IoU for each class
        iou_per_class_val = total_area_intersect_val / (total_area_union_val + 1e-6)

        # Identify valid classes (those with at least one pixel in the union)
        valid_classes_val = total_area_union_val > 0

        # Compute mean IoU over valid classes
        val_miou = iou_per_class_val[valid_classes_val].mean() * 100

        train_miou_list.append(train_miou)
        val_miou_list.append(val_miou)

        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']

        log_message(f"Epoch {epoch+1}/{epochs}, LR: {current_lr:.6f}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train mIoU: {train_miou:.2f}%, Val mIoU: {val_miou:.2f}%")

        # Update the learning rate using the scheduler at the end of each epoch
        scheduler.step()

        # Check for improvement in validation loss and save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0  # Reset patience counter
            # Save model checkpoint
            torch.save(model.state_dict(), checkpoint_path)
            log_message(f"Model saved at epoch {epoch+1}, with val loss: {val_loss:.4f}, val mIoU: {val_miou:.2f}%")
        else:
            patience_counter += 1
            log_message(f"Patience counter: {patience_counter}/{patience}")

        # Early stopping and optimiser switch
        if patience_counter >= patience:
            log_message(f"Patience reached at epoch {epoch+1}, switching optimiser...")

            # Load the best model checkpoint before switching the optimiser
            model.load_state_dict(torch.load(checkpoint_path))

            # Switch to SGD optimiser with decreasing learning rate
            if sgd_lr >= min_sgd_lr:
                optimizer = optim.SGD(model.parameters(), lr=sgd_lr, weight_decay=weight_decay, momentum=0.5)
                log_message(f"Switched to SGD optimiser with learning rate: {sgd_lr:.4f}")

                # Reset the scheduler after switching the optimizer
                if scheduler_type == 'cosine':
                    scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min= 0.5 * sgd_lr)
                elif scheduler_type == 'exponential':
                    scheduler = ExponentialLR(optimizer, gamma=0.9)
        
                log_message(f"Scheduler reset after switching to SGD optimizer.")
                
                sgd_lr /= 2  # Halve the learning rate for the next phase
                patience_counter = 0  # Reset patience counter

                
            else:
                log_message("SGD learning rate has dropped below the minimum threshold, stopping training.")
                break  # Stop training if learning rate drops below minimum threshold

    # Load the best model weights after training
    model.load_state_dict(torch.load(checkpoint_path))
    log_message(f"Training finished. Best model loaded with val loss: {best_val_loss:.4f}")

    # Save train/val losses and mIoU to a CSV file
    loss_df = pd.DataFrame({
        'Epoch': range(1, len(train_losses) + 1),
        'Train Loss': train_losses,
        'Val Loss': val_losses,
        'Train mIoU (%)': train_miou_list,
        'Val mIoU (%)': val_miou_list
    })
    loss_df.to_csv(csv_file_name, index=False)
    log_message(f"Losses and mIoU saved to {csv_file_name}")

    # Plotting the train and validation loss and mIoU curve
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Val Loss')
    plt.plot(range(1, len(train_miou_list) + 1), train_miou_list, label='Train mIoU (%)')
    plt.plot(range(1, len(val_miou_list) + 1), val_miou_list, label='Val mIoU (%)')
    plt.xlabel('Epochs')
    plt.ylabel('Loss / mIoU (%)')
    plt.title('Training and Validation Loss & mIoU Curve')
    plt.legend()
    plt.grid(True)

    # Save the figure
    plot_file_name = csv_file_name.replace('.csv', '.png')
    plt.savefig(plot_file_name)
    log_message(f"Loss and mIoU curve figure saved to {plot_file_name}")
    plt.show()

train_dynamic(model, device, train_loader, val_loader, class_weights, initial_lr=0.001, min_sgd_lr=0.005, weight_decay=1e-7, scheduler_type='cosine', epochs=120, patience=12, checkpoint_path=model_name + '_best.pth', csv_file_name=model_name + '_losses.csv')

In [None]:
model.load_state_dict(torch.load(model_name+'_best.pth'))

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt

def visualize_predictions(model, val_loader, device, image_folder='val/val_image', num_images=4):
    model.eval()
    model = model.to(device)

    images, masks = next(iter(val_loader))
    images, masks = images.to(device), masks.to(device)

    with torch.no_grad():
        outputs = model(images)

    # Ensure we do not exceed the number of images in the batch
    num_images = min(num_images, images.size(0))

    fig, axes = plt.subplots(num_images, 3, figsize=(15, num_images * 5))

    if num_images == 1:
        axes = [axes]  # Make sure axes is iterable when there's only 1 sample

    for i in range(num_images):
        # Load the original image from the folder (assuming filenames match the order in the val_loader)
        image_name = os.listdir(image_folder)[i]  # Get the ith image name from the folder
        image_path = os.path.join(image_folder, image_name)
        original_image = Image.open(image_path)

        # Show original image
        axes[i][0].imshow(original_image)  
        axes[i][0].set_title('Input Image', fontsize=30)  # Larger font size (2x larger)
        axes[i][0].axis('off')  # Turn off axis for the input image

        # Show ground truth mask
        axes[i][1].imshow(masks[i].cpu(), cmap='tab20', vmin=0, vmax=18)
        axes[i][1].set_title('Ground Truth Mask', fontsize=30)  # Larger font size (2x larger)
        axes[i][1].axis('off')  # Turn off axis for the ground truth mask

        # Convert output to predicted class indices
        pred_mask = torch.argmax(outputs[i], dim=0).cpu()  # (512, 512) mask with class labels
        axes[i][2].imshow(pred_mask, cmap='tab20', vmin=0, vmax=18)
        axes[i][2].set_title('Predicted Mask', fontsize=30)  # Larger font size (2x larger)
        axes[i][2].axis('off')  # Turn off axis for the predicted mask

    plt.tight_layout()
    plt.show()

# Visualize a few predictions
visualize_predictions(model, val_loader, device, num_images=10)

In [19]:
import os
import numpy as np
from PIL import Image
import torch



# Function to save a prediction mask as a PNG where pixel values are class labels
def save_prediction(pred_mask, original_filename, output_dir='prediction'):
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert the tensor to a NumPy array
    pred_mask_np = pred_mask.cpu().numpy().astype(np.uint8)  # Ensure it's an 8-bit unsigned integer array
    
    # Save the NumPy array as a PNG image where pixel values are class labels
    pred_img = Image.fromarray(pred_mask_np)
    pred_img.save(os.path.join(output_dir, original_filename.replace('.jpg', '.png')))

# Function to iterate through validation dataset and save predictions
def save_predictions(model, val_loader, val_dataset, output_dir='prediction'):
    model.eval()
    model = model.cuda()

    # Iterate over the validation data loader
    for idx, (images, masks) in enumerate(val_loader):
        images = images.cuda()

        with torch.no_grad():
            outputs = model(images)

        # Get the corresponding filenames for this batch
        batch_filenames = val_dataset.images[idx * val_loader.batch_size : (idx + 1) * val_loader.batch_size]

        for i in range(images.size(0)):
            # Convert output to predicted class indices
            pred_mask = torch.argmax(outputs[i], dim=0)  # (512, 512) mask with class labels

            # Save the predicted mask with the original filename
            original_filename = batch_filenames[i]
            save_prediction(pred_mask, original_filename, output_dir)

# Call the function to save predictions
save_predictions(model, val_loader, val_dataset, output_dir=model_name+'prediction')

In [20]:
import os
import numpy as np
from PIL import Image


def read_masks(path):
    mask = Image.open(path)
    mask = np.array(mask)
    return mask


# replace submit_dir to your result path here
submit_dir = model_name + 'prediction'

# replace truth_dir to ground-truth path here
truth_dir = 'val/val_mask'

# replace output_dir to the desired output path, and you will find 'scores.txt' containing the calculated mIoU
output_dir = 'output'

if not os.path.isdir(submit_dir):
    print(f"{submit_dir} doesn't exist")

if os.path.isdir(submit_dir) and os.path.isdir(truth_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    submit_dir_list = os.listdir(submit_dir)
    if len(submit_dir_list) == 1:
        submit_dir = os.path.join(submit_dir, f"{submit_dir_list[0]}")
        assert os.path.isdir(submit_dir)

    # Class names based on the label list provided
    class_names = [
        'background', 'skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 
        'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 
        'neck_l', 'neck', 'cloth'
    ]

    area_intersect_all = np.zeros(19)
    area_union_all = np.zeros(19)
    
    for idx in range(1000):
        pred_mask = read_masks(os.path.join(submit_dir, f"{idx}.png"))
        gt_mask = read_masks(os.path.join(truth_dir, f"{idx}.png"))
        
        for cls_idx in range(19):
            area_intersect = np.sum(
                (pred_mask == gt_mask) * (pred_mask == cls_idx))

            area_pred_label = np.sum(pred_mask == cls_idx)
            area_gt_label = np.sum(gt_mask == cls_idx)
            area_union = area_pred_label + area_gt_label - area_intersect

            area_intersect_all[cls_idx] += area_intersect
            area_union_all[cls_idx] += area_union

    iou_all = area_intersect_all / area_union_all * 100.0
    miou = iou_all.mean()

    # Create the evaluation score path for mIOU
    output_filename = os.path.join(output_dir, model_name + 'scores.txt')
    with open(output_filename, 'w') as f3:
        f3.write(f'mIOU: {miou:.2f}%\n')

    # Write detailed IoU for each class to a separate file
    detailed_output_filename = os.path.join(output_dir, model_name + 'detailed_scores.txt')
    with open(detailed_output_filename, 'w') as f4:
        f4.write('Class-wise IoU scores:\n')
        for cls_idx in range(19):
            f4.write(f'{class_names[cls_idx]}: {iou_all[cls_idx]:.2f}%\n')
        f4.write(f'\nMean IoU (mIOU): {miou:.2f}%\n')

# !Test!

In [22]:
import torch
from torch import nn
from torchsummary import summary

class SmallerUNet(nn.Module):
    def __init__(self):
        super(SmallerUNet, self).__init__()

        # Encoder: Convolution + ReLU + MaxPool (downsampling)
        self.encoder1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 512x512 -> 256x256
        
        self.encoder2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 256x256 -> 128x128
        
        self.encoder3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 128x128 -> 64x64
        
        self.encoder4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x64 -> 32x32

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )

        # Decoder: Transpose Convolution + ReLU (upsampling)
        self.upconv4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)  # 32x32 -> 64x64
        self.decoder4 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)  # 64x64 -> 128x128
        self.decoder3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        self.upconv2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)  # 128x128 -> 256x256
        self.decoder2 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        self.upconv1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)  # 256x256 -> 512x512
        self.decoder1 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU()
        )

        # Output layer
        self.output = nn.Conv2d(16, 19, kernel_size=1)  # 1x1 Conv to map to the number of classes (19 in this case)

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        x = self.pool1(enc1)
        
        enc2 = self.encoder2(x)
        x = self.pool2(enc2)
        
        enc3 = self.encoder3(x)
        x = self.pool3(enc3)
        
        enc4 = self.encoder4(x)
        x = self.pool4(enc4)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path with skip connections
        x = self.upconv4(x)
        x = torch.cat((enc4, x), dim=1)  # Concatenate encoder output (skip connection)
        x = self.decoder4(x)

        x = self.upconv3(x)
        x = torch.cat((enc3, x), dim=1)
        x = self.decoder3(x)

        x = self.upconv2(x)
        x = torch.cat((enc2, x), dim=1)
        x = self.decoder2(x)

        x = self.upconv1(x)
        x = torch.cat((enc1, x), dim=1)
        x = self.decoder1(x)

        # Output layer
        x = self.output(x)
        
        return x

In [None]:
# Instantiate the model
num_classes = 19  # Replace with the number of classes in your dataset

# Specify a specific GPU (e.g., GPU 2)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate the smaller model
model = SmallerUNet()

model.to(device)

# Print model summary using torchsummary (input size: 3x512x512 for RGB images)
summary(model, input_size=(3, 512, 512))

In [24]:
import torchvision.transforms.functional as F1
from torchvision import transforms
import torch

class TestTransform:
    def __init__( 
        self, resize=(512, 512), mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)
    ):
        self.resize = resize
        self.mean = mean
        self.std = std

    def __call__(self, image):
        # Resize the image
        image = F1.resize(image, self.resize)
        
        # Convert to tensor
        image = F1.to_tensor(image)
        
        # Normalize the image
        image = F1.normalize(image, mean=self.mean, std=self.std)
        
        return image

In [25]:
import torch
import torchvision.transforms.functional as F1
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

In [26]:
class CelebAMaskHQTestDataset(Dataset):
    def __init__(self, image_dir, image_ext='jpg', transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = [file for file in os.listdir(image_dir) if file.endswith(f'.{image_ext}')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])

        # Load image
        image = Image.open(img_path).convert("RGB")

        # Apply transforms
        if self.transform:
            image = self.transform(image)
        else:
            image = F1.to_tensor(image)

        return image, self.images[idx]  # Return the image and its filename for saving the results later

In [27]:
# Define the test transform to match val_transform parameters
test_transform = TestTransform(
    resize=(512, 512), 
    mean=(0.4914, 0.4822, 0.4465), 
    std=(0.2023, 0.1994, 0.2010)
)

# Use the transform when creating the test dataset
test_dataset = CelebAMaskHQTestDataset(
    image_dir='test_image',
    image_ext='jpg',
    transform=test_transform
)

# Create DataLoader for the test set
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, pin_memory=True)

In [None]:
model.load_state_dict(torch.load(model_name+'_best.pth'))

In [29]:
import os
import numpy as np
from PIL import Image
import torch

# Function to save a prediction mask as a PNG where pixel values are class labels
def save_prediction(pred_mask, original_filename, output_dir='prediction'):
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert the tensor to a NumPy array
    pred_mask_np = pred_mask.cpu().numpy().astype(np.uint8)  # Ensure it's an 8-bit unsigned integer array
    
    # Save the NumPy array as a PNG image where pixel values are class labels
    pred_img = Image.fromarray(pred_mask_np)
    pred_img.save(os.path.join(output_dir, original_filename.replace('.jpg', '.png')))

# Function to iterate through the test dataset and save predictions
def save_test_predictions(model, test_loader, output_dir='prediction'):
    model.eval()
    model = model.to(device)  # Move the model to the specified device

    # Iterate over the test data loader
    for images, filenames in test_loader:  # Unpack images and their filenames
        images = images.to(device)  # Move images to the specified device

        with torch.no_grad():
            outputs = model(images)

        for i in range(images.size(0)):
            # Convert output to predicted class indices
            pred_mask = torch.argmax(outputs[i], dim=0)  # (512, 512) mask with class labels

            # Save the predicted mask with the original filename
            original_filename = filenames[i]
            save_prediction(pred_mask, original_filename, output_dir)

# Call the function to save test predictions
save_test_predictions(model, test_loader, output_dir='./test_results/' + model_name + '_test_prediction')

In [30]:
import matplotlib.pyplot as plt

def visualize_predictions(model, test_loader, output_dir='prediction', num_pairs=5):
    model.eval()
    model = model.to(device)  # Move the model to the specified device

    fig, axes = plt.subplots(num_pairs, 2, figsize=(10, num_pairs * 5))

    # Track the number of pairs visualized
    count = 0

    for images, filenames in test_loader:  # Unpack images and their filenames
        images = images.to(device)  # Move images to the specified device

        with torch.no_grad():
            outputs = model(images)

        for i in range(images.size(0)):
            if count >= num_pairs:
                break

            # Convert output to predicted class indices
            pred_mask = torch.argmax(outputs[i], dim=0).cpu().numpy().astype(np.uint8)  # (512, 512) mask with class labels

            # Plot the original image
            ax_image = axes[count, 0]
            ax_image.imshow(images[i].cpu().permute(1, 2, 0).numpy())
            ax_image.set_title('Original Image')
            ax_image.axis('off')

            # Plot the predicted mask
            ax_mask = axes[count, 1]
            ax_mask.imshow(pred_mask, cmap='viridis')  # Use a colour map to visualise the mask
            ax_mask.set_title('Predicted Mask')
            ax_mask.axis('off')

            count += 1

        if count >= num_pairs:
            break

    plt.tight_layout()
    plt.show()

In [31]:
import numpy as np
import matplotlib.pyplot as plt

def plot_label_histogram(model, test_loader):
    model.eval()
    model = model.to(device)  # Move the model to the specified device

    label_counts = np.zeros(19, dtype=int)  # Assuming a maximum of 256 classes

    for images, _ in test_loader:
        images = images.to(device)

        with torch.no_grad():
            outputs = model(images)

        for i in range(images.size(0)):
            pred_mask = torch.argmax(outputs[i], dim=0).cpu().numpy().astype(np.uint8)
            unique, counts = np.unique(pred_mask, return_counts=True)
            label_counts[unique] += counts

    plt.figure(figsize=(10, 6))
    plt.bar(range(len(label_counts)), label_counts)
    plt.xlabel('Class Label')
    plt.ylabel('Pixel Count')
    plt.title('Histogram of Predicted Labels')
    plt.yscale('log')  # Use a log scale if there are large differences in counts
    plt.show()

In [None]:
# Visualize several image and predicted mask pairs
visualize_predictions(model, test_loader, output_dir='./test_results/' + model_name + '_test_prediction', num_pairs=5)

# Plot a histogram of predicted labels
plot_label_histogram(model, test_loader)