# Mpox

In [3]:
import numpy as np
import torchvision.models as tvm

import torch.nn as nn
import torch
import random
from tqdm import tqdm
from torch.utils.data import DataLoader

import os

import torch.optim as to
import matplotlib.pyplot as plt




from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset

%run model.ipynb



cuda


In [4]:
torch.cuda.empty_cache()

In [5]:
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score

In [6]:
import timm
print(timm.__version__)

1.0.12


In [7]:
def evaluate(y_true, y_pred):
    """
    Evaluates the performance of a classification model using precision, recall, and F1-score.

    Parameters:
    y_true (list or array-like): True labels.
    y_pred (list or array-like): Predicted labels by the model.

    Returns:
    tuple: Precision, Recall, and F1-Score, calculated using macro averaging.
    """

    # Calculate precision score using macro averaging (treats all classes equally)
    precision = precision_score(y_true, y_pred, average='macro')

    # Calculate recall score using macro averaging
    recall = recall_score(y_true, y_pred, average='macro')

    # Calculate F1-score using macro averaging
    f1 = f1_score(y_true, y_pred, average='macro')

    # Return the calculated metrics as a tuple
    return precision, recall, f1


In [8]:


def replace_context_modules(model, Module, dilation):
    """
    Replaces the context_module in specific blocks of a model's fourth stage with a custom module.

    Parameters:
    model: The model containing stages and blocks where replacements are to be made.
    Module: Custom module to replace the existing context_module.
    seed: Random seed for reproducibility in the custom module.
    """

    # Access the fourth stage (index 3) of the model
    stage = model.stages[3]

    # Loop through blocks 1 to 6 (indices 1 to 6) in the fourth stage
    for i in range(1, 7):
        block = stage.blocks[i]  # Access the current block

        # Extract the number of input channels from the original context_module
        in_channels = block.context_module.main.qkv.conv.in_channels

        # Extract the number of output channels from the original context_module
        out_channels = block.context_module.main.proj.conv.out_channels

        # Replace the existing context_module with a new custom module
        block.context_module = nn.Sequential(
            Module(in_channels, nn.ReLU, dilation=dilation)  # Initialize custom module with required parameters
        )





In [9]:

def change_classifier(model, model_name, num_classes=4, dropout=0.5,
                     neurons1=4096, neurons2=1024, neurons3=256, neurons4=512, n_layers=2):
    """
    Change the classifier head of various vision models

    Args:
        model: The base model to modify
        model_name: Name/type of the model to determine input features
        num_classes: Number of output classes
        dropout: Dropout rate
        neurons1-4: Number of neurons in each layer
        n_layers: Number of layers in classifier (1-4)
    """
    # Define input features based on model architecture
    input_features = {
        'resnet101.a1_in1k': 2048,
        'deit3_medium_patch16_224': 512,
        'coatnet_1_rw_224.sw_in1k': 768,
        'mobilenetv3_large_100.ra_in1k': 1280,
        'vit_base_patch16_224' : 768,
        'efficientvit_l1.r224_in1k': 3072
    }

    in_features = input_features.get(model_name, 3072)  # Default to 3072 if model not found

    # Create the classifier based on number of layers
    if n_layers == 1:
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, neurons1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons1, num_classes),
            nn.Sigmoid() if num_classes == 1 else nn.Softmax(dim=1)
        )

    elif n_layers == 2:
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, neurons1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons1, neurons2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons2, num_classes),
            nn.Sigmoid() if num_classes == 1 else nn.Softmax(dim=1)
        )

    elif n_layers == 3:
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, neurons1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons1, neurons2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons2, neurons3),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons3, num_classes),
            nn.Sigmoid() if num_classes == 1 else nn.Softmax(dim=1)
        )

    else:  # 4 layers
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, neurons1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons1, neurons2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons2, neurons3),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons3, neurons4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons4, num_classes),
            nn.Sigmoid() if num_classes == 1 else nn.Softmax(dim=1)
        )

    # Determine where to attach the classifier based on model type
    if hasattr(model, 'head'):
      if hasattr(model.head, 'classifier'):
        model.head.classifier = classifier

      elif hasattr(model.head, 'fc'):
        model.head.fc = classifier
      else:
         model.head = classifier
    elif hasattr(model, 'fc'):
      model.fc = classifier
    elif hasattr(model, 'classifier'):
      model.classifier = classifier



    else:
        raise AttributeError("Model structure not supported. Cannot find classifier or head attribute.")

    return model




In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cuda')

In [11]:
#

In [12]:
class_labels = {
    'monkey_pox' : [ 1 , 0 , 0 , 0],  # Label encoding for 'monkey_pox' class
    'normal' : [ 0 , 1 , 0 , 0],      # Label encoding for 'normal' class
    'chicken_pox' : [ 0 , 0 , 1 , 0], # Label encoding for 'chicken_pox' class
    'acne' : [ 0 , 0 , 0 , 1],        # Label encoding for 'acne' class
}

