# Mpox

In [4]:
import numpy as np
import torchvision.models as tvm

import torch.nn as nn
import torch
import random
from tqdm import tqdm
from torch.utils.data import DataLoader

import os

import torch.optim as to
import matplotlib.pyplot as plt




from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset

%run model.ipynb



cuda


In [5]:
torch.cuda.empty_cache()

In [6]:
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score

In [7]:
import timm
print(timm.__version__)

1.0.12


In [9]:
import os
from typing import Tuple, List
import numpy as np
import torch
from torch import Tensor
from sklearn.model_selection import train_test_split
from PIL import Image
import torchvision.transforms as transforms

class MedicalImagePreprocessor:
    """
    Preprocessor for medical image datasets with stratified train-val-test splitting.
    Generates one-hot encoded labels.
    """

    def __init__(
        self,
        mpox_dir: str,
        others_dir: str,
        image_size: Tuple[int, int] = (224, 224),
        random_seed: int = 42
    ):
        """
        Initialize preprocessor with mpox and others image directories.

        Args:
            mpox_dir (str): Path to mpox image directory
            others_dir (str): Path to other images directory
            image_size (Tuple[int, int], optional): Resize dimensions. Defaults to (224, 224)
        """
        self.mpox_dir = mpox_dir
        self.others_dir = others_dir
        self.image_size = image_size
        self.random_seed = random_seed

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # Standard ImageNet normalization
                std=[0.229, 0.224, 0.225]
            )
        ])

    def _collect_images(self) -> Tuple[List[str], List[int]]:
        """
        Collect image paths and corresponding labels from both directories.

        Returns:
            Tuple of image paths and their labels
        """
        image_paths = []
        labels = []

        # Collect mpox images (label 1)
        for filename in os.listdir(self.mpox_dir):
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(self.mpox_dir, filename))
                labels.append(1)  # Mpox positive

        # Collect other images (label 0)
        for filename in os.listdir(self.others_dir):
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(self.others_dir, filename))
                labels.append(0)  # Mpox negative

        return image_paths, labels

    def _load_image(self, image_path: str) -> Tensor:
        """
        Load and preprocess a single image.

        Args:
            image_path (str): Full path to image file

        Returns:
            Tensor: Preprocessed image tensor
        """
        with Image.open(image_path).convert('RGB') as img:
            return self.transform(img)

    def prepare_dataset(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
        """
        Prepare dataset with stratified 70/20/10 train/val/test split.

        Returns:
            Tuple of PyTorch tensors: (X_train, y_train, X_val, y_val, X_test, y_test)
        """
        # Collect image paths and labels
        image_paths, labels = self._collect_images()

        # Validate dataset
        if not image_paths:
            raise ValueError("No valid images found in the specified directories.")

        # Numpy conversion for stratified splitting
        image_paths = np.array(image_paths)
        labels = np.array(labels)

        # Stratified splits
        X_train, X_temp, y_train, y_temp = train_test_split(
            image_paths, labels,
            test_size=0.3,  # 70% train
            stratify=labels,
            random_state=self.random_seed
        )

        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp,
            test_size=1/3,  # 20% validation, 10% test
            stratify=y_temp,
            random_state=self.random_seed
        )

        # Preprocessing images
        X_train_tensors = torch.stack([self._load_image(path) for path in X_train])
        X_val_tensors = torch.stack([self._load_image(path) for path in X_val])
        X_test_tensors = torch.stack([self._load_image(path) for path in X_test])

        # One-hot encoding for labels
        y_train_tensors = torch.nn.functional.one_hot(torch.tensor(y_train), num_classes=2).float()
        y_val_tensors = torch.nn.functional.one_hot(torch.tensor(y_val), num_classes=2).float()
        y_test_tensors = torch.nn.functional.one_hot(torch.tensor(y_test), num_classes=2).float()

        return (X_train_tensors, y_train_tensors,
                X_val_tensors, y_val_tensors,
                X_test_tensors, y_test_tensors)

