In [1]:
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

from efficientnet_pytorch import EfficientNet
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, classification_report


from PIL import Image
from PIL.ExifTags import TAGS, GPSTAGS

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import is_image_file

from vit_pytorch import ViT

from AACN_Model import attention_augmented_resnet50, attention_augmented_efficientnetb0, attention_augmented_inceptionv3, attention_augmented_vit

In [2]:
# Define main directories
base_dir = '/Users/izzymohamed/Downloads/Cherry v2'
crop_root = os.path.join(base_dir, 'Ground_RGB_Photos/All')
split_root = os.path.join(base_dir, 'Ground_RGB_Photos/Split')

In [3]:
# Define function to remove .DS_Store files
def remove_ds_store(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file == '.DS_Store' or '.DS_Store' in file:
                file_path = os.path.join(root, file)
                print(f"Removing {file_path}")
                os.remove(file_path)

In [4]:
# Remove .DS_Store files from base directory
remove_ds_store(base_dir)

In [5]:
def is_image_file(filename):
    # Assuming is_image_file is a function that checks if the file is an image
    return filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'))

In [6]:
# Function to split data into train, validation, and test sets
def split_data(base_dir, train_dir, val_dir, test_dir, val_split=0.2, test_split=0.1):

    # Remove existing directories if they exist
    if os.path.exists(train_dir):
        shutil.rmtree(train_dir)
    if os.path.exists(val_dir):
        shutil.rmtree(val_dir)
    if os.path.exists(test_dir):
        shutil.rmtree(test_dir)

    # Create new directories
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)
    
    classes = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
    for cls in classes:
        print(f'Processing class: {cls}')
        class_dir = os.path.join(base_dir, cls)
        
        train_cls_dir = os.path.join(train_dir, cls)
        val_cls_dir = os.path.join(val_dir, cls)
        test_cls_dir = os.path.join(test_dir, cls)

        os.makedirs(train_cls_dir, exist_ok=True)
        os.makedirs(val_cls_dir, exist_ok=True)
        os.makedirs(test_cls_dir, exist_ok=True)

        images = [f for f in os.listdir(class_dir) if is_image_file(os.path.join(class_dir, f))]
        
        if len(images) == 0:
            print(f"No images found for class {cls}. Skipping...")
            continue
        
        try:
            train, test = train_test_split(images, test_size=test_split)
            train, val = train_test_split(train, test_size=val_split / (1 - test_split))
        except ValueError as e:
            print(f"Not enough images to split for class {cls}: {e}")
            continue

        for img in train:
            shutil.copy(os.path.join(class_dir, img), os.path.join(train_cls_dir, img))
        for img in val:
            shutil.copy(os.path.join(class_dir, img), os.path.join(val_cls_dir, img))
        for img in test:
            shutil.copy(os.path.join(class_dir, img), os.path.join(test_cls_dir, img))

In [7]:
# Define train, validation, and test directories
train_dir = os.path.join(split_root, 'train_set')
val_dir = os.path.join(split_root, 'val_set')
test_dir = os.path.join(split_root, 'test_set')

In [8]:
# Split data
split_data(crop_root, train_dir, val_dir, test_dir)

Processing class: Healthy
Processing class: Armillaria_Stage_1
Processing class: Armillaria_Stage_2
Processing class: Armillaria_Stage_3