def acc_eval(outputs, labels, classwise=False):
    """
    Evaluate accuracy of model predictions.

    Args:
        outputs (torch.Tensor): Model predictions (one-hot encoded).
        labels (torch.Tensor): Ground truth labels (one-hot encoded).
        classwise (bool): If True, calculate and return class-wise accuracy.

    Returns:
        int: Total correct predictions if classwise=False.
        tuple: Total correct predictions and class-wise counts if classwise=True.
    """
    right = 0  # Counter for correct predictions
    class_rights = {class_name: 0 for class_name in class_labels.keys()}  # Initialize class-wise correct counters

    for j in range(outputs.shape[0]):  # Loop over all predictions in the batch
        max_value = torch.max(outputs[j])  # Get the maximum value in the current prediction
        outputs[j] = (outputs[j] == max_value).float()  # Convert to one-hot representation by retaining the max value index
        if torch.all(outputs[j].eq(labels[j])):  # Check if the prediction matches the ground truth
            right += 1  # Increment total correct counter
            if classwise:  # If classwise evaluation is required
                label_list = labels[j].detach().cpu().tolist()  # Convert label tensor to list for comparison
                for class_name, class_label in class_labels.items():  # Loop through each class label
                    if label_list == class_label:  # Check if the ground truth matches the current class label
                        class_rights[class_name] += 1  # Increment correct counter for the class
                        break  # Exit the loop once the class is identified
    if not classwise:  # Return total correct predictions if classwise=False
        return right
    else:  # Return total correct predictions and class-wise correct counts if classwise=True
        return (right, *class_rights.values())


In [13]:
def cal_total(labels):
    """
    Calculate the total count of labels for each class.

    Args:
        labels (torch.Tensor): Ground truth labels (one-hot encoded).

    Returns:
        tuple: Total counts for each class, in the order of class_labels keys.
    """
    # Initialize a dictionary to store the total count for each class
    class_totals = {class_name: 0 for class_name in class_labels.keys()}

    for j in range(labels.shape[0]):  # Loop over all labels in the batch
        label_list = labels[j].detach().cpu().tolist()  # Convert label tensor to a list
        for class_name, class_label in class_labels.items():  # Iterate through all class labels
            if label_list == class_label:  # Check if the label matches the current class
                class_totals[class_name] += 1  # Increment the count for the matched class
                break  # Exit the loop once the class is identified

    # Return the total counts for each class as a tuple
    return tuple(class_totals.values())


In [14]:
def norm(X_train, X_val, X_test):
    """
    Normalize training, validation, and test datasets using training set statistics.
    """
    meanx = X_train.mean()  # Calculate training set mean
    stdx = X_train.std()    # Calculate training set std

    # Normalize datasets using training set mean and std
    X_train = (X_train - meanx) / stdx
    X_valid = (X_val - meanx) / stdx
    X_test = (X_test - meanx) / stdx

    return X_train, X_valid, X_test


In [15]:
def test(model, training_loader, model_num):
    """
    Evaluate the model on the test dataset and compute accuracy, precision, recall, F1 score,
    and per-class accuracies.
    """
    right_total = 0  # Total correct predictions
    total = 0  # Total number of samples
    out = []  # List to store predicted labels
    lab = []  # List to store true labels
    class_totals = {class_name: 0 for class_name in class_labels.keys()}  # Per-class sample counts
    class_rights = {class_name: 0 for class_name in class_labels.keys()}  # Per-class correct predictions

    for i, data in enumerate(tqdm(training_loader)):  # Iterate over the training data
        inputs, labels = data  # Get inputs and labels
        total += inputs.shape[0]  # Update total samples

        outputs = model(inputs)  # Get model predictions

        # Get per-class totals and correct predictions
        class_totals_batch = cal_total(labels)
        right, *class_rights_batch = acc_eval(outputs, labels, classwise=True)

        right_total += right  # Update total correct predictions
        for i, class_name in enumerate(class_labels.keys()):  # Update per-class totals and rights
            class_totals[class_name] += class_totals_batch[i]
            class_rights[class_name] += class_rights_batch[i]

        # Convert outputs and labels to numpy arrays for evaluation
        outputs = np.array(outputs.detach().cpu(), dtype='object')
        labels = np.array(labels.detach().cpu(), dtype='object')

        out.extend(np.argmax(outputs, axis=1))  # Store predicted labels
        lab.extend(np.argmax(labels, axis=1))  # Store true labels

    accuracy = right_total / total  # Calculate overall accuracy
    class_accuracies = {class_name: class_rights[class_name] / class_totals[class_name]
                        if class_totals[class_name] > 0 else 0
                        for class_name in class_labels.keys()}  # Calculate per-class accuracies

    precision, recall, f1 = evaluate(out, lab)  # Evaluate precision, recall, and F1 score

    print("Accuracy:", accuracy)  # Print overall accuracy
    print("Total Right:", right_total)  # Print total correct predictions
    for class_name, class_accuracy in class_accuracies.items():  # Print per-class accuracy
        print(f"{class_name} Accuracy: {class_accuracy}")

    return (accuracy, precision, recall, f1, *class_accuracies.values())  # Return metrics


