# Installing required libraries

In [1]:
# !pip install coremltools
# !pip install -U 'comet-ml>=3.31.0'
# !pip install pycocotools
# !pip install transformers
# !pip install scikit-learn
# !pip install pytorch-lightning
# !pip install torch torchvision torchaudio

# conda install -c conda-forge gcc libstdcxx-ng

# Importing required packages

In [2]:
import os
import cv2
import torch
import joblib
import numpy as np
import torchmetrics
import torch.nn as nn
import coremltools as ct
from datetime import datetime
import pytorch_lightning as pl
from PIL import Image,ImageDraw
import torch.nn.functional as F
from pycocotools.coco import COCO
from matplotlib import pyplot as plt
from pytorch_lightning import Trainer
from torchmetrics import JaccardIndex
import torchvision.transforms as transforms
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, IterableDataset
from transformers import SegformerForSemanticSegmentation
from scipy.ndimage import label, find_objects, distance_transform_edt

scikit-learn version 1.4.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.3.0+cu121 has not been tested with coremltools. You may run into unexpected errors. Torch 2.2.0 is the most recent version that has been tested.
Failed to load _MLModelProxy: No module named 'coremltools.libcoremlpython'


# Setting SLURM variables

In [3]:
os.environ['SLURM_NTASKS_PER_NODE'] = '8'
os.environ.pop('SLURM_NTASKS', None)

# Print SLURM environment variables to verify
# print("SLURM Environment Variables:")
# print("SLURM_JOB_ID:", os.environ.get('SLURM_JOB_ID', 'Not Set'))
# print("SLURM_NTASKS:", os.environ.get('SLURM_NTASKS', 'Not Set'))
# print("SLURM_NTASKS_PER_NODE:", os.environ.get('SLURM_NTASKS_PER_NODE', 'Not Set'))
# print("SLURM_JOB_NODELIST:", os.environ.get('SLURM_JOB_NODELIST', 'Not Set'))
# print("SLURM_JOB_NAME:", os.environ.get('SLURM_JOB_NAME', 'Not Set'))

# Set float32 matmul precision for Tensor Cores
torch.set_float32_matmul_precision('high')

# Defining file locations

In [4]:
# Image paths
test_dir = './images/merged_test'
train_dir = './images/merged_train'
val_dir = './images/merged_val'

#Annotation Paths
test_ann_file = './annotations/merged_test_01_22_25.json'
train_ann_file = './annotations/merged_train_01_22_25.json'
val_ann_file = './annotations/merged_val_01_22_25.json'

# Normalization

In [5]:
# def dynamic_normalization(img_dir, annotation_file):
#     coco = COCO(annotation_file)
#     img_ids = list(coco.imgs.keys())
#     means = []
#     stds = []

#     for img_id in img_ids:
#         img_info = coco.imgs[img_id]
#         path = os.path.join(img_dir, img_info['file_name'])
#         image = Image.open(path).convert('RGB')

#         # Convert image to tensor without normalization
#         to_tensor = transforms.ToTensor()
#         image_tensor = to_tensor(image)
        
#         means.append(image_tensor.mean(dim=(1, 2)))
#         stds.append(image_tensor.std(dim=(1, 2)))

#     mean = torch.stack(means).mean(dim=0)
#     std = torch.stack(stds).mean(dim=0)
    
#         # Print or log the mean and std
#     print(f"Calculated mean: {mean.tolist()}")
#     print(f"Calculated std: {std.tolist()}")

#     return transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean, std),
#     ])

# Normalization transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])

# Defining COCO loader 

In [6]:
# Loader for COCO formatted data
def get_coco_loader(img_dir, annotation_file, transform, augment = False, num_augmentations = 1):
    coco = COCO(annotation_file)
    img_ids = list(coco.imgs.keys())
    def loader():
        for img_id in img_ids:
            img_info = coco.imgs[img_id]
            path = os.path.join(img_dir, img_info['file_name'])
            image = Image.open(path).convert('RGB')
            image_tensor = transform(image)
            ann_ids = coco.getAnnIds(imgIds=[img_id])
            anns = coco.loadAnns(ann_ids)
            mask = np.zeros((img_info['height'], img_info['width']), dtype=np.int64)
            for ann in anns:
                mask = np.maximum(mask, coco.annToMask(ann) * ann['category_id'])
            mask_tensor = torch.tensor(mask)
            yield image_tensor, mask_tensor
            
            if augment:
                for _ in range(num_augmentations):
                    aug_image, aug_mask = apply_transform(image, mask)
                    aug_image_tensor = transform(aug_image)
                    aug_mask_tensor = torch.tensor(aug_mask, dtype=torch.int64)
                    yield aug_image_tensor, aug_mask_tensor
    return loader

