# Imports

In [42]:
# Imports
import json
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, classification_report

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.datasets.folder import is_image_file

from vit_pytorch import ViT

import mlflow
import mlflow.pytorch

from AACN_Model import attention_augmented_resnet18, attention_augmented_inceptionv3, attention_augmented_vgg

In [43]:
# Assuming `crops` and directories (`train_dir`, `val_dir`, `test_dir`) are defined
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Data Preprocessing

In [44]:
date_list = ['04_11_21',
'14_09_21',
'14_09_22',
'15_07_22',
'25_05_22',
'27_07_21']

In [45]:
# Define main directories
base_dir = '/Users/izzymohamed/Desktop/Vision For Social Good/Project/Vision-For-Social-Good/DATA/Peach/'
date = date_list[0]
date_dir = os.path.join(base_dir, date)

In [46]:
# Define directories for the images
uav_dir = os.path.join(date_dir, "Aerial_UAV_Photos")
rgb_dir = os.path.join(date_dir, "Ground_RGB_Photos")
multispectral_dir = os.path.join(date_dir, "Ground_Multispectral_Photos")

In [47]:
# Load CSV data
multimodal_data_path = os.path.join(base_dir, "combined_multimodal_data.csv")
multimodal_df = pd.read_csv(multimodal_data_path)

In [48]:
multimodal_df

Unnamed: 0,Date,Tree_ID,Orchard_Mapping_Image,Aerial_UAV_Image,Ground_RGB_Image,Ground_RGB_Image_with_Bounding_Boxes,Ground_RGB_Image_Annotations,Multispectral_RGB_Image,Multispectral_REG_Image,Multispectral_RED_Image,...,Multispectral_NDRE,Multispectral_SAVI,Multispectral_GNDVI,Multispectral_RVI,Multispectral_TVI,Multispectral_NDVI_Image,Multispectral_GNDVI_Image,Multispectral_NDRE_Image,Multispectral_SAVI_Image,date
0,04_11_21,29-1,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,...,0.054946,0.578614,0.308877,798.861030,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,04_11_21
1,04_11_21,29-2,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,...,0.098110,0.661596,0.420410,815.203405,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,04_11_21
2,04_11_21,29-4,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,...,0.085527,0.585869,0.422484,361.728795,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,04_11_21
3,04_11_21,29-3,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,...,0.059559,0.602647,0.414713,442.763371,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,04_11_21
4,04_11_21,28-10,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,...,0.075293,0.628999,0.412634,789.667606,,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,04_11_21
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5323,27_07_21,30-17,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,,,,,...,,,,,,,,,,27_07_21
5324,27_07_21,30-18,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,,,,,...,,,,,,,,,,27_07_21
5325,27_07_21,30-19,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,,,,,...,,,,,,,,,,27_07_21
5326,27_07_21,30-20,/Users/izzymohamed/Desktop/Vision For Social G...,/Users/izzymohamed/Desktop/Vision For Social G...,,,,,,,...,,,,,,,,,,27_07_21


In [49]:
multimodal_df.columns

Index(['Date', 'Tree_ID', 'Orchard_Mapping_Image', 'Aerial_UAV_Image',
       'Ground_RGB_Image', 'Ground_RGB_Image_with_Bounding_Boxes',
       'Ground_RGB_Image_Annotations', 'Multispectral_RGB_Image',
       'Multispectral_REG_Image', 'Multispectral_RED_Image',
       'Multispectral_NIR_Image', 'Multispectral_GRE_Image',
       'Multispectral_RGB_Bounding_Box_Image',
       'Multispectral_RGB_Bounding_Box_Annotation', 'Label', 'UAV_NDVI',
       'UAV_EVI', 'UAV_NDRE', 'UAV_SAVI', 'UAV_GNDVI', 'UAV_RVI', 'UAV_TVI',
       'x1', 'x2', 'y1', 'y2', 'Multispectral_NDVI', 'Multispectral_EVI',
       'Multispectral_NDRE', 'Multispectral_SAVI', 'Multispectral_GNDVI',
       'Multispectral_RVI', 'Multispectral_TVI', 'Multispectral_NDVI_Image',
       'Multispectral_GNDVI_Image', 'Multispectral_NDRE_Image',
       'Multispectral_SAVI_Image', 'date'],
      dtype='object')

In [50]:
# Image columns
image_columns = [
    'Ground_RGB_Image_with_Bounding_Boxes','Ground_RGB_Image_Annotations', 'Multispectral_NDVI_Image','Multispectral_RGB_Bounding_Box_Image', 'Multispectral_RGB_Bounding_Box_Annotation',
    # 'Multispectral_RGB_Image', 'Multispectral_REG_Image', 'Multispectral_RED_Image', 'Multispectral_NIR_Image', 'Multispectral_GRE_Image',
    # 'Multispectral_NDVI_Image', 'Multispectral_GNDVI_Image', 'Multispectral_NDRE_Image', 'Multispectral_SAVI_Image'
]

In [51]:
# Feature columns
feature_columns = [
    # 'UAV_NDVI', 'UAV_EVI', 'UAV_NDRE', 'UAV_SAVI', 'UAV_GNDVI', 'UAV_RVI', 'UAV_TVI', 'Multispectral_SAVI',
    'Multispectral_NDVI', 'Multispectral_GNDVI', 'Multispectral_SAVI',
    # 'Multispectral_RVI' 'Multispectral_EVI', 'Multispectral_NDRE',
]
csv_features = multimodal_df[feature_columns].values.astype(np.float32)

In [52]:
# Remove rows with label = 0
multimodal_df = multimodal_df[multimodal_df['Label'] != 0]

In [53]:
# Clean the dataset to ensure all image paths are strings
def is_valid_image_path(path):
    return isinstance(path, str) and path.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'))

In [54]:
# Function to check if all image columns contain valid strings
def valid_image_paths(row):
    return all(isinstance(row[col], str) for col in image_columns)

In [55]:
# Apply the filter
valid_rows = multimodal_df.apply(valid_image_paths, axis=1)
multimodal_df = multimodal_df[valid_rows]
multimodal_df.reset_index(drop=True, inplace=True)

In [56]:
# Identify and exclude non-numeric columns from the features
numeric_columns = multimodal_df.select_dtypes(include=[np.number]).columns
csv_image_paths = multimodal_df[image_columns].values
csv_labels = multimodal_df['Label'].values