def main():
    """
    Example usage demonstrating dataset preparation.
    """
    mpox_dir = '/content/MSLD/Augmented Images/Augmented Images/Monkeypox_augmented'
    others_dir = '/content/MSLD/Augmented Images/Augmented Images/Others_augmented'

    preprocessor = MedicalImagePreprocessor(mpox_dir, others_dir)


    X_train, y_train, X_val, y_val, X_test, y_test = preprocessor.prepare_dataset()

    # Logging data shapes
    print(f"Training data shape: {X_train.shape}")
    print(f"Training labels shape: {y_train.shape}")
    print(f"Validation data shape: {X_val.shape}")
    print(f"Validation labels shape: {y_val.shape}")
    print(f"Test data shape: {X_test.shape}")
    print(f"Test labels shape: {y_test.shape}")
    return X_train, y_train, X_val, y_val, X_test, y_test



if __name__ == '__main__':
     X_train, y_train, X_val, y_val, X_test, y_test = main()

Training data shape: torch.Size([2234, 3, 224, 224])
Training labels shape: torch.Size([2234, 2])
Validation data shape: torch.Size([638, 3, 224, 224])
Validation labels shape: torch.Size([638, 2])
Test data shape: torch.Size([320, 3, 224, 224])
Test labels shape: torch.Size([320, 2])


In [10]:
y_train[0]

tensor([1., 0.])

In [11]:
def evaluate(y_true, y_pred):
    """
    Evaluates the performance of a classification model using precision, recall, and F1-score.

    Parameters:
    y_true (list or array-like): True labels.
    y_pred (list or array-like): Predicted labels by the model.

    Returns:
    tuple: Precision, Recall, and F1-Score, calculated using macro averaging.
    """

    # Calculate precision score using macro averaging (treats all classes equally)
    precision = precision_score(y_true, y_pred, average='macro')

    # Calculate recall score using macro averaging
    recall = recall_score(y_true, y_pred, average='macro')

    # Calculate F1-score using macro averaging
    f1 = f1_score(y_true, y_pred, average='macro')

    # Return the calculated metrics as a tuple
    return precision, recall, f1


In [12]:


def replace_context_modules(model, Module, dilation):
    """
    Replaces the context_module in specific blocks of a model's fourth stage with a custom module.

    Parameters:
    model: The model containing stages and blocks where replacements are to be made.
    Module: Custom module to replace the existing context_module.
    seed: Random seed for reproducibility in the custom module.
    """

    # Access the fourth stage (index 3) of the model
    stage = model.stages[3]

    # Loop through blocks 1 to 6 (indices 1 to 6) in the fourth stage
    for i in range(1, 7):
        block = stage.blocks[i]  # Access the current block

        # Extract the number of input channels from the original context_module
        in_channels = block.context_module.main.qkv.conv.in_channels

        # Extract the number of output channels from the original context_module
        out_channels = block.context_module.main.proj.conv.out_channels

        # Replace the existing context_module with a new custom module
        block.context_module = nn.Sequential(
            Module(in_channels, nn.ReLU, dilation=dilation)  # Initialize custom module with required parameters
        )





In [13]:

def change_classifier(model, model_name, num_classes=2, dropout=0.5,
                     neurons1=4096, neurons2=1024, neurons3=256, neurons4=512, n_layers=2):
    """
    Change the classifier head of various vision models

    Args:
        model: The base model to modify
        model_name: Name/type of the model to determine input features
        num_classes: Number of output classes
        dropout: Dropout rate
        neurons1-4: Number of neurons in each layer
        n_layers: Number of layers in classifier (1-4)
    """
    # Define input features based on model architecture
    input_features = {
        'resnet101.a1_in1k': 2048,
        'deit3_medium_patch16_224': 512,
        'coatnet_1_rw_224.sw_in1k': 768,
        'mobilenetv3_large_100.ra_in1k': 1280,
        'vit_base_patch16_224' : 768,
        'efficientvit_l1.r224_in1k': 3072
    }

    in_features = input_features.get(model_name, 3072)  # Default to 3072 if model not found

    # Create the classifier based on number of layers
    if n_layers == 1:
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, neurons1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons1, num_classes),
            nn.Sigmoid() if num_classes == 1 else nn.Softmax(dim=1)
        )

    elif n_layers == 2:
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, neurons1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons1, neurons2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons2, num_classes),
            nn.Sigmoid() if num_classes == 1 else nn.Softmax(dim=1)
        )

    elif n_layers == 3:
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, neurons1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons1, neurons2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons2, neurons3),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons3, num_classes),
            nn.Sigmoid() if num_classes == 1 else nn.Softmax(dim=1)
        )

    else:  # 4 layers
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, neurons1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons1, neurons2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(neurons2, neurons3),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons3, neurons4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(neurons4, num_classes),
            nn.Sigmoid() if num_classes == 1 else nn.Softmax(dim=1)
        )

    # Determine where to attach the classifier based on model type
    if hasattr(model, 'head'):
      if hasattr(model.head, 'classifier'):
        model.head.classifier = classifier

      elif hasattr(model.head, 'fc'):
        model.head.fc = classifier
      else:
         model.head = classifier
    elif hasattr(model, 'fc'):
      model.fc = classifier
    elif hasattr(model, 'classifier'):
      model.classifier = classifier



    else:
        raise AttributeError("Model structure not supported. Cannot find classifier or head attribute.")

    return model