In [9]:
data_transforms = {
    'InceptionV3': {
        'train': transforms.Compose([
            transforms.Resize(299),
            transforms.CenterCrop(299),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(20),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'val': transforms.Compose([
            transforms.Resize(299),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'test': transforms.Compose([
            transforms.Resize(299),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
    },
    'Others': {
        'train': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(20),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'val': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'test': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
    },
}

In [10]:
# Assuming `crops` and directories (`train_dir`, `val_dir`, `test_dir`) are defined
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [11]:
# Function to check if a file is valid
def is_valid_file(path):
    return not path.endswith('.DS_Store') or 'DS_Store' not in path

In [12]:
# Function to adjust learning rate
def adjust_learning_rate(optimizer, epoch, learning_rate):
    """Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
    lr = learning_rate * (0.1 ** (epoch // 10))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [13]:
# Function to evaluate the models
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= len(test_loader)
    test_accuracy = 100 * correct / total
    return test_loss, test_accuracy

In [14]:
# Function to create and train the model
def create_and_train_model(model, train_loader, val_loader, num_classes, device, num_epochs=40, initial_lr=0.001):
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr, weight_decay=5e-4, momentum=0.9)
    early_stopping_patience = 5
    best_val_loss = float('inf')
    patience_counter = 0

    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        adjust_learning_rate(optimizer, epoch, initial_lr)
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct / total

        print(f'Epoch {epoch + 1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping due to no improvement in validation loss.")
                break

    return model

In [15]:
# Define crops and initialize results dictionary
crops = ['Cherry'] #'Armillaria_Stage_1', 'Armillaria_Stage_2', 'Armillaria_Stage_3', 'Healthy'
results = {}

In [16]:
# Function to find classes in a directory
def find_classes(dir):
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)
        print(f"Created directory: {dir}")
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and not d.startswith('.')]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

In [17]:
# Iterate over each crop
for crop in crops:
    crop_train_dir = train_dir  # os.path.join(train_dir, crop)
    crop_val_dir = val_dir  # os.path.join(val_dir, crop)
    crop_test_dir = test_dir  # os.path.join(test_dir, crop)

    if not os.path.exists(crop_train_dir) or not os.listdir(crop_train_dir):
        print(f"No data found in training directory {crop_train_dir}. Skipping {crop}.")
        continue

    if not os.path.exists(crop_val_dir) or not os.listdir(crop_val_dir):
        print(f"No data found in validation directory {crop_val_dir}. Skipping {crop}.")
        continue

    # Load datasets with the new transformations
    train_dataset_inception = datasets.ImageFolder(crop_train_dir, transform=data_transforms['InceptionV3']['train'])
    val_dataset_inception = datasets.ImageFolder(crop_val_dir, transform=data_transforms['InceptionV3']['val'])
    test_dataset_inception = datasets.ImageFolder(crop_test_dir, transform=data_transforms['InceptionV3']['test'])

    train_loader_inception = DataLoader(train_dataset_inception, batch_size=32, shuffle=True)
    val_loader_inception = DataLoader(val_dataset_inception, batch_size=32, shuffle=True)
    test_loader_inception = DataLoader(test_dataset_inception, batch_size=32, shuffle=False)

    # Loaders for other models
    train_dataset_others = datasets.ImageFolder(crop_train_dir, transform=data_transforms['Others']['train'])
    val_dataset_others = datasets.ImageFolder(crop_val_dir, transform=data_transforms['Others']['val'])
    test_dataset_others = datasets.ImageFolder(crop_test_dir, transform=data_transforms['Others']['test'])

    train_loader_others = DataLoader(train_dataset_others, batch_size=32, shuffle=True)
    val_loader_others = DataLoader(val_dataset_others, batch_size=32, shuffle=True)
    test_loader_others = DataLoader(test_dataset_others, batch_size=32, shuffle=False)

    num_classes_inception = len(train_dataset_inception.classes)
    num_classes_others = len(train_dataset_others.classes)
 
    num_heads = 8

    test_dataset_inception.class_to_idx = train_dataset_inception.class_to_idx
    test_dataset_others.class_to_idx = train_dataset_others.class_to_idx

    pretrained_models = {
        'EfficientNetB0': EfficientNet.from_pretrained('efficientnet-b0'),
        'InceptionV3': models.inception_v3(pretrained=True),
        'ResNet50': models.resnet50(pretrained=True),
        "AttentionAugmentedResNet50": attention_augmented_resnet50(num_classes=num_classes_others, attention=[False,True,True,True], num_heads=num_heads),
        "AttentionAugmentedInceptionV3": attention_augmented_inceptionv3(attention=True),
        # "AttentionAugmentedEfficientNetB0": attention_augmented_efficientnetb0(attention=True),
        # "AttentionAugmentedViT": attention_augmented_vit(attention=True),
        'ViT': ViT(
            image_size=224,
            patch_size=16,
            num_classes=num_classes_others,
            dim=1024,
            depth=6,
            heads=16,
            mlp_dim=2048,
            dropout=0.1,
            emb_dropout=0.1
        ),
    }

    crop_results = {}

    print(f'--------------------------------------------- Crop: {crop} ---------------------------------------------\n')

    for model_name, base_model in pretrained_models.items():
        if model_name == 'InceptionV3':
            base_model.AuxLogits.fc = nn.Linear(base_model.AuxLogits.fc.in_features, num_classes_inception)
            base_model.fc = nn.Linear(base_model.fc.in_features, num_classes_inception)
            train_loader = train_loader_inception
            val_loader = val_loader_inception
            test_loader = test_loader_inception
        elif model_name == 'EfficientNetB0':
            base_model._fc = nn.Linear(base_model._fc.in_features, num_classes_others)
            train_loader = train_loader_others
            val_loader = val_loader_others
            test_loader = test_loader_others
        elif model_name == 'ViT':
            base_model.mlp_head = nn.Linear(base_model.mlp_head.in_features, num_classes_others)
            train_loader = train_loader_others
            val_loader = val_loader_others
            test_loader = test_loader_others
        elif model_name == 'ResNet50':
            base_model.fc = nn.Linear(base_model.fc.in_features, num_classes_others)
            train_loader = train_loader_others
            val_loader = val_loader_others
            test_loader = test_loader_others
        else:
            train_loader = train_loader_others
            val_loader = val_loader_others
            test_loader = test_loader_others

        print(f'--------------- Training model: {model_name}')
        model = create_and_train_model(base_model, train_loader, val_loader, num_classes_others, device, initial_lr=0.001)

        test_loss, test_accuracy = evaluate_model(model, test_loader, nn.CrossEntropyLoss(), device)

        crop_results[model_name] = {
            'model': model,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy
        }
        print(f'{crop} - {model_name} Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
        print(f'\n')

    results[crop] = crop_results

Loaded pretrained weights for efficientnet-b0




--------------------------------------------- Crop: Cherry ---------------------------------------------

--------------- Training model: EfficientNetB0
Epoch 1/40, Train Loss: 0.4094, Train Accuracy: 90.87%, Val Loss: 0.5567, Val Accuracy: 92.72%
Epoch 2/40, Train Loss: 0.2886, Train Accuracy: 92.85%, Val Loss: 0.4340, Val Accuracy: 92.72%
Epoch 3/40, Train Loss: 0.2623, Train Accuracy: 92.70%, Val Loss: 0.3918, Val Accuracy: 92.72%
Epoch 4/40, Train Loss: 0.2615, Train Accuracy: 92.85%, Val Loss: 0.2980, Val Accuracy: 92.72%
Epoch 5/40, Train Loss: 0.2418, Train Accuracy: 92.95%, Val Loss: 0.3344, Val Accuracy: 92.89%
Epoch 6/40, Train Loss: 0.2406, Train Accuracy: 92.90%, Val Loss: 0.3261, Val Accuracy: 88.39%
Epoch 7/40, Train Loss: 0.2300, Train Accuracy: 93.10%, Val Loss: 0.2657, Val Accuracy: 92.55%
Epoch 8/40, Train Loss: 0.2372, Train Accuracy: 92.90%, Val Loss: 0.2541, Val Accuracy: 91.85%
Epoch 9/40, Train Loss: 0.2209, Train Accuracy: 93.50%, Val Loss: 0.2626, Val Accuracy:

In [None]:
# Plot comparison of accuracy for each model for each crop
for crop, crop_results in results.items():
    accuracies = [result['test_accuracy'] for result in crop_results.values()]
    model_names = list(crop_results.keys())
    
    plt.figure(figsize=(20, 10))
    plt.bar(model_names, accuracies)
    plt.title(f'Model test accuracy comparison for {crop}')
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Model')
    plt.show()

In [None]:
# Function to display F1, precision, and recall of all models as a table
def display_model_metrics_table(results, test_loader):
    metrics_data = []
    
    for crop, crop_results in results.items():
        for model_name, model_info in crop_results.items():
            model = model_info['model']
            device = next(model.parameters()).device  # Get the device of the model
            model.eval()  # Set the model to evaluation mode

            all_labels = []
            all_predicted = []

            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                with torch.no_grad():
                    outputs = model(images)
                    _, predicted = torch.max(outputs, 1)

                all_labels.extend(labels.cpu().numpy())
                all_predicted.extend(predicted.cpu().numpy())

            precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predicted, average='macro')
            
            metrics_data.append({
                'Crop': crop,
                'Model': model_name,
                'Precision': precision,
                'Recall': recall,
                'F1-score': f1
            })

    metrics_df = pd.DataFrame(metrics_data)
    display(metrics_df)  # Display the DataFrame in Jupyter Notebook

In [None]:
# Function to display the classification report of a given model
def display_classification_report(model, test_loader):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode

    all_labels = []
    all_predicted = []

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_predicted.extend(predicted.cpu().numpy())

    report = classification_report(all_labels, all_predicted, target_names=test_loader.dataset.classes)
    print(report)

In [None]:
# Display some correctly and incorrectly classified images
def display_classification_results(model, test_loader, num_images=5):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()  # Set the model to evaluation mode
    class_labels = test_loader.dataset.classes
    
    images, labels = next(iter(test_loader))
    images, labels = images[:num_images].to(device), labels[:num_images]  # Move tensors to the model's device
    
    with torch.no_grad():
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
    
    fig, axes = plt.subplots(1, num_images, figsize=(20, 8))
    fig.suptitle('Classification Results', fontsize=16)
    
    for i in range(num_images):
        ax = axes[i]
        img = images[i].cpu().numpy().transpose((1, 2, 0))  # Move tensor back to CPU for visualization
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(f'True: {class_labels[labels[i]]}\n Pred: {class_labels[predicted[i].cpu()]}')  # Access CPU tensor for labels
        ax.axis('off')

    plt.show()

In [None]:
# Display the table of metrics for all models
display_model_metrics_table(results, test_loader)

In [None]:
# Display results for each crop
for crop, crop_results in results.items():
    for model_name in crop_results.keys():
        print(f'Displaying results for {crop} - {model_name}')
        display_classification_results(crop_results[model_name]['model'], test_loader)

In [None]:
# Display results for each crop
for crop, crop_results in results.items():
    for model_name in crop_results.keys():
        print(f'Displaying classification report for {crop} - {model_name}')
        display_classification_report(crop_results[model_name]['model'], test_loader)