# DEPENDENCIES

In [1]:
import time
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import multiprocessing as mp
from multiprocessing import Pool
import numpy as np
import warnings

# Suppress all UserWarnings
warnings.simplefilter("ignore", category=UserWarning)

## Data Preparation 

In [2]:
def load_data(index, main_dir, batch_size, img_height, img_width, train_indices, valid_indices):
    """
    Load subset of data for training/validation.

    Args:
        index (int): 0 for training data, 1 for validation data.
        main_dir (str): Directory containing the image dataset.
        batch_size (int): Batch size for data loading.
        img_height (int): Height of input images.
        img_width (int): Width of input images.
        train_indices (list): Indices for training dataset.
        valid_indices (list): Indices for validation dataset.

    Returns:
        tuple: DataLoader object for subset and list of classes.
    """
    transform = transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    full_dataset = datasets.ImageFolder(root=main_dir, transform=transform)
    
    subset_indices = train_indices if index == 0 else valid_indices
    
    subset_dataset = torch.utils.data.Subset(full_dataset, subset_indices)
    classes = full_dataset.classes
    
    return DataLoader(subset_dataset, batch_size=batch_size, shuffle=True, pin_memory=True), classes

## Function using Multiprocessing

In [3]:
def get_data_loaders(main_dir, batch_size, img_height, img_width):
    """
    Constructs and returns data loaders for training and validation sets using multiprocessing.

    Args:
        main_dir (str): The main directory containing the image dataset.
        batch_size (int): The batch size for data loaders.
        img_height (int): The height of the input images.
        img_width (int): The width of the input images.

    Returns:
        tuple: A tuple containing the training data loader and the validation data loader.
        
    Note:
        This function uses multiprocessing for loading data in parallel, which can be more efficient
        especially for large datasets. It splits the dataset into training and validation sets,
        and then loads them using separate processes. The `num_workers` parameter controls the number
        of worker processes for data loading.
    """
    num_workers = 32
    print(f"Number of CPUs: {num_workers}")

    full_dataset = datasets.ImageFolder(root=main_dir)
    total_size = len(full_dataset)
    indices = list(range(total_size))
    np.random.shuffle(indices)
    train_size = int(0.9 * total_size)
    train_indices, valid_indices = indices[:train_size], indices[train_size:]

    with Pool(processes=num_workers) as pool:
        train_loader = pool.apply_async(load_data, (0, main_dir, batch_size, img_height, img_width, train_indices, valid_indices))
        valid_loader = pool.apply_async(load_data, (1, main_dir, batch_size, img_height, img_width, train_indices, valid_indices))

        train_loader, valid_loader = train_loader.get(), valid_loader.get()

    return train_loader, valid_loader, full_dataset, train_indices, valid_indices

In [4]:
def train_model(model, train_loader, valid_loader, criterion, optimizer, device, epochs, full_dataset, train_indices, valid_indices, patience=3):
    best_val_acc = 0.0
    counter = 0
    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []

    start_time = time.time()
    
    print(f'Total Samples: {len(full_dataset)}')
    print(f'Total Training Samples: {len(train_indices)}, Total Validation Samples: {len(valid_indices)}\n')

    for epoch in range(epochs):
        epoch_start_time = time.time()

        # Training
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for step, (images, labels) in enumerate(train_loader, 1):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_accuracy = val_correct / val_total
        val_loss /= len(valid_loader)
        valid_losses.append(val_loss)
        valid_accuracies.append(val_accuracy)

        # # Early stopping
        # if val_accuracy > best_val_acc:
        #     best_val_acc = val_accuracy
        #     counter = 0
        # else:
        #     counter += 1
        #     if counter >= patience:
        #         print("Early stopping.")
        #         break

        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time

        print(f"Epoch {epoch + 1}/{epochs} - Training Loss: {epoch_loss:.4f} - Training Accuracy: {epoch_acc:.4f} - Validation Loss: {val_loss:.4f} - Validation Accuracy: {val_accuracy:.4f} - Time: {epoch_time:.2f}s")

    end_time = time.time()
    total_time = end_time - start_time
    print(f"Total training time: {total_time:.2f}s")

    return train_losses, train_accuracies, valid_losses, valid_accuracies

In [5]:
# Define paths to your dataset
main_dir = "/home/gurram.ri/Project/CASIA"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
img_height = 224
img_width = 224

train_loader, valid_loader, full_dataset, train_indices, valid_indices = get_data_loaders(main_dir, batch_size, img_height, img_width)

# Access DataLoader objects from the tuple
train_loader, valid_loader = train_loader[0], valid_loader[0]

# Define the EfficientNetB3-based model with custom classifier
class EfficientNetB3(nn.Module):
    def __init__(self, num_classes):
        super(EfficientNetB3, self).__init__()
        self.base_model = models.efficientnet_b3(pretrained=True)
        self.base_model.classifier = nn.Identity() #removes final fully connected layer
        num_features = self.base_model(torch.zeros(1, 3, 224, 224)).shape[1]
        self.classifier = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.base_model(x)
        x = self.classifier(x)
        return x

# Define the model
classes = train_loader.dataset.dataset.classes
num_classes = len(classes)
model = EfficientNetB3(num_classes=num_classes).to(device)

# Define optimizer and loss function
optimizer = optim.Adamax(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Train the model
train_losses, train_accuracies, valid_losses, valid_accuracies = train_model(model, train_loader, valid_loader, criterion, optimizer, device, epochs=20, full_dataset=full_dataset, train_indices=train_indices, valid_indices=valid_indices)

Number of CPUs: 32
Total Samples: 12610
Total Training Samples: 11349, Total Validation Samples: 1261

Epoch 1/20 - Training Loss: 0.5060 - Training Accuracy: 0.7507 - Validation Loss: 0.4801 - Validation Accuracy: 0.7748 - Time: 724.18s
Epoch 2/20 - Training Loss: 0.3854 - Training Accuracy: 0.8158 - Validation Loss: 0.4596 - Validation Accuracy: 0.8017 - Time: 685.80s
Epoch 3/20 - Training Loss: 0.3383 - Training Accuracy: 0.8410 - Validation Loss: 1.2159 - Validation Accuracy: 0.7914 - Time: 676.71s
Epoch 4/20 - Training Loss: 0.3134 - Training Accuracy: 0.8514 - Validation Loss: 0.4229 - Validation Accuracy: 0.8128 - Time: 690.16s
Epoch 5/20 - Training Loss: 0.2873 - Training Accuracy: 0.8610 - Validation Loss: 0.4325 - Validation Accuracy: 0.8073 - Time: 677.12s
Epoch 6/20 - Training Loss: 0.2723 - Training Accuracy: 0.8708 - Validation Loss: 0.4276 - Validation Accuracy: 0.8033 - Time: 677.47s
Epoch 7/20 - Training Loss: 0.2527 - Training Accuracy: 0.8792 - Validation Loss: 0.447