In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

torch.cuda.is_available()

True

In [15]:
class_labels = {
    'monkey_pox' : [0.,1.],
    'other': [1.,0.]
}
# def acc_eval(outputs, labels, classwise=False):
#     predicted = (outputs > 0.5).float()
#     correct = (predicted == labels).float()
#     right = correct.sum().item()

#     if not classwise:
#         return right
#     else:
#         class_rights = {
#             'Other': ((1 - labels) * correct).sum().item(),
#             'Monkeypox': (labels * correct).sum().item()
#         }
#         return (right, class_rights['Other'], class_rights['Monkeypox'])

def acc_eval(outputs, labels, classwise=False):
    """
    Evaluate accuracy of model predictions.

    Args:
        outputs (torch.Tensor): Model predictions (one-hot encoded).
        labels (torch.Tensor): Ground truth labels (one-hot encoded).
        classwise (bool): If True, calculate and return class-wise accuracy.

    Returns:
        int: Total correct predictions if classwise=False.
        tuple: Total correct predictions and class-wise counts if classwise=True.
    """
    right = 0  # Counter for correct predictions
    class_rights = {class_name: 0 for class_name in class_labels.keys()}  # Initialize class-wise correct counters

    for j in range(outputs.shape[0]):  # Loop over all predictions in the batch
        max_value = torch.max(outputs[j])  # Get the maximum value in the current prediction
        outputs[j] = (outputs[j] == max_value).float()  # Convert to one-hot representation by retaining the max value index
        if torch.all(outputs[j].eq(labels[j])):  # Check if the prediction matches the ground truth
            right += 1  # Increment total correct counter
            if classwise:  # If classwise evaluation is required
                label_list = labels[j].detach().cpu().tolist()  # Convert label tensor to list for comparison
                for class_name, class_label in class_labels.items():  # Loop through each class label
                    if label_list == class_label:  # Check if the ground truth matches the current class label
                        class_rights[class_name] += 1  # Increment correct counter for the class
                        break  # Exit the loop once the class is identified
    if not classwise:  # Return total correct predictions if classwise=False
        return right
    else:  # Return total correct predictions and class-wise correct counts if classwise=True
        return (right, *class_rights.values())


In [16]:
def cal_total(labels):
    """
    Calculate the total count of labels for each class.

    Args:
        labels (torch.Tensor): Ground truth labels (one-hot encoded).

    Returns:
        tuple: Total counts for each class, in the order of class_labels keys.
    """
    # Initialize a dictionary to store the total count for each class
    class_totals = {class_name: 0 for class_name in class_labels.keys()}

    for j in range(labels.shape[0]):  # Loop over all labels in the batch
        label_list = labels[j].detach().cpu().tolist()  # Convert label tensor to a list
        for class_name, class_label in class_labels.items():  # Iterate through all class labels
            if label_list == class_label:  # Check if the label matches the current class
                class_totals[class_name] += 1  # Increment the count for the matched class
                break  # Exit the loop once the class is identified

    # Return the total counts for each class as a tuple
    return tuple(class_totals.values())


In [17]:
def norm(X_train, X_val, X_test):
    """
    Normalize training, validation, and test datasets using training set statistics.
    """
    meanx = X_train.mean()  # Calculate training set mean
    stdx = X_train.std()    # Calculate training set std

    # Normalize datasets using training set mean and std
    X_train = (X_train - meanx) / stdx
    X_valid = (X_val - meanx) / stdx
    X_test = (X_test - meanx) / stdx

    return X_train, X_valid, X_test