In [57]:
# Remove .DS_Store files if necessary (if working in a directory with such files)
def remove_ds_store(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file == '.DS_Store' or '.DS_Store' in file:
                file_path = os.path.join(root, file)
                print(f"Removing {file_path}")
                os.remove(file_path)

In [58]:
# Remove .DS_Store files from base directory
remove_ds_store(base_dir)

In [59]:
# Split data into train, validation, and test sets directly from the CSV
train_df, test_df = train_test_split(multimodal_df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.25, random_state=42)  # 0.25 x 0.8 = 0.2

In [60]:
# Use the lists of file paths for your dataset loading and transformations
print(f"Train files: {len(train_df)}")
print(f"Validation files: {len(val_df)}")
print(f"Test files: {len(test_df)}")

Train files: 96
Validation files: 32
Test files: 32


In [61]:
# Define the standard image sizes
inception_size = 299
other_size = 224

In [62]:
# Update the data transformations
data_transforms = {
    'InceptionV3': {
        'train': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    },
    'Others': {
        'train': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
}

In [63]:
class CustomMultimodalDataset(Dataset):
    def __init__(self, df, image_columns, feature_columns, class_to_idx, transform=None):
        self.df = df
        self.image_columns = image_columns
        self.feature_columns = feature_columns
        self.class_to_idx = class_to_idx
        self.transform = transform
        
        # Define the mapping from integers to class names
        self.label_to_class_name = {
            0: 'Healthy',
            1: 'Grapholita molesta',
            2: 'Anarsia lineatella',
            3: 'Dead Tree'
        }

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        images = []
        for img_col in self.image_columns:
            img_path = row[img_col]
            
            # Ensure the path is a string and points to a valid image file
            if isinstance(img_path, str) and img_path.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
                try:
                    image = Image.open(img_path).convert('RGB')
                    
                    # Handle bounding boxes for specific image columns
                    if img_col == 'Ground_RGB_Image_with_Bounding_Boxes':
                        annotation_path = row['Ground_RGB_Image_Annotations']
                        if os.path.exists(annotation_path):
                            try:
                                with open(annotation_path, 'r') as f:
                                    bbox_data = json.load(f)
                                
                                # Apply each bounding box found in the annotation
                                for region in bbox_data.get('regions', []):
                                    x1, y1 = min(region['points']['x']), min(region['points']['y'])
                                    x2, y2 = max(region['points']['x']), max(region['points']['y'])
                                    cropped_image = image.crop((x1, y1, x2, y2))
                                    
                                    if self.transform:
                                        cropped_image = self.transform(cropped_image)
                                    images.append(cropped_image)
                                    
                            except json.JSONDecodeError:
                                print(f"Skipping invalid JSON file for {img_col} at index {idx}: {annotation_path}")
                        else:
                            print(f"Annotation file not found for {img_col} at index {idx}: {annotation_path}")
                    
                    elif img_col == 'Multispectral_RGB_Bounding_Box_Image':
                        annotation_path = row['Multispectral_RGB_Bounding_Box_Annotation']
                        if os.path.exists(annotation_path):
                            try:
                                with open(annotation_path, 'r') as f:
                                    bbox_data = json.load(f)
                                
                                # Apply each bounding box found in the annotation
                                for region in bbox_data.get('regions', []):
                                    x1, y1 = min(region['points']['x']), min(region['points']['y'])
                                    x2, y2 = max(region['points']['x']), max(region['points']['y'])
                                    cropped_image = image.crop((x1, y1, x2, y2))
                                    
                                    if self.transform:
                                        cropped_image = self.transform(cropped_image)
                                    images.append(cropped_image)
                                    
                            except json.JSONDecodeError:
                                print(f"Skipping invalid JSON file for {img_col} at index {idx}: {annotation_path}")
                        else:
                            print(f"Annotation file not found for {img_col} at index {idx}: {annotation_path}")

                    # Apply transformation to the image if no bounding boxes are required
                    elif self.transform:
                        image = self.transform(image)
                        images.append(image)

                except Exception as e:
                    raise ValueError(f"Could not open image at path {img_path} at index {idx}: {e}")
            else:
                continue  # Skip non-image files

        if not images:
            raise ValueError(f"No valid images found for index {idx}")

        # Stack images into a tensor and average along the channel dimension
        images = torch.stack(images).mean(dim=0)  # Average the images along the channel dimension to maintain 3 channels

        # Extract features from the DataFrame
        csv_row = row[self.feature_columns].values.astype(np.float32)
        
        # Convert the integer label to its corresponding class name
        label_int = row['Label']
        label_name = self.label_to_class_name.get(label_int, None)
        if label_name is None or label_name not in self.class_to_idx:
            raise KeyError(f"Label '{label_int}' not found in label_to_class_name mapping or class_to_idx dictionary.")
        label = self.class_to_idx[label_name]

        return images, torch.tensor(csv_row, dtype=torch.float32), label

In [64]:
# Create a mapping from class names to indices
classes = ['Healthy', 'Grapholita molesta', 'Anarsia lineatella', 'Dead Tree']
class_to_idx = {
    'Healthy': 0,
    'Grapholita molesta': 1,
    'Anarsia lineatella': 2,
    'Dead Tree': 3
}

# Assuming the Label column contains integers (0, 1, 2, 3)
label_to_class = {0: 'Healthy', 1: 'Grapholita molesta', 2: 'Anarsia lineatella', 3: 'Dead Tree'}

In [65]:

# Initialize the datasets
train_dataset_inception = CustomMultimodalDataset(train_df, image_columns, feature_columns, class_to_idx, transform=data_transforms['InceptionV3']['train'])
val_dataset_inception = CustomMultimodalDataset(val_df, image_columns, feature_columns, class_to_idx, transform=data_transforms['InceptionV3']['val'])
test_dataset_inception = CustomMultimodalDataset(test_df, image_columns, feature_columns, class_to_idx, transform=data_transforms['InceptionV3']['test'])

train_loader_inception = DataLoader(train_dataset_inception, batch_size=16, shuffle=True)
val_loader_inception = DataLoader(val_dataset_inception, batch_size=16, shuffle=True)
test_loader_inception = DataLoader(test_dataset_inception, batch_size=16, shuffle=False)

In [66]:
# Similarly for other models
train_dataset_others = CustomMultimodalDataset(train_df, image_columns, feature_columns, class_to_idx, transform=data_transforms['Others']['train'])
val_dataset_others = CustomMultimodalDataset(val_df, image_columns, feature_columns, class_to_idx, transform=data_transforms['Others']['val'])
test_dataset_others = CustomMultimodalDataset(test_df, image_columns, feature_columns, class_to_idx, transform=data_transforms['Others']['test'])

train_loader_others = DataLoader(train_dataset_others, batch_size=16, shuffle=True)
val_loader_others = DataLoader(val_dataset_others, batch_size=16, shuffle=True)
test_loader_others = DataLoader(test_dataset_others, batch_size=16, shuffle=False)

# Modalities Fusion

### Tabular Training Models

In [67]:
# Define the Deep MLP CSV feature extractor
class MLPCSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=10):
        super(MLPCSVFeatureExtractor, self).__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        self.extractor = nn.Sequential(*layers)

    def forward(self, x):
        if x.dim() == 1:  # Ensure x is at least 2D
            x = x.unsqueeze(0)
        if x.size(1) != self.extractor[0].in_features:
            raise ValueError(f"Expected input with {self.extractor[0].in_features} features, but got {x.size(1)}")
        return self.extractor(x)

In [68]:
# Define the Convolutional CSV feature extractor
class ConvCSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim, seq_len, num_layers=2):
        super(ConvCSVFeatureExtractor, self).__init__()
        self.num_layers = num_layers

        # Define the convolutional layers
        layers = [nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1), nn.ReLU()]
        for _ in range(num_layers - 1):
            layers.append(nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1))
            layers.append(nn.ReLU())
        self.conv_layers = nn.Sequential(*layers)

        # Calculate the size of the output after the convolutions
        with torch.no_grad():
            dummy_input = torch.zeros(1, seq_len).unsqueeze(1)  # (batch_size=1, channels=1, seq_len)
            dummy_output = self.conv_layers(dummy_input)
            conv_output_size = dummy_output.view(1, -1).size(1)  # Flatten the output

        # Initialize the fully connected layer with the correct input size
        self.fc = nn.Linear(conv_output_size, hidden_dim)

    def forward(self, x):
        x = x.unsqueeze(1)  # Adding channel dimension for Conv1D
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        # print(f'Shape before FC layer: {x.shape}')  # Print the shape for debugging
        x = self.fc(x)
        return x

