In [None]:
### IMPORT LIBRARIES ###


# Import necessary libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from IPython.display import display, clear_output
import seaborn as sns

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.transforms import Resize, CenterCrop, ToTensor, Normalize
from torchvision import models
from torchvision.models import efficientnet_b0
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
### DATA PREPARATION FUNCTION ###



def prepare_data(dataset_directory):
    # List and sort the files in the dataset directory
    sorted_file_list = os.listdir(dataset_directory)
    sorted_file_list.sort()
    # Skip the first item if it's a system file
    sorted_file_list = sorted_file_list[1:]

    # Initialize an empty DataFrame for image paths and their types
    image_data = pd.DataFrame(columns=['Image path', 'Type'])

    # Loop through each category folder
    for category in sorted_file_list:
        folder_path = os.path.join(dataset_directory, category)
        # Filter for image files only
        image_files = [file for file in os.listdir(folder_path) if file.lower().endswith(('.png', '.jpg', '.jpeg'))
                       and not file.startswith('.')]

        # Add image path and category to the DataFrame
        for image_file in image_files:
            image_path = os.path.join(folder_path, image_file)
            image_data.loc[len(image_data)] = [image_path, category]

    # Map each category to a unique label
    type_to_label = {type_name: index for index, type_name in enumerate(image_data['Type'].unique())}
    image_data['Label'] = image_data['Type'].apply(lambda x: type_to_label[x])

    return image_data, type_to_label

In [None]:
### DATASET AND DATALOADER SETUP FUNCTION###


# Custom dataset class for loading and transforming images
class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe.iloc[idx]['Image path']
        label = self.dataframe.iloc[idx]['Label']
        image = Image.open(image_path).convert('RGB')  # Convert images to RGB format

        if self.transform:
            image = self.transform(image)

        return image, label

    
def setup_data_loaders(image_data, data_transforms, batch_size=4):
    # Initialize the custom dataset with the provided data and transformations
    custom_dataset = CustomDataset(dataframe=image_data, transform=data_transforms)

    # Split the dataset indices into training, validation, and testing sets
    train_val_indices, test_indices = train_test_split(image_data.index.tolist(),
                                                       test_size=0.2, stratify=image_data['Label'].values)
    train_indices, val_indices = train_test_split(train_val_indices,
                                            test_size=0.25, stratify=image_data['Label'].values[train_val_indices])

    # Create subsets for training, validation, and testing
    train_dataset = Subset(custom_dataset, train_indices)
    val_dataset = Subset(custom_dataset, val_indices)
    test_dataset = Subset(custom_dataset, test_indices)

    # Initialize DataLoader for each set
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

In [None]:
### DISPLAY IMAGES FUNCTION ###


def show_images(images, labels):
    # Determine the number of images
    num_images = len(images)
    # Set up a grid of plots
    fig, axes = plt.subplots(1, num_images, figsize=(12, 4))

    for i in range(num_images):
        # Convert tensor to numpy and normalize the pixel values to [0, 1]
        image = np.clip(images[i].permute(1, 2, 0).numpy(), 0, 1)
        label = labels[i].item()
        # Display the image and its label
        axes[i].imshow(image)
        axes[i].set_title(f"Label: {label}")
        axes[i].axis('off')

    plt.show()

In [None]:
### MODEL SETUP FUNCTION ###


def setup_model(num_classes, device):
    # Load a pre-trained ResNet18 model
    model = models.resnet18(pretrained=True)
    # Freeze all model parameters to prevent updating during training
    for param in model.parameters():
        param.requires_grad = False

    # Replace the final layer to match the number of classes
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    # Move the model to the specified device (GPU or CPU)
    model = model.to(device)

    # Set up the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

    return model, criterion, optimizer

In [None]:
### TRAINING LOOP WITH EARLY STOPPING FUNCTION ###


