# Imports

In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, classification_report

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.datasets.folder import is_image_file

from vit_pytorch import ViT

from pytorch_tabnet.tab_model import TabNetClassifier
# from ft_transformer import FTTransformer
from transformers import BertModel

from AACN_Model import attention_augmented_resnet18, attention_augmented_inceptionv3,attention_augmented_vgg

In [2]:
# Assuming `crops` and directories (`train_dir`, `val_dir`, `test_dir`) are defined
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Data Preprocessing

In [3]:
# Define main directories
base_dir = '/Users/izzymohamed/Desktop/Vision For Social Good/Project/Vision-For-Social-Good/DATA' 
crop_root = os.path.join(base_dir, 'color') # color tester
split_root = os.path.join(base_dir, 'split')

In [4]:
# Load CSV data
csv_path = os.path.join(base_dir, 'plant_disease_multimodal_dataset.csv')  # '/Users/izzymohamed/Desktop/Vision For Social Good/Project/Vision-For-Social-Good/DATA/plant_disease_multimodal_dataset.csv'
csv_data = pd.read_csv(csv_path)

In [5]:
# Separate the image paths and labels from the features
csv_image_paths = csv_data['Image Path'].values
csv_labels = csv_data['Mapped Label'].values
csv_features = csv_data.drop(columns=['Image Path', 'Mapped Label', 'Label']).values.astype(np.float32)