In [16]:
def train_one_epoch(model, epoch_index, model_num, training_loader, loss_fn, loss_fn1, w, optimizer, loss_dict_train):
    """
    Train the model for one epoch, computing losses and accuracies, and updating model weights.

    Args:
        model: The neural network model to train.
        epoch_index: The index of the current epoch.
        model_num: Identifier for the model (used for loss tracking).
        training_loader: DataLoader object for training dataset.
        loss_fn: Primary loss function used for training.
        loss_fn1: Secondary loss function used for training (combined with loss_fn).
        w: Weighting factor for combining loss_fn and loss_fn1.
        optimizer: Optimizer to update model parameters.
        loss_dict_train: Dictionary to track losses for the training process.

    Returns:
        last_loss: The average loss for the last batch in the epoch.
        overall_accuracy: The overall accuracy of the model on the training dataset.
        right_total: Total number of correct predictions in the epoch.
        class_accuracies: Per-class accuracy for each class in the dataset.
    """
    running_loss = 0.  # To track cumulative loss for the current epoch
    last_loss = 0.  # To store the average loss for the last batch
    right_total = 0  # Total correct predictions
    total = 0  # Total samples processed

    class_totals = {class_name: 0 for class_name in class_labels.keys()}  # Store count of each class in the batch
    class_rights = {class_name: 0 for class_name in class_labels.keys()}  # Store correct predictions for each class

    # Iterate through batches in the training set
    for i, data in enumerate(tqdm(training_loader)):
        inputs, labels = data  # Get input images and corresponding labels
        optimizer.zero_grad()  # Reset gradients to zero before backpropagation
        outputs = model(inputs)  # Get model predictions

        total += inputs.shape[0]  # Update the total number of samples processed
        # Compute the loss as a weighted combination of the two loss functions
        loss = (1-w) * loss_fn(outputs, labels) + w * loss_fn1(outputs, labels)
        loss.backward()  # Backpropagate the loss
        optimizer.step()  # Update the model weights using the optimizer

        # Get per-class totals and correct predictions
        class_totals_batch = cal_total(labels)
        right, *class_rights_batch = acc_eval(outputs, labels, classwise=True)

        right_total += right  # Update the total correct predictions
        # Update per-class totals and correct predictions
        for j, class_name in enumerate(class_labels.keys()):
            class_totals[class_name] += class_totals_batch[j]
            class_rights[class_name] += class_rights_batch[j]

        running_loss += loss.item()  # Add current batch loss to running total
        # Print the average loss for every 10 batches
        if i % 10 == 9:
            last_loss = running_loss / 10  # Calculate the average loss for the last 10 batches
            print(f'  batch {i + 1} loss: {last_loss}')
            running_loss = 0.  # Reset running loss for the next set of batches

    # Calculate per-class accuracies
    print("Class totals:", class_totals)
    class_accuracies = {}
    for class_name in class_labels.keys():
        if class_totals[class_name] == 0:
            class_accuracies[class_name] = None  # No data for this class
        else:
            class_accuracies[class_name] = class_rights[class_name] / class_totals[class_name]  # Calculate class accuracy

    # Calculate overall accuracy
    overall_accuracy = right_total / total if total > 0 else None

    # Return the last loss, overall accuracy, total correct predictions, and class-wise accuracies
    return (last_loss, overall_accuracy, right_total,
            *[class_accuracies[class_name] for class_name in class_labels.keys()])


In [17]:
import numpy as np

class CAGA(nn.Module):
    """
      Attributes:
          heads: Number of attention heads in the multi-head attention mechanism.
          dim: Dimensionality of Q, K, V.
          scale: Scaling factor for the attention computation.
          head_dim: The number of channels per attention head.
          dilation: List of dilation values for dilated convolutions.
          total_layer: Total number of layers used in the module.
          get_begin: Initial convolutional layer for feature extraction.
          get_qkv: List of convolutional layers to compute queries, keys, and values.
          convert_to_headdim: Layer to combine and reshape the outputs of all dilated convolutions.
          mix: Convolutional layers to refine the features.
          proj: Final projection layer to map the concatenated features to the input shape.
          norm: Batch normalization layer to normalize the output.
      """
    def __init__(self,
            in_channels,
            activation,
            heads = 3,
            dim = 8,
            expand_ratio = 4,
            head_dim = 16,
            dilation = (1,2),
            random_seed = 82
            ):

        # Set global random seeds for reproducibility
        self._set_global_seeds(random_seed)

        super(CAGA, self).__init__()
        self.heads = heads
        self.dim = dim
        scale = dim
        self.scale = dim ** -0.5
        self.head_dim = head_dim
        self.dilation = dilation

        self.total_layer = 4

        # Reproducible layer initialization
        self.get_begin = self._init_depthwise_separable_conv(
            DepthWiseSeperableConvLayer(in_channels, self.heads * self.head_dim)
        )

        self.get_qkv = nn.ModuleList([
            nn.Sequential(
                self._init_conv(nn.Conv2d(
                    self.head_dim,
                    self.head_dim,
                    3,
                    groups=self.head_dim,
                    dilation=di
                )),
                self._init_conv(nn.Conv2d(self.head_dim, 3 * self.dim, 1, groups=1))
            )
            for di in dilation
        ])

        self.convert_to_headdim = self._init_conv(
            nn.Conv2d(len(dilation) * self.dim, self.head_dim, 1)
        )

        self.mix = nn.Sequential(
            self._init_conv(nn.Conv2d(
                self.dim,
                self.dim * 3,
                1
            )),
            nn.ReLU()
        )

        self.proj = self._init_conv(
            nn.Conv2d(self.heads * self.dim * len(self.dilation), in_channels, 1)
        )

        # Deterministic BatchNorm
        self.norm = nn.BatchNorm2d(num_features=in_channels, affine=True)
        nn.init.constant_(self.norm.weight, 1)
        nn.init.constant_(self.norm.bias, 0)

    def _set_global_seeds(self, seed):
        """Set seeds for reproducibility across libraries."""
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)

    def _init_conv(self, conv_layer):
        """Initialize convolutional layer weights deterministically."""
        nn.init.xavier_uniform_(conv_layer.weight)
        if conv_layer.bias is not None:
            nn.init.constant_(conv_layer.bias, 0)
        return conv_layer

    def _init_depthwise_separable_conv(self, conv_layer):
        """Initialize depthwise separable convolution layers."""
        # Assuming DepthWiseSeperableConvLayer has similar structure to standard conv
        for m in conv_layer.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        return conv_layer
        # self.norm = nn.LayerNorm([8, 256, 14, 14])



    def attention(self,q,k,v , shape):

        B, C, H, W = shape



        q, k, v = q.float(), k.float(), v.float()


        q = q * self.scale
        att_map = q.transpose(-2, -1) @ k


        att_map = att_map.softmax(dim=-1)


        out = v @ att_map.transpose(-2, -1)

        out = out.view(B , -1 , H , W)



        return out

    def forward(self,x):
        B, C, H, W = x.shape



            # print(op)

        x_copy = x

        all_heads  = self.get_begin(x)

        multi_layer = all_heads.split([self.head_dim]*self.heads , dim=1)

        all_heads_after_op = [[]]*self.heads



        for j in range(self.heads):

            for op in self.get_qkv:

                all_heads_after_op[j].append(op(multi_layer[j]))


        all_final = []


        for i in range(self.heads):
            out_all = []
            for j in range(len(self.dilation ) ):



                q , k , v = all_heads_after_op[i][j].split([self.dim, self.dim, self.dim], dim=1)
                shape = q.shape
                q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
                out = self.attention(q , k , v , shape)



                # print(out.shape)
                shape_ahead = all_heads_after_op[i][j+1].shape[3]

                temp = torchvision.transforms.functional.resize(out , (shape_ahead,shape_ahead))

                out = F.interpolate(out, size=( H , W ), mode='bilinear')
                out = out.view(B, self.dim, H, W)


                if j+1 != len(self.dilation ):

                    temp = self.mix(temp).clone()
                    all_heads_after_op[i][j+1] = all_heads_after_op[i][j+1] + temp

                out_all.append(out)




            out_all_one = torch.cat(out_all, dim=1)
            all_final.append(out_all_one)
            if i+1 != self.heads:
                all_heads_after_op[i+1] += self.convert_to_headdim(out_all_one)




      # we need to billinear intterpolate before append
        all_concat = torch.cat(all_final, dim=1)


        x_final = self.proj(all_concat) + x_copy # try oncat later

        return self.norm (x_final)

