# Imports

In [None]:
import os
import shutil
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

from efficientnet_pytorch import EfficientNet
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, classification_report

from PIL import Image
from PIL.ExifTags import TAGS, GPSTAGS

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import is_image_file
# Ensure these imports are included
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import StepLR

from vit_pytorch import ViT

from AACN_Model import attention_augmented_resnet18, attention_augmented_efficientnetb0, attention_augmented_inceptionv3, attention_augmented_vit, attention_augmented_vgg

# Data Preprocessing

In [None]:
# Define main directories
base_dir = '/Users/izzymohamed/Desktop/Vision For Social Good/EXTRA/CODE/shubham10divakar Multimodal-Plant-Disease-Dataset/Data' #'/Users/izzymohamed/Downloads/shubham10divakar Multimodal-Plant-Disease-Dataset/Data'
crop_root = os.path.join(base_dir, 'color')
split_root = os.path.join(base_dir, 'split')

In [None]:
# Define function to remove .DS_Store files
def remove_ds_store(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file == '.DS_Store' or '.DS_Store' in file:
                file_path = os.path.join(root, file)
                print(f"Removing {file_path}")
                os.remove(file_path)

In [None]:
# Remove .DS_Store files from base directory
remove_ds_store(base_dir)

In [None]:
def is_image_file(filename):
    # Assuming is_image_file is a function that checks if the file is an image
    return filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'))

In [None]:
# Function to split data into train, validation, and test sets
def split_data(base_dir, val_split=0.4, test_split=0.1):
    train_files = []
    val_files = []
    test_files = []

    classes = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
    for cls in classes:
        print(f'Processing class: {cls}')
        class_dir = os.path.join(base_dir, cls)

        images = [f for f in os.listdir(class_dir) if is_image_file(os.path.join(class_dir, f))]

        if len(images) == 0:
            print(f"No images found for class {cls}. Skipping...")
            continue

        # Shuffle images to randomize the selection
        random.shuffle(images)

        try:
            train, test = train_test_split(images, test_size=test_split)
            train, val = train_test_split(train, test_size=val_split / (1 - test_split))
        except ValueError as e:
            print(f"Not enough images to split for class {cls}: {e}")
            continue

        train_files.extend([(os.path.join(class_dir, img), cls) for img in train])
        val_files.extend([(os.path.join(class_dir, img), cls) for img in val])
        test_files.extend([(os.path.join(class_dir, img), cls) for img in test])

    return train_files, val_files, test_files, classes

In [None]:
# Split data
train_files, val_files, test_files, classes = split_data(crop_root)

In [None]:
# Use the lists of file paths for your dataset loading and transformations
print(f"Train files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")
print(f"Test files: {len(test_files)}")

In [None]:
# Define the standard image sizes
inception_size = 299
other_size = 224

In [None]:
# Update the data transformations
data_transforms = {
    'InceptionV3': {
        'train': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    },
    'Others': {
        'train': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
}

In [None]:
# Create a custom dataset class to load images from the file lists
class CustomDataset(Dataset):
    def __init__(self, file_paths, class_to_idx, transform=None):
        self.file_paths = file_paths
        self.class_to_idx = class_to_idx
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path, cls = self.file_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.class_to_idx[cls]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
# Create a mapping from class names to indices
class_to_idx = {cls: idx for idx, cls in enumerate(classes)}

In [None]:
# Create datasets and data loaders
train_dataset_inception = CustomDataset(train_files, class_to_idx, transform=data_transforms['InceptionV3']['train'])
val_dataset_inception = CustomDataset(val_files, class_to_idx, transform=data_transforms['InceptionV3']['val'])
test_dataset_inception = CustomDataset(test_files, class_to_idx, transform=data_transforms['InceptionV3']['test'])

train_loader_inception = DataLoader(train_dataset_inception, batch_size=32, shuffle=True)
val_loader_inception = DataLoader(val_dataset_inception, batch_size=32, shuffle=True)
test_loader_inception = DataLoader(test_dataset_inception, batch_size=32, shuffle=False)

In [None]:
# Loaders for other models
train_dataset_others = CustomDataset(train_files, class_to_idx, transform=data_transforms['Others']['train'])
val_dataset_others = CustomDataset(val_files, class_to_idx, transform=data_transforms['Others']['val'])
test_dataset_others = CustomDataset(test_files, class_to_idx, transform=data_transforms['Others']['test'])

train_loader_others = DataLoader(train_dataset_others, batch_size=32, shuffle=True)
val_loader_others = DataLoader(val_dataset_others, batch_size=32, shuffle=True)
test_loader_others = DataLoader(test_dataset_others, batch_size=32, shuffle=False)

# Model Training and Evaluation

In [None]:
# Assuming `crops` and directories (`train_dir`, `val_dir`, `test_dir`) are defined
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Device: {device}")

In [None]:
# Function to check if a file is valid
def is_valid_file(path):
    return not path.endswith('.DS_Store') or 'DS_Store' not in path

In [None]:
# Function to adjust learning rate
def adjust_learning_rate(optimizer, epoch, learning_rate):
    """Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
    lr = learning_rate * (0.1 ** (epoch // 10))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
# Training function with mixed precision and gradient accumulation
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=40, accumulation_steps=4, initial_lr=0.001):
    scaler = GradScaler()
    early_stopping_patience = 5
    best_val_loss = float('inf')
    patience_counter = 0

    model.to(device)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.7)  # Example scheduler


    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        # adjust_learning_rate(optimizer, epoch, initial_lr)
        
        # optimizer.zero_grad()
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # with autocast():
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
                
            scaler.scale(loss).backward()
            
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                with autocast():
                    outputs = model(inputs)
                    if isinstance(outputs, tuple):
                        outputs = outputs[0]
                    loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct / total

        print(f'Epoch {epoch + 1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping due to no improvement in validation loss.")
                break
        
        scheduler.step()

    return model

In [None]:
# Function to create and train the model
def create_and_train_model(model, train_loader, val_loader, num_classes, device, num_epochs=40, initial_lr=0.001):
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
    return train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=num_epochs, initial_lr=initial_lr)

In [None]:
# Function to evaluate the models
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            test_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100 * correct / total
    return test_loss, test_accuracy

In [None]:
# Clear cache function
def clear_cache():
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()
    else:
        torch.cache.empty_cache()

In [None]:
# Save the best model
def save_model(model, model_name, crop):
    model_dir = os.path.join(base_dir, 'saved_models')
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, f'{model_name}.pth')
    torch.save(model.state_dict(), model_path)
    print(f"Model saved at {model_path}")

In [None]:
# Load the best model
def load_model(model, model_name, crop):
    model_dir = os.path.join(base_dir, 'saved_models')
    model_path = os.path.join(model_dir, f'{model_name}.pth')
    model.load_state_dict(torch.load(model_path))
    print(f"Model loaded from {model_path}")
    return model

In [None]:
# Function to find classes in a directory
def find_classes(dir):
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)
        print(f"Created directory: {dir}")
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and not d.startswith('.')]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

### Train the Model

In [None]:
# Count the number of classes
num_classes_inception = len(class_to_idx)
num_classes_others = len(class_to_idx)

In [None]:
# Set the number of features in the CSV data
num_heads = 8

In [None]:
# Define the results dictionary
crop_results = {}

In [None]:
# Define the pretrained models
pretrained_models = {
    'InceptionV3': models.inception_v3(pretrained=True).to(device),
    'ResNet152': models.resnet152(pretrained=True).to(device),
    'VGG19': models.vgg19(pretrained=True).to(device),
    'ViT': ViT(
        image_size=224,
        patch_size=16,
        num_classes=num_classes_others,
        dim=1024,
        depth=6,
        heads=16,
        mlp_dim=2048,
        dropout=0.1,
        emb_dropout=0.1
    ).to(device),
    "AttentionAugmentedInceptionV3": attention_augmented_inceptionv3(attention=True).to(device),
    'AttentionAugmentedVGG19': attention_augmented_vgg('VGG19',num_classes=num_classes_others).to(device),
    "AttentionAugmentedResNet18": attention_augmented_resnet18(num_classes=num_classes_others, attention=[False, True, True, True], num_heads=8).to(device),
}

In [None]:
# Train the models
for model_name, base_model in pretrained_models.items():

    base_model.to(device)  # Ensure the model is on the correct device
    
    if model_name == 'InceptionV3':
        base_model.AuxLogits.fc = nn.Linear(base_model.AuxLogits.fc.in_features, num_classes_inception)
        base_model.fc = nn.Linear(base_model.fc.in_features, num_classes_inception)
        train_loader = train_loader_inception
        val_loader = val_loader_inception
        test_loader = test_loader_inception
    elif model_name == 'ViT':
        base_model.mlp_head = nn.Linear(base_model.mlp_head.in_features, num_classes_others)
        train_loader = train_loader_others
        val_loader = val_loader_others
        test_loader = test_loader_others
    elif model_name == 'ResNet152':
        base_model.fc = nn.Linear(base_model.fc.in_features, num_classes_others)
        train_loader = train_loader_others
        val_loader = val_loader_others
        test_loader = test_loader_others
    elif model_name == 'VGG19':
        base_model.classifier[-1] = nn.Linear(base_model.classifier[-1].in_features, num_classes_others)
        train_loader = train_loader_others
        val_loader = val_loader_others
        test_loader = test_loader_others
    elif model_name == 'AttentionAugmentedResNet18':
        base_model.fc = nn.Linear(base_model.fc.in_features, num_classes_others)
        train_loader = train_loader_others
        val_loader = val_loader_others
        test_loader = test_loader_others
    else:
        train_loader = train_loader_others
        val_loader = val_loader_others
        test_loader = test_loader_others

    print(f'--------------- Training model: {model_name}')
    model = create_and_train_model(base_model, train_loader, val_loader, num_classes_others, device, initial_lr=0.001)

    test_loss, test_accuracy = evaluate_model(model, test_loader, nn.CrossEntropyLoss(), device)

    crop_results[model_name] = {
        'model': model,
        'test_loss': test_loss,
        'test_accuracy': test_accuracy
    }
    print(f'{model_name} Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

    # Clean up: delete the model to free up memory (optional)
    del model
    clear_cache()
    
    print(f'\n')


# Display Results

In [None]:
results_base_dir = "/Users/izzymohamed/Desktop/Vision For Social Good/EXTRA/CODE/shubham10divakar Multimodal-Plant-Disease-Dataset/Results/Single Modal"
results_folder = os.path.join(results_base_dir, 'T3')
os.makedirs(results_folder, exist_ok=True)

In [None]:
# Function to save figures
def save_figure(fig, filename):
    fig.savefig(os.path.join(results_folder, filename))
    plt.close(fig)

### Accuracy Comparision

In [None]:
# Plot comparison of accuracy for each model for each crop
def plot_accuracy_comparison(results):
    accuracies = [result['test_accuracy'] for result in results.values()]
    model_names = list(results.keys())

    fig = plt.figure(figsize=(20, 10))
    plt.bar(model_names, accuracies)
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Model')
    plt.show()
    save_figure(fig, 'accuracy_comparison.png')


In [None]:
# Plot comparison of accuracy for each model for each crop
plot_accuracy_comparison(crop_results)

### Metrics Table

In [None]:
# Function to display F1, precision, and recall of all models as a table
def display_model_metrics_table(results, test_loader):
    metrics_data = []
    
    for model_name, model_info in crop_results.items():
        model = model_info['model']
        device = next(model.parameters()).device  # Get the device of the model
        model.eval()  # Set the model to evaluation mode

        all_labels = []
        all_predicted = []

        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            with torch.no_grad():
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_predicted.extend(predicted.cpu().numpy())

        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predicted, average='macro')
        
        metrics_data.append({
            'Model': model_name,
            'Precision': precision,
            'Recall': recall,
            'F1-score': f1
        })

    metrics_df = pd.DataFrame(metrics_data)
    display(metrics_df)  # Display the DataFrame in Jupyter Notebook
    metrics_df.to_csv(os.path.join(results_folder, 'model_metrics.csv'), index=False)

In [None]:
# Display the table of metrics for all models
display_model_metrics_table(crop_results, test_loader)

### Classification Results

In [None]:
# Display some correctly and incorrectly classified images
def display_classification_results(model, test_loader, num_images=5):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode
    class_labels = list(test_loader.dataset.class_to_idx.keys())
    
    images, labels = next(iter(test_loader))
    images, labels = images[:num_images].to(device), labels[:num_images]  # Move tensors to the model's device
    
    with torch.no_grad():
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
    
    fig, axes = plt.subplots(1, num_images, figsize=(20, 8))
    # fig.suptitle(f'{model_name} - Classification Results', fontsize=28)
    
    for i in range(num_images):
        ax = axes[i]
        img = images[i].cpu().numpy().transpose((1, 2, 0))  # Move tensor back to CPU for visualization
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(f'True: {class_labels[labels[i]]}\n Pred: {class_labels[predicted[i].cpu()]}')  # Access CPU tensor for labels
        ax.axis('off')

    plt.show()
    save_figure(fig, f'{model_name}_classification_results.png')

In [None]:
# Display results for each crop
for model_name in crop_results.keys():
    print(f'Displaying results for {model_name}')
    display_classification_results(crop_results[model_name]['model'], test_loader)

### Classification Report

In [None]:
# Function to display the classification report of a given model
def display_classification_report(model, test_loader, model_name):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode

    all_labels = []
    all_predicted = []

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_predicted.extend(predicted.cpu().numpy())

    report = classification_report(all_labels, all_predicted, target_names=list(test_loader.dataset.class_to_idx.keys()))
    
    print(report)
    
    report_filename = os.path.join(results_folder, f'{model_name}_classification_report.txt')
    with open(report_filename, 'w') as f:
        f.write(report)
        

In [None]:
# Display results for each crop
for model_name in crop_results.keys():
    print(f'Displaying classification report for {model_name}')
    display_classification_report(crop_results[model_name]['model'], test_loader, model_name)

### Confusion Metrics

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def plot_confusion_matrix(labels, pred_labels, classes, model_name):
    fig = plt.figure(figsize=(50, 50))
    # fig.suptitle(f'{model_name} - Confusion Matrix\n', fontsize=28, y=0.83)
    ax = fig.add_subplot(1, 1, 1)
    cm = confusion_matrix(labels, pred_labels)
    cm_display = ConfusionMatrixDisplay(cm, display_labels=classes)
    cm_display.plot(values_format='d', cmap='Blues', ax=ax)
    fig.delaxes(fig.axes[1])  # Delete colorbar
    plt.xticks(rotation=90)
    plt.xlabel('Predicted Label', fontsize=50)
    plt.ylabel('True Label', fontsize=50)

    plt.show()
    save_figure(fig, f'{model_name}_confusion_matrix.png')

In [None]:
# Function to extract all labels and predictions
def get_all_labels_and_preds(model, test_loader):
    all_labels = []
    all_preds = []
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    return all_labels, all_preds

In [None]:
# Generate and plot confusion matrices
def generate_confusion_matrices(results, test_loader):
    classes = list(test_loader.dataset.class_to_idx.keys())
    for model_name, model_info in results.items():
        model = model_info['model']
        labels, pred_labels = get_all_labels_and_preds(model, test_loader)
        plot_confusion_matrix(labels, pred_labels, classes, model_name)

In [None]:
generate_confusion_matrices(crop_results, test_loader)

### Incorrect Predictions

In [None]:
# Function to normalize images
def normalize_image(image):
    image = image - image.min()
    image = image / image.max()
    return image

In [None]:
# Function to plot the most incorrect predictions
def plot_most_incorrect(incorrect, classes, n_images, model_name, normalize=True):
    rows = int(np.ceil(np.sqrt(n_images)))
    cols = int(np.ceil(n_images / rows))

    fig = plt.figure(figsize=(25, 20))
    # fig.suptitle(f'{model_name} - Most Incorrect\n', fontsize=28)

    for i in range(rows * cols):
        if i >= len(incorrect):
            break
        ax = fig.add_subplot(rows, cols, i + 1)
        image, true_label, probs = incorrect[i]
        image = image.permute(1, 2, 0)
        true_prob = probs[true_label]
        incorrect_prob, incorrect_label = torch.max(probs, dim=0)
        true_class = classes[true_label]
        incorrect_class = classes[incorrect_label]

        if normalize:
            image = normalize_image(image)

        ax.imshow(image.cpu().numpy())
        ax.set_title(f'true label:\n{true_class} ({true_prob:.3f})\n'
                     f'pred label:\n{incorrect_class} ({incorrect_prob:.3f})', fontsize=10)
        ax.axis('off')

    plt.tight_layout()
    fig.subplots_adjust(hspace=0.7)
    
    plt.show()
    save_figure(fig, f'{model_name}_most_incorrect.png')


In [None]:
def get_all_details(model, test_loader):
    all_labels = []
    all_preds = []
    all_probs = []
    all_images = []
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            probs = F.softmax(outputs, dim=1)

            all_images.extend(images.cpu())
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu())

    return all_images, all_labels, all_preds, all_probs


In [None]:
# Define the number of images to display
N_IMAGES = 36

In [None]:
# Use this function to get the details
def plot_most_incorrect_predictions(results, test_loader, n_images=36):
    classes = list(test_loader.dataset.class_to_idx.keys())
    for model_name, model_info in results.items():
        model = model_info['model']
        images, labels, pred_labels, probs = get_all_details(model, test_loader)
        corrects = torch.eq(torch.tensor(labels), torch.tensor(pred_labels))
        incorrect_examples = []

        for image, label, prob, correct in zip(images, labels, probs, corrects):
            if not correct:
                incorrect_examples.append((image, label, prob))

    incorrect_examples.sort(key=lambda x: torch.max(x[2], dim=0)[0], reverse=True)
    plot_most_incorrect(incorrect_examples[:n_images], classes, n_images, model_name)

In [None]:
plot_most_incorrect_predictions(crop_results, test_loader, N_IMAGES)

### Representations and Dimensionality Reduction

In [None]:
from sklearn import decomposition, manifold

def get_representations(model, iterator):
    model.eval()
    outputs = []
    labels = []

    with torch.no_grad():
        for x, y in iterator:
            x = x.to(device)
            y_pred = model(x)
            outputs.append(y_pred.cpu())
            labels.append(y)

    outputs = torch.cat(outputs, dim=0)
    labels = torch.cat(labels, dim=0)
    return outputs, labels

In [None]:
def get_pca(data, n_components=2):
    pca = decomposition.PCA(n_components=n_components)
    pca_data = pca.fit_transform(data)
    return pca_data

In [None]:
def plot_representations(data, labels, classes, n_images=None):
    if n_images is not None:
        data = data[:n_images]
        labels = labels[:n_images]

    fig = plt.figure(figsize=(15, 15))
    # fig.suptitle(f'{model_name} - PCA', fontsize=28, y=0.95)
    ax = fig.add_subplot(111)
    scatter = ax.scatter(data[:, 0], data[:, 1], c=labels, cmap='hsv')
    plt.show()
    save_figure(fig, f'{model_name}_pca.png')

In [None]:
outputs, labels = get_representations(model, train_loader)

for model_name in crop_results.keys():
    output_pca_data = get_pca(outputs)
    plot_representations(output_pca_data, labels, classes)  # Adjusted to pass only three arguments

In [None]:
def get_tsne(data, n_components=2, n_images=None):
    if n_images is not None:
        data = data[:n_images]
    tsne = manifold.TSNE(n_components=n_components, random_state=0)
    tsne_data = tsne.fit_transform(data)
    return tsne_data

In [None]:

for model_name in crop_results.keys():
    output_tsne_data = get_tsne(outputs)
    plot_representations(output_tsne_data, labels, classes)

### Filter Visualization

In [None]:
# Function to plot filtered images
def plot_filtered_images(images, filters, model_name, n_filters=None, normalize=True):
    images = torch.cat([i.unsqueeze(0) for i in images], dim=0).cpu()
    filters = filters.cpu()

    if n_filters is not None:
        filters = filters[:n_filters]

    n_images = images.shape[0]
    n_filters = filters.shape[0]

    filtered_images = F.conv2d(images, filters)

    fig = plt.figure(figsize=(30, 30))
    # fig.suptitle(f'{model_name} - Filtered Images', fontsize=28, y=0.8)

    for i in range(n_images):
        image = images[i]
        if normalize:
            image = normalize_image(image)
        ax = fig.add_subplot(n_images, n_filters + 1, i + 1 + (i * n_filters))
        ax.imshow(image.permute(1, 2, 0).numpy())
        ax.set_title('Original')
        ax.axis('off')

        for j in range(n_filters):
            image = filtered_images[i][j]
            if normalize:
                image = normalize_image(image)
            ax = fig.add_subplot(n_images, n_filters + 1, i + 1 + (i * n_filters) + j + 1)
            ax.imshow(image.numpy(), cmap='bone')
            ax.set_title(f'Filter {j + 1}')
            ax.axis('off')

    fig.subplots_adjust(hspace=-0.7)
    plt.show()
    save_figure(fig, f'{model_name}_filtered_images.png')

In [None]:
N_FILTERS = 7

In [None]:
# Example usage within the existing loop
conv_models = ['ResNet152', 'VGG19', 'InceptionV3', 'AttentionAugmentedInceptionV3']  # Add models expected to have conv layers

for model_name, model_info in crop_results.items():
    model = model_info['model']
    if model_name in conv_models:
        if hasattr(model, 'conv1'):
            filters = model.conv1.weight.data
        elif hasattr(model, 'features') and hasattr(model.features, '0'):
            filters = model.features[0].weight.data
        else:
            print(f"Model {model_name} structure is not recognized for convolutional layers.")
            filters = None
    else:
        filters = None  # No convolutional filters in models like ViT

    if filters is not None:
        images = [image for image, label in [train_dataset_others[i] for i in range(N_IMAGES)]]
        plot_filtered_images(images, filters, model_name, n_filters=N_FILTERS)

### Filter Plotting

In [None]:
def plot_filters(filters, normalize=True):
    filters = filters.cpu()
    n_filters = filters.shape[0]
    rows = int(np.sqrt(n_filters))
    cols = int(np.sqrt(n_filters))

    fig = plt.figure(figsize=(30, 15))
    # fig.suptitle(f'{model_name} - Filters', fontsize=28, y=0.95)

    for i in range(rows * cols):
        image = filters[i]
        if normalize:
            image = normalize_image(image)
        ax = fig.add_subplot(rows, cols, i + 1)
        ax.imshow(image.permute(1, 2, 0))
        ax.axis('off')

    fig.subplots_adjust(wspace=-0.9)
    plt.show()
    save_figure(fig, f'{model_name}_filters.png')

In [None]:
# Example usage within the existing loop
conv_models = ['ResNet152', 'VGG19', 'InceptionV3', 'AttentionAugmentedInceptionV3']  # Add models expected to have conv layers

for model_name, model_info in crop_results.items():
    model = model_info['model']
    if model_name in conv_models:
        if hasattr(model, 'conv1'):
            filters = model.conv1.weight.data
        elif hasattr(model, 'features') and hasattr(model.features, '0'):
            filters = model.features[0].weight.data
        else:
            print(f"Model {model_name} structure is not recognized for convolutional layers.")
            filters = None
    else:
        filters = None  # No convolutional filters in models like ViT

    if filters is not None:
        images = [image for image, label in [train_dataset_others[i] for i in range(N_IMAGES)]]
        plot_filters(filters)