In [6]:
# Define function to remove .DS_Store files
def remove_ds_store(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file == '.DS_Store' or '.DS_Store' in file:
                file_path = os.path.join(root, file)
                print(f"Removing {file_path}")
                os.remove(file_path)

In [7]:
# Remove .DS_Store files from base directory
remove_ds_store(base_dir)

In [8]:
# Function to check if a file is an image file
def is_image_file(filename):
    return filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'))

In [9]:
# Function to split data into train, validation, and test sets
def split_data(base_dir, val_split=0.4, test_split=0.1):
    train_files = []
    val_files = []
    test_files = []

    classes = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
    for cls in classes:
        print(f'Processing class: {cls}')
        class_dir = os.path.join(base_dir, cls)

        images = [f for f in os.listdir(class_dir) if is_image_file(os.path.join(class_dir, f))]

        if len(images) == 0:
            print(f"No images found for class {cls}. Skipping...")
            continue

        # Shuffle images to randomize the selection
        random.shuffle(images)

        try:
            train, test = train_test_split(images, test_size=test_split)
            train, val = train_test_split(train, test_size=val_split / (1 - test_split))
        except ValueError as e:
            print(f"Not enough images to split for class {cls}: {e}")
            continue

        train_files.extend([(os.path.join(class_dir, img), cls) for img in train])
        val_files.extend([(os.path.join(class_dir, img), cls) for img in val])
        test_files.extend([(os.path.join(class_dir, img), cls) for img in test])

    return train_files, val_files, test_files, classes

In [10]:
# Split data
train_files, val_files, test_files, classes = split_data(crop_root)

Processing class: Corn_(maize)___healthy
Processing class: Tomato___Target_Spot
Processing class: Tomato___Late_blight
Processing class: Tomato___Tomato_mosaic_virus
Processing class: Pepper,_bell___healthy
Processing class: Orange___Haunglongbing_(Citrus_greening)
Processing class: Tomato___Leaf_Mold
Processing class: Tomato___Bacterial_spot
Processing class: Tomato___Early_blight
Processing class: Corn_(maize)___Common_rust_
Processing class: Tomato___healthy
Processing class: Tomato___Tomato_Yellow_Leaf_Curl_Virus
Processing class: Corn_(maize)___Northern_Leaf_Blight
Processing class: Tomato___Spider_mites Two-spotted_spider_mite
Processing class: Pepper,_bell___Bacterial_spot
Processing class: Tomato___Septoria_leaf_spot
Processing class: Squash___Powdery_mildew
Processing class: Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
Processing class: Soybean___healthy


In [11]:
# Use the lists of file paths for your dataset loading and transformations
print(f"Train files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")
print(f"Test files: {len(test_files)}")

Train files: 18444
Validation files: 14769
Test files: 3700


In [12]:
# Define the standard image sizes
inception_size = 299
other_size = 224

In [13]:
# Update the data transformations
data_transforms = {
    'InceptionV3': {
        'train': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    },
    'Others': {
        'train': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
}

In [14]:
# Class to create the datasets and data loaders and to map between the different modalities
class CustomMultimodalDataset(Dataset):
    def __init__(self, file_paths, csv_features, csv_labels, class_to_idx, transform=None):
        """
        Initializes the dataset with image paths, CSV features, labels, class mapping, and optional transforms.
        
        Args:
            file_paths (list of tuples): List of (image_path, class_label) tuples.
            csv_features (ndarray): Array of CSV feature rows.
            csv_labels (ndarray): Array of CSV labels corresponding to csv_features.
            class_to_idx (dict): Mapping from class labels to indices.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.file_paths = file_paths
        self.csv_features = csv_features
        self.csv_labels = csv_labels
        self.class_to_idx = class_to_idx
        self.transform = transform

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        return len(self.file_paths)

    def __getitem__(self, idx):
        """
        Retrieves the sample at the given index.
        
        Args:
            idx (int): Index of the sample to retrieve.
        
        Returns:
            tuple: (image, csv_row, label) where image is the transformed image tensor,
                   csv_row is the corresponding CSV feature row, and label is the class index.
        """
        img_path, cls = self.file_paths[idx]  # Get image path and class label
        image = Image.open(img_path).convert('RGB')  # Open image and convert to RGB
        label = self.class_to_idx[cls]  # Map class label to index
        csv_row = self.csv_features[idx]  # Get the corresponding CSV feature row
        
        if self.transform:
            image = self.transform(image)  # Apply image transformations if provided
        
        return image, csv_row, label  # Return the image, CSV features, and label

In [15]:
# Create a mapping from class names to indices
class_to_idx = {cls: idx for idx, cls in enumerate(classes)}

In [16]:
# Create datasets and data loaders
train_dataset_inception = CustomMultimodalDataset(train_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['train'])
val_dataset_inception = CustomMultimodalDataset(val_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['val'])
test_dataset_inception = CustomMultimodalDataset(test_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['test'])

train_loader_inception = DataLoader(train_dataset_inception, batch_size=32, shuffle=True)
val_loader_inception = DataLoader(val_dataset_inception, batch_size=32, shuffle=True)
test_loader_inception = DataLoader(test_dataset_inception, batch_size=32, shuffle=False)

In [17]:
# Loaders for other models
train_dataset_others = CustomMultimodalDataset(train_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['train'])
val_dataset_others = CustomMultimodalDataset(val_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['val'])
test_dataset_others = CustomMultimodalDataset(test_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['test'])

train_loader_others = DataLoader(train_dataset_others, batch_size=32, shuffle=True)
val_loader_others = DataLoader(val_dataset_others, batch_size=32, shuffle=True)
test_loader_others = DataLoader(test_dataset_others, batch_size=32, shuffle=False)

# Modalities Fusion

### CSV feature extractor models

In [18]:
# Define the TabNet CSV feature extractor
class TabNetCSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        """
        Initializes the TabNetCSVFeatureExtractor.

        Args:
            input_dim (int): Number of input features.
            hidden_dim (int): Desired output dimension of the feature extractor.
        """
        super(TabNetCSVFeatureExtractor, self).__init__()
        self.tabnet = TabNetClassifier(
            input_dim=input_dim,
            output_dim=hidden_dim,  # Set output dimension to match desired feature size
            n_d=hidden_dim,
            n_a=hidden_dim,
            n_steps=3,
            gamma=1.3,
            n_independent=2,
            n_shared=2,
            lambda_sparse=1e-3,
            optimizer_fn=torch.optim.Adam,
            optimizer_params=dict(lr=2e-2),
            scheduler_fn=None,
            scheduler_params=None,
            mask_type='sparsemax'
        )
        # Placeholder for a linear layer if needed to match dimensions
        self.linear = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, x):
        # Ensure x is on CPU since TabNetClassifier might not support GPU
        x = x.cpu().numpy()
        # Predict on input features
        predictions = self.tabnet.predict(x)
        # Convert predictions to tensor
        feature_output = torch.tensor(predictions, dtype=torch.float32)
        # Optionally pass through a linear layer if needed
        return self.linear(feature_output)

# Define the FT-Transformer CSV feature extractor
# class FTTransformerCSVFeatureExtractor(nn.Module):
#     def __init__(self, input_dim, hidden_dim):
#         """
#         Initializes the FTTransformerCSVFeatureExtractor.

#         Args:
#             input_dim (int): Number of input features.
#             hidden_dim (int): Number of hidden units.
#         """
#         super(FTTransformerCSVFeatureExtractor, self).__init__()
#         self.ft_transformer = FTTransformer(input_dim=input_dim, output_dim=hidden_dim)

#     def forward(self, x):
#         return self.ft_transformer(x)

# Define the Multimodal Transformer (BERT-based) CSV feature extractor
class BERTCSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        """
        Initializes the BERTCSVFeatureExtractor.

        Args:
            input_dim (int): Number of input features.
            hidden_dim (int): Number of hidden units.
        """
        super(BERTCSVFeatureExtractor, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(self.bert.config.hidden_size, hidden_dim)

    def forward(self, x):
        # BERT expects tokenized input, so we need to process x appropriately
        # Assuming x is tokenized and of shape (batch_size, seq_length)
        output = self.bert(input_ids=x)[1]  # [1] corresponds to the pooled output
        return self.fc(output)

# Define the MLP CSV feature extractor
class MLP_CSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        """
        Initializes the MLP_CSVFeatureExtractor.

        Args:
            input_dim (int): Number of input features.
            hidden_dim (int): Number of hidden units.
        """
        super(MLP_CSVFeatureExtractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.extractor(x)

In [19]:
# Define the Simple MLP CSV feature extractor
class SimpleCSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        """
        Initializes the SimpleCSVFeatureExtractor.

        Args:
            input_dim (int): Number of input features.
            hidden_dim (int): Number of hidden units.
        """
        super(SimpleCSVFeatureExtractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.extractor(x)

class DeepCSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        """
        Initializes the DeepCSVFeatureExtractor.

        Args:
            input_dim (int): Number of input features.
            hidden_dim (int): Number of hidden units.
        """
        super(DeepCSVFeatureExtractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.extractor(x)

class ConvCSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim, seq_len):
        """
        Initializes the ConvCSVFeatureExtractor.

        Args:
            input_dim (int): Number of input features.
            hidden_dim (int): Number of hidden units.
            seq_len (int): Length of the input sequence.
        """
        super(ConvCSVFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
        self.fc = nn.Linear(hidden_dim * seq_len, hidden_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

### Fusion Model

In [20]:
# Define the fusion methods for combining image and CSV features
class FusionModel(nn.Module):
    def __init__(self, model_name, base_model, csv_input_dim, csv_hidden_dim, num_classes, fusion_method, csv_model_type='simple', seq_len=None):
        """
        Initializes the FusionModel.

        Args:
            model_name (str): Name of the base model architecture.
            base_model (nn.Module): The base model to be used.
            csv_input_dim (int): Number of features in the CSV data.
            csv_hidden_dim (int): Number of hidden units in the CSV feature extractor.
            num_classes (int): Number of classes for classification.
            fusion_method (str): Method of fusion ('early', 'intermediate', 'late').
            csv_model_type (str): Type of CSV feature extractor ('simple', 'deep', 'conv').
            seq_len (int, optional): Length of the input sequence for convolutional model.
        """
        super(FusionModel, self).__init__()
        self.model_name = model_name
        self.base_model = base_model
        self.csv_input_dim = csv_input_dim
        self.csv_hidden_dim = csv_hidden_dim
        self.num_classes = num_classes
        self.fusion_method = fusion_method

        # Initialize the base model and get the feature size
        self.base_model, self.feature_size = self.initialize_base_model(model_name, base_model)

        # Define the CSV feature extractor based on the type
        if csv_model_type == 'simple':
            self.csv_feature_extractor = SimpleCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim)
        elif csv_model_type == 'deep':
            self.csv_feature_extractor = DeepCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim)
        elif csv_model_type == 'conv':
            if seq_len is None:
                raise ValueError("seq_len must be provided for the convolutional model")
            self.csv_feature_extractor = ConvCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim, seq_len)
        else:
            raise ValueError("Unsupported csv_model_type")

        # Define additional layers for fusion based on the fusion method
        if self.fusion_method == 'late':
            self.fusion_layer = nn.Linear(self.feature_size + self.csv_hidden_dim, self.num_classes)
        elif self.fusion_method == 'intermediate':
            self.intermediate_layer = nn.Linear(self.feature_size, 512)
            self.fusion_layer = nn.Linear(512 + self.csv_hidden_dim, self.num_classes)
        elif self.fusion_method == 'early':
            # Calculate the new feature size after early fusion
            early_fusion_size = self.feature_size + self.csv_hidden_dim
            self.early_fusion_layer = nn.Linear(early_fusion_size, self.feature_size)
            self.classifier = nn.Linear(self.feature_size, num_classes)
        else:
            raise ValueError("Unsupported fusion method")

    def initialize_base_model(self, model_name, base_model):
        """
        Initializes the base model for fusion, replacing the final classification layer.

        Args:
            model_name (str): Name of the base model architecture.
            base_model (nn.Module): The base model to be modified.

        Returns:
            tuple: (modified_base_model, feature_size)
        """
        if model_name == 'InceptionV3':
            feature_size = base_model.fc.in_features  # Access in_features before replacing
            base_model.aux_logits = False  # Disable auxiliary logits
            base_model.AuxLogits = None  # Remove auxiliary logits
            base_model.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'ResNet152' or model_name == 'AttentionAugmentedResNet18':
            feature_size = base_model.fc.in_features  # Access in_features before replacing
            base_model.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'VGG19' or model_name == 'AttentionAugmentedVGG19':
            feature_size = base_model.classifier[6].in_features  # Access in_features before replacing
            base_model.classifier[6] = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'ViT':
            # Generalized approach to identify and replace the classification head
            if hasattr(base_model, 'heads'):
                feature_size = base_model.heads.head.in_features
                base_model.heads.head = nn.Identity()
            elif hasattr(base_model, 'classifier'):
                feature_size = base_model.classifier.in_features
                base_model.classifier = nn.Identity()
            elif hasattr(base_model, 'head'):
                feature_size = base_model.head.in_features
                base_model.head = nn.Identity()
            else:
                # Fallback: Inspect all attributes and find a suitable final layer
                for attr_name in dir(base_model):
                    attr = getattr(base_model, attr_name)
                    if isinstance(attr, nn.Linear):
                        feature_size = attr.in_features
                        setattr(base_model, attr_name, nn.Identity())
                        break
                else:
                    raise ValueError(f"Unsupported ViT model structure for model: {model_name}")
        elif model_name == 'AttentionAugmentedInceptionV3':
            inception_model = base_model.inception
            feature_size = inception_model.fc.in_features  # Access in_features before replacing
            inception_model.aux_logits = False  # Disable auxiliary logits
            inception_model.AuxLogits = None  # Remove auxiliary logits
            inception_model.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'AttentionAugmentedVGG19':
            # Access the VGG model within the wrapper
            vgg_model = base_model.features
            feature_size = vgg_model[-2].out_channels  # Assuming the penultimate layer is the feature extractor
            # Replace the last layer with an identity function
            vgg_model[-1] = nn.Identity()
        elif model_name == 'AttentionAugmentedResNet18':
            # Access the ResNet model within the wrapper
            resnet_model = base_model
            feature_size = resnet_model.fc.in_features
            resnet_model.fc = nn.Identity()
        else:
            raise ValueError(f"Unsupported model type: {model_name}")

        return base_model, feature_size

    def forward(self, img, csv):
        """
        Forward pass for the FusionModel.

        Args:
            img (Tensor): Image tensor input.
            csv (Tensor): CSV feature input.

        Returns:
            Tensor: Output logits.
        """
        # Extract CSV features
        csv_features = self.csv_feature_extractor(csv)

        if self.fusion_method == 'early':
            # Early fusion: Concatenate features before passing through the base model
            img = img.view(img.size(0), -1)  # Flatten the image tensor
            img_csv_combined = torch.cat((img, csv_features), dim=1)  # Concatenate image and CSV features
            img_csv_features = self.early_fusion_layer(img_csv_combined)  # Pass through the early fusion layer
            output = self.classifier(img_csv_features)  # Classify the fused features
        else:
            img_features = self.base_model(img)  # Extract image features using the base model
            if self.model_name == 'AttentionAugmentedInceptionV3':
                # Check if img_features is of type InceptionOutputs and extract the tensor
                if isinstance(img_features, tuple):  # Handle Inception model outputs
                    img_features = img_features.logits
                    img_features = img_features.view(img_features.size(0), -1)  # Flatten the image features

            if self.fusion_method == 'late':
                combined_features = torch.cat((img_features, csv_features), dim=1)  # Concatenate image and CSV features
                output = self.fusion_layer(combined_features)  # Classify the fused features
            elif self.fusion_method == 'intermediate':
                img_features = self.intermediate_layer(img_features)  # Pass image features through intermediate layer
                combined_features = torch.cat((img_features, csv_features), dim=1)  # Concatenate features
                output = self.fusion_layer(combined_features)  # Classify the fused features
            else:
                raise ValueError("Unsupported fusion method")
        return output  # Return the output logits

In [21]:
# Define the fusion methods for combining image and CSV features
class FusionModel1(nn.Module):
    def __init__(self, model_name, base_model, csv_input_dim, csv_hidden_dim, num_classes, fusion_method):
        """
        Initializes the FusionModel.

        Args:
            model_name (str): Name of the base model architecture.
            base_model (nn.Module): The base model to be used.
            csv_input_dim (int): Number of features in the CSV data.
            csv_hidden_dim (int): Number of hidden units in the CSV feature extractor.
            num_classes (int): Number of classes for classification.
            fusion_method (str): Method of fusion ('early', 'intermediate', 'late').
        """
        super(FusionModel1, self).__init__()
        self.model_name = model_name
        self.base_model = base_model
        self.csv_input_dim = csv_input_dim
        self.csv_hidden_dim = csv_hidden_dim
        self.num_classes = num_classes
        self.fusion_method = fusion_method

        # Initialize the base model and get the feature size
        self.base_model, self.feature_size = self.initialize_base_model(model_name, base_model)

        # Define the CSV feature extractor
        self.csv_feature_extractor = nn.Sequential(
            nn.Linear(self.csv_input_dim, self.csv_hidden_dim),
            nn.ReLU(),
            nn.Linear(self.csv_hidden_dim, self.csv_hidden_dim),
            nn.ReLU()
        )

        # Define additional layers for fusion based on the fusion method
        if self.fusion_method == 'late':
            self.fusion_layer = nn.Linear(self.feature_size + self.csv_hidden_dim, self.num_classes)
        elif self.fusion_method == 'intermediate':
            self.intermediate_layer = nn.Linear(self.feature_size, 512)
            self.fusion_layer = nn.Linear(512 + self.csv_hidden_dim, self.num_classes)
        elif self.fusion_method == 'early':
            # Calculate the new feature size after early fusion
            early_fusion_size = self.feature_size + self.csv_hidden_dim
            self.early_fusion_layer = nn.Linear(early_fusion_size, self.feature_size)
            self.classifier = nn.Linear(self.feature_size, num_classes)
        else:
            raise ValueError("Unsupported fusion method")

    def initialize_base_model(self, model_name, base_model):
        """
        Initializes the base model for fusion, replacing the final classification layer.

        Args:
            model_name (str): Name of the base model architecture.
            base_model (nn.Module): The base model to be modified.

        Returns:
            tuple: (modified_base_model, feature_size)
        """
        if model_name == 'InceptionV3':
            feature_size = base_model.fc.in_features  # Access in_features before replacing
            base_model.aux_logits = False  # Disable auxiliary logits
            base_model.AuxLogits = None  # Remove auxiliary logits
            base_model.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'ResNet152' or model_name == 'AttentionAugmentedResNet18':
            feature_size = base_model.fc.in_features  # Access in_features before replacing
            base_model.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'VGG19' or model_name == 'AttentionAugmentedVGG19':
            feature_size = base_model.classifier[6].in_features  # Access in_features before replacing
            base_model.classifier[6] = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'ViT':
            # Generalized approach to identify and replace the classification head
            if hasattr(base_model, 'heads'):
                feature_size = base_model.heads.head.in_features
                base_model.heads.head = nn.Identity()
            elif hasattr(base_model, 'classifier'):
                feature_size = base_model.classifier.in_features
                base_model.classifier = nn.Identity()
            elif hasattr(base_model, 'head'):
                feature_size = base_model.head.in_features
                base_model.head = nn.Identity()
            else:
                # Fallback: Inspect all attributes and find a suitable final layer
                for attr_name in dir(base_model):
                    attr = getattr(base_model, attr_name)
                    if isinstance(attr, nn.Linear):
                        feature_size = attr.in_features
                        setattr(base_model, attr_name, nn.Identity())
                        break
                else:
                    raise ValueError(f"Unsupported ViT model structure for model: {model_name}")
        elif model_name == 'AttentionAugmentedInceptionV3':
            inception_model = base_model.inception
            feature_size = inception_model.fc.in_features  # Access in_features before replacing
            inception_model.aux_logits = False  # Disable auxiliary logits
            inception_model.AuxLogits = None  # Remove auxiliary logits
            inception_model.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'AttentionAugmentedVGG19':
            # Access the VGG model within the wrapper
            vgg_model = base_model.features
            feature_size = vgg_model[-2].out_channels  # Assuming the penultimate layer is the feature extractor
            # Replace the last layer with an identity function
            vgg_model[-1] = nn.Identity()
        elif model_name == 'AttentionAugmentedResNet18':
            # Access the ResNet model within the wrapper
            resnet_model = base_model
            feature_size = resnet_model.fc.in_features
            resnet_model.fc = nn.Identity()
        else:
            raise ValueError(f"Unsupported model type: {model_name}")

        return base_model, feature_size

    def forward(self, img, csv):
        """
        Forward pass for the FusionModel.

        Args:
            img (Tensor): Image tensor input.
            csv (Tensor): CSV feature input.

        Returns:
            Tensor: Output logits.
        """
        # Extract CSV features
        csv_features = self.csv_feature_extractor(csv)

        if self.fusion_method == 'early':
            # Early fusion: Concatenate features before passing through the base model
            img = img.view(img.size(0), -1)  # Flatten the image tensor
            img_csv_combined = torch.cat((img, csv_features), dim=1)  # Concatenate image and CSV features
            img_csv_features = self.early_fusion_layer(img_csv_combined)  # Pass through the early fusion layer
            output = self.classifier(img_csv_features)  # Classify the fused features
        else:
            img_features = self.base_model(img)  # Extract image features using the base model
            if self.model_name == 'AttentionAugmentedInceptionV3':
                # Check if img_features is of type InceptionOutputs and extract the tensor
                if isinstance(img_features, tuple):  # Handle Inception model outputs
                    img_features = img_features.logits
                    img_features = img_features.view(img_features.size(0), -1)  # Flatten the image features

            if self.fusion_method == 'late':
                combined_features = torch.cat((img_features, csv_features), dim=1)  # Concatenate image and CSV features
                output = self.fusion_layer(combined_features)  # Classify the fused features
            elif self.fusion_method == 'intermediate':
                img_features = self.intermediate_layer(img_features)  # Pass image features through intermediate layer
                combined_features = torch.cat((img_features, csv_features), dim=1)  # Concatenate features
                output = self.fusion_layer(combined_features)  # Classify the fused features
            else:
                raise ValueError("Unsupported fusion method")
        return output  # Return the output logits

In [22]:
# Define the fusion method model
class FusionModel2(nn.Module):
    def __init__(self, model_name, image_feature_extractor, csv_feature_extractor, num_classes, fusion_method):
        """
        Initializes the FusionModel.

        Args:
            model_name (str): Name of the image feature extractor architecture.
            image_feature_extractor (nn.Module): The model for extracting features from images.
            csv_feature_extractor (nn.Module): The model for extracting features from CSV data.
            num_classes (int): Number of classes for classification.
            fusion_method (str): Method of fusion ('early', 'intermediate', 'late').
        """
        super(FusionModel2, self).__init__()
        self.model_name = model_name
        self.image_feature_extractor = image_feature_extractor
        self.csv_feature_extractor = csv_feature_extractor
        self.num_classes = num_classes
        self.fusion_method = fusion_method

        # Initialize the image feature extractor and get the feature size
        self.image_feature_extractor, self.feature_size = self.initialize_image_feature_extractor(model_name, image_feature_extractor)

        # Initialize the CSV feature extractor
        # Determine csv_hidden_dim based on the CSV feature extractor
        if hasattr(self.csv_feature_extractor, 'output_dim'):
            self.csv_hidden_dim = self.csv_feature_extractor.output_dim
        else:
            # Default or custom handling if 'output_dim' is not present
            self.csv_hidden_dim = 512  # Example default

        # Define additional layers for fusion based on the fusion method
        if self.fusion_method == 'late':
            self.fusion_layer = nn.Linear(self.feature_size + self.csv_hidden_dim, self.num_classes)
        elif self.fusion_method == 'intermediate':
            self.intermediate_layer = nn.Linear(self.feature_size, 512)
            self.fusion_layer = nn.Linear(512 + self.csv_hidden_dim, self.num_classes)
        elif self.fusion_method == 'early':
            # Calculate the new feature size after early fusion
            early_fusion_size = self.feature_size + self.csv_hidden_dim
            self.early_fusion_layer = nn.Linear(early_fusion_size, self.feature_size)
            self.classifier = nn.Linear(self.feature_size, num_classes)
        else:
            raise ValueError("Unsupported fusion method")

    def initialize_image_feature_extractor(self, model_name, image_feature_extractor):
        """
        Initializes the image feature extractor for fusion, replacing the final classification layer.

        Args:
            model_name (str): Name of the image feature extractor architecture.
            image_feature_extractor (nn.Module): The model to be modified.

        Returns:
            tuple: (modified_image_feature_extractor, feature_size)
        """
        if model_name == 'InceptionV3':
            feature_size = image_feature_extractor.fc.in_features  # Access in_features before replacing
            image_feature_extractor.aux_logits = False  # Disable auxiliary logits
            image_feature_extractor.AuxLogits = None  # Remove auxiliary logits
            image_feature_extractor.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'ResNet152' or model_name == 'AttentionAugmentedResNet18':
            feature_size = image_feature_extractor.fc.in_features  # Access in_features before replacing
            image_feature_extractor.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'VGG19' or model_name == 'AttentionAugmentedVGG19':
            feature_size = image_feature_extractor.classifier[6].in_features  # Access in_features before replacing
            image_feature_extractor.classifier[6] = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'ViT':
            # Generalized approach to identify and replace the classification head
            if hasattr(image_feature_extractor, 'heads'):
                feature_size = image_feature_extractor.heads.head.in_features
                image_feature_extractor.heads.head = nn.Identity()
            elif hasattr(image_feature_extractor, 'classifier'):
                feature_size = image_feature_extractor.classifier.in_features
                image_feature_extractor.classifier = nn.Identity()
            elif hasattr(image_feature_extractor, 'head'):
                feature_size = image_feature_extractor.head.in_features
                image_feature_extractor.head = nn.Identity()
            else:
                # Fallback: Inspect all attributes and find a suitable final layer
                for attr_name in dir(image_feature_extractor):
                    attr = getattr(image_feature_extractor, attr_name)
                    if isinstance(attr, nn.Linear):
                        feature_size = attr.in_features
                        setattr(image_feature_extractor, attr_name, nn.Identity())
                        break
                else:
                    raise ValueError(f"Unsupported ViT model structure for model: {model_name}")
        elif model_name == 'AttentionAugmentedInceptionV3':
            inception_model = image_feature_extractor.inception
            feature_size = inception_model.fc.in_features  # Access in_features before replacing
            inception_model.aux_logits = False  # Disable auxiliary logits
            inception_model.AuxLogits = None  # Remove auxiliary logits
            inception_model.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif model_name == 'AttentionAugmentedVGG19':
            # Access the VGG model within the wrapper
            vgg_model = image_feature_extractor.features
            feature_size = vgg_model[-2].out_channels  # Assuming the penultimate layer is the feature extractor
            # Replace the last layer with an identity function
            vgg_model[-1] = nn.Identity()
        elif model_name == 'AttentionAugmentedResNet18':
            # Access the ResNet model within the wrapper
            resnet_model = image_feature_extractor
            feature_size = resnet_model.fc.in_features
            resnet_model.fc = nn.Identity()
        else:
            raise ValueError(f"Unsupported model type: {model_name}")

        return image_feature_extractor, feature_size

    def forward(self, img, csv):
        # Extract features from the image using the image feature extractor
        img_features = self.image_feature_extractor(img)
        
        # Handle specific cases for certain models
        if self.model_name == 'AttentionAugmentedInceptionV3':
            # Check if img_features is of type InceptionOutputs and extract the tensor
            if isinstance(img_features, tuple):  # Handle Inception model outputs
                img_features = img_features.logits
                img_features = img_features.view(img_features.size(0), -1)
        
        # Extract features from the CSV data
        csv_features = self.csv_feature_extractor(csv)

        # Perform fusion based on the specified method
        if self.fusion_method == 'early':
            # Early fusion: Concatenate features before passing through the base model
            img = img.view(img.size(0), -1)  # Flatten the image tensor
            img_csv_combined = torch.cat((img, csv_features), dim=1)
            img_csv_features = self.early_fusion_layer(img_csv_combined)
            output = self.classifier(img_csv_features)
        elif self.fusion_method == 'late':
            # Late fusion: Combine features after processing through image feature extractor and CSV extractor
            combined_features = torch.cat((img_features, csv_features), dim=1)
            output = self.fusion_layer(combined_features)
        elif self.fusion_method == 'intermediate':
            # Intermediate fusion: Process image features through an intermediate layer before fusion
            img_features = self.intermediate_layer(img_features)
            combined_features = torch.cat((img_features, csv_features), dim=1)
            output = self.fusion_layer(combined_features)
        else:
            raise ValueError("Unsupported fusion method")
        
        return output

# Model Training and Evaluation

In [23]:
# Clear cache function
def clear_cache():
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()
    else:
        torch.cache.empty_cache()

In [24]:
# Function to adjust learning rate
def adjust_learning_rate(optimizer, epoch, learning_rate):
    """Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
    lr = learning_rate * (0.1 ** (epoch // 10))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [25]:
# Training function with mixed precision and gradient accumulation
def train_model(model, criterion, optimizer, train_loader, val_loader, num_classes, csv_input_dim, device, fusion_method='late', num_epochs=40, initial_lr=0.001):
    # optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr, weight_decay=5e-4, momentum=0.9)
    early_stopping_patience = 5
    best_val_loss = float('inf')
    patience_counter = 0

    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for i, data in enumerate(train_loader):
            inputs_img, inputs_csv, labels = data
            inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs_img, inputs_csv)
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in val_loader:
                inputs_img, inputs_csv, labels = data
                inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)
                outputs = model(inputs_img, inputs_csv)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct / total

        print(f'Epoch {epoch + 1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping due to no improvement in validation loss.")
                break

    return model

In [26]:
# Function to create and train the model
def create_and_train_fusion_model(model, train_loader, val_loader, num_classes, csv_input_dim, device, fusion_method='late', num_epochs=40, initial_lr=0.001):
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
    return train_model(model, criterion, optimizer, train_loader, val_loader, num_classes, csv_input_dim, device, fusion_method=fusion_method, num_epochs=num_epochs, initial_lr=initial_lr)

In [27]:
# Function to evaluate the fusion model
def evaluate_fusion_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs_img, inputs_csv, labels = data
            inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)
            outputs = model(inputs_img, inputs_csv)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= len(test_loader)
    test_accuracy = 100 * correct / total
    return test_loss, test_accuracy

### Train the Model

In [28]:
# Count the number of classes
num_classes_inception = len(class_to_idx)
num_classes_others = len(class_to_idx)

In [29]:
# Set the number of features in the CSV data
num_heads = 8

In [30]:
# Define the csv input dimensions
csv_input_dim = csv_features.shape[1]

In [31]:
# Define the csv hidden dimensions
csv_hidden_dim = 256  # Adjust based on your model requirements

In [32]:
# Define the results dictionary
crop_results = {}

In [33]:
# Define sequence length for BERT
seq_len = 128  # Adjust this value according to your dataset and BERT model requirements

In [35]:
# Iterate over all combinations of fusion methods, CNN models, and CSV feature extractors
for fusion_method in ['intermediate', 'late']:
    # Define CSV feature extractor models
    csv_feature_extractors = {
        # 'tabnet': lambda: TabNetCSVFeatureExtractor(csv_input_dim, csv_hidden_dim),
        # 'ft_transformer': lambda: FTTransformerCSVFeatureExtractor(csv_input_dim, csv_hidden_dim),
        'bert': lambda: BERTCSVFeatureExtractor(csv_input_dim, csv_hidden_dim),
        'mlp': lambda: MLP_CSVFeatureExtractor(csv_input_dim, csv_hidden_dim),
    }

    # Define CNN models with pretrained weights
    cnn_feature_extractors = {
        'InceptionV3': models.inception_v3(pretrained=True).to(device),
        'ResNet152': models.resnet152(pretrained=True).to(device),
        'VGG19': models.vgg19(pretrained=True).to(device),
        'ViT': ViT(
            image_size=224,
            patch_size=16,
            num_classes=num_classes_others,
            dim=1024,
            depth=6,
            heads=16,
            mlp_dim=2048,
            dropout=0.1,
            emb_dropout=0.1
        ).to(device),
        "AttentionAugmentedInceptionV3": attention_augmented_inceptionv3(attention=True).to(device),
        # 'AttentionAugmentedVGG19': attention_augmented_vgg('VGG19', num_classes=num_classes_others).to(device),
        # "AttentionAugmentedResNet18": attention_augmented_resnet18(num_classes=num_classes_others, attention=[False, True, True, True], num_heads=8).to(device),
    }

    # Disable auxiliary logits for InceptionV3
    if 'InceptionV3' in cnn_feature_extractors:
        cnn_feature_extractors['InceptionV3'].aux_logits = False

    print(f'-------------------------- Fusion Method: {fusion_method} --------------------------')
    for model_name, base_model in cnn_feature_extractors.items():
        base_model.to(device)  # Ensure the CNN model is on the correct device

        # Loop over each CSV feature extractor model type
        # for csv_model_type, csv_feature_extractor_factory in csv_feature_extractors.items():
        
        print(f'---------------- CNN Base Model: {model_name} ----------------')
        for csv_model_type in ['simple','deep','conv']:
            print(f'Training {model_name} with {fusion_method} fusion and CSV extractor {csv_model_type}')

            # Initialize the CSV feature extractor using the factory function
            # csv_feature_extractor = csv_feature_extractor_factory()

            # Create the FusionModel
            fusion_model = FusionModel(
                model_name=model_name,
                base_model=base_model,
                csv_input_dim=csv_input_dim,
                csv_hidden_dim=csv_hidden_dim,
                # image_feature_extractor=base_model,
                # csv_feature_extractor=csv_feature_extractor,
                num_classes=num_classes_others,
                fusion_method=fusion_method,
                csv_model_type=csv_model_type, 
                seq_len=seq_len
            )

            # Train the FusionModel
            model = create_and_train_fusion_model(
                fusion_model,
                train_loader_others,
                val_loader_others,
                num_classes_others,
                csv_input_dim,
                device,
                fusion_method,
                initial_lr=0.001
            )

            # Evaluate the trained FusionModel
            test_loss, test_accuracy = evaluate_fusion_model(model, test_loader_others, nn.CrossEntropyLoss(), device)

            # Store results for this model configuration
            crop_results[f"{model_name}_{fusion_method}_{csv_model_type}"] = {
                'model': model,
                'model_name': model_name,
                'fusion_method': fusion_method,
                'csv_model_type': csv_model_type,
                'test_loss': test_loss,
                'test_accuracy': test_accuracy
            }
            print(f'{model_name} with {fusion_method} fusion and CSV extractor {csv_model_type} Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

            # Clean up: delete the model to free up memory (optional)
            del model
            clear_cache()
        print('\n')
    print('----------------------------------------------------------------------------------------------------------------------')
    print('\n')

-------------------------- Fusion Method: intermediate --------------------------




---------------- CNN Base Model: InceptionV3 ----------------
Training InceptionV3 with intermediate fusion and CSV extractor simple


# Display Results

In [None]:
results_base_dir = "/Users/izzymohamed/Desktop/Vision For Social Good/Project/Vision-For-Social-Good/RESULTS/Multimodal"
results_folder = os.path.join(results_base_dir, 'T1')
os.makedirs(results_folder, exist_ok=True)

In [None]:
# Function to save figures
def save_figure(fig, filename):
    fig.savefig(os.path.join(results_folder, filename))
    plt.close(fig)

### Accuracy Comparision

In [None]:
for model_name, model_info in results.items():
    print(fusion_method)

In [None]:
# Plot comparison of accuracy for each model for each crop
def plot_accuracy_comparison(results):
    accuracies = [result['test_accuracy'] for result in results.values()]
    model_names = list(results.keys())

    fig = plt.figure(figsize=(20, 10))
    plt.bar(model_names, accuracies)
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Model')
    plt.show()
    save_figure(fig, 'all_fusion_accuracy_comparison.png')

    if fusion_method == 'late':
        fig = plt.figure(figsize=(20, 10))
        plt.bar(model_names, accuracies)
        plt.ylabel('Accuracy (%)')
        plt.xlabel('Model')
        plt.show()
        save_figure(fig, 'late_accuracy_comparison.png')
    elif fusion_method == 'intermediate':
        fig = plt.figure(figsize=(20, 10))
        plt.bar(model_names, accuracies)
        plt.ylabel('Accuracy (%)')
        plt.xlabel('Model')
        plt.show()
        save_figure(fig, 'intermediate_accuracy_comparison.png')


In [None]:
# Plot comparison of accuracy for each model for each crop
plot_accuracy_comparison(crop_results)

### Metrics Table

In [None]:
# Function to display F1, precision, and recall of all models as a table
def display_model_metrics_table(results, test_loader_inception, test_loader_others):
        
    metrics_data = []
    
    for model_name, model_info in results.items():

        if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3']:  # Adjust model names as needed
            test_loader = test_loader_inception
        else:
            test_loader = test_loader_others

        model = model_info['model']
        device = next(model.parameters()).device  # Get the device of the model
        model.eval()  # Set the model to evaluation mode

        all_labels = []
        all_predicted = []

        for images, csv_features, labels in test_loader:
            images, csv_features, labels = images.to(device), csv_features.to(device), labels.to(device)

            with torch.no_grad():
                outputs = model(images, csv_features)
                _, predicted = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_predicted.extend(predicted.cpu().numpy())

        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predicted, average='macro')
        
        metrics_data.append({
            'Model': model_name,
            'Precision': precision,
            'Recall': recall,
            'F1-score': f1
        })

    metrics_df = pd.DataFrame(metrics_data)
    display(metrics_df)  # Display the DataFrame in Jupyter Notebook
    metrics_df.to_csv(os.path.join(results_folder, 'model_metrics.csv'), index=False)

In [None]:
# Display the table of metrics for all models
display_model_metrics_table(crop_results, test_loader_inception, test_loader_others)

### Classification Results

In [None]:
# Display some correctly and incorrectly classified images
def display_classification_results(model, test_loader, num_images=5):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode
    class_labels = list(test_loader.dataset.class_to_idx.keys())
    
    images, labels = next(iter(test_loader))
    images, labels = images[:num_images].to(device), labels[:num_images]  # Move tensors to the model's device
    
    with torch.no_grad():
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
    
    fig, axes = plt.subplots(1, num_images, figsize=(20, 8))
    # fig.suptitle(f'{model_name} - Classification Results', fontsize=28)
    
    for i in range(num_images):
        ax = axes[i]
        img = images[i].cpu().numpy().transpose((1, 2, 0))  # Move tensor back to CPU for visualization
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(f'True: {class_labels[labels[i]]}\n Pred: {class_labels[predicted[i].cpu()]}')  # Access CPU tensor for labels
        ax.axis('off')

    plt.show()
    save_figure(fig, f'{model_name}_classification_results.png')

In [None]:
# Display results for each crop
for model_name in results.keys():
    if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3']:  # Adjust model names as needed
        test_loader = test_loader_inception
    else:
        test_loader = test_loader_others
    
    print(f'Displaying results for {model_name}')
    display_classification_results(results[model_name]['model'], test_loader)

### Classification Report

In [None]:
# Function to display the classification report of a given model
def display_classification_report(model, test_loader, model_name):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode

    all_labels = []
    all_predicted = []

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_predicted.extend(predicted.cpu().numpy())

    report = classification_report(all_labels, all_predicted, target_names=list(test_loader.dataset.class_to_idx.keys()))
    
    print(report)
    
    report_filename = os.path.join(results_folder, f'{model_name}_classification_report.txt')
    with open(report_filename, 'w') as f:
        f.write(report)
        

In [None]:
# Display results for each crop
for model_name in crop_results.keys():
    if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3']:  # Adjust model names as needed
        test_loader = test_loader_inception
    else:
        test_loader = test_loader_others
        
    print(f'Displaying classification report for {model_name}')
    display_classification_report([model_name]['model'], test_loader, model_name)

### Confusion Metrics

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def plot_confusion_matrix(labels, pred_labels, classes, model_name):
    fig = plt.figure(figsize=(50, 50))
    # fig.suptitle(f'{model_name} - Confusion Matrix\n', fontsize=28, y=0.83)
    ax = fig.add_subplot(1, 1, 1)
    cm = confusion_matrix(labels, pred_labels)
    cm_display = ConfusionMatrixDisplay(cm, display_labels=classes)
    cm_display.plot(values_format='d', cmap='Blues', ax=ax)
    fig.delaxes(fig.axes[1])  # Delete colorbar
    plt.xticks(rotation=90)
    plt.xlabel('Predicted Label', fontsize=50)
    plt.ylabel('True Label', fontsize=50)

    plt.show()
    save_figure(fig, f'{model_name}_confusion_matrix.png')

In [None]:
# Function to extract all labels and predictions
def get_all_labels_and_preds(model, test_loader):
    all_labels = []
    all_preds = []
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    return all_labels, all_preds

In [None]:
# Generate and plot confusion matrices
def generate_confusion_matrices(results, test_loader_inception, test_loader_others):

    classes = list(test_loader.dataset.class_to_idx.keys())
    for model_name, model_info in results.items():
        if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3']:  # Adjust model names as needed
            test_loader = test_loader_inception
        else:
            test_loader = test_loader_others

        model = model_info['model']
        labels, pred_labels = get_all_labels_and_preds(model, test_loader)
        plot_confusion_matrix(labels, pred_labels, classes, model_name)

In [None]:
generate_confusion_matrices(crop_results, test_loader_inception, test_loader_others)

### Incorrect Predictions

In [None]:
# Function to normalize images
def normalize_image(image):
    image = image - image.min()
    image = image / image.max()
    return image

In [None]:
# Function to plot the most incorrect predictions
def plot_most_incorrect(incorrect, classes, n_images, model_name, normalize=True):
    rows = int(np.ceil(np.sqrt(n_images)))
    cols = int(np.ceil(n_images / rows))

    fig = plt.figure(figsize=(25, 20))
    # fig.suptitle(f'{model_name} - Most Incorrect\n', fontsize=28)

    for i in range(rows * cols):
        if i >= len(incorrect):
            break
        ax = fig.add_subplot(rows, cols, i + 1)
        image, true_label, probs = incorrect[i]
        image = image.permute(1, 2, 0)
        true_prob = probs[true_label]
        incorrect_prob, incorrect_label = torch.max(probs, dim=0)
        true_class = classes[true_label]
        incorrect_class = classes[incorrect_label]

        if normalize:
            image = normalize_image(image)

        ax.imshow(image.cpu().numpy())
        ax.set_title(f'true label:\n{true_class} ({true_prob:.3f})\n'
                     f'pred label:\n{incorrect_class} ({incorrect_prob:.3f})', fontsize=10)
        ax.axis('off')

    plt.tight_layout()
    fig.subplots_adjust(hspace=0.7)
    
    plt.show()
    save_figure(fig, f'{model_name}_most_incorrect.png')


In [None]:
def get_all_details(model, test_loader):
    all_labels = []
    all_preds = []
    all_probs = []
    all_images = []
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            probs = F.softmax(outputs, dim=1)

            all_images.extend(images.cpu())
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu())

    return all_images, all_labels, all_preds, all_probs


In [None]:
# Define the number of images to display
N_IMAGES = 36

In [None]:
# Use this function to get the details
def plot_most_incorrect_predictions(results, test_loader_inception, test_loader_others, n_images=36):
    classes = list(test_loader.dataset.class_to_idx.keys())
    for model_name, model_info in results.items():
        if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3']:  # Adjust model names as needed
            test_loader = test_loader_inception
        else:
            test_loader = test_loader_others

        model = model_info['model']
        images, labels, pred_labels, probs = get_all_details(model, test_loader)
        corrects = torch.eq(torch.tensor(labels), torch.tensor(pred_labels))
        incorrect_examples = []

        for image, label, prob, correct in zip(images, labels, probs, corrects):
            if not correct:
                incorrect_examples.append((image, label, prob))

    incorrect_examples.sort(key=lambda x: torch.max(x[2], dim=0)[0], reverse=True)
    plot_most_incorrect(incorrect_examples[:n_images], classes, n_images, model_name)

In [None]:
plot_most_incorrect_predictions(crop_results, test_loader_inception, test_loader_others, N_IMAGES)

### Representations and Dimensionality Reduction

In [None]:
from sklearn import decomposition, manifold

def get_representations(model, iterator):
    model.eval()
    outputs = []
    labels = []

    with torch.no_grad():
        for x, y in iterator:
            x = x.to(device)
            y_pred = model(x)
            outputs.append(y_pred.cpu())
            labels.append(y)

    outputs = torch.cat(outputs, dim=0)
    labels = torch.cat(labels, dim=0)
    return outputs, labels

In [None]:
def get_pca(data, n_components=2):
    pca = decomposition.PCA(n_components=n_components)
    pca_data = pca.fit_transform(data)
    return pca_data

In [None]:
def plot_representations(data, labels, classes, n_images=None):
    if n_images is not None:
        data = data[:n_images]
        labels = labels[:n_images]

    fig = plt.figure(figsize=(15, 15))
    # fig.suptitle(f'{model_name} - PCA', fontsize=28, y=0.95)
    ax = fig.add_subplot(111)
    scatter = ax.scatter(data[:, 0], data[:, 1], c=labels, cmap='hsv')
    plt.show()
    save_figure(fig, f'{model_name}_pca.png')

In [None]:
outputs, labels = get_representations(model, train_loader)
for model_name in crop_results.keys():
    output_pca_data = get_pca(outputs)
    plot_representations(output_pca_data, labels, classes)  # Adjusted to pass only three arguments

In [None]:
def get_tsne(data, n_components=2, n_images=None):
    if n_images is not None:
        data = data[:n_images]
    tsne = manifold.TSNE(n_components=n_components, random_state=0)
    tsne_data = tsne.fit_transform(data)
    return tsne_data

In [None]:
for model_name in crop_results.keys():
    output_tsne_data = get_tsne(outputs)
    plot_representations(output_tsne_data, labels, classes)

### Filter Visualization

In [None]:
# Function to plot filtered images
def plot_filtered_images(images, filters, model_name, n_filters=None, normalize=True):
    images = torch.cat([i.unsqueeze(0) for i in images], dim=0).cpu()
    filters = filters.cpu()

    if n_filters is not None:
        filters = filters[:n_filters]

    n_images = images.shape[0]
    n_filters = filters.shape[0]

    filtered_images = F.conv2d(images, filters)

    fig = plt.figure(figsize=(30, 30))
    # fig.suptitle(f'{model_name} - Filtered Images', fontsize=28, y=0.8)

    for i in range(n_images):
        image = images[i]
        if normalize:
            image = normalize_image(image)
        ax = fig.add_subplot(n_images, n_filters + 1, i + 1 + (i * n_filters))
        ax.imshow(image.permute(1, 2, 0).numpy())
        ax.set_title('Original')
        ax.axis('off')

        for j in range(n_filters):
            image = filtered_images[i][j]
            if normalize:
                image = normalize_image(image)
            ax = fig.add_subplot(n_images, n_filters + 1, i + 1 + (i * n_filters) + j + 1)
            ax.imshow(image.numpy(), cmap='bone')
            ax.set_title(f'Filter {j + 1}')
            ax.axis('off')

    fig.subplots_adjust(hspace=-0.7)
    plt.show()
    save_figure(fig, f'{model_name}_filtered_images.png')

In [None]:
N_FILTERS = 7

In [None]:
# Example usage within the existing loop
conv_models = ['ResNet152', 'VGG19', 'InceptionV3', 'AttentionAugmentedInceptionV3']  # Add models expected to have conv layers
for model_name, model_info in crop_results.items():
    model = model_info['model']
    if model_name in conv_models:
        if hasattr(model, 'conv1'):
            filters = model.conv1.weight.data
        elif hasattr(model, 'features') and hasattr(model.features, '0'):
            filters = model.features[0].weight.data
        else:
            print(f"Model {model_name} structure is not recognized for convolutional layers.")
            filters = None
    else:
        filters = None  # No convolutional filters in models like ViT

    if filters is not None:
        images = [image for image, label in [train_dataset_others[i] for i in range(N_IMAGES)]]
        plot_filtered_images(images, filters, model_name, n_filters=N_FILTERS)

### Filter Plotting

In [None]:
def plot_filters(filters, normalize=True):
    filters = filters.cpu()
    n_filters = filters.shape[0]
    rows = int(np.sqrt(n_filters))
    cols = int(np.sqrt(n_filters))

    fig = plt.figure(figsize=(30, 15))
    # fig.suptitle(f'{model_name} - Filters', fontsize=28, y=0.95)

    for i in range(rows * cols):
        image = filters[i]
        if normalize:
            image = normalize_image(image)
        ax = fig.add_subplot(rows, cols, i + 1)
        ax.imshow(image.permute(1, 2, 0))
        ax.axis('off')

    fig.subplots_adjust(wspace=-0.9)
    plt.show()
    save_figure(fig, f'{model_name}_filters.png')

In [None]:
# Example usage within the existing loop
conv_models = ['ResNet152', 'VGG19', 'InceptionV3', 'AttentionAugmentedInceptionV3']  # Add models expected to have conv layers
for model_name, model_info in crop_results.items():
    model = model_info['model']
    if model_name in conv_models:
        if hasattr(model, 'conv1'):
            filters = model.conv1.weight.data
        elif hasattr(model, 'features') and hasattr(model.features, '0'):
            filters = model.features[0].weight.data
        else:
            print(f"Model {model_name} structure is not recognized for convolutional layers.")
            filters = None
    else:
        filters = None  # No convolutional filters in models like ViT

    if filters is not None:
        images = [image for image, label in [train_dataset_others[i] for i in range(N_IMAGES)]]
        plot_filters(filters)

### Generate all results

In [None]:
# Function to generate all results
def generate_all_results(results, test_loader):
    plot_accuracy_comparison(results)
    display_model_metrics_table(results, test_loader)
    generate_confusion_matrices(results, test_loader)
    plot_most_incorrect_predictions(results, test_loader, n_images=36)

    for model_name, model_info in results.items():
        model = model_info['model']
        display_classification_results(model, test_loader, model_name)
        display_classification_report(model, test_loader, model_name)