In [18]:
import numpy as np

class CAGA_nocascading_in_CAA(nn.Module):
    """
    This for ablation study
      Attributes:
          heads: Number of attention heads in the multi-head attention mechanism.
          dim: Dimensionality of Q, K, V.
          scale: Scaling factor for the attention computation.
          head_dim: The number of channels per attention head.
          dilation: List of dilation values for dilated convolutions.
          total_layer: Total number of layers used in the module.
          get_begin: Initial convolutional layer for feature extraction.
          get_qkv: List of convolutional layers to compute queries, keys, and values.
          convert_to_headdim: Layer to combine and reshape the outputs of all dilated convolutions.
          mix: Convolutional layers to refine the features.
          proj: Final projection layer to map the concatenated features to the input shape.
          norm: Batch normalization layer to normalize the output.
      """
    def __init__(self,
            in_channels,
            activation,
            heads = 3,
            dim = 8,
            expand_ratio = 4,
            head_dim = 16,
            dilation = (1,2),
            random_seed = 82
            ):

        # Set global random seeds for reproducibility
        self._set_global_seeds(random_seed)

        super(CAGA_nocascading_in_CAA, self).__init__()
        self.heads = heads
        self.dim = dim
        scale = dim
        self.scale = dim ** -0.5
        self.head_dim = head_dim
        self.dilation = dilation

        self.total_layer = 4

        # Reproducible layer initialization
        self.get_begin = self._init_depthwise_separable_conv(
            DepthWiseSeperableConvLayer(in_channels, self.heads * self.head_dim)
        )

        self.get_qkv = nn.ModuleList([
            nn.Sequential(
                self._init_conv(nn.Conv2d(
                    self.head_dim,
                    self.head_dim,
                    3,
                    groups=self.head_dim,
                    dilation=di
                )),
                self._init_conv(nn.Conv2d(self.head_dim, 3 * self.dim, 1, groups=1))
            )
            for di in dilation
        ])

        self.convert_to_headdim = self._init_conv(
            nn.Conv2d(len(dilation) * self.dim, self.head_dim, 1)
        )

        self.mix = nn.Sequential(
            self._init_conv(nn.Conv2d(
                self.dim,
                self.dim * 3,
                1
            )),
            nn.ReLU()
        )

        self.proj = self._init_conv(
            nn.Conv2d(self.heads * self.dim * len(self.dilation), in_channels, 1)
        )

        # Deterministic BatchNorm
        self.norm = nn.BatchNorm2d(num_features=in_channels, affine=True)
        nn.init.constant_(self.norm.weight, 1)
        nn.init.constant_(self.norm.bias, 0)

    def _set_global_seeds(self, seed):
        """Set seeds for reproducibility across libraries."""
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)

    def _init_conv(self, conv_layer):
        """Initialize convolutional layer weights deterministically."""
        nn.init.xavier_uniform_(conv_layer.weight)
        if conv_layer.bias is not None:
            nn.init.constant_(conv_layer.bias, 0)
        return conv_layer

    def _init_depthwise_separable_conv(self, conv_layer):
        """Initialize depthwise separable convolution layers."""
        # Assuming DepthWiseSeperableConvLayer has similar structure to standard conv
        for m in conv_layer.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        return conv_layer
        # self.norm = nn.LayerNorm([8, 256, 14, 14])



    def attention(self,q,k,v , shape):

        B, C, H, W = shape



        q, k, v = q.float(), k.float(), v.float()


        q = q * self.scale
        att_map = q.transpose(-2, -1) @ k


        att_map = att_map.softmax(dim=-1)


        out = v @ att_map.transpose(-2, -1)

        out = out.view(B , -1 , H , W)



        return out

    def forward(self,x):
        B, C, H, W = x.shape



            # print(op)

        x_copy = x

        all_heads  = self.get_begin(x)

        multi_layer = all_heads.split([self.head_dim]*self.heads , dim=1)

        all_heads_after_op = [[]]*self.heads



        for j in range(self.heads):

            for op in self.get_qkv:

                all_heads_after_op[j].append(op(multi_layer[j]))


        all_final = []


        for i in range(self.heads):
            out_all = []
            for j in range(len(self.dilation ) ):



                q , k , v = all_heads_after_op[i][j].split([self.dim, self.dim, self.dim], dim=1)
                shape = q.shape
                q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
                out = self.attention(q , k , v , shape)



                # print(out.shape)
                shape_ahead = all_heads_after_op[i][j+1].shape[3]

                temp = torchvision.transforms.functional.resize(out , (shape_ahead,shape_ahead))

                out = F.interpolate(out, size=( H , W ), mode='bilinear')
                out = out.view(B, self.dim, H, W)


               # if j+1 != len(self.dilation ):

                #    temp = self.mix(temp).clone()
                 #   all_heads_after_op[i][j+1] = all_heads_after_op[i][j+1] + temp

                out_all.append(out)




            out_all_one = torch.cat(out_all, dim=1)
            all_final.append(out_all_one)
            if i+1 != self.heads:
                all_heads_after_op[i+1] += self.convert_to_headdim(out_all_one)




      # we need to billinear intterpolate before append
        all_concat = torch.cat(all_final, dim=1)


        x_final = self.proj(all_concat) + x_copy # try oncat later

        return self.norm (x_final)