In [18]:
def test(model, training_loader, model_num):
    """
    Evaluate the model on the test dataset and compute accuracy, precision, recall, F1 score,
    and per-class accuracies.
    """
    right_total = 0  # Total correct predictions
    total = 0  # Total number of samples
    out = []  # List to store predicted labels
    lab = []  # List to store true labels
    class_totals = {class_name: 0 for class_name in class_labels.keys()}  # Per-class sample counts
    class_rights = {class_name: 0 for class_name in class_labels.keys()}  # Per-class correct predictions

    for i, data in enumerate(tqdm(training_loader)):  # Iterate over the training data
        inputs, labels = data  # Get inputs and labels
        total += inputs.shape[0]  # Update total samples

        outputs = model(inputs)  # Get model predictions

        # Get per-class totals and correct predictions
        class_totals_batch = cal_total(labels)
        right, *class_rights_batch = acc_eval(outputs, labels, classwise=True)

        right_total += right  # Update total correct predictions
        for i, class_name in enumerate(class_labels.keys()):  # Update per-class totals and rights
            class_totals[class_name] += class_totals_batch[i]
            class_rights[class_name] += class_rights_batch[i]

        # Convert outputs and labels to numpy arrays for evaluation
        outputs = np.array(outputs.detach().cpu(), dtype='object')
        labels = np.array(labels.detach().cpu(), dtype='object')

        out.extend(np.argmax(outputs, axis=1))  # Store predicted labels
        lab.extend(np.argmax(labels, axis=1))  # Store true labels

    accuracy = right_total / total  # Calculate overall accuracy
    class_accuracies = {class_name: class_rights[class_name] / class_totals[class_name]
                        if class_totals[class_name] > 0 else 0
                        for class_name in class_labels.keys()}  # Calculate per-class accuracies

    precision, recall, f1 = evaluate(out, lab)  # Evaluate precision, recall, and F1 score

    print("Accuracy:", accuracy)  # Print overall accuracy
    print("Total Right:", right_total)  # Print total correct predictions
    for class_name, class_accuracy in class_accuracies.items():  # Print per-class accuracy
        print(f"{class_name} Accuracy: {class_accuracy}")

    return (accuracy, precision, recall, f1, *class_accuracies.values())  # Return metrics


In [19]:
def train_one_epoch(model, epoch_index, model_num, training_loader, loss_fn, loss_fn1, w, optimizer, loss_dict_train):
    """
    Train the model for one epoch, computing losses and accuracies, and updating model weights.

    Args:
        model: The neural network model to train.
        epoch_index: The index of the current epoch.
        model_num: Identifier for the model (used for loss tracking).
        training_loader: DataLoader object for training dataset.
        loss_fn: Primary loss function used for training.
        loss_fn1: Secondary loss function used for training (combined with loss_fn).
        w: Weighting factor for combining loss_fn and loss_fn1.
        optimizer: Optimizer to update model parameters.
        loss_dict_train: Dictionary to track losses for the training process.

    Returns:
        last_loss: The average loss for the last batch in the epoch.
        overall_accuracy: The overall accuracy of the model on the training dataset.
        right_total: Total number of correct predictions in the epoch.
        class_accuracies: Per-class accuracy for each class in the dataset.
    """
    running_loss = 0.  # To track cumulative loss for the current epoch
    last_loss = 0.  # To store the average loss for the last batch
    right_total = 0  # Total correct predictions
    total = 0  # Total samples processed

    class_totals = {class_name: 0 for class_name in class_labels.keys()}  # Store count of each class in the batch
    class_rights = {class_name: 0 for class_name in class_labels.keys()}  # Store correct predictions for each class

    # Iterate through batches in the training set
    for i, data in enumerate(tqdm(training_loader)):
        inputs, labels = data  # Get input images and corresponding labels
        optimizer.zero_grad()  # Reset gradients to zero before backpropagation
        outputs = model(inputs)  # Get model predictions

        total += inputs.shape[0]  # Update the total number of samples processed
        # Compute the loss as a weighted combination of the two loss functions
        loss = (1-w) * loss_fn(outputs, labels) + w * loss_fn1(outputs, labels)
        loss.backward()  # Backpropagate the loss
        optimizer.step()  # Update the model weights using the optimizer

        # Get per-class totals and correct predictions
        class_totals_batch = cal_total(labels)
        right, *class_rights_batch = acc_eval(outputs, labels, classwise=True)

        right_total += right  # Update the total correct predictions
        # Update per-class totals and correct predictions
        for j, class_name in enumerate(class_labels.keys()):
            class_totals[class_name] += class_totals_batch[j]
            class_rights[class_name] += class_rights_batch[j]

        running_loss += loss.item()  # Add current batch loss to running total
        # Print the average loss for every 10 batches
        if i % 10 == 9:
            last_loss = running_loss / 10  # Calculate the average loss for the last 10 batches
            print(f'  batch {i + 1} loss: {last_loss}')
            running_loss = 0.  # Reset running loss for the next set of batches

    # Calculate per-class accuracies
    print("Class totals:", class_totals)
    class_accuracies = {}
    for class_name in class_labels.keys():
        if class_totals[class_name] == 0:
            class_accuracies[class_name] = None  # No data for this class
        else:
            class_accuracies[class_name] = class_rights[class_name] / class_totals[class_name]  # Calculate class accuracy

    # Calculate overall accuracy
    overall_accuracy = right_total / total if total > 0 else None

    # Return the last loss, overall accuracy, total correct predictions, and class-wise accuracies
    return (last_loss, overall_accuracy, right_total,
            *[class_accuracies[class_name] for class_name in class_labels.keys()])