def train_model(model, criterion, optimizer, train_loader, val_loader, device, patience=5):
    best_val_loss = float('inf')  # Initialize best validation loss for early stopping
    train_loss_values, val_loss_values = [], []  # Lists to track loss values
    epoch = 0

    while True:
        # Training phase
        model.train()  # Set model to training mode
        total_train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to device
            optimizer.zero_grad()  # Clear gradients
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Update model parameters
            total_train_loss += loss.item()  # Accumulate loss
        train_loss_values.append(total_train_loss / len(train_loader))  # Compute average training loss
        
        # Validation phase
        model.eval()  # Set model to evaluation mode
        total_val_loss = 0.0
        with torch.no_grad():  # No gradient computation in validation phase
            for val_inputs, val_labels in val_loader:
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)  # Move data to device
                val_outputs = model(val_inputs)  # Forward pass
                total_val_loss += criterion(val_outputs, val_labels).item()  # Accumulate loss
        total_val_loss /= len(val_loader)  # Compute average validation loss
        val_loss_values.append(total_val_loss)  # Track validation loss
        
        # Clear previous output
        clear_output(wait=True)

        # Plot training and validation loss
        plt.figure(figsize=(10, 6))
        plt.plot(loss_values, marker='o', linestyle='-', color='b', label='Train Loss')
        plt.plot(val_loss_values, marker='o', linestyle='-', color='r', label='Val Loss')
        plt.title('Train vs Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, which="both", ls="-")
        plt.xticks(range(epoch + 1))
        display(plt.gcf())
        plt.close()
    
        print(f"Epoch {epoch+1}, Training Loss: {loss_values[-1]}, Validation Loss: {val_loss_values[-1]}")
        
        # Early stopping check
        if total_val_loss < best_val_loss:
            best_val_loss = total_val_loss  # Update best validation loss
            best_model_state = model.state_dict()  # Save best model state
            epochs_no_improve = 0  # Reset counter for epochs without improvement
        else:
            epochs_no_improve += 1  # Increment counter
            if epochs_no_improve == patience:  # Check if patience limit is reached
                print(f"Early stopping at epoch {epoch+1}. Best validation loss: {best_val_loss:.4f}")
                break  # Exit training loop
        
        epoch += 1  # Increment epoch counter

    return model, best_model_state, train_loss_values, val_loss_values  # Return model, best state, and loss values

In [None]:
### TEST MODEL FUNCTION ###


def test_model(model, test_loader, device, class_idx_to_label):
    # Ensure the model is in evaluation mode
    model.eval()
    
    # Lists to store all predictions and labels
    all_predictions, all_labels = [], []
    
    # No gradient computation needed
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to the appropriate device
            outputs = model(inputs)  # Forward pass
            _, predicted = torch.max(outputs.data, 1)  # Get the predicted classes
            
            # Append predictions and labels for later evaluation
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Compute the accuracy
    accuracy = 100 * np.sum(np.array(all_predictions) == np.array(all_labels)) / len(all_labels)
    print(f'Accuracy on the test set: {accuracy:.2f}%')
    
    # Detailed classification report
    print("\nDetailed classification report:")
    print(classification_report(all_labels, all_predictions, target_names=[class_idx_to_label[i] 
                                                                        for i in range(len(class_idx_to_label))]))
    
    # Plot the confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=[class_idx_to_label[i]
                                for i in range(len(class_idx_to_label))], yticklabels=[class_idx_to_label[i]
                                                                       for i in range(len(class_idx_to_label))])
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

In [None]:
### PREPARE THE DATA ###

dataset_directory = './data/archive'  # Ensure this is the correct path to your dataset
image_data, type_to_label = prepare_data(dataset_directory)

In [None]:
### SET UP DATALOADERS ###

# Data augmentation and normalization for training and validation
data_transforms = transforms.Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize DataLoaders for training, validation, and testing
batch_size = 4  # You can adjust this according to your system's capabilities
train_loader, val_loader, test_loader = setup_data_loaders(image_data, data_transforms, batch_size=batch_size)

In [None]:
### DISPLAY A BATCH OF IMAGES ###

# Display a batch of images from the training set
for _ in range(2):
    for batch_images, batch_labels in train_loader:
        show_images(batch_images, batch_labels)
        break  # Only display one batch

In [None]:
### SETUP THE MODEL ###

# Check for device availability (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model, criterion (loss function), and optimizer
num_classes = len(type_to_label)  # The number of classes in your dataset
model, criterion, optimizer = setup_model(num_classes, device)

In [None]:
### TRAIN THE MODEL ###

# Train the model (this might take a while depending on your dataset and system)
model, best_model_state, train_loss_values, val_loss_values = \
    train_model(model, criterion, optimizer, train_loader, val_loader, device, patience=5)

In [None]:
### TEST THE MODEL ###

# Loads the best state of the model
model.load_state_dict(best_model_state)

# Reverse the `type_to_label` dictionary to get class indices mapped back to their labels
class_idx_to_label = {v: k for k, v in type_to_label.items()}

# Call the `test_model` function
test_model(model, test_loader, device, class_idx_to_label)