In [19]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        # Initialize parameters for early stopping
        # patience: number of consecutive epochs to wait before stopping if no improvement
        # min_delta: minimum change in validation accuracy to qualify as an improvement
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0  # Counter to track consecutive epochs without improvement
        self.max_validation_acc = float('-inf')  # Highest validation accuracy observed so far
        self.val_acc_1 = 0  # Flag to track if a perfect accuracy threshold is reached
        self.perfect_val_count = 6  # Number of epochs to wait after perfect validation accuracy

    def early_stop(self, validation_acc):
        # Method to determine if early stopping criteria are met

        # If validation accuracy exceeds 0.97 for the first time
        if validation_acc > 0.97 and self.val_acc_1 == 0:
            self.counter = 0  # Reset counter
            self.val_acc_1 = 1  # Mark the perfect validation flag

        # If perfect validation accuracy is already flagged
        elif self.val_acc_1 > 0.97:
            self.counter += 1  # Increment counter
            # Stop if the perfect validation condition persists for the defined count
            if self.counter >= self.perfect_val_count:
                return (True, self.counter)

        else:
            # Update the maximum validation accuracy if the current is better
            if validation_acc > self.max_validation_acc:
                self.max_validation_acc = validation_acc
                self.counter = 0  # Reset counter
            # If validation accuracy drops below the threshold (with min_delta tolerance)
            elif validation_acc <= (self.max_validation_acc - self.min_delta):
                print("max_acc", self.max_validation_acc)  # Debug print statement
                self.counter += 1  # Increment counter
                # Stop if the patience limit is reached
                if self.counter >= self.patience:
                    return (True, self.counter)

        # Return False to indicate training can continue and the current counter value
        return (False, self.counter)


In [20]:
class FocalCELoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean', num_classes=4, device=None):
        """
        Focal Cross Entropy Loss

        Args:
            alpha (torch.Tensor, optional): Weight for each class. Must be of size C.
            gamma (float): Focusing parameter to reduce the relative loss for well-classified examples.
            reduction (str): Reduction method for the final loss ('none', 'mean', or 'sum').
            num_classes (int): Number of classes in the classification task.
            device (torch.device): Device on which to place the class weights and other tensors.
        """
        super(FocalCELoss, self).__init__()
        self.gamma = gamma  # Focusing parameter for down-weighting well-classified examples
        self.reduction = reduction  # Reduction method for the loss
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Handle class weights (alpha) if provided
        if alpha is not None:
            if isinstance(alpha, (list, tuple)):
                self.alpha = torch.tensor(alpha)  # Convert list/tuple to tensor
            else:
                self.alpha = alpha  # Use alpha directly if it is already a tensor
            assert len(self.alpha) == num_classes, "Alpha size must match the number of classes"
            # Move alpha tensor to the specified device
            self.alpha = self.alpha.to(self.device)
        else:
            self.alpha = None  # If alpha is not provided, no weighting is applied

    def forward(self, inputs, targets):
        """
        Forward pass to compute the Focal Cross Entropy Loss.

        Args:
            inputs: Tensor of shape (N, C), where N is the batch size and C is the number of classes.
            targets: Tensor of shape (N,), containing class indices in the range [0, C-1].
        """
        # Ensure inputs and targets are on the same device as the class weights (alpha)
        inputs = inputs.to(self.device)
        targets = targets.to(self.device)

        # Compute the standard cross-entropy loss without reduction
        ce_loss = F.cross_entropy(inputs, targets, weight=self.alpha, reduction='none')

        # Compute the probability of the predicted class (pt)
        pt = torch.exp(-ce_loss)  # pt is the probability of the true class

        # Compute the Focal Loss by scaling the cross-entropy loss with the modulating factor
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        # Apply the specified reduction method
        if self.reduction == 'mean':
            return torch.mean(focal_loss)  # Return the mean loss
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)  # Return the sum of the loss
        else:  # 'none'
            return focal_loss  # Return the loss for each sample without reduction


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import random