### Fusion Model

In [69]:
# Define the fusion methods for combining image and CSV features
class FusionModel(nn.Module):
    def __init__(self, model_name, image_feature_extractor, csv_input_dim, csv_hidden_dim, num_classes, fusion_method, csv_model_type='medium', seq_len=None):
        super(FusionModel, self).__init__()
        self.model_name = model_name
        self.image_feature_extractor = image_feature_extractor
        self.csv_input_dim = csv_input_dim
        self.csv_hidden_dim = csv_hidden_dim
        self.num_classes = num_classes
        self.fusion_method = fusion_method

        self.image_feature_extractor, self.feature_size = self.initialize_image_feature_extractor(model_name, image_feature_extractor)

        if csv_model_type == 'small':
            self.csv_feature_extractor = MLPCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim, num_layers=2)
        elif csv_model_type == 'medium':
            self.csv_feature_extractor = MLPCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim, num_layers=50)
        elif csv_model_type == 'deep':
            self.csv_feature_extractor = MLPCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim, num_layers=100)
        elif csv_model_type == 'conv_small':
            if seq_len is None:
                raise ValueError("seq_len must be provided for the convolutional model")
            self.csv_feature_extractor = ConvCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim, seq_len, num_layers=2)
        elif csv_model_type == 'conv_medium':
            if seq_len is None:
                raise ValueError("seq_len must be provided for the convolutional model")
            self.csv_feature_extractor = ConvCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim, seq_len, num_layers=50)
        elif csv_model_type == 'conv_deep':
            if seq_len is None:
                raise ValueError("seq_len must be provided for the convolutional model")
            self.csv_feature_extractor = ConvCSVFeatureExtractor(self.csv_input_dim, self.csv_hidden_dim, seq_len, num_layers=100)
        else:
            raise ValueError("Unsupported csv_model_type")

        if self.fusion_method == 'late':
            self.fusion_layer = nn.Linear(self.feature_size + self.csv_hidden_dim, self.num_classes)
        elif self.fusion_method == 'intermediate':
            self.intermediate_layer = nn.Linear(self.feature_size, 512)
            self.fusion_layer = nn.Linear(512 + self.csv_hidden_dim, self.num_classes)
        elif self.fusion_method == 'early':
            early_fusion_size = self.feature_size + self.csv_hidden_dim
            self.early_fusion_layer = nn.Linear(early_fusion_size, self.feature_size)
            self.classifier = nn.Linear(self.feature_size, num_classes)
        else:
            raise ValueError("Unsupported fusion method")

    def initialize_image_feature_extractor(self, model_name, image_feature_extractor):
        if model_name == 'InceptionV3':
            feature_size = image_feature_extractor.fc.in_features
            image_feature_extractor.aux_logits = False
            image_feature_extractor.AuxLogits = None
            image_feature_extractor.fc = nn.Identity()
        elif model_name == 'ResNet152' or model_name == 'AttentionAugmentedResNet18':
            feature_size = image_feature_extractor.fc.in_features
            image_feature_extractor.fc = nn.Identity()
        elif model_name == 'VGG19' or model_name == 'AttentionAugmentedVGG19':
            feature_size = image_feature_extractor.classifier[6].in_features
            image_feature_extractor.classifier[6] = nn.Identity()
        elif model_name == 'ViT':
            if hasattr(image_feature_extractor, 'heads'):
                feature_size = image_feature_extractor.heads.head.in_features
                image_feature_extractor.heads.head = nn.Identity()
            elif hasattr(image_feature_extractor, 'classifier'):
                feature_size = image_feature_extractor.classifier.in_features
                image_feature_extractor.classifier = nn.Identity()
            elif hasattr(image_feature_extractor, 'head'):
                feature_size = image_feature_extractor.head.in_features
                image_feature_extractor.head = nn.Identity()
            else:
                for attr_name in dir(image_feature_extractor):
                    attr = getattr(image_feature_extractor, attr_name)
                    if isinstance(attr, nn.Linear):
                        feature_size = attr.in_features
                        setattr(image_feature_extractor, attr_name, nn.Identity())
                        break
                else:
                    raise ValueError(f"Unsupported ViT model structure for model: {model_name}")
        elif model_name == 'AttentionAugmentedInceptionV3':
            inception_model = image_feature_extractor.inception
            feature_size = inception_model.fc.in_features
            inception_model.aux_logits = False
            inception_model.AuxLogits = None
            inception_model.fc = nn.Identity()
        elif model_name == 'AttentionAugmentedVGG19':
            vgg_model = image_feature_extractor.features
            feature_size = vgg_model[-2].out_channels
            vgg_model[-1] = nn.Identity()
        elif model_name == 'AttentionAugmentedResNet18':
            resnet_model = image_feature_extractor
            feature_size = resnet_model.fc.in_features
            resnet_model.fc = nn.Identity()
        else:
            raise ValueError(f"Unsupported model type: {model_name}")

        return image_feature_extractor, feature_size

    def forward(self, img, csv):
        csv_features = self.csv_feature_extractor(csv)

        # Ensure csv_features has the correct shape
        if len(csv_features.shape) == 1:
            csv_features = csv_features.unsqueeze(0)

        if self.fusion_method == 'early':
            img = img.view(img.size(0), -1)
            img_csv_combined = torch.cat((img, csv_features), dim=1)
            img_csv_features = self.early_fusion_layer(img_csv_combined)
            output = self.classifier(img_csv_features)
        else:
            img_features = self.image_feature_extractor(img)
            if self.model_name == 'AttentionAugmentedInceptionV3':
                if isinstance(img_features, tuple):
                    img_features = img_features.logits
                    img_features = img_features.view(img_features.size(0), -1)

            if self.fusion_method == 'late':
                combined_features = torch.cat((img_features, csv_features), dim=1)
                output = self.fusion_layer(combined_features)
            elif self.fusion_method == 'intermediate':
                img_features = self.intermediate_layer(img_features)
                combined_features = torch.cat((img_features, csv_features), dim=1)
                output = self.fusion_layer(combined_features)
            else:
                raise ValueError("Unsupported fusion method")
        return output