In [20]:
import numpy as np

class CAGA(nn.Module):
    """
      Attributes:
          heads: Number of attention heads in the multi-head attention mechanism.
          dim: Dimensionality of Q, K, V.
          scale: Scaling factor for the attention computation.
          head_dim: The number of channels per attention head.
          dilation: List of dilation values for dilated convolutions.
          total_layer: Total number of layers used in the module.
          get_begin: Initial convolutional layer for feature extraction.
          get_qkv: List of convolutional layers to compute queries, keys, and values.
          convert_to_headdim: Layer to combine and reshape the outputs of all dilated convolutions.
          mix: Convolutional layers to refine the features.
          proj: Final projection layer to map the concatenated features to the input shape.
          norm: Batch normalization layer to normalize the output.
      """
    def __init__(self,
            in_channels,
            activation,
            heads = 3,
            dim = 8,
            expand_ratio = 4,
            head_dim = 16,
            dilation = (1,2),
            random_seed = 82
            ):

        # Set global random seeds for reproducibility
        self._set_global_seeds(random_seed)

        super(CAGA, self).__init__()
        self.heads = heads
        self.dim = dim
        scale = dim
        self.scale = dim ** -0.5
        self.head_dim = head_dim
        self.dilation = dilation

        self.total_layer = 4

        # Reproducible layer initialization
        self.get_begin = self._init_depthwise_separable_conv(
            DepthWiseSeperableConvLayer(in_channels, self.heads * self.head_dim)
        )

        self.get_qkv = nn.ModuleList([
            nn.Sequential(
                self._init_conv(nn.Conv2d(
                    self.head_dim,
                    self.head_dim,
                    3,
                    groups=self.head_dim,
                    dilation=di
                )),
                self._init_conv(nn.Conv2d(self.head_dim, 3 * self.dim, 1, groups=1))
            )
            for di in dilation
        ])

        self.convert_to_headdim = self._init_conv(
            nn.Conv2d(len(dilation) * self.dim, self.head_dim, 1)
        )

        self.mix = nn.Sequential(
            self._init_conv(nn.Conv2d(
                self.dim,
                self.dim * 3,
                1
            )),
            nn.ReLU()
        )

        self.proj = self._init_conv(
            nn.Conv2d(self.heads * self.dim * len(self.dilation), in_channels, 1)
        )

        # Deterministic BatchNorm
        self.norm = nn.BatchNorm2d(num_features=in_channels, affine=True)
        nn.init.constant_(self.norm.weight, 1)
        nn.init.constant_(self.norm.bias, 0)

    def _set_global_seeds(self, seed):
        """Set seeds for reproducibility across libraries."""
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)

    def _init_conv(self, conv_layer):
        """Initialize convolutional layer weights deterministically."""
        nn.init.xavier_uniform_(conv_layer.weight)
        if conv_layer.bias is not None:
            nn.init.constant_(conv_layer.bias, 0)
        return conv_layer

    def _init_depthwise_separable_conv(self, conv_layer):
        """Initialize depthwise separable convolution layers."""
        # Assuming DepthWiseSeperableConvLayer has similar structure to standard conv
        for m in conv_layer.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        return conv_layer
        # self.norm = nn.LayerNorm([8, 256, 14, 14])



    def attention(self,q,k,v , shape):

        B, C, H, W = shape



        q, k, v = q.float(), k.float(), v.float()


        q = q * self.scale
        att_map = q.transpose(-2, -1) @ k


        att_map = att_map.softmax(dim=-1)


        out = v @ att_map.transpose(-2, -1)

        out = out.view(B , -1 , H , W)



        return out

    def forward(self,x):
        B, C, H, W = x.shape



            # print(op)

        x_copy = x

        all_heads  = self.get_begin(x)

        multi_layer = all_heads.split([self.head_dim]*self.heads , dim=1)

        all_heads_after_op = [[]]*self.heads



        for j in range(self.heads):

            for op in self.get_qkv:

                all_heads_after_op[j].append(op(multi_layer[j]))


        all_final = []


        for i in range(self.heads):
            out_all = []
            for j in range(len(self.dilation ) ):



                q , k , v = all_heads_after_op[i][j].split([self.dim, self.dim, self.dim], dim=1)
                shape = q.shape
                q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
                out = self.attention(q , k , v , shape)



                # print(out.shape)
                shape_ahead = all_heads_after_op[i][j+1].shape[3]

                temp = torchvision.transforms.functional.resize(out , (shape_ahead,shape_ahead))

                out = F.interpolate(out, size=( H , W ), mode='bilinear')
                out = out.view(B, self.dim, H, W)


                if j+1 != len(self.dilation ):

                    temp = self.mix(temp).clone()
                    all_heads_after_op[i][j+1] = all_heads_after_op[i][j+1] + temp

                out_all.append(out)




            out_all_one = torch.cat(out_all, dim=1)
            all_final.append(out_all_one)
            if i+1 != self.heads:
                all_heads_after_op[i+1] += self.convert_to_headdim(out_all_one)




      # we need to billinear intterpolate before append
        all_concat = torch.cat(all_final, dim=1)


        x_final = self.proj(all_concat) + x_copy # try oncat later

        return self.norm (x_final)