class CAA(nn.Module):
    """
    Cascaded Atrous Attention (CAA) Module for Ablation Study.

    """

    def __init__(self,
            in_channels,
            activation,
            heads=3,  # Number of attention heads, set to 3 for experimentation
            dim=8,  # Dimensionality of query, key, and value
            expand_ratio=4,
            head_dim=16,  # Dimensionality of each attention head
            dilation = (1, 2, None),
            random_seed=82  # Seed for reproducibility
            ):
        """
        Initialize the CAA module.

        Args:
            in_channels (int): Number of input channels.
            activation (nn.Module): Activation function.
            heads (int): Number of attention heads.
            dim (int): Dimensionality of query, key, and value vectors.
            expand_ratio (int): Ratio for expansion layers (not used here).
            head_dim (int): Dimensionality of each head.
            random_seed (int): Seed for reproducibility across runs.
        """

        # Set global random seeds for reproducibility
        self._set_global_seeds(random_seed)

        super(CAA, self).__init__()
        self.heads = heads
        self.dim = dim
        self.scale = dim ** -0.5
        self.head_dim = head_dim

        # Dilation rates for each head (first two heads use dilation, last does not)
        self.dilations = dilation
        self.total_layer = 3  # Total number of layers

        # Initial depthwise separable convolution
        self.get_begin = self._init_depthwise_separable_conv(
            DepthWiseSeperableConvLayer(in_channels, self.heads * self.head_dim)
        )

        # Query, Key, Value layers with selective dilation
        self.get_qkv = nn.ModuleList([
            nn.Sequential(
                self._init_conv(nn.Conv2d(
                    self.head_dim,
                    self.head_dim,
                    3,
                    groups=self.head_dim,
                    dilation=self.dilations[i] if self.dilations[i] is not None else 1
                )),
                self._init_conv(nn.Conv2d(self.head_dim, 3 * self.dim, 1, groups=1))
            )
            for i in range(self.heads)
        ])

        # Projections between heads
        self.head_projections = nn.ModuleList([
            self._init_conv(nn.Conv2d(self.dim, self.head_dim, 1))
            for _ in range(self.heads - 1)
        ])

        # Final projection layer
        self.proj = self._init_conv(
            nn.Conv2d(self.heads * self.dim, in_channels, 1)
        )

        # Batch normalization for stability
        self.norm = nn.BatchNorm2d(num_features=in_channels, affine=True)
        nn.init.constant_(self.norm.weight, 1)
        nn.init.constant_(self.norm.bias, 0)

    def _set_global_seeds(self, seed):
        """Set seeds for reproducibility across libraries."""
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)

    def _init_conv(self, conv_layer):
        """Initialize convolutional layer weights deterministically."""
        nn.init.xavier_uniform_(conv_layer.weight)
        if conv_layer.bias is not None:
            nn.init.constant_(conv_layer.bias, 0)
        return conv_layer

    def _init_depthwise_separable_conv(self, conv_layer):
        """Initialize depthwise separable convolution layers."""
        for m in conv_layer.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        return conv_layer

    def attention(self, q, k, v, shape):
        """
        Perform attention calculation.

        Args:
            q: Query tensor.
            k: Key tensor.
            v: Value tensor.
            shape: Shape of the input tensor.

        Returns:
            Attention output tensor.
        """
        B, C, H, W = shape

        # Ensure tensors are float for computation
        q, k, v = q.float(), k.float(), v.float()

        # Scale query
        q = q * self.scale

        # Compute attention map
        att_map = q.transpose(-2, -1) @ k
        att_map = att_map.softmax(dim=-1)

        # Compute output
        out = v @ att_map.transpose(-2, -1)
        out = out.view(B, -1, H, W)

        return out

    def forward(self, x):
        """
        Forward pass of the CAA module.

        Args:
            x: Input tensor of shape (B, C, H, W).

        Returns:
            Output tensor after applying cascaded atrous attention.
        """
        B, C, H, W = x.shape
        x_copy = x

        # Initial head processing
        all_heads = self.get_begin(x)
        multi_layer = all_heads.split([self.head_dim] * self.heads, dim=1)

        # Store outputs for each head
        all_final = []

        # Process each head
        for i in range(self.heads):
            if i == 0:
                # First head
                head_ops = self.get_qkv[i](multi_layer[i])
                q, k, v = head_ops.split([self.dim, self.dim, self.dim], dim=1)
                shape = q.shape
                q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
                head_out = self.attention(q, k, v, shape)
                head_out = F.interpolate(head_out, size=(H, W), mode='bilinear')
                head_out = head_out.view(B, self.dim, H, W)
                all_final.append(head_out)
            elif i == self.heads - 1:
                # Last head
                prev_head_proj = self.head_projections[i - 1](all_final[-1])
                combined_input = multi_layer[i] + prev_head_proj

                head_ops = self.get_qkv[i](combined_input)
                q, k, v = head_ops.split([self.dim, self.dim, self.dim], dim=1)
                shape = q.shape
                q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
                head_out = self.attention(q, k, v, shape)
                head_out = F.interpolate(head_out, size=(H, W), mode='bilinear')
                head_out = head_out.view(B, self.dim, H, W)
                all_final.append(head_out)
            else:
                # Intermediate heads
                prev_head_proj = self.head_projections[i - 1](all_final[-1])
                combined_input = multi_layer[i] + prev_head_proj

                head_ops = self.get_qkv[i](combined_input)
                q, k, v = head_ops.split([self.dim, self.dim, self.dim], dim=1)
                shape = q.shape
                q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
                head_out = self.attention(q, k, v, shape)
                head_out = F.interpolate(head_out, size=(H, W), mode='bilinear')
                head_out = head_out.view(B, self.dim, H, W)
                all_final.append(head_out)

        # Concatenate all head outputs
        all_concat = torch.cat(all_final, dim=1)

        # Project and add residual connection
        x_final = self.proj(all_concat) + x_copy

        return self.norm(x_final)