# Model Training and Evaluation

In [70]:
# Clear cache function
def clear_cache():
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()
    else:
        torch.cache.empty_cache()

In [71]:
# Function to adjust learning rate
def adjust_learning_rate(optimizer, epoch, learning_rate):
    lr = learning_rate * (0.1 ** (epoch // 10))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [72]:
# Function to train the model
def train_model(cnn_model_name, csv_model_name, fusion_mehod, model, criterion, optimizer, train_loader, val_loader, num_classes, csv_input_dim, device, num_epochs=40, initial_lr=0.001, save_path = ''):
    early_stopping_patience = 5
    best_val_loss = float('inf')
    patience_counter = 0

    model.to(device)

    # Start an MLflow run
    with mlflow.start_run(run_name=f"{cnn_model_name}_{csv_model_name}_{fusion_mehod}"):

        # Log parameters
        mlflow.log_param("model_name", cnn_model_name)
        mlflow.log_param("fusion_method", fusion_mehod)
        mlflow.log_param("csv_model_type", csv_model_name)
        mlflow.log_param("num_epochs", num_epochs)
        mlflow.log_param("learning_rate", initial_lr)

        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            for i, data in enumerate(train_loader):
                inputs_img, inputs_csv, labels = data
                inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(inputs_img, inputs_csv)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]

                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            train_loss = running_loss / len(train_loader)
            train_accuracy = 100 * correct / total

            # Validation
            model.eval()
            val_loss = 0.0
            correct = 0
            total = 0
            with torch.no_grad():
                for data in val_loader:
                    inputs_img, inputs_csv, labels = data
                    inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)
                    outputs = model(inputs_img, inputs_csv)
                    if isinstance(outputs, tuple):
                        outputs = outputs[0]
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

            val_loss /= len(val_loader)
            val_accuracy = 100 * correct / total

            # Log metrics
            mlflow.log_metric("train_loss", train_loss, step=epoch)
            mlflow.log_metric("train_accuracy", train_accuracy, step=epoch)
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metric("val_accuracy", val_accuracy, step=epoch)

            print(f'Epoch {epoch + 1}/{num_epochs}, '
                  f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
                  f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

            # Save the best checkpoint
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                checkpoint = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_loss': best_val_loss,
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'train_accuracy': train_accuracy,
                    'val_accuracy': val_accuracy
                }
                torch.save(checkpoint, os.path.join(save_path, f'{cnn_model_name}_{csv_model_name}_{fusion_mehod}_model_checkpoint.pth'))
                print(f"Checkpoint saved to model_checkpoint.pth")
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    print("Early stopping due to no improvement in validation loss.")
                    break

        # Log the trained model
        mlflow.pytorch.log_model(model, f"{cnn_model_name}_{csv_model_name}_{fusion_mehod}_model")

    return model

In [73]:
# Function to create and train the model
def create_and_train_fusion_model(cnn_model_name, csv_model_name, fusion_mehod, model, train_loader, val_loader, num_classes, csv_input_dim, device, num_epochs=40, initial_lr=0.001, save_path=None):
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
    return train_model(cnn_model_name, csv_model_name, fusion_mehod, model, criterion, optimizer, train_loader, val_loader, num_classes, csv_input_dim, device, num_epochs=num_epochs, initial_lr=initial_lr, save_path=save_path)