In [21]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_acc = float('-inf')
        self.val_acc_1 = 0
        self.perfect_val_count =6

    def early_stop(self, validation_acc):


        if validation_acc > self.max_validation_acc:
            self.max_validation_acc = validation_acc
            self.counter = 0
        elif validation_acc <= (self.max_validation_acc - self.min_delta):
            print("max_acc" , self.max_validation_acc)

            self.counter += 1
            if self.counter >= self.patience:
                return (True,self.counter)
        return (False,self.counter)


In [25]:
import os
import torch
import numpy as np
from datetime import datetime
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import csv



def training(model_name , model_config_name, HYPERPARAMS,X_train, y_train, X_val, y_val, X_test, y_test,Module = 'CAGA' , pre_trained = True,run = 1):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    accuracy_dict_train = {}
    accuracy_dict_val = {}
    loss_dict_train = {}
    loss_dict_val = {}
    accu = []

    # Initialize metrics
    accuracy_avg = 0
    precision_avg = 0
    recall_avg = 0
    f1_avg = 0
    model_name = model_name
    # Create directories using Cola paths
    csv_dir = f'/results/run_{run}'
    os.makedirs(csv_dir, exist_ok=True)

    main_csv_path = os.path.join(csv_dir, f'results_{timestamp}_{model_name}.csv')

    # Initialize main results CSV
    header = ['Model', 'Accuracy', 'Precision', 'Recall', 'F1'] + \
         [f'{class_name} Accuracy' for class_name in class_labels.keys()] + \
         list(HYPERPARAMS.keys())

    with open(main_csv_path, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(header)

    model_num = 0
    print(f"####################################")

    X_train_norm, X_val_norm, X_test_norm = norm(X_train, X_val, X_test)

    epoch_number = 0
    alpha = torch.tensor([0.81, 1.12]).to(device)


    model = timm.create_model(model_config_name, pretrained= pre_trained) #for ablation turn pre-trained to False
    if Module == 'CAGA':
        replace_context_modules(model, CAGA, dilation = HYPERPARAMS['dilation']) # comment this for all other models apart for efficientViT-CAGA
    change_classifier(model, model_config_name, dropout=HYPERPARAMS['dropout'], neurons1=HYPERPARAMS['neurons1'], neurons2=HYPERPARAMS['neurons2'], neurons3=HYPERPARAMS['neurons3'], neurons4=HYPERPARAMS['neurons4'], n_layers=HYPERPARAMS['n_layers'])

    print(model)
    model.to(device)
    

    training_loader = DataLoader(list(zip(torch.Tensor(X_train_norm).to(device), torch.Tensor(y_train).to(device))), batch_size=HYPERPARAMS['batch_size_train'], shuffle=True)
    validation_loader = DataLoader(list(zip(torch.Tensor(X_val_norm).to(device), torch.Tensor(y_val).to(device))), batch_size=HYPERPARAMS['batch_size_val'], shuffle=True)
    test_loader = DataLoader(list(zip(torch.Tensor(X_test_norm).to(device), torch.Tensor(y_test).to(device))), batch_size=HYPERPARAMS['batch_size_test'], shuffle=False)

    loss_fn = torch.nn.BCELoss()
    loss_fn1 = torch.nn.CrossEntropyLoss()
    #loss_fn = FocalCELoss( alpha = alpha, gamma=2.0, num_classes=2)
    w = 0
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=HYPERPARAMS['learning_rate'], weight_decay=HYPERPARAMS['weight_decay'])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=HYPERPARAMS['scheduler_gamma'])
    early_stopper = EarlyStopper(patience=HYPERPARAMS['patience'], min_delta=HYPERPARAMS['min_delta'])

    for epoch in range(HYPERPARAMS['EPOCH']):
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        print(f'EPOCH {epoch}:')

        model.train(True)
        avg_loss, train_accuracy, right, *class_accuracies_train = train_one_epoch(model, epoch_number, model_num, training_loader, loss_fn, loss_fn1, w, optimizer, loss_dict_train)

        print(f"Average loss: {avg_loss}")

        running_vloss = 0.0
        model.eval()
        vright_total = 0
        total_val = 0

        class_totals_val = {class_name: 0 for class_name in class_labels.keys()}
        class_rights_val = {class_name: 0 for class_name in class_labels.keys()}

        with torch.no_grad():
            for i, vdata in enumerate(validation_loader):
                vinputs, vlabels = vdata
                total_val += vinputs.shape[0]

                voutputs = model(vinputs)
                vloss = (1 - w) * loss_fn(voutputs, vlabels) + w * loss_fn1(voutputs, vlabels)

                class_totals_batch = cal_total(vlabels)
                vright, *class_rights_batch = acc_eval(voutputs, vlabels, classwise=True)

                vright_total += vright

                for j, class_name in enumerate(class_labels.keys()):
                    class_totals_val[class_name] += class_totals_batch[j]
                    class_rights_val[class_name] += class_rights_batch[j]

                running_vloss += vloss

        avg_vloss = running_vloss / (i + 1)
        print(f"Validation loss: {avg_vloss}")

        class_accuracies_val = {class_name: class_rights_val[class_name] / class_totals_val[class_name] if class_totals_val[class_name] > 0 else None for class_name in class_labels.keys()}

        scheduler.step()

        val_acc = vright_total / total_val

        print(f"Validation accuracy: {val_acc}")
        print(f'LOSS train {avg_loss} valid {avg_vloss}')
        print(f'Right train {right} valid {vright_total}')
        print(f'Accuracy train {train_accuracy} valid {val_acc}')


        early = early_stopper.early_stop(val_acc)
        print("Current early_stop count", early[1])
        if early[0]:
            print('Early stopping triggered')
            break

    with torch.no_grad():
        acc, precision, recall, f1, *class_accuracies_test = test(model, test_loader, model_num)
        row = [model_name, acc, precision, recall, f1] + \
                    class_accuracies_test + \
                    list(HYPERPARAMS.values())

        with open(main_csv_path, 'a', newline='') as csvfile:
            csvwriter = csv.writer(csvfile)
            csvwriter.writerow(row)