In [23]:
def analyze_model(model, input_size):
    """
    Comprehensive model analysis function

    Parameters:
    - model (nn.Module): PyTorch model to analyze
    - input_size (tuple): Input tensor size (batch_size, channels, height, width)

    Returns:
    - dict: Comprehensive model analysis metrics
    """
    # Prepare input tensor
    device = next(model.parameters()).device
    input_tensor = torch.randn(input_size).to(device)

    # 1. Parameter Count
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # 2. FLOPs and MACs Calculation
    try:
        flops, params = thop.profile(model, inputs=(input_tensor,), verbose=False)



        return {
            "Total Parameters": total_params,
            "Trainable Parameters": trainable_params,
            "FLOPs": flops,
            "MACs": params,

        }

    except Exception as e:
        print(f"Error in analysis: {e}")
        return None



In [24]:
import os
import torch
import numpy as np
from datetime import datetime
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import KFold
import csv
import timm
import random
import thop

def set_seed(seed):
    """Set random seed for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def training(model_name , model_config_name, hyperparams, Module = 'CAGA', pre_trained = True,run = 1 ,EPOCH=35, n_splits=10, max_attempts=8):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    # Create a directory for saving CSV files
    csv_dir = f'/results/run_{run}'
    os.makedirs(csv_dir, exist_ok=True)

    models_to_train = [
        (model_name, None)
    ]

    # Create two CSV files - one for successful runs and one for failed attempts
    main_csv_path = os.path.join(csv_dir, f'results_{timestamp}_{model_name}.csv')

    
    # Initialize main results CSV
    header = ['Model', 'Seed', 'Fold', 'Accuracy', 'Precision', 'Recall', 'F1'] + \
         [f'{class_name} Accuracy' for class_name in class_labels.keys()] + \
         list(hyperparams.keys())

    with open(main_csv_path, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(header)

    # Load data
    X_data = torch.Tensor(np.load('/content/drive/MyDrive/X_train_final_multi_10_folds_40_each_equal.npy', allow_pickle=True))
    y_data = torch.Tensor(np.load('/content/drive/MyDrive/y_train_final_multi_10_folds_40_each__equal.npy', allow_pickle=True))

    # Reshape data for KFold
    X_reshaped = X_data.reshape(X_data.shape[0] * X_data.shape[1], *X_data.shape[2:])
    y_reshaped = y_data.reshape(y_data.shape[0] * y_data.shape[1], *y_data.shape[2:])

    for model_name, custom_head in models_to_train:
        accuracy_dict_train = {}
        accuracy_dict_val = {}
        loss_dict_train = {}
        loss_dict_val = {}
        accu = []

        accuracy_avg = 0
        precision_avg = 0
        recall_avg = 0
        f1_avg = 0

        class_total_test_acc = {class_name: 0 for class_name in class_labels.keys()}

        epoch_number = 0



        # Seed control for entire cross-validation

        current_seed = 59


        # Set seed for this attempt
        set_seed(current_seed)

        # Initialize KFold
        kf = KFold(n_splits=n_splits, shuffle=True, random_state=current_seed)


            # Iterate through KFold splits
        for fold, (train_val_index, test_index) in enumerate(kf.split(X_reshaped )):
            print(f"\nStarting Fold {fold + 1}/{n_splits}")

            # if fold<=5:
            #   continue

            # Reset epoch counter for each fold
            epoch_number = 0
            print(fold)
# Split the entire dataset into train_val and test
            X_train_val, X_test = X_reshaped[train_val_index], X_reshaped[test_index]
            y_train_val, y_test = y_reshaped[train_val_index], y_reshaped[test_index]

            # Create a validation split within the training data
            val_size = int(len(X_train_val) * 0.2)  # 20% validation
            X_train = X_train_val[:-val_size]
            y_train = y_train_val[:-val_size]
            X_val = X_train_val[-val_size:]
            y_val = y_train_val[-val_size:]



            # Normalization (you may need to adjust this based on your existing norm function)
            X_train, X_val, X_test = norm(X_train, X_val, X_test)

            print("Xtrain" , X_train.shape)
            print("Xval" , X_val.shape)
            print("Xtest" , X_test.shape)

            # Create DataLoaders
            training_loader = DataLoader(
                TensorDataset(torch.Tensor(X_train).to(device), torch.Tensor(y_train).to(device)),
                batch_size=16,
                shuffle=True
            )
            validation_loader = DataLoader(
                TensorDataset(torch.Tensor(X_val).to(device), torch.Tensor(y_val).to(device)),
                batch_size=16,
                shuffle=True
            )
            test_loader = DataLoader(
                TensorDataset(torch.Tensor(X_test).to(device), torch.Tensor(y_test).to(device)),
                batch_size=8,
                shuffle=False
            )

            # Model setup (keep your existing model creation code)
            
            model = timm.create_model(model_config_name, pretrained= pre_trained) #for ablation turn pre-trained to False
            if Module == 'CAGA':
                replace_context_modules(model, CAGA, dilation = hyperparams['dilation']) # comment this for all other models apart for efficientViT-CAGA
            elif Module == 'CAA':
                replace_context_modules(model, CAGA, dilation = hyperparams['dilation'])
            elif Module == 'CAGA_nocascading_in_CAA':
                replace_context_modules(model, CAGA_nocascading_in_CAA, dilation = hyperparams['dilation'])
            change_classifier(model,model_config_name, dropout=0.111975, neurons1=hyperparams['neurons1'], neurons2=hyperparams['neurons2'], neurons3=768, neurons4=512, n_layers= hyperparams['n_layers'])
            model.to(device)

            print(model)
            # input_size = (1, 3, 224, 224)  # Batch size, Channels, Height, Width
            # analysis_results = analyze_model(model, input_size)
            # if analysis_results:
            #   print("Model Analysis Results:")
            #   for key, value in analysis_results.items():
            #       print(f"{key}: {value}")



            alpha = torch.Tensor([0.1412, 0.0992, 0.1409, 0.0872]).to(device)
            loss_fn = torch.nn.BCELoss()
            if hyperparams['loss']=='BCE':
              loss_fn1 = torch.nn.BCELoss()
            elif hyperparams['loss']=='CE':
            #loss_fn1 = FocalCELoss(alpha = alpha)
              loss_fn1 = torch.nn.CrossEntropyLoss()
            elif hyperparams['loss']=='FCE':
              loss_fn1 = FocalCELoss(alpha = alpha)
            w = 1
            optimizer = torch.optim.AdamW(params=model.parameters(), lr=hyperparams['lr'], weight_decay = hyperparams['weight_decay'])
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)
            early_stopper = EarlyStopper(patience=9, min_delta=0)

            print(f'Model {fold}:')

            for epoch in range(EPOCH):
                print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
                print(f'EPOCH {epoch_number + 1}:')

                model.train(True)
                avg_loss, train_accuracy, right, *class_accuracies_train = train_one_epoch(model, epoch_number, fold, training_loader, loss_fn, loss_fn1, w, optimizer, loss_dict_train)

                print(f"Average loss: {avg_loss}")

                running_vloss = 0.0
                model.eval()
                vright_total = 0
                total_val = 0

                class_totals_val = {class_name: 0 for class_name in class_labels.keys()}
                class_rights_val = {class_name: 0 for class_name in class_labels.keys()}

                with torch.no_grad():
                    for i, vdata in enumerate(validation_loader):
                        vinputs, vlabels = vdata
                        total_val += vinputs.shape[0]

                        voutputs = model(vinputs)
                        vloss = (1-w) * loss_fn(voutputs, vlabels) + w * loss_fn1(voutputs, vlabels)

                        class_totals_batch = cal_total(vlabels)
                        vright, *class_rights_batch = acc_eval(voutputs, vlabels, classwise=True)

                        vright_total += vright

                        for j, class_name in enumerate(class_labels.keys()):
                            class_totals_val[class_name] += class_totals_batch[j]
                            class_rights_val[class_name] += class_rights_batch[j]

                        running_vloss += vloss

                avg_vloss = running_vloss / (i + 1)
                print(f"Validation loss: {avg_vloss}")

                class_accuracies_val = {}
                for class_name in class_labels.keys():
                    if class_totals_val[class_name] == 0:
                        class_accuracies_val[class_name] = None
                    else:
                        class_accuracies_val[class_name] = class_rights_val[class_name] / class_totals_val[class_name]

                scheduler.step()

                val_acc = vright_total / total_val

                print(f"Validation accuracy: {val_acc}")
                print(f'LOSS train {avg_loss} valid {avg_vloss}')
                print(f'Right train {right} valid {vright_total}')
                print(f'Accuracy train {train_accuracy} valid {val_acc}')

                print('---------->classwise<-----------')
                for class_name in class_labels.keys():
                    print(f'Accuracy {class_name} train {class_accuracies_train[list(class_labels.keys()).index(class_name)]} valid {class_accuracies_val[class_name]}')

                epoch_number += 1
                early = early_stopper.early_stop(val_acc)
                print("Current early_stop count", early[1])
                if early[0]:
                    print('Early stopping triggered')
                    break


            # After testing, check accuracy
            with torch.no_grad():
              acc, precision, recall, f1, *class_accuracies_test = test(model, test_loader, fold)


              print('##########################################################')
              print(f'Accuracy test {acc}')
              print(f'Precision test {precision}')
              print(f'Recall test {recall}')
              print(f'F1 test {f1}')

              print('---------->classwise test<-----------')
              for i, class_name in enumerate(class_labels.keys()):
                  print(f'Accuracy {class_name} test {class_accuracies_test[i]}')




              # Save results to CSV (modify to include seed)
              # Save results to CSV (including hyperparameters)
              row = [model_name, current_seed, fold, acc, precision, recall, f1] + \
                    class_accuracies_test + \
                    list(hyperparams.values())

              with open(main_csv_path, 'a', newline='') as csvfile:
                  csvwriter = csv.writer(csvfile)
                  csvwriter.writerow(row)

              # Accumulate results
              accuracy_avg += acc
              precision_avg += precision
              recall_avg += recall
              f1_avg += f1

              for i, class_name in enumerate(class_labels.keys()):
                  class_total_test_acc[class_name] += class_accuracies_test[i]

        # If we've made it through all folds, break the seed attempt loop
        print(f"\n=== Successfully completed all folds with seed {current_seed} ===")





        # Calculate and log final averages
        accuracy_avg /= n_splits
        precision_avg /= n_splits
        recall_avg /= n_splits
        f1_avg /= n_splits

        for class_name in class_labels.keys():
            class_total_test_acc[class_name] /= n_splits

        print(f'Average results for {model_name}:')
        print(f'Average Accuracy test {accuracy_avg}')
        print(f'Average precision test {precision_avg}')
        print(f'Average recall test {recall_avg}')
        print(f'Average f1 test {f1_avg}')

        print('---------->classwise test Average<-----------')
        for class_name, avg_acc in class_total_test_acc.items():
            print(f'Average {class_name} Accuracy test {avg_acc}')

        # Save average results to CSV
        # Save average results with hyperparameters
        avg_row = [model_name,current_seed,'Average', accuracy_avg, precision_avg, recall_avg, f1_avg] + \
                  list(class_total_test_acc.values()) + \
                  list(hyperparams.values())

        with open(main_csv_path, 'a', newline='') as csvfile:
            csvwriter = csv.writer(csvfile)
            csvwriter.writerow(avg_row)