In [74]:
# Function to evaluate the fusion model
def evaluate_fusion_model(model, test_loader, criterion, device, num_classes):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    all_confidences = [[] for _ in range(num_classes)]
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for data in test_loader:
            inputs_img, inputs_csv, labels = data
            inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)
            outputs = model(inputs_img, inputs_csv)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            # Calculate confidence level
            softmax_outputs = F.softmax(outputs, dim=1)
            confidence, predicted = torch.max(softmax_outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for i in range(len(labels)):
                class_label = labels[i].item()
                all_confidences[class_label].append(softmax_outputs[i, class_label].item())

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_loss /= len(test_loader)
    test_accuracy = 100 * correct / total

    # Calculate average confidence per class
    avg_confidences_per_class = [np.mean(confidences) if confidences else 0 for confidences in all_confidences]

    return test_loss, test_accuracy, avg_confidences_per_class, all_predictions, all_labels

In [75]:
# Train the Model
num_classes_inception = len(class_to_idx)
num_classes_others = len(class_to_idx)

In [76]:
# Set the number of features in the CSV data
num_heads = 8

In [77]:
# Define the csv input dimensions
csv_input_dim = csv_features.shape[1]

In [78]:
# Define the csv hidden dimensions
csv_hidden_dim = 256

In [79]:
# Define the results dictionary
crop_results = {}

In [80]:
seq_len = len(feature_columns)
seq_len

3

In [81]:
# Create a results folder
results_base_dir = "/Users/izzymohamed/Desktop/Vision For Social Good/Project/Vision-For-Social-Good/RESULTS/Multimodal"
results_folder = os.path.join(results_base_dir,'Peach', 'CNN+BB', 'T2+NoES+BS016')
os.makedirs(results_folder, exist_ok=True)

In [82]:
# Train and evaluate the models
with open(os.path.join(results_folder, 'train_val_test_results8.txt'), 'w') as f:
    for fusion_method in ['intermediate', 'late']:
        output_line = f'---------------------------------------------------- Fusion Method: {fusion_method} ----------------------------------------------------'
        print(output_line)
        f.write(output_line + '\n')

        for csv_model_type in ['small','medium','conv_small', 'conv_medium']:
            cnn_feature_extractors = {
                'InceptionV3': models.inception_v3(pretrained=True).to(device),
                'ResNet152': models.resnet152(pretrained=True).to(device),
                'VGG19': models.vgg19(pretrained=True).to(device),
                'ViT': ViT(
                    image_size=224,
                    patch_size=16,
                    num_classes=num_classes_others,
                    dim=256,
                    depth=6,
                    heads=24,
                    mlp_dim=2048,
                    dropout=0.1,
                    emb_dropout=0.1
                ).to(device),
                "AttentionAugmentedInceptionV3": attention_augmented_inceptionv3(attention=True).to(device),
                "AttentionAugmentedResNet18": attention_augmented_resnet18(num_classes=num_classes_others, attention=[False, True, True, True], num_heads=8).to(device),
            }

            if 'InceptionV3' in cnn_feature_extractors:
                cnn_feature_extractors['InceptionV3'].aux_logits = False

            output_line = f'-------------------------------- CSV Model Type: {csv_model_type} --------------------------------'
            print(output_line)
            f.write(output_line + '\n')

            for model_name, image_feature_extractor in cnn_feature_extractors.items():
                output_line = f'Training {model_name} with {fusion_method} fusion and CSV extractor {csv_model_type}'
                print(output_line)
                f.write(output_line + '\n')

                # Initialize and train the fusion model
                fusion_model = FusionModel(
                    model_name=model_name,
                    image_feature_extractor=image_feature_extractor,
                    csv_input_dim=csv_input_dim,
                    csv_hidden_dim=csv_hidden_dim,
                    num_classes=num_classes_others,
                    fusion_method=fusion_method,
                    csv_model_type=csv_model_type,
                    seq_len=seq_len
                ).to(device)

                checkpoint_path = os.path.join(results_folder, f'{model_name}_{csv_model_type}_{fusion_method}_model_checkpoint.pth')
                if os.path.exists(checkpoint_path):
                    checkpoint = torch.load(checkpoint_path)
                    fusion_model.load_state_dict(checkpoint['model_state_dict'])
                    trained_model = fusion_model
                    train_loss = checkpoint['train_loss']
                    # trained_model.eval()

                    # print(trained_model.eval())

                else:
                    trained_model = create_and_train_fusion_model(
                        model_name,
                        csv_model_type,
                        fusion_method,
                        fusion_model,
                        train_loader_others,
                        val_loader_others,
                        num_classes_others,
                        csv_input_dim,
                        device,
                        num_epochs=40,
                        initial_lr=0.001,
                        save_path=results_folder
                    )

                # Evaluate the loaded model
                test_loss, test_accuracy, all_confidences, all_predictions, all_labels = evaluate_fusion_model(
                    fusion_model, test_loader_others, nn.CrossEntropyLoss(), device, 4
                )

                # Store the results
                crop_results[f"{model_name}_{fusion_method}_{csv_model_type}"] = {
                    'model': trained_model,
                    'model_name': model_name,
                    'fusion_method': fusion_method,
                    'csv_model_type': csv_model_type,
                    'test_loss': test_loss,
                    'test_accuracy': test_accuracy
                }

                output_line = f'{model_name} with {fusion_method} fusion and CSV extractor {csv_model_type} Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%'
                print(output_line)
                f.write(output_line + '\n')

                # Clear GPU memory
                del fusion_model
                del trained_model
                torch.cuda.empty_cache()

                print('\n')
                f.write('\n')

            print('\n')
            f.write('\n')

        print('----------------------------------------------------------------------------------------------------------------------')
        f.write('----------------------------------------------------------------------------------------------------------------------\n')
        print('\n')
        f.write('\n')

---------------------------------------------------- Fusion Method: intermediate ----------------------------------------------------




-------------------------------- CSV Model Type: small --------------------------------
Training InceptionV3 with intermediate fusion and CSV extractor small
InceptionV3 with intermediate fusion and CSV extractor small Test Loss: 0.3684, Test Accuracy: 93.75%


Training ResNet152 with intermediate fusion and CSV extractor small
ResNet152 with intermediate fusion and CSV extractor small Test Loss: 0.6574, Test Accuracy: 93.75%


Training VGG19 with intermediate fusion and CSV extractor small
VGG19 with intermediate fusion and CSV extractor small Test Loss: 0.8472, Test Accuracy: 59.38%


Training ViT with intermediate fusion and CSV extractor small
ViT with intermediate fusion and CSV extractor small Test Loss: 0.3478, Test Accuracy: 90.62%


Training AttentionAugmentedInceptionV3 with intermediate fusion and CSV extractor small
AttentionAugmentedInceptionV3 with intermediate fusion and CSV extractor small Test Loss: 0.3300, Test Accuracy: 93.75%


Training AttentionAugmentedResNet18 wi



-------------------------------- CSV Model Type: medium --------------------------------
Training InceptionV3 with intermediate fusion and CSV extractor medium
InceptionV3 with intermediate fusion and CSV extractor medium Test Loss: 0.5515, Test Accuracy: 87.50%


Training ResNet152 with intermediate fusion and CSV extractor medium
ResNet152 with intermediate fusion and CSV extractor medium Test Loss: 0.2363, Test Accuracy: 93.75%


Training VGG19 with intermediate fusion and CSV extractor medium
VGG19 with intermediate fusion and CSV extractor medium Test Loss: 0.3769, Test Accuracy: 93.75%


Training ViT with intermediate fusion and CSV extractor medium
ViT with intermediate fusion and CSV extractor medium Test Loss: 0.2296, Test Accuracy: 96.88%


Training AttentionAugmentedInceptionV3 with intermediate fusion and CSV extractor medium
AttentionAugmentedInceptionV3 with intermediate fusion and CSV extractor medium Test Loss: 3.1540, Test Accuracy: 90.62%


Training AttentionAugmented



-------------------------------- CSV Model Type: conv_small --------------------------------
Training InceptionV3 with intermediate fusion and CSV extractor conv_small
InceptionV3 with intermediate fusion and CSV extractor conv_small Test Loss: 0.8972, Test Accuracy: 90.62%


Training ResNet152 with intermediate fusion and CSV extractor conv_small
ResNet152 with intermediate fusion and CSV extractor conv_small Test Loss: 0.4638, Test Accuracy: 93.75%


Training VGG19 with intermediate fusion and CSV extractor conv_small
VGG19 with intermediate fusion and CSV extractor conv_small Test Loss: 0.3126, Test Accuracy: 93.75%


Training ViT with intermediate fusion and CSV extractor conv_small
ViT with intermediate fusion and CSV extractor conv_small Test Loss: 0.3349, Test Accuracy: 90.62%


Training AttentionAugmentedInceptionV3 with intermediate fusion and CSV extractor conv_small
Epoch 1/40, Train Loss: 0.8037, Train Accuracy: 73.96%, Val Loss: 11.9470, Val Accuracy: 62.50%
Checkpoint sav



AttentionAugmentedInceptionV3 with intermediate fusion and CSV extractor conv_small Test Loss: 2.9762, Test Accuracy: 71.88%


Training AttentionAugmentedResNet18 with intermediate fusion and CSV extractor conv_small
Epoch 1/40, Train Loss: 1.0878, Train Accuracy: 72.92%, Val Loss: 1.8535, Val Accuracy: 96.88%
Checkpoint saved to model_checkpoint.pth
Epoch 2/40, Train Loss: 0.3592, Train Accuracy: 86.46%, Val Loss: 3.0153, Val Accuracy: 93.75%
Epoch 3/40, Train Loss: 0.3812, Train Accuracy: 80.21%, Val Loss: 2.3764, Val Accuracy: 93.75%
Epoch 4/40, Train Loss: 0.3962, Train Accuracy: 88.54%, Val Loss: 0.2689, Val Accuracy: 93.75%
Checkpoint saved to model_checkpoint.pth
Epoch 5/40, Train Loss: 0.5055, Train Accuracy: 79.17%, Val Loss: 0.1748, Val Accuracy: 96.88%
Checkpoint saved to model_checkpoint.pth
Epoch 6/40, Train Loss: 0.4148, Train Accuracy: 83.33%, Val Loss: 0.2032, Val Accuracy: 93.75%
Epoch 7/40, Train Loss: 0.2876, Train Accuracy: 89.58%, Val Loss: 0.2436, Val Accuracy: 93

Python(11272) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


AttentionAugmentedResNet18 with intermediate fusion and CSV extractor conv_small Test Loss: 0.3607, Test Accuracy: 90.62%








-------------------------------- CSV Model Type: conv_medium --------------------------------
Training InceptionV3 with intermediate fusion and CSV extractor conv_medium
Epoch 1/40, Train Loss: 0.6522, Train Accuracy: 75.00%, Val Loss: 2.0429, Val Accuracy: 71.88%
Checkpoint saved to model_checkpoint.pth
Epoch 2/40, Train Loss: 0.3304, Train Accuracy: 84.38%, Val Loss: 10.1900, Val Accuracy: 68.75%
Epoch 3/40, Train Loss: 0.2268, Train Accuracy: 93.75%, Val Loss: 0.1074, Val Accuracy: 96.88%
Checkpoint saved to model_checkpoint.pth
Epoch 4/40, Train Loss: 0.3006, Train Accuracy: 89.58%, Val Loss: 0.7810, Val Accuracy: 84.38%
Epoch 5/40, Train Loss: 0.1617, Train Accuracy: 93.75%, Val Loss: 0.0947, Val Accuracy: 96.88%
Checkpoint saved to model_checkpoint.pth
Epoch 6/40, Train Loss: 0.1914, Train Accuracy: 92.71%, Val Loss: 0.5900, Val Accuracy: 96.88%
Epoch 7/40, Train Loss: 0.6004, Train Accuracy: 80.21%, Val Loss: 10.7425, Val Accuracy: 65.62%
Epoch 8/40, Train Loss: 0.3087, Train Ac

Python(12986) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


InceptionV3 with intermediate fusion and CSV extractor conv_medium Test Loss: 2.8001, Test Accuracy: 90.62%


Training ResNet152 with intermediate fusion and CSV extractor conv_medium
Epoch 1/40, Train Loss: 0.8556, Train Accuracy: 68.75%, Val Loss: 1161.5923, Val Accuracy: 34.38%
Checkpoint saved to model_checkpoint.pth
Epoch 2/40, Train Loss: 0.3004, Train Accuracy: 86.46%, Val Loss: 906.2516, Val Accuracy: 68.75%
Checkpoint saved to model_checkpoint.pth
Epoch 3/40, Train Loss: 0.3851, Train Accuracy: 85.42%, Val Loss: 546.2539, Val Accuracy: 62.50%
Checkpoint saved to model_checkpoint.pth
Epoch 4/40, Train Loss: 0.2863, Train Accuracy: 91.67%, Val Loss: 109.1504, Val Accuracy: 6.25%
Checkpoint saved to model_checkpoint.pth
Epoch 5/40, Train Loss: 0.2136, Train Accuracy: 94.79%, Val Loss: 43.6733, Val Accuracy: 21.88%
Checkpoint saved to model_checkpoint.pth
Epoch 6/40, Train Loss: 0.2561, Train Accuracy: 90.62%, Val Loss: 0.8497, Val Accuracy: 81.25%
Checkpoint saved to model_checkp

Python(14752) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


ResNet152 with intermediate fusion and CSV extractor conv_medium Test Loss: 0.5376, Test Accuracy: 84.38%


Training VGG19 with intermediate fusion and CSV extractor conv_medium
Epoch 1/40, Train Loss: 11.8923, Train Accuracy: 36.46%, Val Loss: 1.2074, Val Accuracy: 62.50%
Checkpoint saved to model_checkpoint.pth
Epoch 2/40, Train Loss: 1.4736, Train Accuracy: 53.12%, Val Loss: 1.1217, Val Accuracy: 34.38%
Checkpoint saved to model_checkpoint.pth
Epoch 3/40, Train Loss: 1.2143, Train Accuracy: 52.08%, Val Loss: 1.0326, Val Accuracy: 18.75%
Checkpoint saved to model_checkpoint.pth
Epoch 4/40, Train Loss: 0.9942, Train Accuracy: 48.96%, Val Loss: 0.7788, Val Accuracy: 62.50%
Checkpoint saved to model_checkpoint.pth
Epoch 5/40, Train Loss: 1.1280, Train Accuracy: 54.17%, Val Loss: 1.0320, Val Accuracy: 34.38%
Epoch 6/40, Train Loss: 1.3283, Train Accuracy: 41.67%, Val Loss: 0.9429, Val Accuracy: 62.50%
Epoch 7/40, Train Loss: 1.0814, Train Accuracy: 54.17%, Val Loss: 0.7854, Val Accuracy:

Python(17286) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


VGG19 with intermediate fusion and CSV extractor conv_medium Test Loss: 0.5104, Test Accuracy: 93.75%


Training ViT with intermediate fusion and CSV extractor conv_medium
Epoch 1/40, Train Loss: 0.8411, Train Accuracy: 77.08%, Val Loss: 0.5027, Val Accuracy: 87.50%
Checkpoint saved to model_checkpoint.pth
Epoch 2/40, Train Loss: 0.5257, Train Accuracy: 89.58%, Val Loss: 0.6762, Val Accuracy: 84.38%
Epoch 3/40, Train Loss: 0.3557, Train Accuracy: 89.58%, Val Loss: 0.3377, Val Accuracy: 90.62%
Checkpoint saved to model_checkpoint.pth
Epoch 4/40, Train Loss: 0.3258, Train Accuracy: 90.62%, Val Loss: 0.4675, Val Accuracy: 87.50%
Epoch 5/40, Train Loss: 0.2319, Train Accuracy: 90.62%, Val Loss: 0.4168, Val Accuracy: 90.62%
Epoch 6/40, Train Loss: 0.2111, Train Accuracy: 91.67%, Val Loss: 0.4440, Val Accuracy: 90.62%
Epoch 7/40, Train Loss: 0.1874, Train Accuracy: 92.71%, Val Loss: 0.5999, Val Accuracy: 90.62%
Epoch 8/40, Train Loss: 0.1578, Train Accuracy: 93.75%, Val Loss: 0.8062, Val Acc

Python(18474) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


ViT with intermediate fusion and CSV extractor conv_medium Test Loss: 0.7882, Test Accuracy: 81.25%


Training AttentionAugmentedInceptionV3 with intermediate fusion and CSV extractor conv_medium
Epoch 1/40, Train Loss: 0.7852, Train Accuracy: 72.92%, Val Loss: 19.0833, Val Accuracy: 62.50%
Checkpoint saved to model_checkpoint.pth
Epoch 2/40, Train Loss: 0.6674, Train Accuracy: 78.12%, Val Loss: 9.6629, Val Accuracy: 62.50%
Checkpoint saved to model_checkpoint.pth
Epoch 3/40, Train Loss: 0.3009, Train Accuracy: 86.46%, Val Loss: 2.3552, Val Accuracy: 71.88%
Checkpoint saved to model_checkpoint.pth
Epoch 4/40, Train Loss: 0.1114, Train Accuracy: 93.75%, Val Loss: 0.3461, Val Accuracy: 90.62%
Checkpoint saved to model_checkpoint.pth
Epoch 5/40, Train Loss: 0.1176, Train Accuracy: 96.88%, Val Loss: 0.2105, Val Accuracy: 93.75%
Checkpoint saved to model_checkpoint.pth
Epoch 6/40, Train Loss: 0.0919, Train Accuracy: 96.88%, Val Loss: 0.3559, Val Accuracy: 93.75%
Epoch 7/40, Train Loss: 0.25

Python(19651) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


AttentionAugmentedInceptionV3 with intermediate fusion and CSV extractor conv_medium Test Loss: 1.3394, Test Accuracy: 71.88%


Training AttentionAugmentedResNet18 with intermediate fusion and CSV extractor conv_medium
Epoch 1/40, Train Loss: 0.9418, Train Accuracy: 63.54%, Val Loss: 0.6806, Val Accuracy: 96.88%
Checkpoint saved to model_checkpoint.pth
Epoch 2/40, Train Loss: 0.4529, Train Accuracy: 86.46%, Val Loss: 0.2910, Val Accuracy: 93.75%
Checkpoint saved to model_checkpoint.pth
Epoch 3/40, Train Loss: 0.3985, Train Accuracy: 84.38%, Val Loss: 0.4055, Val Accuracy: 90.62%
Epoch 4/40, Train Loss: 0.3231, Train Accuracy: 86.46%, Val Loss: 0.2497, Val Accuracy: 96.88%
Checkpoint saved to model_checkpoint.pth
Epoch 5/40, Train Loss: 0.3598, Train Accuracy: 89.58%, Val Loss: 0.1323, Val Accuracy: 96.88%
Checkpoint saved to model_checkpoint.pth
Epoch 6/40, Train Loss: 0.3079, Train Accuracy: 89.58%, Val Loss: 0.1461, Val Accuracy: 93.75%
Epoch 7/40, Train Loss: 0.2481, Train Accuracy:

Python(20671) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


AttentionAugmentedResNet18 with intermediate fusion and CSV extractor conv_medium Test Loss: 0.2379, Test Accuracy: 93.75%




----------------------------------------------------------------------------------------------------------------------


---------------------------------------------------- Fusion Method: late ----------------------------------------------------




-------------------------------- CSV Model Type: small --------------------------------
Training InceptionV3 with late fusion and CSV extractor small
Epoch 1/40, Train Loss: 0.5688, Train Accuracy: 80.21%, Val Loss: 4.6170, Val Accuracy: 37.50%
Checkpoint saved to model_checkpoint.pth


KeyboardInterrupt: 

# Display Results

In [None]:
# Function to save figures
def save_figure(fig, filename):
    fig.savefig(os.path.join(results_folder, filename))
    plt.close(fig)

In [None]:
# Plot comparison of accuracy for each model for each crop
def plot_accuracy_comparison(results):
    accuracies = [result['test_accuracy'] for result in results.values()]
    model_names = list(results.keys())

    fig = plt.figure(figsize=(20, 10))
    plt.bar(model_names, accuracies)
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Model')
    plt.xticks(rotation=90)  # Make x-axis labels diagonal
    plt.show()
    save_figure(fig, 'all_fusion_accuracy_comparison.png')

    for fusion_method in ['late', 'intermediate']:
        # Filter models based on the fusion method
        filtered_results = {model: details for model, details in results.items() if details['fusion_method'] == fusion_method}

        if filtered_results:  # Check if there are models with this fusion method
            accuracies = [details['test_accuracy'] for details in filtered_results.values()]
            model_names = list(filtered_results.keys())

            fig = plt.figure(figsize=(20, 10))
            plt.bar(model_names, accuracies)
            plt.ylabel('Accuracy (%)')
            plt.xlabel('Model')
            plt.xticks(rotation=90)  # Make x-axis labels diagonal
            plt.title(f'Accuracy Comparison for {fusion_method.capitalize()} Fusion')
            plt.show()
            save_figure(fig, f'{fusion_method}_accuracy_comparison.png')

In [None]:
# Plot comparison of accuracy for each model for each crop
plot_accuracy_comparison(crop_results)

In [None]:
# Function to display F1, precision, and recall of all models as a table
def display_model_metrics_table(results, test_loader_inception, test_loader_others):
    metrics_data = []

    for model_name, model_info in results.items():
        test_loader = test_loader_inception if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3'] else test_loader_others

        model = model_info['model']
        device = next(model.parameters()).device
        model.eval()

        all_labels = []
        all_predicted = []

        for images, csv_features, labels in test_loader:
            images, csv_features, labels = images.to(device), csv_features.to(device), labels.to(device)

            with torch.no_grad():
                outputs = model(images, csv_features)
                _, predicted = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_predicted.extend(predicted.cpu().numpy())

        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predicted, average='macro')

        metrics_data.append({
            'Model': model_name,
            'Precision': precision,
            'Recall': recall,
            'F1-score': f1
        })

    metrics_df = pd.DataFrame(metrics_data)
    display(metrics_df)
    metrics_df.to_csv(os.path.join(results_folder, 'model_metrics.csv'), index=False)

