In [None]:
"""
# -*- coding: utf-8 -*-
-----------------------------------------------------------------------------------
# Author: Zakria Mehmood
# DoC: 2025.04.15
# email: zakriamehmood2001@gmail.com
-----------------------------------------------------------------------------------
# Description: COMPLETE NOTEBOOK OF THE CODE FOR EASE OF EXECUTION
# This code is a PyTorch implementation of a Vision Transformer (ViT) Auto Encoder for a research article related to Semi & Self supervised learning.

# The code includes the following components:
# 1. Imports: Necessary libraries and modules for the implementation.
# 2. Configuration: A configuration file that contains hyperparameters and settings for the model.
# 3. Data Preparation: Functions to load and preprocess the dataset.
# 4: All Utility functions: Functions for various utility tasks such as logging, saving checkpoints, and loading data.
# 5. Model Definition: The Vision Transformer model architecture.
# 6. Training and Evaluation: Functions to train the model and evaluate its performance.
# 7. Testing: Functions to test the model on a specific dataset.

"""

# PyTorch and related imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset, random_split
import torch.utils.data.distributed
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter


# Vision and image processing
import torchvision
import torchvision.models as models
from torchvision import transforms
from PIL import Image
# Data manipulation and utilities
import numpy as np
from einops import rearrange, einops
from tqdm import tqdm
from easydict import EasyDict as edict
from torchinfo import summary

# System and miscellaneous
import os
import sys
import time
import math
import glob
import copy
import random
import logging
import warnings
import yaml
import matplotlib.pyplot as plt


## Network configs

In [None]:
# Define your model configurations dictionary
MODEL_CONFIGS = {
    'model_base': {
        # Base model configuration
        'in_channels': 3,
        'image_size': 256,

        # Standard architecture parameters
        'embed_dim': 512,
        'depth': 12,
        'num_heads': 8,
        'decoder_embed_dim': 256,
        'decoder_depth': 4,
        'decoder_num_heads': 8,
        'mlp_ratio': 4,
        'activation': 'gelu',
        'recon_method': 'method1',
        'dropout': 0.05,

        # Additional parameters
        'segmentation_features': True,
        'edge_aware_loss': True
    }
}


## Configs

In [None]:


def get_configs():
    configs = {
        # Task configuration
        'task_type': 'segmentation',  # pretext (reconstruction), segmentation_pretraining, or segmentation
        
        # General configs
        'patch_size': 16,
        'model': 'model_base',
        'saved_fn': 'vit_autoencoder',
        'arch': 'vit',
        
        # Dataset configs
        'root_dir': 'another_one/Augmented_Dataset',
        'split_ratios': [0.8, 0.15, 0.05],
        'shuffle_data': True,
        'seed': 42,
        
        # Training configs
        'num_epochs': 100,
        'start_epoch': 0,
        'optimizer': 'adamw',
        'device': 'cuda',
        'loss_function': 'bce', #only works on segmentation task
        'batch_size': 16,
        'num_workers': 4,
        'activation_function': 'gelu',
        'early_stopping_patience': 20,
        'clip_grad_norm': 1.0,
        'use_amp': True,
        'evaluate': False,
        'print_freq': 50,
        'checkpoint_freq': 5,
        'save_best_freq': 1,
        
        # Learning rate configs
        'lr': 5e-5,
        'lr_type': 'cosine',
        'burn_in': 10,
        'steps': [1500, 4000],
        'lr_step_size': 10,
        'lr_gamma': 0.1,
        'weight_decay': 1e-4,
        'minimum_lr': 1e-6,
        'milestones': [30, 60, 90],
        'e_gamma': 0.95,
        't_max': 10,
        'min_lr': 1e-6,
        'base_lr': 1e-6,
        'max_lr': 1e-3,
        'step_size_up': 2000,
        'cyclic_mode': 'triangular',
        
        # Distributed training configs
        'world_size': -1,
        'rank': -1,
        'dist_url': 'tcp://127.0.0.1:29500',
        'dist_backend': 'nccl',
        'gpu_idx': None,
        'no_cuda': False,
        'distributed': False,
        
        # Paths
        'checkpoints_dir': 'checkpoints',
        'pretrained_path': 'another_one/checkpoints/Model_vit_autoencoder_best_epoch_95_segmentation.pth',
        'logs_dir': 'logs_4',
        'results_dir': 'results_4',
        'model_dir': 'model',
        
        # New configs
        'early_stopping': True,
        'patience': 15,
        'save_checkpoint_freq': 5,
        'step_lr_in_epoch': True,
        'warmup_epochs': 5,
        'normalize_data': True
    }
    
    # Convert to EasyDict
    configs = edict(configs)
    
    # Update with model specific configs
    if configs.model in MODEL_CONFIGS:
        model_config = MODEL_CONFIGS[configs.model]
        for key, value in model_config.items():
            setattr(configs, key, value)
    
    return configs


## ALL THE UTILITIES

In [None]:
def mask_transform_fn(mask):
    """Transforms the mask by converting it into a tensor of 0s and 1s."""
    return torch.tensor(np.array(mask) > 0, dtype=torch.float32)

def cleanup():
    """Clean up distributed training resources."""
    if dist.is_initialized():
        dist.destroy_process_group()

def calculate_metrics(outputs, masks):
    """
    Calculate performance metrics for segmentation.
    
    Args:
        outputs: Model predictions
        masks: Ground truth masks
        
    Returns:
        dict: Dictionary containing various metrics
    """
    # Ensure inputs are properly shaped
    if outputs.ndim == 3:
        outputs = outputs.unsqueeze(1)
    if masks.ndim == 3:
        masks = masks.unsqueeze(1)
        
    # Convert outputs to grayscale and apply sigmoid
    grayscale_outputs = 0.299 * outputs[:, 0, :, :] + 0.587 * outputs[:, 1, :, :] + 0.114 * outputs[:, 2, :, :]
    grayscale_outputs = grayscale_outputs.unsqueeze(1)
    predictions = torch.sigmoid(grayscale_outputs)
    
    # Threshold predictions to get binary mask
    predictions = (predictions > 0.5).float()
    
    # Ensure masks are float tensors
    masks = masks.float()
    
    # Ensure predictions and masks have the same shape
    if predictions.shape != masks.shape:
        predictions = predictions.squeeze(1)
        masks = masks.squeeze(1)
    
    # Calculate intersection and union using float operations
    intersection = (predictions * masks).sum((1, 2) if predictions.ndim == 3 else (1, 2, 3))
    union = (predictions + masks).clamp(0, 1).sum((1, 2) if predictions.ndim == 3 else (1, 2, 3))
    
    # IoU (Jaccard)
    iou = (intersection + 1e-6) / (union + 1e-6)
    mean_iou = iou.mean().item()
    
    # Dice coefficient
    dice = (2 * intersection + 1e-6) / (predictions.sum((1, 2) if predictions.ndim == 3 else (1, 2, 3)) + masks.sum((1, 2) if predictions.ndim == 3 else (1, 2, 3)) + 1e-6)
    mean_dice = dice.mean().item()
    
    # Precision and Recall
    true_positives = intersection
    false_positives = predictions.sum((1, 2) if predictions.ndim == 3 else (1, 2, 3)) - intersection
    false_negatives = masks.sum((1, 2) if predictions.ndim == 3 else (1, 2, 3)) - intersection
    
    precision = (true_positives + 1e-6) / (true_positives + false_positives + 1e-6)
    recall = (true_positives + 1e-6) / (true_positives + false_negatives + 1e-6)
    
    mean_precision = precision.mean().item()
    mean_recall = recall.mean().item()
    
    # F1 Score
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
    mean_f1 = f1.mean().item()
    
    return {
        'iou': mean_iou,
        'dice': mean_dice,
        'precision': mean_precision,
        'recall': mean_recall,
        'f1': mean_f1
    }

class Logger():
    """
        Create logger to save logs during training
        Args:
            logs_dir:
            saved_fn:

        Returns:

        """

    def __init__(self, logs_dir, saved_fn):
        logger_fn = 'logger_{}.txt'.format(saved_fn)
        logger_path = os.path.join(logs_dir, logger_fn)
#/content/Self Supervised Learning - Copy/logs/logger_vit_autoencoder.txt
        with open(logger_path, "w") as file:
          pass


        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

        # formatter = logging.Formatter('%(asctime)s:File %(module)s.py:Func %(funcName)s:Line %(lineno)d:%(levelname)s: %(message)s')
        formatter = logging.Formatter(
            '%(asctime)s: %(module)s.py - %(funcName)s(), at Line %(lineno)d:%(levelname)s:\n%(message)s')

        file_handler = logging.FileHandler(logger_path)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)

        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(formatter)

        self.logger.addHandler(file_handler)
        self.logger.addHandler(stream_handler)

    def info(self, message):
        self.logger.info(message)




def calculate_iou(preds, targets, num_classes=1):
    preds = preds.argmax(dim=1)  # Get the predicted class
    iou_list = []
    for cls in range(num_classes):
        intersection = ((preds == cls) & (targets == cls)).sum().item()
        union = ((preds == cls) | (targets == cls)).sum().item()
        iou = intersection / (union + 1e-6)  # Avoid division by zero
        iou_list.append(iou)
    return sum(iou_list) / len(iou_list) if iou_list else 0


def calculate_dice(preds, targets):
    preds = preds.argmax(dim=1)  # Get the predicted class
    intersection = (preds & targets).sum().item()
    return (2. * intersection) / (preds.sum().item() + targets.sum().item() + 1e-6)


def calculate_mse(preds, targets):
    return F.mse_loss(preds, targets)


def calculate_psnr(mse):
    return 20 * torch.log10(1.0 / torch.sqrt(mse))




def make_folder(folder_name):
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    # or os.makedirs(folder_name, exist_ok=True)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def get_message(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        return '\t'.join(entries)

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()





def early_stopping(val_losses, patience, logger):
    """
    Implements early stopping based on validation loss.

    Args:
        val_losses (list): List of validation losses for each epoch.
        patience (int): Number of epochs to wait before stopping if no improvement.

    Returns:
        (bool, int): (stop_training_flag, best_epoch)
    """
    if len(val_losses) < patience:
        return False, -1  # Not enough epochs to decide

    min_loss = min(val_losses)
    best_epoch = val_losses.index(min_loss) + 1  # Convert to 1-based index

    # If the best validation loss didn't change in the last `patience` epochs, stop training
    if best_epoch < len(val_losses) - patience:
        logger.info(
            f"Early stopping triggered at epoch {len(val_losses)}. Best epoch: {best_epoch}")
        return True, best_epoch
    return False, best_epoch


def save_best_model(model, optimizer, lr_scheduler, epoch, best_loss, val_loss, model_type, logger, configs):
    """
    Saves the best model checkpoint based on validation loss and deletes previous models.

    Args:
        model (torch.nn.Module): The model being trained.
        optimizer (torch.optim.Optimizer): The optimizer used for training.
        lr_scheduler (torch.optim.lr_scheduler): The learning rate scheduler.
        epoch (int): Current epoch number.
        best_loss (float): The best validation loss recorded.
        val_loss (float): The current validation loss.
        configs (Namespace): Configuration parameters.
    """
    if val_loss < best_loss:
        best_loss = val_loss
        model_state_dict, utils_state_dict = get_saved_state(
            model, optimizer, lr_scheduler, epoch, configs
        )

        model_save_path = os.path.join(
            configs.model_dir, f'Model_{configs.saved_fn}_epoch_{epoch}_'+model_type+'.pth')
        utils_save_path = os.path.join(
            configs.model_dir, f'Utils_{configs.saved_fn}_epoch_{epoch}_'+model_type+'.pth')

        # Save the new best model
        torch.save(model_state_dict, model_save_path)
        torch.save(utils_state_dict, utils_save_path)
        logger.info(
            f"New best model saved at epoch {epoch} with val_loss: {val_loss:.4f}")

        # Delete previous models to save space
        for file in os.listdir(configs.model_dir):
            if file.startswith('Model_') and file != os.path.basename(model_save_path):
                os.remove(os.path.join(configs.model_dir, file))
            if file.startswith('Utils_') and file != os.path.basename(utils_save_path):
                os.remove(os.path.join(configs.model_dir, file))
                logger.info(f"Deleted previous model checkpoint: {file}")

    return best_loss