def get_position_transform():
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),  
        transforms.RandomRotation(15),      
    ])
    return transform

def get_color_transform():
    transform = transforms.Compose([
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Randomly changes the brightness, contrast, saturation, and hue
    ])
    return transform

def apply_transform(image, mask):
    # Ensure the mask is in uint8 format then convert to PIL Image
    mask = mask.astype(np.uint8)  
    mask = Image.fromarray(mask) 

    # seed for random transformations
    seed = torch.randint(0, 2**32, (1,)).item()

    # Apply the same position transformation to both image and mask
    torch.manual_seed(seed)
    image = get_position_transform()(image)
    torch.manual_seed(seed)
    mask = get_position_transform()(mask)

    # apply color transformation to image
    image = get_color_transform()(image)

    # Convert mask back to numpy array
    mask = np.array(mask)

    return image, mask

# Custom DataLoader for COCO data
class COCOLoader(IterableDataset):
    def __init__(self, img_dir, annotation_file, transform, augment=False, num_augmentations=1):
        super().__init__()
        self.img_dir = img_dir
        self.annotation_file = annotation_file
        self.transform = transform
        self.augment = augment
        self.num_augmentations = num_augmentations

    def __iter__(self):
        return get_coco_loader(self.img_dir, self.annotation_file, self.transform, self.augment, self.num_augmentations)()

# Defining Test-Train-Val loaders

In [7]:
# # Setup data tranforms
# test_transform = dynamic_normalization(test_dir, test_ann_file)
# train_transform = dynamic_normalization(train_dir, train_ann_file)
# val_transform = dynamic_normalization(val_dir, val_ann_file)


# # Setup data loaders
test_loader = DataLoader(COCOLoader(test_dir, test_ann_file, transform), batch_size = 128, num_workers = 4)
train_loader = DataLoader(COCOLoader(train_dir, train_ann_file, transform), batch_size = 128, num_workers = 4)
val_loader = DataLoader(COCOLoader(val_dir, val_ann_file, transform), batch_size = 128, num_workers = 4)

# Visualize Images with annotations

In [8]:
# for images, masks in val_loader:
#     # Convert tensors to numpy arrays
#     images_np = images.permute(0, 2, 3, 1).numpy()  # Convert from (N, C, H, W) to (N, H, W, C)
#     masks_np = masks.numpy()

#     # Plot images and masks
#     for i in range(len(images)):
#         image_np = images_np[i]
#         mask_np = masks_np[i]

#         # Overlay mask on image
#         plt.figure(figsize=(10, 5))
#         plt.subplot(1, 2, 2)
#         plt.imshow(image_np)
#         plt.imshow(mask_np, alpha=0.5, cmap='jet', interpolation='nearest')  # Overlay mask on image
#         plt.title('Image with Pupil Highlighted')
#         plt.axis('off')

#         plt.show()
#         break

# Segformer Model Definition

In [9]:
class SegformerWrapper(nn.Module):
    def __init__(self, model):
        super(SegformerWrapper, self).__init__()
        self.model = model

    def forward(self, pixel_values, labels=None):
        outputs = self.model(pixel_values=pixel_values, labels=labels)
        logits = outputs.logits if isinstance(outputs, dict) else outputs
        return logits