In [None]:
# Display the table of metrics for all models
display_model_metrics_table(crop_results, test_loader_inception, test_loader_others)

In [None]:
# Display some correctly and incorrectly classified images
def display_classification_results(model, test_loader, model_name, num_images=5):
    device = next(model.parameters()).device
    model.eval()
    class_labels = list(test_loader.dataset.class_to_idx.keys())

    # Get a batch of images, CSV inputs, and labels
    batch = next(iter(test_loader))
    images, csv_inputs, labels = batch

    images, csv_inputs, labels = images[:num_images].to(device), csv_inputs[:num_images].to(device), labels[:num_images]

    with torch.no_grad():
        outputs = model(images, csv_inputs)
        softmax_outputs = F.softmax(outputs, dim=1)
        confidence, predicted = torch.max(softmax_outputs, 1)

    fig, axes = plt.subplots(1, num_images, figsize=(20, 8))

    for i in range(num_images):
        ax = axes[i]
        img = images[i].cpu().numpy().transpose((1, 2, 0))
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(f"True: {class_labels[labels[i]]}\nPred: {class_labels[predicted[i].cpu()]}\nConf: {confidence[i].cpu():.2f}")
        ax.axis('off')

    plt.show()
    save_figure(fig, f'{model_name}_classification_results_with_confidence.png')

