
# SketchaNet model
### Packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, classification_report, top_k_accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import torch.nn.functional as F
import time
from sklearn.model_selection import train_test_split



### Model Training

In [None]:

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms to preprocess the data with data augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to fit the input size of SketchANet
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the dataset
dataset = ImageFolder(root='/user/5/toubalih/Lab_project_2/Lab_project_S8/png/', transform=transform)

# Extract the targets for stratified splitting
targets = np.array([sample[1] for sample in dataset.samples])

# Calculate dataset sizes
train_size = 0.7
val_size = 0.15
test_size = 0.15

# First split: train+val and test
train_val_idx, test_idx, y_train_val, y_test = train_test_split(
    range(len(targets)), targets, stratify=targets, test_size=test_size, random_state=42)

# Calculate the validation size with respect to the remaining data (train+val)
val_size_adjusted = val_size / (train_size + val_size)

# Second split: train and val
train_idx, val_idx, y_train, y_val = train_test_split(
    train_val_idx, y_train_val, stratify=y_train_val, test_size=val_size_adjusted, random_state=42)

# Create Subsets
train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)
test_dataset = Subset(dataset, test_idx)

# Define data loaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

class SketchANet(nn.Module):
    def __init__(self, num_classes=10):
        super(SketchANet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=15, stride=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(512)
        self.conv7 = nn.Conv2d(512, 512, kernel_size=1)
        self.bn7 = nn.BatchNorm2d(512)
        self.fc = nn.Linear(512 * 6 * 6, num_classes)  # Adjusted input size for fc layer

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 3, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 3, 2)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.max_pool2d(x, 3, 2)
        x = F.dropout(F.relu(self.bn6(self.conv6(x))), 0.5)
        x = F.dropout(F.relu(self.bn7(self.conv7(x))), 0.5)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# Initialize the model, loss function, and optimizer with weight decay
model = SketchANet(num_classes=len(dataset.classes)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)  # Add weight decay
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Track start time
start_time = time.time()

# Train the model with scheduler
num_epochs = 50
best_val_loss = float('inf')

# Track start time
start_time = time.time()

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_accuracy = 100 * correct_train / total_train
    scheduler.step()

    # Validate the model
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

    val_accuracy = 100 * correct_val / total_val

    # Print training and validation statistics
    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Training Loss: {running_loss / len(train_loader):.4f}, '
          f'Training Accuracy: {train_accuracy:.2f}%, '
          f'Validation Loss: {val_loss / len(val_loader):.4f}, '
          f'Validation Accuracy: {val_accuracy:.2f}%')

    # Save the model with the best validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt')

print('Finished Training')

# Test the model
model.load_state_dict(torch.load('best_model.pt'))
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on test set: {100 * correct / total:.2f}%')

# Calculate and print total run time
end_time = time.time()
total_time = end_time - start_time
print(f'Total run time: {total_time:.2f} seconds')


### Confusion Matrix and Top Accuracy

In [None]:

def get_class_names(folder_path):
    """
    Returns a list of class names, where each class name is the name of a folder
    containing images of that class.
    """
    class_names = [d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))]
    class_names.sort()  # Optional: sort the class names alphabetically
    return class_names

folder_path = '/user/5/toubalih/Lab_project_2/Lab_project_S8/updated_png'
class_names = get_class_names(folder_path)
print(class_names)

def plot_confusion_matrix_and_metrics(model, device, data_loader, class_names):
    """
    Generates and plots a confusion matrix and prints classification metrics for the given data.
    """
    model.eval()
    true_labels = []
    pred_labels = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            preds = torch.argmax(outputs, dim=1)
            true_labels.extend(target.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())

    # Compute the confusion matrix
    conf_mat = confusion_matrix(true_labels, pred_labels)
    # Compute other classification metrics
    class_report = classification_report(true_labels, pred_labels, target_names=class_names)

    # Print the classification report
    print("Classification Report:")
    print(class_report)

    # Plot the confusion matrix
    plt.figure(figsize=(10, 10))
    sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix for SketchNet')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

plot_confusion_matrix_and_metrics(model, device, test_loader, class_names)

def evaluate_model_with_sklearn_top_k(model, device, test_loader, classes, k=5):
    """
    Evaluates the model on the test set using scikit-learn's top_k_accuracy_score.

    Parameters:
    - model (torch.nn.Module): The trained model to evaluate.
    - device (torch.device): The device (CPU/GPU) on which to perform the evaluation.
    - test_loader (torch.utils.data.DataLoader): DataLoader for the test set.
    - classes (list): List of all class labels.
    - k (int): The top-k accuracy to calculate.

    Returns:
    - None, prints the average Top-K accuracy.
    """
    model.eval()  # Ensure the model is in evaluation mode
    true_labels = []
    all_scores = []

    with torch.no_grad():  # Disable gradient computation
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            probabilities = torch.softmax(outputs, dim=1)
            all_scores.extend(probabilities.cpu().numpy())
            true_labels.extend(target.cpu().numpy())

    true_labels = np.array(true_labels)
    all_scores = np.array(all_scores)
    labels = np.arange(len(classes))

    top_k_accuracy = top_k_accuracy_score(true_labels, all_scores, k=k, labels=labels)
    print(f"Top-{k} Accuracy on Test Set: {top_k_accuracy * 100:.2f}%")


evaluate_model_with_sklearn_top_k(model, device, test_loader, class_names, k=5)


### Accuracy and Loss Curve

In [None]:

# Create a new figure
fig, ax1 = plt.subplots()

# Plot training and validation accuracy
color = 'tab:blue'
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Accuracy', color=color)
ax1.plot(range(1, num_epochs + 1), train_accuracies, label='Training Accuracy', color=color, marker='o')
ax1.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy', linestyle='dashed', color=color, marker='o')
ax1.tick_params(axis='y', labelcolor=color)
ax1.legend(loc='upper left')
ax1.grid(True)

# Create a second y-axis to plot training and validation loss
ax2 = ax1.twinx()
color = 'tab:red'
ax2.set_ylabel('Loss', color=color)
ax2.plot(range(1, num_epochs + 1), train_losses, label='Training Loss', color=color, marker='x')
ax2.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss', linestyle='dashed', color=color, marker='x')
ax2.tick_params(axis='y', labelcolor=color)
ax2.legend(loc='upper right')

plt.title('Accuracy and Loss Curve of sketchNet(50 dataset)')
plt.tight_layout()
plt.show()