class SegformerModule(pl.LightningModule):
    def __init__(self, num_classes=2):
        super().__init__()
        self.model = SegformerForSemanticSegmentation.from_pretrained(
            'nvidia/segformer-b0-finetuned-ade-512-512',
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
        self.wrapper = SegformerWrapper(self.model)
        self.test_losses = []
        self.test_ious = []
        self.metrics = torchmetrics.JaccardIndex(task='multiclass', num_classes=self.model.config.num_labels, average='macro')

          #Initialize example input array with a dummy input
        self.example_input_array = torch.randn(1, 3, 640, 480)

    def forward(self, pixel_values, labels=None):
        return self.wrapper(pixel_values=pixel_values, labels=labels)

    def compute_loss(self, outputs, labels):
        logits = outputs
        logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode='bilinear', align_corners=False)
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return loss, logits

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images, labels=labels)
        loss, logits = self.compute_loss(outputs, labels)
        self.log('train_loss', loss)

        preds = torch.argmax(logits, dim = 1)
        self.metrics.update(preds, labels)
        self.log('train_iou', self.metrics, on_step = False, on_epoch = True)

        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images, labels=labels)
        val_loss, logits = self.compute_loss(outputs, labels)
        self.log('val_loss', val_loss)

        preds = torch.argmax(logits, dim=1)
        self.metrics.update(preds, labels)
        self.log('val_iou', self.metrics, on_step=False, on_epoch=True)

        return val_loss

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images, labels=labels)
        loss, logits = self.compute_loss(outputs, labels)
        self.log('test_loss', loss)

        preds = torch.argmax(logits, dim=1)
        self.metrics.update(preds, labels)
        self.log('test_iou', self.metrics, on_step=False, on_epoch=True)

        self.test_losses.append(loss.detach())
        return {'test_loss': loss}

    def on_test_epoch_end(self):
        avg_loss = torch.stack(self.test_losses).mean()
        avg_iou = self.metrics.compute()  # Compute the final IoU score
        self.log('avg_test_loss', avg_loss)
        self.log('avg_iou_score', avg_iou)
        self.test_losses.clear()
        self.test_ious.clear()
        self.metrics.reset()  # Reset metric states for the next epoch

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.0001)
        return optimizer

In [10]:
#Check If cuda is available
# print(torch.cuda.is_available())
# print(torch.cuda.device_count())

# Model Initialization

In [None]:
# Initialize model and checkpoint
num_classes = 2 
segformer = SegformerModule(num_classes)
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', dirpath='./checkpoints')

# Defining trainer attributes

In [None]:
trainer = Trainer(    
        max_epochs = 128,
        accelerator = 'gpu' if torch.cuda.is_available() else 'cpu',
        devices = 1, 
        callbacks = [checkpoint_callback]
    )

# Training the model

In [None]:
# Train the model
trainer.fit(segformer, train_loader, val_loader)

# Testing the model

In [None]:
# Test the model
trainer.test(model = segformer, dataloaders = test_loader)

# Saving the model

In [None]:
model_save_path = './checkpoints/segformer_model.pth'
torch.save(segformer.state_dict(), model_save_path)

# Loading model from checkpoint

In [None]:
# Assuming `checkpoint_callback.best_model_path` contains the path to the best model checkpoint
model_path = checkpoint_callback.best_model_path
segformer = SegformerModule.load_from_checkpoint(model_path, num_classes=2)
segformer.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
segformer.to(device);

# Visualizing model's predictions

In [None]:
# Function to de-normalize the image
# def denormalize(image_tensor, mean, std):
#     image_np = image_tensor.permute(1, 2, 0).cpu().numpy()  # Convert to HWC format
#     mean = np.array(mean)
#     std = np.array(std)
#     image_np = std * image_np + mean  # De-normalize
#     image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)  # Convert to uint8
#     return image_np

def denormalize(image_tensor):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy()  # Convert to HWC format
    image_np = std * image_np + mean  # De-normalize
    image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)  # Convert to uint8
    return image_np

color_map = {
    0: (0, 0, 0),   # Background
    1: (255, 0, 0), # Original mask
    2: (0, 255, 0), # Predicted mask
}

# Define mean and std for de-normalization
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Ensure your existing model instance is in evaluation mode
segformer.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
segformer.to(device)  # Move the model to the appropriate device

# Function to overlay mask
def overlay_mask(image_np, mask_np, color):
    vis_shape = image_np.shape
    overlay = image_np.copy()
    for i in range(3):
        overlay[:, :, i] = np.where(mask_np == 1, color[i], overlay[:, :, i])
    return Image.blend(Image.fromarray(image_np), Image.fromarray(overlay), alpha=0.5)

with torch.no_grad():
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = segformer(images)
        logits = outputs.logits if hasattr(outputs, 'logits') else outputs

        upsampled_logits = torch.nn.functional.interpolate(logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)
        predicted = upsampled_logits.argmax(dim=1).cpu().numpy()
        masks = masks.cpu().numpy()

        for i in range(len(images)):
            image_np = denormalize(images[i], mean, std)
            original_mask_np = masks[i]
            pred_mask_np = predicted[i]

            original_overlay = overlay_mask(image_np, original_mask_np, color_map[1])
            predicted_overlay = overlay_mask(image_np, pred_mask_np, color_map[1])

            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(original_overlay)
            plt.title('Image with Original Mask')
            plt.axis('off')

            plt.subplot(1, 2, 2)
            plt.imshow(predicted_overlay)
            plt.title('Image with Predicted Mask')
            plt.axis('off')

            plt.show()
        break