## Dependencies Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import os

## Hyperparameter Configuration

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

BATCH_SIZE = 128
EPOCHS = 10
LR = 0.001
LR_SCHEDULER_FACTOR = 0.5
LR_SCHEDULER_PATIENCE = 3

## Define Transform for the Data

In [None]:
# This is not used with the current dataset, as it is already in tensor format

transform = transforms.Compose([
	#transforms.Grayscale(),
	transforms.ToTensor(),
	transforms.Normalize((0.5,), (0.5,))
])

## Dataset Setup

In [None]:
# Load all .npy files from "data_npy", assume each file is one class
data_dir = "../data_npy"
class_files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npy"))
all_data, all_labels = [], []

for label_idx, fname in enumerate(class_files):
    arr = np.load(os.path.join(data_dir, fname))       # shape (N, H, W)
    arr = arr[:, None, :, :]   # add channel dimension (N, 1, H, W)
    all_data.append(arr)
    all_labels.append(np.full(arr.shape[0], label_idx))

all_data = np.concatenate(all_data, axis=0)             # (total_samples, C, H, W)
all_labels = np.concatenate(all_labels, axis=0)         # (total_samples,)

# Create a TensorDataset
dataset = TensorDataset(torch.from_numpy(all_data).float(), torch.from_numpy(all_labels).long())

In [None]:
# Directory containing the .pt files, one per class
data_dir = "../data_tensors_10000"
class_files = sorted(f for f in os.listdir(data_dir) if f.endswith(".pt"))

all_data, all_labels = [], []

for label_idx, fname in enumerate(class_files):
    # Load the tensor that contains all the samples for this class
    tensor = torch.load(os.path.join(data_dir, fname))  # shape: (N, C, H, W)    
    all_data.append(tensor)
    # Create a label vector of the same size
    labels = torch.full((tensor.size(0),), label_idx, dtype=torch.long)
    all_labels.append(labels)

# Concatenate all classes into a single large tensor
all_data   = torch.cat(all_data,   dim=0)  # (total_samples, C, H, W)
all_labels = torch.cat(all_labels, dim=0)  # (total_samples,)

# Create the TensorDataset
dataset = TensorDataset(all_data.float(), all_labels)


In [None]:
num_classes = len(class_files)
print(f"Number of classes: {num_classes}")

# Split into train/val/test
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size   = int(0.1 * total_size)
test_size  = total_size - train_size - val_size

print(f"Total size: {total_size}")
print(f"Train: {train_size}\t Val: {val_size}\t Test: {test_size}")

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# DataLoader parameters
NUM_WORKERS = 7

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# Show the shape of one sample from the dataset
for images, labels in test_loader:
    print(f"Image batch shape: {images.size()}")
    print(f"Label batch shape: {labels.size()}")
    break

## Model Definition

In [None]:
class SimpleCNN(nn.Module):
	def __init__(self, num_classes=10):
		super().__init__()
		self.conv1 = nn.Conv2d(1, 32, 3)
		self.pool = nn.MaxPool2d(2)
		self.conv2 = nn.Conv2d(32, 64, 3)
		self.fc1 = nn.Linear(64 * 5 * 5, 128)
		self.fc2 = nn.Linear(128, num_classes)
		
	def forward(self, x):
		x = self.pool(F.relu(self.conv1(x)))
		x = self.pool(F.relu(self.conv2(x)))
		x = x.view(-1, 64 * 5 * 5)
		x = F.relu(self.fc1(x))
		return self.fc2(x)


class BiggerCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        with torch.no_grad():
            dummy = torch.zeros(1, 1, 28, 28)
            out = self.pool(F.relu(self.conv2(F.relu(self.conv1(dummy)))))
            out = self.pool(F.relu(self.conv4(F.relu(self.conv3(out)))))
            self.flattened_size = out.view(1, -1).size(1)

        self.fc1 = nn.Linear(self.flattened_size, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

## Training

In [None]:
def evaluate_model(model, loader, criterion, device):
    model.eval()
    losses, correct, total = [], 0, 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            losses.append(loss.item())
            preds = outputs.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    avg_loss = np.mean(losses)
    accuracy = correct / total
    return avg_loss, accuracy

In [None]:
model = BiggerCNN(num_classes=num_classes).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE, verbose=True)

# Training loop
def train_cnn(num_epochs):
    print("Starting training...")

    train_loss_history, val_loss_history, val_accuracy_history = [], [], []

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_losses = []
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()

        avg_train_loss = np.mean(train_losses)
        train_loss_history.append(avg_train_loss)
        
        # Validation
        avg_val_loss, accuracy = evaluate_model(model, val_loader, criterion, DEVICE)

        val_loss_history.append(avg_val_loss)
        val_accuracy_history.append(accuracy)

        # Adjust learning rate
        scheduler.step(avg_val_loss)

        # Print statistics
        print(f"Epoch {epoch+1}/{num_epochs}\t Train Loss: {avg_train_loss:.4f}\t Validation Loss: {avg_val_loss:.4f}\t Validation Accuracy: {accuracy:.4f}")

    print("Training complete.")
    # Save the model
    torch.save(model.state_dict(), "../models/model_big_10000.pt")
    return train_loss_history, val_loss_history, val_accuracy_history

train_losses, val_losses, val_accuracies = train_cnn(EPOCHS)

In [None]:
# Plot the results
plt.figure(figsize=(12, 5))
plt.subplot(1,2,1)
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Evolution")
plt.legend()

plt.subplot(1,2,2)
plt.plot(val_accuracies, label="Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy Evolution")
plt.legend()
plt.tight_layout()
plt.show()

## Test

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

with open("../all_classes.txt", "r") as f:
    CLASSES = [line.strip() for line in f]

model.eval()

total = 0
correct = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

# cm = confusion_matrix(all_labels, all_preds)
# disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CLASSES)
# disp.plot(xticks_rotation=45, cmap="Blues")
# plt.tight_layout()
# plt.show()