def get_saved_state(model, optimizer, lr_scheduler, epoch, configs):
    """Get the information to save with checkpoints"""
    if hasattr(model, 'module'):
        model_state_dict = model.module.state_dict()
    else:
        model_state_dict = model.state_dict()
    utils_state_dict = {
        'epoch': epoch,
        'configs': configs,
        'optimizer': copy.deepcopy(optimizer.state_dict()),
        'lr_scheduler': copy.deepcopy(lr_scheduler.state_dict())
    }

    return model_state_dict, utils_state_dict


def save_checkpoint(checkpoints_dir, saved_fn, model_state_dict, utils_state_dict, epoch, model_type):
    """Save checkpoint every epoch only is best model or after every checkpoint_freq epoch"""

    model_save_path = os.path.join(
        checkpoints_dir, f'Model_{saved_fn}_epoch_{epoch}_'+model_type+'.pth')
    utils_save_path = os.path.join(
        checkpoints_dir, f'Utils_{saved_fn}_epoch_{epoch}_'+model_type+'.pth')

    torch.save(model_state_dict, model_save_path)
    torch.save(utils_state_dict, utils_save_path)

    print('save a checkpoint at {}'.format(model_save_path))


def reduce_tensor(tensor, world_size):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= world_size
    return rt


def to_python_float(t):
    if hasattr(t, 'item'):
        return t.item()
    else:
        return t[0]


def get_tensorboard_log(model):
    if hasattr(model, 'module'):
        yolo_layers = model.module.yolo_layers
    else:
        yolo_layers = model.yolo_layers

    tensorboard_log = {}
    tensorboard_log['Average_All_Layers'] = {}
    for idx, yolo_layer in enumerate(yolo_layers, start=1):
        layer_name = 'YOLO_Layer{}'.format(idx)
        tensorboard_log[layer_name] = {}
        for name, metric in yolo_layer.metrics.items():
            tensorboard_log[layer_name]['{}'.format(name)] = metric
            if idx == 1:
                tensorboard_log['Average_All_Layers']['{}'.format(
                    name)] = metric / len(yolo_layers)
            else:
                tensorboard_log['Average_All_Layers']['{}'.format(
                    name)] += metric / len(yolo_layers)

    return tensorboard_log


def plot_lr_scheduler(optimizer, scheduler, num_epochs=300, save_dir=''):
    # Plot LR simulating training for full num_epochs
    optimizer, scheduler = copy.copy(optimizer), copy.copy(
        scheduler)  # do not modify originals
    y = []
    for _ in range(num_epochs):
        scheduler.step()
        y.append(optimizer.param_groups[0]['lr'])
    plt.plot(y, '.-', label='LR')
    plt.xlabel('epoch')
    plt.ylabel('LR')
    plt.grid()
    plt.xlim(0, num_epochs)
    plt.ylim(0)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'LR.png'), dpi=200)





def get_loss_function(loss_name, smooth=1e-6):
    """
    Get a loss function by name.

    Args:
        loss_name (str): Name of the loss function
        smooth (float): Smoothing factor for dice loss

    Returns:
        callable: Loss function
    """

    loss_functions = {
        'mse': nn.MSELoss(),
        'l1': nn.L1Loss(),
        'bce': nn.BCEWithLogitsLoss(),
        'perceptual': PerceptualLoss(),
        'ssim': SSIMLoss(),
    }

    # Special cases that need custom handling
    if loss_name == 'dice':
        def dice_loss(pred, target, smooth=smooth):
            pred = torch.sigmoid(pred)
            pred_flat = pred.view(-1)
            target_flat = target.view(-1)
            intersection = (pred_flat * target_flat).sum()
            return 1 - ((2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth))
        return dice_loss
    
    elif loss_name == 'combined':
        return CombinedLoss(alpha=0.8, beta=0.2, gamma=0.1)
    
    elif loss_name == 'l1_ssim':
        # Combination of L1 and SSIM
        l1_loss = nn.L1Loss()
        ssim_loss = SSIMLoss()
        def l1_ssim_loss(pred, target):
            return 0.5 * l1_loss(pred, target) + 0.5 * ssim_loss(pred, target)
        return l1_ssim_loss

    # Return the loss function if it exists, otherwise default to combined loss
    if loss_name in loss_functions:
        return loss_functions[loss_name]
    
    else:
        print(
            f"Unknown loss function: {loss_name}, using combined loss instead")
        return CombinedLoss(alpha=0.8, beta=0.2, gamma=0.1)


def get_activation_function(activation_name):
    """
    Get an activation function by name.

    Args:
        activation_name (str): Name of the activation function

    Returns:
        nn.Module: Activation function
    """
    activations = {
        "relu": nn.ReLU(),
        "gelu": nn.GELU(),
        "sigmoid": nn.Sigmoid(),
        "tanh": nn.Tanh(),
        "leaky_relu": nn.LeakyReLU(),
        "silu": nn.SiLU(),  # Added SiLU/Swish activation
        "mish": nn.Mish(),  # Added Mish activation
    }
    return activations.get(activation_name.lower(), nn.ReLU())


def get_optimizer(optimizer_name, model_params, learning_rate=1e-4, weight_decay=1e-5, momentum=0.9):
    """
    Get an optimizer by name.

    Args:
        optimizer_name (str): Name of the optimizer
        model_params (iterable): Model parameters to optimize
        learning_rate (float): Learning rate
        weight_decay (float): Weight decay factor
        momentum (float): Momentum factor for SGD

    Returns:
        torch.optim.Optimizer: Optimizer
    """
    optimizer_name = optimizer_name.lower()

    if optimizer_name == "sgd":
        return optim.SGD(model_params, lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
    elif optimizer_name == "adam":
        return optim.Adam(model_params, lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_name == "adamw":
        return optim.AdamW(model_params, lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_name == "rmsprop":
        return optim.RMSprop(model_params, lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_name == "adagrad":
        return optim.Adagrad(model_params, lr=learning_rate, weight_decay=weight_decay)
    else:
        print(f"Unknown optimizer: {optimizer_name}, using AdamW instead")
        return optim.AdamW(model_params, lr=learning_rate, weight_decay=weight_decay)


def create_optimizer(configs, model):
    """
    Create an optimizer for training a transformer-based model with parameter groups.

    This function splits the model parameters into two groups:
      1. Parameters that will receive weight decay (all except biases and normalization parameters).
      2. Parameters that will not receive weight decay (biases and normalization parameters).

    Args:
        configs: Configuration object with optimizer settings
        model (nn.Module): The model to optimize

    Returns:
        torch.optim.Optimizer: Configured optimizer
    """
    # Retrieve the model parameters (supporting models wrapped in DistributedDataParallel)
    if hasattr(model, 'module'):
        params_dict = dict(model.module.named_parameters())
    else:
        params_dict = dict(model.named_parameters())

    # Define keywords to exclude from weight decay
    no_decay_keywords = ["bias", "norm",
                         "LayerNorm.weight", "layer_norm.weight"]

    # Split parameters into decay and no-decay groups
    decay_params = []
    no_decay_params = []

    for k, v in params_dict.items():
        if any(nd in k for nd in no_decay_keywords):
            no_decay_params.append(v)
        else:
            decay_params.append(v)

    # Create parameter groups
    optimizer_grouped_parameters = [
        {"params": decay_params, "weight_decay": configs.weight_decay},
        {"params": no_decay_params, "weight_decay": 0.0},
    ]

    # Get optimizer settings
    optimizer_name = configs.optimizer.lower()
    learning_rate = configs.lr

    # Create the optimizer
    if optimizer_name == 'sgd':
        optimizer = optim.SGD(
            optimizer_grouped_parameters,
            lr=learning_rate,
            momentum=getattr(configs, 'momentum', 0.9),
            nesterov=True
        )
    elif optimizer_name == 'adam':
        optimizer = optim.Adam(optimizer_grouped_parameters, lr=learning_rate)
    elif optimizer_name == 'adamw':
        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)
    elif optimizer_name == 'rmsprop':
        optimizer = optim.RMSprop(
            optimizer_grouped_parameters, lr=learning_rate)
    elif optimizer_name == 'adagrad':
        optimizer = optim.Adagrad(
            optimizer_grouped_parameters, lr=learning_rate)
    else:
        print(f"Unknown optimizer: {optimizer_name}, using AdamW instead")
        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Print information about parameter groups
    print(f'Optimizer: {optimizer_name}, Learning rate: {learning_rate}')
    print(
        f'Parameter groups: {len(decay_params)} parameters with weight decay, {len(no_decay_params)} parameters without weight decay')

    return optimizer


def create_lr_scheduler(optimizer, configs):
    """
    Create a learning rate scheduler based on configuration.

    Args:
        optimizer: The optimizer to schedule
        configs: Configuration object with scheduler settings

    Returns:
        torch.optim.lr_scheduler._LRScheduler: Learning rate scheduler
    """
    scheduler_name = configs.lr_type.lower()

    # Handle warmup if specified
    if hasattr(configs, 'warmup_epochs') and configs.warmup_epochs > 0:
        # Create warmup scheduler
        def warmup_lambda(epoch):
            if epoch < configs.warmup_epochs:
                return float(epoch) / float(max(1, configs.warmup_epochs))
            return 1.0

        warmup_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=warmup_lambda)
        print(f"Using {configs.warmup_epochs} epochs of warmup")
    else:
        warmup_scheduler = None

    # Create main scheduler
    if scheduler_name == 'multi_step':
        def burnin_schedule(i):
            if i < configs.burn_in:
                factor = pow(i / configs.burn_in, 4)
            elif i < configs.steps[0]:
                factor = 1.0
            elif i < configs.steps[1]:
                factor = 0.1
            else:
                factor = 0.01
            return factor

        main_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, burnin_schedule)

    elif scheduler_name == 'cosin':
        # Cosine annealing with warmup
        def cosine_lambda(epoch):
            if warmup_scheduler is not None and epoch < configs.warmup_epochs:
                return warmup_scheduler.get_lr()[0] / configs.lr

            # Cosine decay from https://arxiv.org/pdf/1812.01187.pdf
            adjusted_epoch = epoch - \
                (configs.warmup_epochs if warmup_scheduler else 0)
            adjusted_total = configs.num_epochs - \
                (configs.warmup_epochs if warmup_scheduler else 0)
            return ((1 + math.cos(adjusted_epoch * math.pi / adjusted_total)) / 2) ** 1.0 * 0.9 + 0.1

        main_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=cosine_lambda)

    elif scheduler_name == 'step_lr':
        main_scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=configs.lr_step_size, gamma=configs.lr_gamma)

    elif scheduler_name == 'multistep_lr':
        main_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=configs.milestones, gamma=configs.lr_gamma)

    elif scheduler_name == 'exponential_lr':
        main_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=configs.e_gamma)

    elif scheduler_name == 'cosine_annealing_lr':
        main_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=configs.t_max, eta_min=configs.min_lr)

    elif scheduler_name == 'reduce_lr_on_plateau':
        main_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=configs.lr_gamma, patience=configs.patience, min_lr=configs.min_lr)

    elif scheduler_name == 'cyclic_lr':
        main_scheduler = optim.lr_scheduler.CyclicLR(
            optimizer, base_lr=configs.base_lr, max_lr=configs.max_lr,
            step_size_up=configs.step_size_up, mode=configs.cyclic_mode)
    else:
        print(
            f"Unknown scheduler: {scheduler_name}, using CosineAnnealingLR instead")
        main_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=configs.num_epochs, eta_min=getattr(configs, 'min_lr', 1e-6))

    # If using warmup, return a SequentialLR
    if warmup_scheduler is not None and scheduler_name != 'cosin':
        return optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, main_scheduler],
            milestones=[configs.warmup_epochs]
        )

    return main_scheduler