In [None]:
# Display results for each crop
for model_name in crop_results.keys():
    if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3']:
        test_loader = test_loader_inception
    else:
        test_loader = test_loader_others
    
    print(f'Displaying results for {model_name}')
    display_classification_results(crop_results[model_name]['model'], test_loader, model_name)

In [None]:
# Function to display the classification report of a given model
def display_classification_report(model, test_loader, model_name):
    all_labels = []
    all_predicted = []
    device = next(model.parameters()).device

    for batch in test_loader:
        images, csv_features, labels = batch
        images, csv_features, labels = images.to(device), csv_features.to(device), labels.to(device)

        with torch.no_grad():
            outputs = model(images, csv_features)
            predicted = torch.argmax(outputs, dim=1)

        all_labels.extend(labels.cpu().numpy())
        all_predicted.extend(predicted.cpu().numpy())

    # Ensure target names are in the correct order
    target_names = ['Healthy', 'Grapholita molesta', 'Anarsia lineatella', 'Dead Tree']
    labels_list = [class_to_idx[name] for name in target_names]  # Get indices in the specified order

    print(f'Classification Report for {model_name}:')
    report = classification_report(all_labels, all_predicted, labels=labels_list, target_names=target_names)
    print(report)

    report_filename = os.path.join(results_folder, f'{model_name}_classification_report.txt')
    with open(report_filename, 'w') as f:
        f.write(report)

