In [None]:
import os
import shutil
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

from efficientnet_pytorch import EfficientNet
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, classification_report

from PIL import Image
from PIL.ExifTags import TAGS, GPSTAGS

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import is_image_file

from vit_pytorch import ViT

from AACN_Model import attention_augmented_resnet152, attention_augmented_efficientnetb0, attention_augmented_inceptionv3, attention_augmented_vit

In [None]:
# Define main directories
base_dir = '/Users/izzymohamed/Desktop/Vision For Social Good/EXTRA/CODE/shubham10divakar Multimodal-Plant-Disease-Dataset/Data' 
crop_root = os.path.join(base_dir, 'color')
split_root = os.path.join(base_dir, 'split')

In [None]:
# Load CSV data
csv_path = '/Users/izzymohamed/Desktop/Vision For Social Good/EXTRA/CODE/shubham10divakar Multimodal-Plant-Disease-Dataset/Data/plant_disease_multimodal_dataset.csv'
csv_data = pd.read_csv(csv_path)

In [None]:
# Separate the image paths and labels from the features
csv_image_paths = csv_data['Image Path'].values
csv_labels = csv_data['Mapped Label'].values
csv_features = csv_data.drop(columns=['Image Path', 'Mapped Label', 'Label']).values.astype(np.float32)