# Custom loss functions for image reconstruction
class PerceptualLoss(nn.Module):
    """
    Perceptual loss using VGG16 features.
    """

    def __init__(self, layers=[4, 9], weights=[1.0, 1.0]):
        """
        Args:
            layers (list): Indices of VGG layers to use for feature extraction
            weights (list): Weights for each layer's contribution to the loss
        """
        super().__init__()
        # Load pretrained VGG16 model
        vgg = models.vgg16(pretrained=True).features

        # Create slices for feature extraction
        self.slices = nn.ModuleList()
        start_idx = 0
        for end_idx in layers:
            self.slices.append(nn.Sequential(
                *list(vgg.children())[start_idx:end_idx]))
            start_idx = end_idx

        self.weights = weights

        # Freeze parameters
        for param in self.parameters():
            param.requires_grad = False

        # Register mean and std for normalization
        self.register_buffer('mean', torch.tensor(
            [0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor(
            [0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def _normalize(self, img):
        """Normalize image for VGG"""
        return (img - self.mean) / self.std

    def forward(self, input, target):
        """
        Args:
            input (torch.Tensor): Predicted image
            target (torch.Tensor): Target image

        Returns:
            torch.Tensor: Perceptual loss
        """
        # Make sure input and target are on the same device as the model
        device = next(self.parameters()).device
        input = input.to(device)
        target = target.to(device)

        # Normalize inputs
        input = self._normalize(input)
        target = self._normalize(target)

        # Extract features and compute loss
        loss = 0.0
        input_features = input
        target_features = target

        for i, slice in enumerate(self.slices):
            input_features = slice(input_features)
            target_features = slice(target_features)
            loss += self.weights[i] * \
                F.mse_loss(input_features, target_features)

        return loss


class SSIMLoss(nn.Module):
    """
    Structural Similarity Index (SSIM) loss.
    """

    def __init__(self, window_size=11, size_average=True):
        """
        Args:
            window_size (int): Size of the Gaussian window
            size_average (bool): Whether to average the loss over spatial dimensions
        """
        super().__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = self._create_window(window_size)

    def _create_window(self, window_size):
        """Create a Gaussian window"""
        _1D_window = torch.Tensor([1.0]).expand(window_size).unsqueeze(1)
        _1D_window = _1D_window * _1D_window.t()
        _1D_window = _1D_window / _1D_window.sum()
        window = _1D_window.unsqueeze(0).unsqueeze(0)
        return window

    def forward(self, img1, img2):
        """
        Args:
            img1 (torch.Tensor): First image
            img2 (torch.Tensor): Second image

        Returns:
            torch.Tensor: 1 - SSIM (as a loss, lower is better)
        """
        # Move window to same device as input
        window = self.window.to(img1.device)
        window = window.expand(
            img1.size(1), 1, self.window_size, self.window_size)

        # Calculate means
        mu1 = F.conv2d(img1, window, padding=self.window_size //
                       2, groups=img1.size(1))
        mu2 = F.conv2d(img2, window, padding=self.window_size //
                       2, groups=img2.size(1))

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        # Calculate variances and covariance
        sigma1_sq = F.conv2d(
            img1 * img1, window, padding=self.window_size//2, groups=img1.size(1)) - mu1_sq
        sigma2_sq = F.conv2d(
            img2 * img2, window, padding=self.window_size//2, groups=img2.size(1)) - mu2_sq
        sigma12 = F.conv2d(
            img1 * img2, window, padding=self.window_size//2, groups=img1.size(1)) - mu1_mu2

        # Constants for stability
        C1 = 0.01 ** 2
        C2 = 0.03 ** 2

        # Calculate SSIM
        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
            ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

        # Return 1 - SSIM to convert to a loss (lower is better)
        if self.size_average:
            return 1 - ssim_map.mean()
        else:
            return 1 - ssim_map.mean(1).mean(1).mean(1)


class CombinedLoss(nn.Module):
    """
    Combined loss for image reconstruction, using L1, MSE, and perceptual losses.
    """

    def __init__(self, alpha=0.8, beta=0.2, gamma=0.1):
        """
        Args:
            alpha (float): Weight for L1 loss
            beta (float): Weight for MSE loss
            gamma (float): Weight for perceptual loss
        """
        super().__init__()
        self.alpha = alpha  # Weight for L1 loss
        self.beta = beta    # Weight for MSE loss
        self.gamma = gamma  # Weight for perceptual loss
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        self.perceptual_loss = PerceptualLoss()

    def forward(self, pred, target):
        """
        Args:
            pred (torch.Tensor): Predicted image
            target (torch.Tensor): Target image

        Returns:
            torch.Tensor: Combined loss
        """
        # Compute L1 and MSE losses
        l1 = self.l1_loss(pred, target)
        mse = self.mse_loss(pred, target)

        # Compute perceptual loss
        perceptual = self.perceptual_loss(pred, target)

        # Weighted combination
        return self.alpha * l1 + self.beta * mse + self.gamma * perceptual





def visualize_reconstructions(model, dataloader, device, epoch, save_dir):
    """
    Visualize and save model reconstructions during training
    """
    model.eval()
    with torch.no_grad():
        # Get a batch of images
        images, _ = next(iter(dataloader))
        images = images.to(device)

        # Generate reconstructions
        reconstructions = model(images)
        reconstructions = reconstructions.to(device)
        # Create a grid of original and reconstructed images
        comparison = torch.cat([images[:8], reconstructions[:8]])
        grid = torchvision.utils.make_grid(comparison, nrow=8, normalize=True)

        # Save the grid
        os.makedirs(save_dir, exist_ok=True)
        torchvision.utils.save_image(
            grid, f"{save_dir}/reconstruction_epoch_{epoch}.png")

        # Also save as matplotlib figure for better visualization
        plt.figure(figsize=(20, 10))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.title(f"Reconstructions at Epoch {epoch}")
        plt.axis('off')
        plt.savefig(f"{save_dir}/reconstruction_plot_epoch_{epoch}.png")
        plt.close()

    model.train()
    return reconstructions




def denormalize(img, mean, std):
    """
    Undo normalization for visualization.
    img: (C, H, W) tensor
    """
    img = img.clone()
    for c in range(3):
        img[c] = img[c] * std[c] + mean[c]
    img = torch.clamp(img, 0, 1)  # Clamp to [0, 1] for safe visualization
    return img

def visualize_segmentation(model, dataloader, device, epoch, save_dir):
    model.eval()
    with torch.no_grad():
        images, masks = next(iter(dataloader))
        images = images.to(device)
        masks = masks.to(device)

        # Fix masks shape
        if masks.ndim == 3:
            masks = masks.unsqueeze(1)

        predictions = model(images)
        grayscale_outputs = 0.299 * predictions[:, 0, :, :] + 0.587 * predictions[:, 1, :, :] + 0.114 * predictions[:, 2, :, :]
        grayscale_outputs = grayscale_outputs.unsqueeze(1)
        predictions = torch.sigmoid(grayscale_outputs)
        predictions = (predictions > 0.5).float()

        # Move tensors to CPU
        images = images.cpu()
        masks = masks.cpu()
        predictions = predictions.cpu()

        # Denormalization parameters (standard ImageNet mean/std)
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        num_samples = min(8, images.size(0))
        fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

        for i in range(num_samples):
            # De-normalize input image
            img = denormalize(images[i], mean, std)

            # Input image
            axes[i, 0].imshow(img.permute(1, 2, 0).numpy())
            axes[i, 0].set_title("Input Image")
            axes[i, 0].axis("off")

            # Ground truth mask
            axes[i, 1].imshow(masks[i, 0].numpy(), cmap="gray")
            axes[i, 1].set_title("Ground Truth Mask")
            axes[i, 1].axis("off")

            # Predicted mask
            axes[i, 2].imshow(predictions[i, 0].numpy(), cmap="gray")
            axes[i, 2].set_title("Predicted Mask")
            axes[i, 2].axis("off")

        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"segmentation_epoch_{epoch}.png")
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

        print(f"Segmentation visualization saved at {save_path}")

    model.train()

def visualize_reconstructions_2(reconstructions, images, save_dir, batch):
    """
    Visualize and save model reconstructions during testing
    """
    
    # Create a grid of original and reconstructed images
    comparison = torch.cat([images[:8], reconstructions[:8]])
    grid = torchvision.utils.make_grid(comparison, nrow=8, normalize=True)

    torchvision.utils.save_image(
        grid, f"{save_dir}/reconstruction_results_{batch}.png")

    # Also save as matplotlib figure for better visualization
    plt.figure(figsize=(20, 10))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.title(f"Reconstructions of test batch {batch}")
    plt.axis('off')
    plt.savefig(f"{save_dir}/reconstruction_results_{batch}.png")
    plt.close()


def visualize_segmentation_2(predictions, images, masks, batch, save_dir):
    """
    Visualize and save segmentation results during testing with error handling.
    """
    try:
        # Fix masks shape
        if masks.ndim == 3:
            masks = masks.unsqueeze(1)

        grayscale_outputs = 0.299 * predictions[:, 0, :, :] + 0.587 * predictions[:, 1, :, :] + 0.114 * predictions[:, 2, :, :]
        grayscale_outputs = grayscale_outputs.unsqueeze(1)
        predictions = torch.sigmoid(grayscale_outputs)
        predictions = (predictions > 0.5).float()

        # Move tensors to CPU
        try:
            images = images.cpu()
            masks = masks.cpu()
            predictions = predictions.cpu()
        except Exception as e:
            print(f"Error moving tensors to CPU: {e}")
            return

        # Denormalization parameters (standard ImageNet mean/std)
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        num_samples = min(16, images.size(0))
        try:
            fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
        except Exception as e:
            print(f"Error creating subplot figure: {e}")
            return

        try:
            for i in range(num_samples):
                # De-normalize input image
                img = denormalize(images[i], mean, std)

                # Input image
                axes[i, 0].imshow(img.permute(1, 2, 0).numpy())
                axes[i, 0].set_title("Input Image")
                axes[i, 0].axis("off")

                # Ground truth mask
                axes[i, 1].imshow(masks[i, 0].numpy(), cmap="gray")
                axes[i, 1].set_title("Ground Truth Mask")
                axes[i, 1].axis("off")

                # Predicted mask
                axes[i, 2].imshow(predictions[i, 0].numpy(), cmap="gray")
                axes[i, 2].set_title("Predicted Mask")
                axes[i, 2].axis("off")

        except Exception as e:
            print(f"Error plotting images: {e}")
            plt.close(fig)
            return

        try:
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f"segmentation_result_batch_{batch}.png")
            plt.tight_layout()
            plt.savefig(save_path)
            plt.close(fig)
        except Exception as e:
            print(f"Error saving visualization: {e}")
            plt.close(fig)
            return

    except Exception as e:
        print(f"Unexpected error in visualize_segmentation_2: {e}")
        # Ensure figure is closed in case of error
        try:
            plt.close()
        except:
            pass
        return







## DATALOADER AND DATASET

In [None]:
class ImageMaskDataset(Dataset):
    def __init__(self, root_dir, image_size=(256, 256)):
        """
        Dataset class for loading images and corresponding masks.
        Args:
        - root_dir (str): Root directory containing images and masks
        - image_size (tuple[int, int]): Tuple specifying image resize dimensions
        """
        assert os.path.exists(root_dir), f"Error: Directory '{root_dir}' does not exist."
        
        self.image_paths, self.mask_paths = self._load_paths(root_dir)
        self.valid_pairs = self._validate_pairs()
        print(f"Found {len(self.valid_pairs)} valid image-mask pairs")

        self.image_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.mask_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.Lambda(mask_transform_fn)
        ])

    
    def _load_paths(self, root_dir):
        """
        Load paths for images and their corresponding masks from multiple dataset directories.
        
        Args:
            root_dir (str): Root directory containing multiple dataset folders
            
        Returns:
            tuple[list, list]: Lists of image and mask file paths
        """
        image_paths = []
        mask_paths = []
        
        # Get all dataset directories
        dataset_dirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        
        for dataset_dir in dataset_dirs:
            dataset_path = os.path.join(root_dir, dataset_dir)
            images_dir = os.path.join(dataset_path, 'images')
            masks_dir = os.path.join(dataset_path, 'masks')
            
            # Skip if images or masks directory doesn't exist
            if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
                print(f"Warning: Skipping {dataset_dir} - missing images or masks directory")
                continue
            
            # Get list of image files
            image_files = sorted([f for f in os.listdir(images_dir) 
                                if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))])
            
            # Get list of mask files
            mask_files = sorted([f for f in os.listdir(masks_dir)
                            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))])
            
            # Verify matching numbers of images and masks
            if len(image_files) != len(mask_files):
                print(f"Warning: Skipping {dataset_dir} - number of images ({len(image_files)}) "
                    f"does not match number of masks ({len(mask_files)})")
                continue
            
            # Add full paths to lists
            image_paths.extend([os.path.join(images_dir, f) for f in image_files])
            mask_paths.extend([os.path.join(masks_dir, f) for f in mask_files])
            
            print(f"Added {len(image_files)} pairs from {dataset_dir}")
        
        if not image_paths:
            raise ValueError(f"No valid image-mask pairs found in {root_dir}")
        
        print(f"Found total of {len(image_paths)} image-mask pairs across all datasets")
        return image_paths, mask_paths


    def _validate_image(self, image_path):
        """Validate if an image file is readable"""
        try:
            with Image.open(image_path) as img:
                img.verify()
            return True
        except:
            print(f"Warning: Corrupted or unreadable image file: {image_path}")
            return False

    def _validate_pairs(self):
        """Validate all image-mask pairs and return only valid ones"""
        valid_pairs = []
        for img_path, mask_path in tqdm(zip(self.image_paths, self.mask_paths), 
                                      desc="Validating image-mask pairs",
                                      total=len(self.image_paths)):
            if self._validate_image(img_path) and self._validate_image(mask_path):
                valid_pairs.append((img_path, mask_path))
            else:
                print(f"Skipping corrupted pair:\nImage: {img_path}\nMask: {mask_path}\n")
        return valid_pairs

    def __getitem__(self, idx):
        """
        Return a single image-mask pair.
        Args:
        - idx (int): Index of the desired image-mask pair
        Returns:
        - tuple[torch.Tensor, torch.Tensor]: Transformed image and mask
        """
        try:
            image_path, mask_path = self.valid_pairs[idx]
            
            # Load and convert image
            with Image.open(image_path) as image:
                image = image.convert("RGB")
                image_tensor = self.image_transform(image)

            # Load and convert mask
            with Image.open(mask_path) as mask:
                mask = mask.convert("L")
                mask_tensor = self.mask_transform(mask)

            return image_tensor, mask_tensor

        except Exception as e:
            print(f"Error loading pair {idx}:\nImage: {image_path}\nMask: {mask_path}\nError: {str(e)}")
            # Return a zero tensor of appropriate size as fallback
            return torch.zeros(3, 256, 256), torch.zeros(256, 256)

    def __len__(self):
        return len(self.valid_pairs)
    