In [None]:
# Display results for each crop
for model_name in crop_results.keys():
    if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3']:
        test_loader = test_loader_inception
    else:
        test_loader = test_loader_others
        
    print(f'Displaying classification report for {model_name}')
    display_classification_report(crop_results[model_name]['model'], test_loader, model_name)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:
# Confusion Matrix Plotting
def plot_confusion_matrix(labels, pred_labels, classes, model_name):
    # Ensure the classes are in the correct order
    classes = ['Healthy', 'Grapholita molesta', 'Anarsia lineatella', 'Dead Tree']
    fig = plt.figure(figsize=(50, 50))
    ax = fig.add_subplot(1, 1, 1)
    cm = confusion_matrix(labels, pred_labels, labels=[class_to_idx[cls] for cls in classes])
    cm_display = ConfusionMatrixDisplay(cm, display_labels=classes)
    cm_display.plot(values_format='d', cmap='Blues', ax=ax)
    fig.delaxes(fig.axes[1])
    plt.xticks(rotation=90)
    plt.xlabel('Predicted Label', fontsize=50)
    plt.ylabel('True Label', fontsize=50)

    plt.show()
    save_figure(fig, f'{model_name}_confusion_matrix.png')

def get_all_labels_and_preds(model, test_loader):
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for data in test_loader:
            images, csv_features, labels = data
            images, csv_features, labels = images.to(device), csv_features.to(device), labels.to(device)

            outputs = model(images, csv_features)
            _, preds = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy().astype(int))
            all_preds.extend(preds.cpu().numpy().astype(int))

    return all_labels, all_preds

def generate_confusion_matrices(results, test_loader_inception, test_loader_others):
    for model_name, model_info in results.items():
        test_loader = test_loader_inception if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3'] else test_loader_others

        classes = ['Healthy', 'Grapholita molesta', 'Anarsia lineatella', 'Dead Tree']
        model = model_info['model']
        labels, pred_labels = get_all_labels_and_preds(model, test_loader)
        plot_confusion_matrix(labels, pred_labels, classes, model_name)

In [None]:
generate_confusion_matrices(crop_results, test_loader_inception, test_loader_others)

In [None]:
# Function to normalize images
def normalize_image(image):
    image = image - image.min()
    image = image / image.max()
    return image

In [None]:
# Function to plot the most incorrect predictions
def plot_most_incorrect(incorrect, classes, n_images, model_name, normalize=True):
    rows = int(np.ceil(np.sqrt(n_images)))
    cols = int(np.ceil(n_images / rows))

    fig = plt.figure(figsize=(25, 20))

    for i in range(rows * cols):
        if i >= len(incorrect):
            break
        ax = fig.add_subplot(rows, cols, i + 1)
        image, true_label, probs = incorrect[i]
        image = image.permute(1, 2, 0)
        true_prob = probs[true_label]
        incorrect_prob, incorrect_label = torch.max(probs, dim=0)
        true_class = classes[true_label]
        incorrect_class = classes[incorrect_label]

        if normalize:
            image = normalize_image(image)

        ax.imshow(image.cpu().numpy())
        ax.set_title(f'true label:\n{true_class} ({true_prob:.3f})\n'
                     f'pred label:\n{incorrect_class} ({incorrect_prob:.3f})', fontsize=10)
        ax.axis('off')

    plt.tight_layout()
    fig.subplots_adjust(hspace=0.7)

    plt.show()
    save_figure(fig, f'{model_name}_most_incorrect.png')


def get_all_details(model, test_loader):
    all_labels = []
    all_preds = []
    all_probs = []
    all_images = []
    device = next(model.parameters()).device
    model.eval()

    with torch.no_grad():
        for images, csv_features, labels in test_loader:
            images, csv_features, labels = images.to(device), csv_features.to(device), labels.to(device)
            outputs = model(images, csv_features)
            _, preds = torch.max(outputs, 1)
            probs = F.softmax(outputs, dim=1)

            all_images.extend(images.cpu())
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu())

    return all_images, all_labels, all_preds, all_probs

def plot_most_incorrect_predictions(results, test_loader_inception, test_loader_others, n_images=36):
    classes = list(test_loader_inception.dataset.class_to_idx.keys())
    for model_name, model_info in results.items():
        test_loader = test_loader_inception if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3'] else test_loader_others

        model = model_info['model']
        images, labels, pred_labels, probs = get_all_details(model, test_loader)
        corrects = torch.eq(torch.tensor(labels), torch.tensor(pred_labels))
        incorrect_examples = []

        for image, label, prob, correct in zip(images, labels, probs, corrects):
            if not correct:
                incorrect_examples.append((image, label, prob))

        incorrect_examples.sort(key=lambda x: torch.max(x[2], dim=0)[0], reverse=True)
        plot_most_incorrect(incorrect_examples[:n_images], classes, n_images, model_name)


In [None]:
plot_most_incorrect_predictions(crop_results, test_loader_inception, test_loader_others, n_images=36)

In [None]:
# Function to generate all results
def generate_all_results(results, test_loader_inception, test_loader_others):
    plot_accuracy_comparison(results)
    display_model_metrics_table(results, test_loader_inception, test_loader_others)
    generate_confusion_matrices(results, test_loader_inception, test_loader_others)
    plot_most_incorrect_predictions(results, test_loader_inception, test_loader_others, n_images=36)

    for model_name, model_info in results.items():
        model = model_info['model']
        test_loader = test_loader_inception if model_name in ['InceptionV3', 'AttentionAugmentedInceptionV3'] else test_loader_others
        display_classification_results(model, test_loader, model_name)
        display_classification_report(model, test_loader, model_name)

In [None]:
# Call the generate_all_results function
generate_all_results(crop_results, test_loader_inception, test_loader_others)