In [None]:
# Define function to remove .DS_Store files
def remove_ds_store(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file == '.DS_Store' or '.DS_Store' in file:
                file_path = os.path.join(root, file)
                print(f"Removing {file_path}")
                os.remove(file_path)

In [None]:
# Remove .DS_Store files from base directory
remove_ds_store(base_dir)

In [None]:
# Function to check if a file is an image file
def is_image_file(filename):
    return filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'))

In [None]:
# Function to split data into train, validation, and test sets
def split_data(base_dir, val_split=0.4, test_split=0.1):
    train_files = []
    val_files = []
    test_files = []

    classes = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
    for cls in classes:
        print(f'Processing class: {cls}')
        class_dir = os.path.join(base_dir, cls)

        images = [f for f in os.listdir(class_dir) if is_image_file(os.path.join(class_dir, f))]

        if len(images) == 0:
            print(f"No images found for class {cls}. Skipping...")
            continue

        # Shuffle images to randomize the selection
        random.shuffle(images)

        try:
            train, test = train_test_split(images, test_size=test_split)
            train, val = train_test_split(train, test_size=val_split / (1 - test_split))
        except ValueError as e:
            print(f"Not enough images to split for class {cls}: {e}")
            continue

        train_files.extend([(os.path.join(class_dir, img), cls) for img in train])
        val_files.extend([(os.path.join(class_dir, img), cls) for img in val])
        test_files.extend([(os.path.join(class_dir, img), cls) for img in test])

    return train_files, val_files, test_files, classes

In [None]:
# Split data
train_files, val_files, test_files, classes = split_data(crop_root)

In [None]:
# Use the lists of file paths for your dataset loading and transformations
print(f"Train files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")
print(f"Test files: {len(test_files)}")

In [None]:
# Define the standard image sizes
inception_size = 299
other_size = 224

In [None]:
data_transforms = {
    'InceptionV3': {
        'train': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((inception_size, inception_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    },
    'Others': {
        'train': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((other_size, other_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
}

In [None]:
class CustomMultimodalDataset(Dataset):
    def __init__(self, file_paths, csv_features, csv_labels, class_to_idx, transform=None):
        self.file_paths = file_paths
        self.csv_features = csv_features
        self.csv_labels = csv_labels
        self.class_to_idx = class_to_idx
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path, cls = self.file_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.class_to_idx[cls]
        csv_row = self.csv_features[idx]
        if self.transform:
            image = self.transform(image)
        return image, csv_row, label

In [None]:
# Create a mapping from class names to indices
class_to_idx = {cls: idx for idx, cls in enumerate(classes)}

In [None]:
# Create datasets and data loaders
train_dataset_inception = CustomMultimodalDataset(train_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['train'])
val_dataset_inception = CustomMultimodalDataset(val_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['val'])
test_dataset_inception = CustomMultimodalDataset(test_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['test'])

train_loader_inception = DataLoader(train_dataset_inception, batch_size=32, shuffle=True)
val_loader_inception = DataLoader(val_dataset_inception, batch_size=32, shuffle=True)
test_loader_inception = DataLoader(test_dataset_inception, batch_size=32, shuffle=False)

In [None]:
# Loaders for other models
train_dataset_others = CustomMultimodalDataset(train_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['train'])
val_dataset_others = CustomMultimodalDataset(val_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['val'])
test_dataset_others = CustomMultimodalDataset(test_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['test'])

train_loader_others = DataLoader(train_dataset_others, batch_size=32, shuffle=True)
val_loader_others = DataLoader(val_dataset_others, batch_size=32, shuffle=True)
test_loader_others = DataLoader(test_dataset_others, batch_size=32, shuffle=False)


In [None]:
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.conv0 = nn.Conv2d(in_channels, 128, kernel_size=1)
        self.conv1 = nn.Conv2d(128, 768, kernel_size=3)  # Adjusted kernel size to 3
        self.fc = nn.Linear(768, num_classes)

    def forward(self, x):
        x = self.conv0(x)
        x = self.conv1(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [None]:
models.Inception3

In [None]:
class FusionModel(nn.Module):
    def __init__(self, base_model, csv_input_dim, num_classes, fusion_method='late'):
        super(FusionModel, self).__init__()
        self.base_model = base_model
        self.fusion_method = fusion_method

        # Add support for the specific model
        if isinstance(self.base_model, models.Inception3):
            self.base_model.aux_logits = False
            self.feature_size = 2048  # Output feature size for InceptionV3
        elif isinstance(self.base_model, models.ResNet):
            self.feature_size = self.base_model.fc.in_features  # Output feature size for ResNet
            self.base_model.fc = nn.Identity()  # Replace the final fully connected layer with identity
        elif isinstance(self.base_model, models.VGG):
            self.feature_size = self.base_model.classifier[0].in_features  # Output feature size for VGG
            self.base_model.classifier = nn.Identity()  # Replace the final classifier layer with identity
        elif isinstance(self.base_model, ViT):
            self.feature_size = self.base_model.dim  # Output feature size for ViT
        elif isinstance(self.base_model, attention_augmented_resnet152) or isinstance(self.base_model, attention_augmented_inceptionv3):
            self.feature_size = 2048  # Adjust as needed for attention-augmented models
        else:
            raise NotImplementedError("Model not supported")

        # Define CSV feature extractor
        self.csv_fc = nn.Linear(csv_input_dim, 128)
        
        if fusion_method == 'intermediate':
            self.fc1 = nn.Linear(128 + self.feature_size, 256)
            self.fc2 = nn.Linear(256, num_classes)
        elif fusion_method == 'late':
            self.fc = nn.Linear(self.feature_size + 128, num_classes)
    
    def forward(self, x_img, x_csv):
        # Extract features from image
        x_img = self.base_model(x_img)

        if isinstance(x_img, tuple):
            x_img = x_img[0]
        
        # Extract features from CSV data
        x_csv = F.relu(self.csv_fc(x_csv))
        
        if self.fusion_method == 'intermediate':
            # Intermediate fusion
            x = torch.cat((x_img, x_csv), dim=1)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
        elif self.fusion_method == 'late':
            # Late fusion
            x = torch.cat((x_img, x_csv), dim=1)
            x = self.fc(x)
        
        return x

In [None]:
# Function to create and train the fusion model
def create_and_train_fusion_model(model, train_loader, val_loader, num_classes, csv_input_dim, device, fusion_method='late', num_epochs=1, initial_lr=0.001):
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
    early_stopping_patience = 5
    best_val_loss = float('inf')
    patience_counter = 0

    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for i, data in enumerate(train_loader):
            inputs_img, inputs_csv, labels = data
            inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs_img, inputs_csv)
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in val_loader:
                inputs_img, inputs_csv, labels = data
                inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)
                outputs = model(inputs_img, inputs_csv)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct / total

        print(f'Epoch {epoch + 1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping due to no improvement in validation loss.")
                break

    return model

In [None]:
# Function to evaluate the fusion model
def evaluate_fusion_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs_img, inputs_csv, labels = data
            inputs_img, inputs_csv, labels = inputs_img.to(device), inputs_csv.to(device), labels.to(device)
            outputs = model(inputs_img, inputs_csv)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= len(test_loader)
    test_accuracy = 100 * correct / total
    return test_loss, test_accuracy

In [None]:
# Assuming `crops` and directories (`train_dir`, `val_dir`, `test_dir`) are defined
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
# Define crops and initialize results dictionary
results = {}

In [None]:
# Iterate over each crop
train_dataset_inception = CustomMultimodalDataset(train_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['train'])
val_dataset_inception = CustomMultimodalDataset(val_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['val'])
test_dataset_inception = CustomMultimodalDataset(test_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['test'])

train_loader_inception = DataLoader(train_dataset_inception, batch_size=32, shuffle=True)
val_loader_inception = DataLoader(val_dataset_inception, batch_size=32, shuffle=True)
test_loader_inception = DataLoader(test_dataset_inception, batch_size=32, shuffle=False)

train_dataset_others = CustomMultimodalDataset(train_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['train'])
val_dataset_others = CustomMultimodalDataset(val_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['val'])
test_dataset_others = CustomMultimodalDataset(test_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['Others']['test'])

train_loader_others = DataLoader(train_dataset_others, batch_size=32, shuffle=True)
val_loader_others = DataLoader(val_dataset_others, batch_size=32, shuffle=True)
test_loader_others = DataLoader(test_dataset_others, batch_size=32, shuffle=False)

num_classes_inception = len(class_to_idx)
num_classes_others = len(class_to_idx)

num_heads = 8
csv_input_dim = csv_features.shape[1]

pretrained_models = {
    # 'InceptionV3': models.inception_v3(pretrained=True, aux_logits=True),
    # 'ResNet152': models.resnet152(pretrained=True),
    # 'VGG19': models.vgg19(pretrained=True),
    "AttentionAugmentedResNet152": attention_augmented_resnet152(num_classes=num_classes_others, attention=[False, True, True, True], num_heads=num_heads),
    "AttentionAugmentedInceptionV3": attention_augmented_inceptionv3(attention=True),
    'ViT': ViT(
        image_size=224,
        patch_size=16,
        num_classes=num_classes_others,
        dim=1024,
        depth=6,
        heads=16,
        mlp_dim=2048,
        dropout=0.1,
        emb_dropout=0.1
    ),
}

if 'InceptionV3' in pretrained_models:
    pretrained_models['InceptionV3'].aux_logits = False

crop_results = {}

for model_name, base_model in pretrained_models.items():
    for fusion_method in ['late', 'intermediate']:
        fusion_model = FusionModel(base_model, csv_input_dim, num_classes_others, fusion_method)

        print(f'Training {model_name} with {fusion_method} fusion')
        model = create_and_train_fusion_model(fusion_model, train_loader_others, val_loader_others, num_classes_others, csv_input_dim, device, fusion_method, initial_lr=0.001)

        test_loss, test_accuracy = evaluate_fusion_model(model, test_loader_others, nn.CrossEntropyLoss(), device)

        crop_results[f"{model_name}_{fusion_method}"] = {
            'model': model,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy
        }
        print(f'{crop} - {model_name} with {fusion_method} fusion Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

results[crop] = crop_results

In [None]:
# Create datasets and data loaders
train_dataset_inception = CustomMultimodalDataset(train_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['train'])
val_dataset_inception = CustomMultimodalDataset(val_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['val'])
test_dataset_inception = CustomMultimodalDataset(test_files, csv_features, csv_labels, class_to_idx, transform=data_transforms['InceptionV3']['test'])

train_loader_inception = DataLoader(train_dataset_inception, batch_size=32, shuffle=True)
val_loader_inception = DataLoader(val_dataset_inception, batch_size=32, shuffle=True)
test_loader_inception = DataLoader(test_dataset_inception, batch_size=32, shuffle=False)

In [None]:
# Function to display F1, precision, and recall of all models as a table
def display_model_metrics_table(results, test_loader):
    metrics_data = []
    
    for crop, crop_results in results.items():
        for model_name, model_info in crop_results.items():
            model = model_info['model']
            device = next(model.parameters()).device  # Get the device of the model
            model.eval()  # Set the model to evaluation mode

            all_labels = []
            all_predicted = []

            for images, csv_data, labels in test_loader:
                images, csv_data, labels = images.to(device), csv_data.to(device), labels.to(device)

                with torch.no_grad():
                    outputs = model(images, csv_data)
                    _, predicted = torch.max(outputs, 1)

                all_labels.extend(labels.cpu().numpy())
                all_predicted.extend(predicted.cpu().numpy())

            precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predicted, average='macro')           
            metrics_data.append({
                'Crop': crop,
                'Model': model_name,
                'Precision': precision,
                'Recall': recall,
                'F1-score': f1
            })

    metrics_df = pd.DataFrame(metrics_data)
    display(metrics_df)  # Display the DataFrame in Jupyter Notebook

In [None]:
# Function to display the classification report of a given model
def display_classification_report(model, test_loader):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode

    all_labels = []
    all_predicted = []

    for images, csv_data, labels in test_loader:
        images, csv_data, labels = images.to(device), csv_data.to(device), labels.to(device)

        with torch.no_grad():
            outputs = model(images, csv_data)
            _, predicted = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_predicted.extend(predicted.cpu().numpy())

    report = classification_report(all_labels, all_predicted, target_names=test_loader.dataset.class_to_idx.keys())
    print(report)

In [None]:
# Display some correctly and incorrectly classified images
def display_classification_results(model, test_loader, num_images=5):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode
    class_labels = list(test_loader.dataset.class_to_idx.keys())
    
    images, csv_data, labels = next(iter(test_loader))
    images, csv_data, labels = images[:num_images].to(device), csv_data[:num_images].to(device), labels[:num_images].to(device)
    
    with torch.no_grad():
        outputs = model(images, csv_data)
        _, predicted = torch.max(outputs, 1)
    
    fig, axes = plt.subplots(1, num_images, figsize=(20, 8))
    fig.suptitle('Classification Results', fontsize=16)
    
    for i in range(num_images):
        ax = axes[i]
        img = images[i].cpu().numpy().transpose((1, 2, 0))
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(f'True: {class_labels[labels[i]]}\n Pred: {class_labels[predicted[i].cpu()]}')
        ax.axis('off')

    plt.show()

In [None]:
# Display the table of metrics for all models
display_model_metrics_table(results, test_loader_others)

In [None]:
# Display results for each crop
for crop, crop_results in results.items():
    for model_name in crop_results.keys():
        print(f'Displaying results for {crop} - {model_name}')
        display_classification_results(crop_results[model_name]['model'], test_loader_others)

In [None]:
# Display results for each crop
for crop, crop_results in results.items():
    for model_name in crop_results.keys():
        print(f'Displaying classification report for {crop} - {model_name}')
        display_classification_report(crop_results[model_name]['model'], test_loader_others)