import cv2
import numpy as np


class DatasetLoader:
    def __init__(self, image_size, root_dir, split_ratios=[0.7, 0.15, 0.15], seed=42, 
                 Shuffle=True, batch_size=16, num_workers=2, task='pretext'):
        """
        Initialize DatasetLoader with optional augmentation.
        
        Args:
            augment (string): Whether to apply augmentation to the dataset depeding on the task type (pretext or downstream)
        """
        self.image_size = image_size
        self.task = task
        
        if task == 'pretext':
            # Perform augmentation before creating the dataset
            self.augment_dataset(root_dir)
            # Use augmented directory for dataset
            self.dataset = ImageMaskDataset(
                os.path.join(root_dir, 'augmented'), (image_size, image_size))
        else:
            self.dataset = ImageMaskDataset(root_dir, (image_size, image_size))
            
        self.split_ratios = split_ratios
        self.seed = seed
        self.shuffle_data = Shuffle
        self.batch_size = batch_size
        self.num_workers = num_workers

    def augment_dataset(self, root_dir):
        """Apply augmentation to the dataset."""
        # Create augmented directory
        augmented_dir = os.path.join(root_dir, 'augmented')
        os.makedirs(augmented_dir, exist_ok=True)
        
        # Parameters for augmentation
        angles = [45, 90, 270]
        scale_factors = [0.7, 1.4]
        shift_values = [(25, 0), (-25, 0), (0, 25), (0, -25)]
        flip_codes = [0, 1, -1]
        target_resolution = (768, 576)

        for dataset_dir in os.listdir(root_dir):
            dataset_path = os.path.join(root_dir, dataset_dir)
            if not os.path.isdir(dataset_path) or dataset_dir == 'augmented':
                continue

            images_dir = os.path.join(dataset_path, 'images')
            masks_dir = os.path.join(dataset_path, 'masks')
            
            if not (os.path.exists(images_dir) and os.path.exists(masks_dir)):
                continue

            # Create output directories for augmented data
            aug_dataset_dir = os.path.join(augmented_dir, dataset_dir)
            aug_images_dir = os.path.join(aug_dataset_dir, 'images')
            aug_masks_dir = os.path.join(aug_dataset_dir, 'masks')
            os.makedirs(aug_images_dir, exist_ok=True)
            os.makedirs(aug_masks_dir, exist_ok=True)

            # Copy original images and masks
            for img_name in os.listdir(images_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    # Load image and mask
                    img_path = os.path.join(images_dir, img_name)
                    mask_path = os.path.join(masks_dir, img_name)
                    
                    if not os.path.exists(mask_path):
                        continue
                        
                    image = cv2.imread(img_path)
                    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                    
                    # Save original
                    cv2.imwrite(os.path.join(aug_images_dir, img_name), image)
                    cv2.imwrite(os.path.join(aug_masks_dir, img_name), mask)
                    
                    # Rotations
                    for angle in angles:
                        rotated_img, rotated_mask = self.rotate_image_and_mask(
                            image, mask, angle, target_resolution)
                        cv2.imwrite(os.path.join(aug_images_dir, 
                            f"{os.path.splitext(img_name)[0]}_rot{angle}.jpg"), rotated_img)
                        cv2.imwrite(os.path.join(aug_masks_dir, 
                            f"{os.path.splitext(img_name)[0]}_rot{angle}.png"), rotated_mask)
                    
                    # Scaling
                    for scale in scale_factors:
                        scaled_img, scaled_mask = self.scale_image_and_mask(
                            image, mask, scale, target_resolution)
                        cv2.imwrite(os.path.join(aug_images_dir, 
                            f"{os.path.splitext(img_name)[0]}_scale{scale}.jpg"), scaled_img)
                        cv2.imwrite(os.path.join(aug_masks_dir, 
                            f"{os.path.splitext(img_name)[0]}_scale{scale}.png"), scaled_mask)
                    
                    # Shifts
                    for shift_x, shift_y in shift_values:
                        shifted_img, shifted_mask = self.shift_image_and_mask(
                            image, mask, shift_x, shift_y, target_resolution)
                        cv2.imwrite(os.path.join(aug_images_dir, 
                            f"{os.path.splitext(img_name)[0]}_shift{shift_x}_{shift_y}.jpg"), shifted_img)
                        cv2.imwrite(os.path.join(aug_masks_dir, 
                            f"{os.path.splitext(img_name)[0]}_shift{shift_x}_{shift_y}.png"), shifted_mask)
                    
                    # Flips
                    for flip_code in flip_codes:
                        flipped_img, flipped_mask = self.flip_image_and_mask(
                            image, mask, flip_code, target_resolution)
                        cv2.imwrite(os.path.join(aug_images_dir, 
                            f"{os.path.splitext(img_name)[0]}_flip{flip_code}.jpg"), flipped_img)
                        cv2.imwrite(os.path.join(aug_masks_dir, 
                            f"{os.path.splitext(img_name)[0]}_flip{flip_code}.png"), flipped_mask)

    @staticmethod
    def rotate_image_and_mask(image, mask, angle, target_resolution):
        # Implementation of rotate_image_and_mask function
        if angle == 90:
            rotated_image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
            rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        elif angle == 180:
            rotated_image = cv2.rotate(image, cv2.ROTATE_180)
            rotated_mask = cv2.rotate(mask, cv2.ROTATE_180)
        elif angle == 270:
            rotated_image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
            rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_COUNTERCLOCKWISE)
        elif angle == 45:
            (h, w) = image.shape[:2]
            diag_length = int(np.sqrt(h**2 + w**2))
            padding = (diag_length - h) // 2

            padded_image = cv2.copyMakeBorder(image, padding, padding, padding, padding, 
                                            borderType=cv2.BORDER_REPLICATE)
            padded_mask = cv2.copyMakeBorder(mask, padding, padding, padding, padding, 
                                           borderType=cv2.BORDER_REPLICATE)

            (ph, pw) = padded_image.shape[:2]
            center = (pw // 2, ph // 2)
            M = cv2.getRotationMatrix2D(center, angle, 1.0)
            rotated_image = cv2.warpAffine(padded_image, M, (pw, ph), 
                flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(255,255,255))
            rotated_mask = cv2.warpAffine(padded_mask, M, (pw, ph), 
                flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0)

            start_x, start_y = (pw - w) // 2, (ph - h) // 2
            rotated_image = rotated_image[start_y:start_y + h, start_x:start_x + w]
            rotated_mask = rotated_mask[start_y:start_y + h, start_x:start_x + w]
        else:
            rotated_image = image.copy()
            rotated_mask = mask.copy()

        resized_image = cv2.resize(rotated_image, target_resolution, interpolation=cv2.INTER_LINEAR)
        resized_mask = cv2.resize(rotated_mask, target_resolution, interpolation=cv2.INTER_NEAREST)
        
        return resized_image, resized_mask

    @staticmethod
    def scale_image_and_mask(image, mask, scale_factor, target_resolution):
        scaled_image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, 
                                interpolation=cv2.INTER_LINEAR)
        scaled_mask = cv2.resize(mask, None, fx=scale_factor, fy=scale_factor, 
                               interpolation=cv2.INTER_NEAREST)
        
        resized_image = cv2.resize(scaled_image, target_resolution, interpolation=cv2.INTER_LINEAR)
        resized_mask = cv2.resize(scaled_mask, target_resolution, interpolation=cv2.INTER_NEAREST)
        
        return resized_image, resized_mask

    @staticmethod
    def shift_image_and_mask(image, mask, shift_x, shift_y, target_resolution):
        M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
        shifted_image = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]), 
            borderMode=cv2.BORDER_CONSTANT, borderValue=(255,255,255))
        shifted_mask = cv2.warpAffine(mask, M, (mask.shape[1], mask.shape[0]), 
            borderMode=cv2.BORDER_CONSTANT, borderValue=0)
        
        resized_image = cv2.resize(shifted_image, target_resolution, interpolation=cv2.INTER_LINEAR)
        resized_mask = cv2.resize(shifted_mask, target_resolution, interpolation=cv2.INTER_NEAREST)
        
        return resized_image, resized_mask

    @staticmethod
    def flip_image_and_mask(image, mask, flip_code, target_resolution):
        flipped_image = cv2.flip(image, flip_code)
        flipped_mask = cv2.flip(mask, flip_code)
        
        resized_image = cv2.resize(flipped_image, target_resolution, interpolation=cv2.INTER_LINEAR)
        resized_mask = cv2.resize(flipped_mask, target_resolution, interpolation=cv2.INTER_NEAREST)
        
        return resized_image, resized_mask

    def get_dataloaders(self):
        """Split dataset and return DataLoaders for train, val, and test."""
        # Rest of the method remains unchanged
        dataset_size = len(self.dataset)
        train_size = int(dataset_size * self.split_ratios[0])
        val_size = int(dataset_size * self.split_ratios[1])
        test_size = dataset_size - train_size - val_size

        generator = torch.Generator().manual_seed(self.seed)

        if self.shuffle_data:
            train_dataset, val_dataset, test_dataset = random_split(
                self.dataset, [train_size, val_size, test_size], generator=generator)
        else:
            train_dataset = torch.utils.data.Subset(
                self.dataset, range(0, train_size))
            val_dataset = torch.utils.data.Subset(
                self.dataset, range(train_size, train_size + val_size))
            test_dataset = torch.utils.data.Subset(
                self.dataset, range(train_size + val_size, dataset_size))

        print("Creating DataLoaders...")
        train_loader = DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        val_loader = DataLoader(
            val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        test_loader = DataLoader(
            test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

        return train_loader, val_loader, test_loader

## MODEL CODE

### PATCH EMBEDDING

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size=256, patch_size=16, in_channels=3, embed_dim=1024):
        """
        Args:
            image_size (int): "The spatial size of the input image (assumed square, default is 256)."
            patch_size (int): "The size of each patch along both height and width dimensions (default is 16)."
            in_channels (int): "Number of input channels in the image (default is 3 for RGB images)."
            embed_dim (int): "The size of the embedding dimension for each patch (default is 1024)."
        """
        super(PatchEmbedding, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.num_patches_h = image_size // patch_size
        self.num_patches_w = image_size // patch_size
        # calculating total number of patches
        self.num_patches = self.num_patches_h * self.num_patches_w
        
        # Conv projection of the patches
        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        
        self.pos_embed = nn.Parameter(torch.randn(
            1, self.num_patches, embed_dim))  # Positional Embedding

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): "A 4D input tensor of shape (B, C, H, W), where:
                               B is batch size,
                               C is the number of channels,
                               H and W are the height and width of the image."
        Returns:
            torch.Tensor: "A tensor of shape (B, num_patches, embed_dim) containing the projected patch embeddings with positional embeddings added."
        """
        # Apply conv projection
        x = self.proj(x)
        
        # Reshape using einops: (B, embed_dim, H', W') -> (B, num_patches, embed_dim)
        x = einops.rearrange(x, 'b e h w -> b (h w) e')
        
        # Add positional embeddings
        x = x + self.pos_embed

        return x


## CUSTOM TRANSFORMER BLOCK

In [None]:
class CustomTransformerBlock(nn.Module):
    """
    Custom transformer encoder block with multi-head self-attention and a feed-forward network.

    Args:
        embed_dim (int): The embedding dimension of the input.
        num_heads (int): The number of attention heads.
        mlp_ratio (float): The ratio of the size of the feed-forward network to the embedding size.
        activation (str): The activation function to use in the feed-forward network.
        is_decoder(bool): Whether this block is part of a decoder.

    Methods:
        forward(src, src_mask=None, src_key_padding_mask=None):
            Args:
                src (torch.Tensor): The input tensor.
                src_mask (torch.Tensor, optional): The mask for the input tensor.
                src_key_padding_mask (torch.Tensor, optional): The key padding mask for the input tensor.

            Returns:
                torch.Tensor: The output tensor after self-attention and feed-forward pass.
    """

    def __init__(self, embed_dim, num_heads, mlp_ratio, activation, dropout, is_decoder=False):
        super(CustomTransformerBlock, self).__init__()

        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.is_decoder = is_decoder
        
        if is_decoder:
            self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
            self.norm3 = nn.LayerNorm(embed_dim)

        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            get_activation_function(activation),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)            
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, src, memory=None, src_mask=None, memory_mask=None,src_key_padding_mask=None, memory_key_padding_mask=None):
        """
        Forward pass for the encoder block.

        Args:
            src (torch.Tensor): The input tensor.
            src_mask (torch.Tensor, optional): The mask for the input tensor.
            src_key_padding_mask (torch.Tensor, optional): The key padding mask for the input tensor.

        Returns:
            torch.Tensor: The output tensor after self-attention and feed-forward pass.
        """
        
        src = self.norm1(src)  # Layer normalization
        
        attn_output, _ = self.self_attn(
            src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        
        src = src + self.dropout(attn_output)  # Add residual connection
        
        src = self.norm2(src)  # Layer normalization
        
        if self.is_decoder:
            cross_output, _ = self.cross_attn(src, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)
            src = src + self.dropout(cross_output)
            src = self.norm3(src)  # Layer normalization
            
        ffn_output = self.ffn(src)
        
        return ffn_output



## ENCODER

In [None]:
class Encoder(nn.Module):
    def __init__(self, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, activation='gelu', dropout=0.1):
        """
        Args:
            embed_dim (int): "Dimension of the input embedding for each patch (default is 1024)."
            depth (int): "Number of transformer encoder layers (default is 24)."
            num_heads (int): "Number of attention heads in the multi-head attention mechanism (default is 16)."
            mlp_ratio (int): "Multiplier for the feedforward network dimension in each layer (default is 4)."
            activation (str): "Activation function to use in the decoder layers (default is 'gelu')."
            Dropout (int): "Dropout ratio (0-1) for the drop out layer (default is 0.1)."
        """
        super(Encoder, self).__init__()

        self.layers = nn.ModuleList([
            CustomTransformerBlock(
                embed_dim, num_heads, mlp_ratio, activation, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): "A tensor of shape (B, num_patches, embed_dim) representing the input patch embeddings."
        Returns:
            torch.Tensor: "A normalized tensor of the same shape after processing through the transformer encoder layers."
        """

        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x



## RECONSTRUCTION

In [None]:
"""
# -*- coding: utf-8 -*-
-----------------------------------------------------------------------------------
# Author: Zakria Mehmood
# DoC: 2025.04.15
# email: zakriamehmood2001@gmail.com
-----------------------------------------------------------------------------------
# Description: All types of Reconstruction methods
"""


# linear mapping method

class ReconstructionMethod1(nn.Module):
    def __init__(self, decoder_dim, patch_size, image_size, output_channels):
        super(ReconstructionMethod1, self).__init__()
        self.decoder_dim = decoder_dim
        self.patch_size = patch_size
        self.image_size = image_size
        self.out_channels = output_channels
        # Total patches in the image
        self.num_patches = (image_size // patch_size) ** 2

        self.linear = nn.Linear(decoder_dim, patch_size * patch_size * output_channels, bias=True)

    def forward(self, x):

        batch_size, num_patches, _ = x.shape

        if hasattr(self, 'num_patches'):
            assert num_patches == self.num_patches, f"Expected {self.num_patches} patches but got {num_patches}"
        else:
            self.num_patches = num_patches
            print(
                f"Warning: num_patches was not initialized. Setting to {num_patches} based on input.")

        x = self.linear(x)
        
        x = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
                      h=int(self.num_patches**0.5), w=int(self.num_patches**0.5),
                      p1=self.patch_size, p2=self.patch_size, c=x.shape[-1] // (self.patch_size * self.patch_size))

        return x


# Convolutional Reconstruction
class ReconstructionMethod2(nn.Module):
    def __init__(self, decoder_dim, patch_size, output_channels):
        super().__init__()
        self.patch_size = patch_size

        self.conv_decoder = nn.Sequential(
            nn.Conv2d(decoder_dim, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, output_channels, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        B, N, C = x.shape
        h = w = int(N ** 0.5)
        x = x.transpose(1, 2).view(B, C, h, w)
        return self.conv_decoder(x)


# Progressive Upsampling Reconstruction
class ReconstructionMethod3(nn.Module):
    def __init__(self, decoder_dim, patch_size, output_channels):
        super().__init__()
        self.patch_size = patch_size

        self.decoder = nn.Sequential(
            nn.Linear(decoder_dim, patch_size * patch_size * 64),
            nn.ReLU(),
            nn.Unflatten(2, (64, patch_size, patch_size)),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, output_channels, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.decoder(x)


# Residual Reconstruction
class ReconstructionMethod4(nn.Module):
    def __init__(self, decoder_dim, patch_size, output_channels):
        super().__init__()

        class ResBlock(nn.Module):
            def __init__(self, channels):
                super().__init__()
                self.conv = nn.Sequential(
                    nn.Conv2d(channels, channels, 3, padding=1),
                    nn.BatchNorm2d(channels),
                    nn.ReLU(),
                    nn.Conv2d(channels, channels, 3, padding=1),
                    nn.BatchNorm2d(channels)
                )

            def forward(self, x):
                return x + self.conv(x)

        self.decoder = nn.Sequential(
            nn.Linear(decoder_dim, patch_size * patch_size * 64),
            nn.Unflatten(1, (64, patch_size, patch_size)),
            ResBlock(64),
            ResBlock(64),
            nn.Conv2d(64, output_channels, 1),
            nn.Tanh()
        )

# Attention-guided Reconstruction


class ReconstructionMethod5(nn.Module):
    def __init__(self, decoder_dim, patch_size, output_channels):
        super().__init__()

        self.query = nn.Linear(decoder_dim, decoder_dim)
        self.key = nn.Linear(decoder_dim, decoder_dim)
        self.value = nn.Linear(decoder_dim, decoder_dim)

        self.final = nn.Sequential(
            nn.Linear(decoder_dim, patch_size * patch_size * output_channels),
            nn.Unflatten(2, (output_channels, patch_size, patch_size))
        )

    def forward(self, x):
        B, N, C = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        attention = (q @ k.transpose(-2, -1)) / math.sqrt(C)
        attention = F.softmax(attention, dim=-1)

        x = attention @ v
        return self.final(x)


# Hybrid Reconstruction
class ReconstructionMethod6(nn.Module):
    def __init__(self, decoder_dim, patch_size, output_channels):
        super().__init__()

        self.transformer_part = nn.TransformerDecoderLayer(
            d_model=decoder_dim,
            nhead=8
        )

        self.cnn_part = nn.Sequential(
            nn.ConvTranspose2d(decoder_dim, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, output_channels, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.transformer_part(x, x)
        B, N, C = x.shape
        h = w = int(math.sqrt(N))
        x = x.transpose(1, 2).view(B, C, h, w)
        return self.cnn_part(x)


## DECODER

In [None]:

class Decoder(nn.Module):
    def __init__(self, embed_dim=1024, decoder_embed_dim=512, depth=8, num_heads=16, mlp_ratio=4, image_size=256, activation='gelu', patch_size=16, recon_method="method1", dropout=0.1, output_channels=3):
        """
        Args:
            embed_dim (int): "The embedding dimension of the encoder output (default is 1024)."
            decoder_embed_dim (int): "The embedding dimension for the decoder (default is 512)."
            depth (int): "Number of transformer decoder layers (default is 8)."
            num_heads (int): "Number of attention heads in the multi-head attention mechanism (default is 16)."
            mlp_ratio (int): "Multiplier for the feedforward network dimension in each decoder layer (default is 4)."
            image_size (int): "The size of the reconstructed image (default is 256)."
            activation (str): "Activation function to use in the decoder layers (default is 'gelu')."
            patch_size (int): "The size of each patch along both height and width dimensions (default is 16)."
            recon_method (str): "The selection of reconstruction method (default is 'method1'). 
                                 Use 'method4' for segmentation-optimized reconstruction."
            Dropout (int): "Dropout ratio (0-1) for the drop out layer (default is 0.1)."
        """
        super().__init__()

        # Projection from encoder embedding dimension to decoder embedding dimension
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)

        # Position embeddings for decoder
        self.decoder_pos_embed = nn.Parameter(torch.zeros(
            1, (image_size // patch_size) ** 2, decoder_embed_dim))

        # Decoder blocks
        self.decoder_blocks = nn.ModuleList([
            CustomTransformerBlock(
                embed_dim=decoder_embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                activation=activation,
                dropout=dropout,
                is_decoder=True
            )
            for _ in range(depth)
        ])
        
        # Normalization layer
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)

        # Store the reconstruction method name
        self.recon_method = recon_method

        # Select reconstruction method based on the argument
        if recon_method == "method1":
            self.reconstruction = ReconstructionMethod1(
                decoder_dim=decoder_embed_dim,
                patch_size=patch_size,
                image_size=image_size,
                output_channels=output_channels
            )
        elif recon_method == "method2":
            self.reconstruction = ReconstructionMethod2(
                decoder_dim=decoder_embed_dim,
                patch_size=patch_size,
                output_channels=output_channels
            )
        elif recon_method == "method3":
            self.reconstruction = ReconstructionMethod3(
                decoder_dim=decoder_embed_dim,
                patch_size=patch_size,
                output_channels=output_channels
            )
        elif recon_method == "method4":
            # Segmentation-optimized reconstruction method
            self.reconstruction = ReconstructionMethod4(
                decoder_dim=decoder_embed_dim,
                patch_size=patch_size,
                output_channels=output_channels
            )
        elif recon_method == "method5":
            self.reconstruction = ReconstructionMethod5(
                decoder_dim=decoder_embed_dim,
                patch_size=patch_size,
                output_channels=output_channels
            )
        elif recon_method == "method6":
            self.reconstruction = ReconstructionMethod6(
                decoder_dim=decoder_embed_dim,
                patch_size=patch_size,
                output_channels=output_channels
            )

        else:
            raise ValueError(
                f"Unknown reconstruction method: {recon_method}. Choose from 'method1', 'method2', 'method3', 'method4', 'method5' or 'method6'.")


    def forward(self, x, memory=None):
        # Project from encoder dimension to decoder dimension
        x = self.decoder_embed(x)

        # Add position embeddings
        x = x + self.decoder_pos_embed

        # Apply decoder blocks
        for block in self.decoder_blocks:
            x = block(x, memory) if memory is not None else block(x, x)

        # Apply normalization
        x = self.decoder_norm(x)

        # Apply reconstruction method to get the final output
        output = self.reconstruction(x)

        return output


## Model BASE CLASS

In [None]:
class VisionTransformerAutoencoder(nn.Module):
    def __init__(self, image_size=256, patch_size=16, in_channels=3, embed_dim=512, depth=12,
                 num_heads=8, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=8,
                 mlp_ratio=4, activation='gelu', recon_method="method3", dropout=0.1, output_channels=3):
        """
        
        Args:
            image_size (int): "The input image height and width, assumed to be square (default is 256)."
            patch_size (int): "The size of each patch extracted from the image (default is 16)."
            in_channels (int): "Number of input image channels (default is 3 for RGB)."
            embed_dim (int): "Embedding dimension for patch embeddings in the encoder (default is 512)."
            depth (int): "Number of transformer encoder layers (default is 12)."
            num_heads (int): "Number of attention heads for the encoder (default is 8)."
            decoder_embed_dim (int): "Embedding dimension for the decoder (default is 256)."
            decoder_depth (int): "Number of transformer decoder layers (default is 4)."
            decoder_num_heads (int): "Number of attention heads for the decoder (default is 8)."
            mlp_ratio (int): "Multiplier for the feedforward network dimension in transformer layers (default is 4)."
            activation (str): "Activation function to use in the decoder layers (default is 'gelu')."
            recon_method (str): "The selection of reconstruction method (default is 'method3')."
            Dropout (int): "Dropout ratio (0-1) for the drop out layer (default is 0.1)."
        """
        super().__init__()

        assert (image_size % patch_size == 0) and (image_size % patch_size == 0), \
            f"Image dimensions ({image_size}x{image_size}) must be divisible by patch size ({patch_size})"

        # Patch embedding
        self.patch_embed = PatchEmbedding(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )

        # Encoder
        self.encoder = Encoder(
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            activation=activation,
            dropout=dropout
        )

        # Decoder
        self.decoder = Decoder(
            embed_dim=embed_dim,
            decoder_embed_dim=decoder_embed_dim,
            depth=decoder_depth,
            num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio,
            image_size=image_size,
            activation=activation,
            patch_size=patch_size,
            recon_method=recon_method,
            dropout=dropout,
            output_channels=output_channels
        )

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): "Input image tensor of shape (B, C, H, W), where B is batch size, C is number of channels, H and W are image dimensions."

        Returns:
            torch.Tensor: "Reconstructed image tensor of shape (B, 3, image_size, image_size)."
        """
        # Patch embedding
        x = self.patch_embed(x)

        # Encoder
        x = self.encoder(x)

        # Decoder
        x = self.decoder(x)

        return x


## model utilities

In [None]:
def initialize_weights(model):
    """
    Initialize all model weights using Xavier uniform initialization.
    
    Args:
        model (nn.Module): The model to initialize
        
    Returns:
        model (nn.Module): The initialized model
    """
    for m in model.modules():
        if isinstance(m, nn.Linear):
            # Xavier uniform for all linear layers
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
                
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            # Standard initialization for normalization layers
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
            
        elif isinstance(m, nn.Conv2d):
            # Xavier uniform for all convolutional layers
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
                
        elif isinstance(m, nn.ConvTranspose2d):
            # Xavier uniform for transpose convolutions
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
                
        elif isinstance(m, nn.MultiheadAttention):
            # Xavier uniform for attention layers
            if hasattr(m, 'in_proj_weight') and m.in_proj_weight is not None:
                nn.init.xavier_uniform_(m.in_proj_weight)
            
            # Initialize out projection
            if hasattr(m, 'out_proj') and hasattr(m.out_proj, 'weight'):
                nn.init.xavier_uniform_(m.out_proj.weight)
                if m.out_proj.bias is not None:
                    nn.init.zeros_(m.out_proj.bias)
            
            # Initialize separate Q, K, V projections if they exist
            for weight_name in ['q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
                if hasattr(m, weight_name) and getattr(m, weight_name) is not None:
                    weight = getattr(m, weight_name)
                    nn.init.xavier_uniform_(weight)

    # Initialize positional embeddings if they exist
    if hasattr(model, 'patch_embed') and hasattr(model.patch_embed, 'pos_embed'):
        nn.init.xavier_uniform_(model.patch_embed.pos_embed)
    
    # Initialize decoder positional embeddings if they exist
    if hasattr(model, 'decoder') and hasattr(model.decoder, 'decoder_pos_embed'):
        nn.init.xavier_uniform_(model.decoder.decoder_pos_embed)

    return model


def verify_initialization(model):
    """
    Verify that the model weights are properly initialized using Xavier uniform initialization.
    
    Args:
        model: The model to verify
        
    Returns:
        bool: True if initialization is correct, False otherwise
    """
    initialization_ok = True
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'weight' in name and param.dim() >= 2:
                # Calculate expected Xavier bounds
                fan_in = param.size(1)
                fan_out = param.size(0)
                std = math.sqrt(2.0 / (fan_in + fan_out))
                bound = math.sqrt(3.0) * std
                
                # Check if weights are within Xavier bounds
                if torch.any(torch.abs(param) > bound):
                    print(f"Warning: {name} has weights outside Xavier bounds (-{bound:.4f}, {bound:.4f})")
                    initialization_ok = False
                
                # Check weight distribution
                actual_std = torch.std(param).item()
                expected_std = std
                if not (0.7 * expected_std <= actual_std <= 1.3 * expected_std):
                    print(f"Warning: {name} has unexpected standard deviation: {actual_std:.4f} (expected ≈ {expected_std:.4f})")
                    initialization_ok = False
                
                # Check for dead neurons
                if torch.any(torch.all(param == 0, dim=1)):
                    print(f"Warning: {name} has dead neurons (all-zero weights)")
                    initialization_ok = False
            
            # Check normalization layers
            elif any(x in name for x in ['norm', 'ln', 'batch_norm']) and 'weight' in name:
                if not torch.allclose(param, torch.ones_like(param), rtol=1e-3):
                    print(f"Warning: {name} normalization weights not initialized to ones")
                    initialization_ok = False
            
            # Verify bias initialization
            elif 'bias' in name:
                if not torch.allclose(param, torch.zeros_like(param), rtol=1e-3):
                    print(f"Warning: {name} bias not initialized to zeros")
                    initialization_ok = False
    
    # Verify positional embeddings
    if hasattr(model, 'patch_embed') and hasattr(model.patch_embed, 'pos_embed'):
        pos_embed = model.patch_embed.pos_embed
        fan_in = pos_embed.size(-1)
        fan_out = pos_embed.size(-2)
        std = math.sqrt(2.0 / (fan_in + fan_out))
        bound = math.sqrt(3.0) * std
        
        if torch.any(torch.abs(pos_embed) > bound):
            print(f"Warning: Positional embeddings outside Xavier bounds (-{bound:.4f}, {bound:.4f})")
            initialization_ok = False
    
    # Verify decoder positional embeddings
    if hasattr(model, 'decoder') and hasattr(model.decoder, 'decoder_pos_embed'):
        dec_pos_embed = model.decoder.decoder_pos_embed
        fan_in = dec_pos_embed.size(-1)
        fan_out = dec_pos_embed.size(-2)
        std = math.sqrt(2.0 / (fan_in + fan_out))
        bound = math.sqrt(3.0) * std
        
        if torch.any(torch.abs(dec_pos_embed) > bound):
            print(f"Warning: Decoder positional embeddings outside Xavier bounds (-{bound:.4f}, {bound:.4f})")
            initialization_ok = False
    
    if initialization_ok:
        print("All weights properly initialized with Xavier uniform distribution!")
        print(f"Note: Weights are bounded by their respective fan-in/fan-out values")
    
    return initialization_ok



def log_initialization(model, logger=None):
    """
    Log information about the initialized model.

    Args:
        model (nn.Module): The initialized model
        logger: Logger object to log information
    """
    # Count total parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel()
                           for p in model.parameters() if p.requires_grad)

    # Log initialization info
    if logger:
        logger.info(f"Model initialized with Xavier weights")
        logger.info(f"Total parameters: {total_params:,}")
        logger.info(f"Trainable parameters: {trainable_params:,}")

        # Log model structure
        logger.info("Model structure:")
        logger.info(
            f"- Patch Embedding: {model.patch_embed.__class__.__name__}")
        logger.info(
            f"- Encoder: {model.encoder.__class__.__name__} with {len(model.encoder.layers)} layers")
        logger.info(
            f"- Decoder: {model.decoder.__class__.__name__} with {len(model.decoder.decoder_blocks)} layers")
        if hasattr(model.decoder, 'reconstruction'):
            logger.info(
                f"- Reconstruction method: {model.decoder.reconstruction.__class__.__name__}")
    else:
        print(f"Model initialized with Xavier weights")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")





def create_model(configs, logger=None):
    """
    Create model based on architecture name and return the model and appropriate loss function.

    Args:
        configs: Configuration object with model parameters
        logger: Optional logger for logging information

    Returns:
        tuple: (model, loss_function) - The initialized model and appropriate loss function
    """


    network = MODEL_CONFIGS['model_base']
    # Check if the model is a vision transformer

    if logger is not None:
        logger.info('Using vision transformer...')
        logger.info('Model Network: {}'.format(network))

    # Get reconstruction method from config or use default
    recon_method = network.get('recon_method', 'method1')

    output_channels = 3  # Default output channels for RGB images

    model = VisionTransformerAutoencoder(configs.image_size,
                                            configs.patch_size,
                                            configs.in_channels,
                                            int(network['embed_dim']),
                                            int(network['depth']),
                                            int(network['num_heads']),
                                            int(network['decoder_embed_dim']),
                                            int(network['decoder_depth']),
                                            int(network['decoder_num_heads']),
                                            int(network['mlp_ratio']),
                                            network['activation'],
                                            recon_method,
                                            int(network['dropout']),
                                            output_channels = output_channels)

    # Initialize model weights using Xavier initialization
    model = initialize_weights(model)

    # Log initialization information
    log_initialization(model, logger)

    # Verify initialization
    verify_initialization(model)

    # Create appropriate loss function
    if configs.task_type == 'pretext':
        loss_fn = create_reconstruction_loss(recon_method)
    else:
        loss_fn = get_loss_function(configs.loss_function)

    if logger is not None:
        logger.info(f'Using reconstruction method: {recon_method}')
        logger.info(f'Using loss function: {configs.loss_function}')
        logger.info(f"Model architecture: {model.__class__.__name__}")
    
    logger.info(f"Total parameters: {get_num_parameters(model):,}")

    return model, loss_fn


def get_num_parameters(model):
    """Count number of trained parameters of the model"""
    if hasattr(model, 'module'):
        num_parameters = sum(p.numel()
                             for p in model.module.parameters() if p.requires_grad)
    else:
        num_parameters = sum(p.numel()
                             for p in model.parameters() if p.requires_grad)

    return num_parameters


def make_data_parallel(model, configs):
    if configs.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if configs.gpu_idx is not None:
            torch.cuda.set_device(configs.gpu_idx)
            model.cuda(configs.gpu_idx)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            configs.batch_size = int(
                configs.batch_size / configs.ngpus_per_node)
            configs.num_workers = int(
                (configs.num_workers + configs.ngpus_per_node - 1) / configs.ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[configs.gpu_idx])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif configs.gpu_idx is not None:
        torch.cuda.set_device(configs.gpu_idx)
        model = model.cuda(configs.gpu_idx)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    return model


def create_reconstruction_loss(recon_method="method1"):
    """
    Create a loss function suitable for the specified reconstruction method.

    Args:
        recon_method (str): The reconstruction method being used

    Returns:
        function: A loss function that takes model output and target as input
    """
    if recon_method == "method4":
        # For segmentation-optimized reconstruction (method4)
        def segmentation_pretraining_loss(output, target):
            # Unpack outputs if in training mode
            if isinstance(output, tuple):
                recon_img, semantic_features = output

                # Reconstruction loss (L1 loss for better edge preservation)
                recon_loss = nn.functional.l1_loss(recon_img, target)

                # Feature consistency loss (encourage similar features for similar regions)
                # This helps the model learn semantic representations useful for segmentation
                # We use a simple proxy by computing gradients in the target image
                target_gray = 0.299 * \
                    target[:, 0] + 0.587 * target[:, 1] + 0.114 * target[:, 2]
                target_gray = target_gray.unsqueeze(1)  # Add channel dimension

                # Compute gradients using Sobel filters
                sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
                                       dtype=torch.float32, device=target.device).view(1, 1, 3, 3)
                sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
                                       dtype=torch.float32, device=target.device).view(1, 1, 3, 3)

                # Apply filters
                grad_x = nn.functional.conv2d(target_gray, sobel_x, padding=1)
                grad_y = nn.functional.conv2d(target_gray, sobel_y, padding=1)
                grad_magnitude = torch.sqrt(grad_x**2 + grad_y**2)

                # Normalize gradient magnitude to [0, 1]
                grad_magnitude = grad_magnitude / \
                    (torch.max(grad_magnitude) + 1e-8)

                # Feature consistency loss - higher weight on edge regions
                edge_weight = 1.0 + 5.0 * grad_magnitude  # Higher weight on edges
                feature_loss = torch.mean(edge_weight * torch.abs(
                    nn.functional.normalize(semantic_features, dim=1) -
                    nn.functional.normalize(grad_magnitude, dim=1)
                ))

                # Total loss with weighting
                total_loss = recon_loss + 0.5 * feature_loss
                return total_loss
            else:
                # If not in training mode, just use L1 loss
                return nn.functional.l1_loss(output, target)

        return segmentation_pretraining_loss
    else:
        # For other reconstruction methods, use a combination of L1 and SSIM loss
        def standard_reconstruction_loss(output, target):
            # L1 loss for pixel-wise accuracy
            l1_loss = nn.functional.l1_loss(output, target)

            # MSE loss for overall image similarity
            mse_loss = nn.functional.mse_loss(output, target)

            # Combined loss
            return 0.8 * l1_loss + 0.2 * mse_loss

        return standard_reconstruction_loss


def transfer_to_segmentation_model(pretrained_model, num_classes, logger=None):
    """
    Transfer a pretrained VisionTransformerAutoencoder model to a segmentation model.

    Args:
        pretrained_model (nn.Module): The pretrained VisionTransformerAutoencoder model
        num_classes (int): Number of segmentation classes
        logger: Optional logger for logging information

    Returns:
        nn.Module: A segmentation model initialized with pretrained weights
    """
    # Create a new model that reuses the encoder and patch embedding
    class SegmentationModel(nn.Module):
        def __init__(self, pretrained_vit, num_classes):
            super().__init__()

            # Reuse patch embedding from pretrained model
            self.patch_embed = pretrained_vit.patch_embed

            # Reuse encoder from pretrained model
            self.encoder = pretrained_vit.encoder

            # Get dimensions from pretrained model
            self.embed_dim = next(pretrained_vit.encoder.parameters()).size(1)
            self.patch_size = pretrained_vit.patch_embed.patch_size
            self.image_size = pretrained_vit.patch_embed.image_size
            self.num_patches = pretrained_vit.patch_embed.num_patches

            # Create segmentation head
            self.segmentation_head = nn.Sequential(
                nn.Linear(self.embed_dim, 256),
                nn.GELU(),
                nn.Linear(256, num_classes * self.patch_size * self.patch_size)
            )

            # Final upsampling and refinement
            self.refine = nn.Sequential(
                nn.Conv2d(num_classes, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 32, kernel_size=3, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.Conv2d(32, num_classes, kernel_size=1)
            )

        def forward(self, x):
            # Get patches
            x = self.patch_embed(x)

            # Encode patches
            x = self.encoder(x)

            # Apply segmentation head
            batch_size = x.shape[0]
            # [B, num_patches, num_classes * patch_size^2]
            x = self.segmentation_head(x)

            # Reshape to image format
            patches_per_side = int(math.sqrt(self.num_patches))
            x = x.reshape(batch_size, patches_per_side, patches_per_side,
                          num_classes, self.patch_size, self.patch_size)
            x = x.permute(0, 3, 1, 4, 2, 5)
            x = x.reshape(batch_size, num_classes,
                          self.image_size, self.image_size)

            # Apply refinement
            x = self.refine(x)

            return x

    # Create the segmentation model
    seg_model = SegmentationModel(pretrained_model, num_classes)

    # Initialize new layers with Xavier initialization
    for m in seg_model.modules():
        if isinstance(m, nn.Linear) and m not in pretrained_model.modules():
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d) and m not in pretrained_model.modules():
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d) and m not in pretrained_model.modules():
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    if logger:
        logger.info(f"Created segmentation model with {num_classes} classes")
        logger.info(
            f"Transferred weights from pretrained VisionTransformerAutoencoder")

        # Count parameters
        total_params = sum(p.numel() for p in seg_model.parameters())
        trainable_params = sum(p.numel()
                               for p in seg_model.parameters() if p.requires_grad)
        logger.info(f"Total parameters: {total_params:,}")
        logger.info(f"Trainable parameters: {trainable_params:,}")

    return seg_model


def load_pretrained_for_segmentation(pretrained_path, num_classes, configs, logger=None):
    """
    Load a pretrained model and convert it to a segmentation model.

    Args:
        pretrained_path (str): Path to the pretrained model checkpoint
        num_classes (int): Number of segmentation classes
        configs: Configuration object with model parameters
        logger: Optional logger for logging information

    Returns:
        nn.Module: A segmentation model initialized with pretrained weights
    """
    if logger:
        logger.info(f"Loading pretrained model from {pretrained_path}")

    # Create the base model first
    base_model, _ = create_model(configs, logger)

    # Load the pretrained weights
    checkpoint = torch.load(pretrained_path, map_location='cpu')
    if 'model' in checkpoint:
        # If the checkpoint contains a 'model' key, use that
        base_model.load_state_dict(checkpoint['model'])
    else:
        # Otherwise assume the checkpoint is the model state dict directly
        base_model.load_state_dict(checkpoint)

    if logger:
        logger.info("Pretrained weights loaded successfully")

    # Transfer to segmentation model
    seg_model = transfer_to_segmentation_model(base_model, num_classes, logger)

    return seg_model


## Training

In [None]:
def train_one_epoch(train_dataloader, model, optimizer, lr_scheduler, epoch, configs, logger, tb_writer, device, scaler, loss_fn):

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    grad_norm = AverageMeter('GradNorm', ':.4e')

    progress = ProgressMeter(len(train_dataloader), [batch_time, data_time, losses, grad_norm],
                             prefix="Train - Epoch: [{}/{}]".format(epoch, configs.num_epochs))

    lr_scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    logger.info(
        f"Epoch {epoch+1}/{configs.num_epochs}: Learning rate = {current_lr:.8f}")

    num_iters_per_epoch = len(train_dataloader)

    # switch to train mode
    model.train()
    start_time = time.time()

    for batch_idx, (imgs, masks) in enumerate(train_dataloader):
        # Move data to the correct device
        imgs = imgs.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs = model(imgs)

        # Handle the case where outputs is a tuple (for method4)
        if configs.task_type == 'pretext':
            if isinstance(outputs, tuple):
                # Use the reconstructed image for visualization
                outputs = outputs[0]

        outputs = outputs.to(device)

        # Calculate loss
        if configs.task_type == 'pretext':
            total_loss = loss_fn(outputs, imgs)
        else:
            grayscale_outputs = 0.299 * outputs[:, 0, :, :] + 0.587 * outputs[:, 1, :, :] + 0.114 * outputs[:, 2, :, :]
            grayscale_outputs = grayscale_outputs.unsqueeze(1)  # shape: (batch_size, 1, H, W)
            masks = masks.unsqueeze(1)
            total_loss = loss_fn(grayscale_outputs, masks)
            #total_loss = loss_fn(outputs, masks)

        optimizer.zero_grad()

        if scaler:
            scaler.scale(total_loss).backward()

            # Unscale the gradients for gradient clipping
            scaler.unscale_(optimizer)

            # Compute gradient norm for monitoring
            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5
            grad_norm.update(total_norm)

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), configs.clip_grad_norm)

            scaler.step(optimizer)
            scaler.update()
        else:
            total_loss.backward()

            # Compute gradient norm for monitoring
            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5
            grad_norm.update(total_norm)

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), configs.clip_grad_norm)

            optimizer.step()

        #if batch_idx % configs.subdivisions:
        #    optimizer.zero_grad()

        reduced_loss = total_loss.data

        losses.update(to_python_float(reduced_loss), imgs.size(0))

        # measure elapsed time
        # torch.cuda.synchronize()
        batch_time.update(time.time() - start_time)

        if tb_writer is not None:
            tb_writer.add_scalar('avg_loss', losses.avg, batch_idx)
            tb_writer.add_scalar('Loss/train', losses.avg, epoch)
            tb_writer.add_scalar('GradNorm', grad_norm.avg, batch_idx)

        # Log message
        if logger is not None and batch_idx % 10 == 0:
            logger.info(progress.get_message(batch_idx))

        # More detailed logging every 50 batches
        if logger is not None and batch_idx % 50 == 0:
            logger.info(
                f"Epoch {epoch+1} - Batch {batch_idx}/{len(train_dataloader)} - "
                f"Loss: {total_loss:.4f}, Grad Norm: {grad_norm.val:.4f}, "
                f"LR: {optimizer.param_groups[0]['lr']:.8f}")

        start_time = time.time()

    # Visualize outputs every few epochs
    #if epoch % 5 == 0 and configs.task_type == 'pretext':
    #    visualize_reconstructions(
    #        model, train_dataloader, device, epoch, configs.results_dir)
    #elif configs.task_type == 'segmentation':
    #    visualize_segmentation(
    #        model, train_dataloader, device, epoch, configs.results_dir)



In [None]:
def main():
    try:
        configs = get_configs()  # Using our previously defined get_configs()

        # Create necessary directories
        try:
            os.makedirs(configs.checkpoints_dir, exist_ok=True)
            os.makedirs(configs.logs_dir, exist_ok=True)
            os.makedirs(configs.results_dir, exist_ok=True)
            os.makedirs(configs.model_dir, exist_ok=True)
        except Exception as e:
            print(f"Error creating directories: {e}")
            sys.exit(1)

        # Set seeds for reproducibility
        if configs.seed is not None:
            random.seed(configs.seed)
            np.random.seed(configs.seed)
            torch.manual_seed(configs.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        # Initialize logging
        try:
            logger = Logger(configs.logs_dir, configs.saved_fn)
            writer = SummaryWriter(log_dir=os.path.join(configs.logs_dir, 'tensorboard'))
        except Exception as e:
            print(f"Error initializing logger or TensorBoard writer: {e}")
            sys.exit(1)

        # Create model and loss function
        try:
            model, loss_fn = create_model(configs, logger)
        except Exception as e:
            logger.info(f"Error creating model: {e}")
            sys.exit(1)

        # Load pretrained weights if specified
        if configs.pretrained_path is not None:
            try:
                checkpoint = torch.load(configs.pretrained_path, map_location=configs.device)
                # Remove 'module.' from keys if present
                new_state_dict = {}
                for k, v in checkpoint.items():
                    if k.startswith('module.'):
                        k = k[7:]  # remove 'module.'
                    new_state_dict[k] = v
                # Load into your model
                model.load_state_dict(new_state_dict)
                #model.load_state_dict(torch.load(configs.pretrained_path))
                logger.info(f'Loaded pretrained model from {configs.pretrained_path}')
            except Exception as e:
                logger.info(f"Failed to load pretrained model from {configs.pretrained_path}. Error: {e}")
                sys.exit(1)


        try:
            model = make_data_parallel(model, configs)
        except Exception as e:
            logger.info(f"Error setting up data parallelism: {e}")
            sys.exit(1)

        # Move model to device and setup parallel processing
        try:
            device = torch.device(configs.device)
            model = model.to(device)
            summary_x = summary(model, input_size=(1, 3, 256, 256))  # (batch_size, channels, height, width)
            logger.info(f"Model moved to device: {device}")
            print(f"\n\nModel summary: {summary_x}\n\n")

        except Exception as e:
            logger.info(f"Error moving model to device: {e}")
            sys.exit(1)

        # Create optimizer and scheduler
        try:
            optimizer = create_optimizer(configs, model)
            lr_scheduler = create_lr_scheduler(optimizer, configs)
        except Exception as e:
            logger.info(f"Error creating optimizer or learning rate scheduler: {e}")
            sys.exit(1)

        # Load dataset
        try:
            dataset_loader = DatasetLoader(configs.image_size,
                                           configs.root_dir,
                                           configs.split_ratios,
                                           configs.seed,
                                           configs.shuffle_data,
                                           configs.batch_size,
                                           configs.num_workers,
                                           configs.task_type,)
            train_dataloader, val_dataloader, _ = dataset_loader.get_dataloaders()
        except Exception as e:
            logger.info(f"Error loading dataset: {e}")
            sys.exit(1)

        # Setup AMP if enabled
        scaler = GradScaler() if configs.use_amp else None

        # Training loop
        best_val_loss = float("inf")
        val_losses = []

        for epoch in range(configs.start_epoch, configs.num_epochs):
            try:
                # Clear GPU cache
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # Train for one epoch
                train_one_epoch(train_dataloader, model, optimizer,
                                lr_scheduler, epoch, configs, logger, writer,
                                device, scaler, loss_fn)

                # Validation phase
                model.eval()
                running_val_loss = 0.0

                with torch.no_grad():
                    for images, masks in tqdm(val_dataloader, desc=f"Validation Epoch {epoch}"):
                        images = images.to(device)
                        masks = masks.to(device)

                        with autocast() if scaler else torch.no_grad():
                            outputs = model(images)

                            if configs.task_type == 'pretext':
                                loss = loss_fn(outputs, images)
                            else:
                                grayscale_outputs = 0.299 * outputs[:, 0, :, :] + 0.587 * outputs[:, 1, :, :] + 0.114 * outputs[:, 2, :, :]
                                grayscale_outputs = grayscale_outputs.unsqueeze(1)  # shape: (batch_size, 1, H, W)
                                masks = masks.unsqueeze(1)
                                
                                loss = loss_fn(grayscale_outputs, masks)

#                            loss = loss_fn(outputs, images if configs.task_type == "pretext" else masks)


                        running_val_loss += loss.item()

                avg_val_loss = running_val_loss / len(val_dataloader)
                val_losses.append(avg_val_loss)
                writer.add_scalar('Loss/val', avg_val_loss, epoch)

                # Save best model
                if epoch % configs.save_checkpoint_freq == 0:
                    save_checkpoint(configs.checkpoints_dir,
                                    configs.saved_fn + '_best',
                                    model.state_dict(),
                                    {'optimizer': optimizer.state_dict(),
                                     'lr_scheduler': lr_scheduler.state_dict(),
                                     'epoch': epoch},
                                    epoch, model_type=configs.task_type)

                # Early stopping check
                if len(val_losses) > configs.patience:
                    if all(val_losses[-i-1] >= val_losses[-i-2]
                           for i in range(configs.patience)):
                        print(f"Early stopping triggered after {epoch + 1} epochs")
                        break

                # Log validation loss
                logger.info(
                    f"Validation Loss: {avg_val_loss:.4f} at epoch {epoch+1}/{configs.num_epochs}")
                

                if epoch % 5 == 0:
                    if configs.task_type == 'pretext':
                        visualize_reconstructions(model, train_dataloader,
                                                device, epoch, configs.results_dir)
                    elif configs.task_type == 'segmentation':
                        visualize_segmentation(model, train_dataloader,
                                            device, epoch, configs.results_dir)

            except Exception as e:
                logger.info(f"Error during training or validation at epoch {epoch}: {e}")
                cleanup()
                sys.exit(1)

        writer.close()

    except KeyboardInterrupt:
        print("Training interrupted by user. Cleaning up...")
        cleanup()
        sys.exit(0)
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        cleanup()
        sys.exit(1)

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        print("Training interrupted by user. Cleaning up...")
        cleanup()
        sys.exit(0)
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        cleanup()
        sys.exit(1)

## Testing

In [None]:

def main():
    
    configs = get_configs()  # Using our previously defined get_configs()

    # Create necessary directories
    try:
        os.makedirs(configs.logs_dir, exist_ok=True)
        os.makedirs(configs.results_dir, exist_ok=True)
    except Exception as e:
        print(f"Error creating directories: {e}")
        sys.exit(1)

    # Initialize logging
    try:
        logger = Logger(configs.logs_dir, configs.saved_fn)
    except Exception as e:
        print(f"Error initializing logger or TensorBoard writer: {e}")
        sys.exit(1)

    # log the configurations
    try:
        logger.info("Configurations:")
        for key, value in vars(configs).items():
            logger.info(f"{key}: {value}")
    except Exception as e:
        logger.info(f"Error logging configurations: {e}")
        sys.exit(1)

    # Create model and loss function
    try:
        model, loss_fn = create_model(configs, logger)
    except Exception as e:
        logger.info(f"Error creating model: {e}")
        sys.exit(1)

    # Load pretrained weights if specified
    if configs.pretrained_path is not None:
        try:
            checkpoint = torch.load(configs.pretrained_path, map_location=configs.device)
            # Remove 'module.' from keys if present
            new_state_dict = {}
            for k, v in checkpoint.items():
                if k.startswith('module.'):
                    k = k[7:]  # remove 'module.'
                new_state_dict[k] = v
            # Load into your model
            model.load_state_dict(new_state_dict)
            #model.load_state_dict(torch.load(configs.pretrained_path))
            logger.info(f'Loaded pretrained model from {configs.pretrained_path}')
        except Exception as e:
            logger.info(f"Failed to load pretrained model from {configs.pretrained_path}. Error: {e}")
            sys.exit(1)


    try:
        model = make_data_parallel(model, configs)
    except Exception as e:
        logger.info(f"Error setting up data parallelism: {e}")
        sys.exit(1)

    # Move model to device and setup parallel processing
    try:
        device = torch.device(configs.device)
        model = model.to(device)
        summary_x = summary(model, input_size=(1, 3, 256, 256))  # (batch_size, channels, height, width)
        logger.info(f"Model moved to device: {device}")
        print(f"\n\nModel summary: {summary_x}\n\n")

    except Exception as e:
        logger.info(f"Error moving model to device: {e}")
        sys.exit(1)

    # Load dataset
    try:
        dataset_loader = DatasetLoader(configs.image_size,
                                        configs.root_dir,
                                        configs.split_ratios,
                                        configs.seed,
                                        configs.shuffle_data,
                                        configs.batch_size,
                                        configs.num_workers,
                                        configs.task_type)
        _, _, test_dataloader = dataset_loader.get_dataloaders()
    except Exception as e:
        logger.info(f"Error loading dataset: {e}")
        sys.exit(1)

    # Setup AMP if enabled
    scaler = GradScaler() if configs.use_amp else None



    try:
        # Clear GPU cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Evaluation phase
        model.eval()
        
        # Initialize metrics storage
        all_metrics = []
        running_val_loss = 0.0
        total_samples = 0
        batch = 0
        
        # Lists to store all predictions and ground truth for overall metrics
        all_predictions = []
        all_masks = []
        
        with torch.no_grad():
            for images, masks in tqdm(test_dataloader, desc="Testing", unit="batch"):
                # Move data to device
                images = images.to(device)
                masks = masks.to(device)
                
                # Ensure masks have the correct shape
                if masks.ndim == 2:
                    masks = masks.unsqueeze(0)  # Add batch dimension
                if masks.ndim == 3:
                    masks = masks.unsqueeze(1)  # Add channel dimension
                
                
                # Forward pass
                outputs = model(images)
                
                # Store predictions and masks for overall metrics
                all_predictions.append(outputs.cpu())
                all_masks.append(masks.cpu())
                
                if configs.task_type == 'pretext':
                    visualize_reconstructions_2(outputs, images, configs.results_dir, batch)
                    loss = loss_fn(outputs, images)
                elif configs.task_type == 'segmentation':
                    visualize_segmentation_2(outputs, images, masks, batch, configs.results_dir)
                    
                    # Calculate metrics for current batch
                    batch_metrics = calculate_metrics(outputs, masks)
                    all_metrics.append(batch_metrics)
                    
                    # Calculate loss
                    grayscale_outputs = 0.299 * outputs[:, 0, :, :] + 0.587 * outputs[:, 1, :, :] + 0.114 * outputs[:, 2, :, :]
                    grayscale_outputs = grayscale_outputs.unsqueeze(1)
                    loss = loss_fn(grayscale_outputs, masks)
                
                running_val_loss += loss.item() * images.size(0)
                total_samples += images.size(0)
                batch += 1

        # Calculate average loss
        avg_val_loss = running_val_loss / total_samples
        
        # Calculate overall metrics
        if configs.task_type == 'segmentation':
            # Concatenate all predictions and masks
            all_predictions = torch.cat(all_predictions, dim=0)
            all_masks = torch.cat(all_masks, dim=0)
            
            # Calculate overall metrics
            overall_metrics = calculate_metrics(all_predictions, all_masks)
            
            # Calculate mean metrics across all batches
            mean_metrics = {
                'iou': np.mean([m['iou'] for m in all_metrics]),
                'dice': np.mean([m['dice'] for m in all_metrics]),
                'precision': np.mean([m['precision'] for m in all_metrics]),
                'recall': np.mean([m['recall'] for m in all_metrics]),
                'f1': np.mean([m['f1'] for m in all_metrics])
            }
            
            # Log results
            logger.info(f"\nTest Results:")
            logger.info(f"Average Loss: {avg_val_loss:.4f}")
            logger.info("\nBatch-wise Mean Metrics:")
            for metric, value in mean_metrics.items():
                logger.info(f"{metric.upper()}: {value:.4f}")
            
            logger.info("\nOverall Metrics:")
            for metric, value in overall_metrics.items():
                logger.info(f"{metric.upper()}: {value:.4f}")

    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        cleanup()
        sys.exit(1)


if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        print("Training interrupted by user. Cleaning up...")
        cleanup()
        sys.exit(0)
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        cleanup()
        sys.exit(1)