In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import time

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [4]:
# 1. Define a CNN architecture
class BasicCNN(nn.Module):
    def __init__(self):
        super(BasicCNN, self).__init__()
        # First convolutional layer
        # Input: 3 channels (RGB), Output: 16 feature maps, 3x3 kernel
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        # Second convolutional layer
        # Input: 16 feature maps, Output: 32 feature maps, 3x3 kernel
        self.conv2 = nn.Conv2d(
            in_channels=16, out_channels=32, kernel_size=3, padding=1
        )
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding=1
        )
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        # Fully connected layers
        # Assuming input images are 32x32, after two 2x2 pooling layers, we have 32/2/2 = 8
        # So the feature maps are 8x8 with 32 channels: 32 * 8 * 8 = 2048
        self.fc1 = nn.Linear(64 * 4 * 4, 128)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)  # 10 output classes (e.g., for CIFAR-10 dataset)

    def forward(self, x):
        # Apply convolutional layers
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))

        # Flatten the feature maps
        x = x.view(-1, 64 * 4 * 4)

        # Apply fully connected layers
        x = self.fc2(self.relu4(self.fc1(x)))

        return x


# 2. Create an instance of the model
model = BasicCNN()
model = model.to(device)

# 3. Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.04, momentum=0.5)

# 4. Load and preprocess data (example with CIFAR-10)
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

trainset = torchvision.datasets.ImageFolder(
    "../../data/processed/train_augmented_rotated", transform=transform
)
valset = torchvision.datasets.ImageFolder("../../data/raw/valid/", transform=transform)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=32, shuffle=True, num_workers=2
)
valloader = torch.utils.data.DataLoader(
    valset, batch_size=32, shuffle=False, num_workers=2
)

In [5]:
class EarlyStopper:
    def __init__(self, max_iter: int = 10, patience: int = 2, min_delta: float = 0):
        self.max_iter = max_iter
        self.min_error = float("Inf")
        self.patience = patience
        self.counter = 0

    def check(self, val_error: float, model):

        if self.counter >= self.max_iter:
            return True

        if val_error < self.min_error:
            self.min_error = val_error
            torch.save(model.state_dict(), "best_model.pth")
            self.counter = 0
        else:
            if self.counter >= self.patience:
                return True
            self.counter = self.counter + 1
        return False

In [6]:
ES = EarlyStopper()
import json
import os
from datetime import datetime


def train(epochs=5):
    time_started = time.time()

    # Initialize JSON structure
    log_data = {
        "metadata": {
            "model": str(model),
            "freq_bins": 64,
            "time_steps": 64,
            "batch_size": trainloader.batch_size,
            "train_set_size": len(trainloader.dataset),
            "optimizer": optimizer.__class__.__name__,
            "loss_function": str(criterion),
            "num_epochs": epochs,
        },
        "data": {},
    }

    # Setup filenames and paths
    timestamp = datetime.now().strftime("%d-%m-%Y")
    arch = f"{log_data['metadata']['freq_bins']}x{log_data['metadata']['time_steps']}"
    bsz = log_data["metadata"]["batch_size"]
    opt = log_data["metadata"]["optimizer"]
    lr = optimizer.param_groups[0]["lr"]
    size = log_data["metadata"]["train_set_size"]
    filename = f"{arch}x{bsz}x{lr}x{opt}x{size}_{timestamp}.json"

    logs_dir = "../logs/model_on_rotated_images"
    models_dir = "../saved_models"
    os.makedirs(logs_dir, exist_ok=True)
    os.makedirs(models_dir, exist_ok=True)

    filepath = os.path.join(logs_dir, filename)

    for epoch in range(epochs):
        model.train()
        correct = 0
        total = 0
        running_loss = 0.0
        log_data["data"][str(epoch + 1)] = {"batches": {}}

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            batch_loss = loss.item()
            running_loss += batch_loss
            correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
            total += len(outputs)

            if i % 200 == 199:
                print(
                    f"[{epoch + 1}, {i + 1}], time {time.time() - time_started:.1f}s, loss: {running_loss / 200:.3f}, acc: {correct/total * 100:.2f}%"
                )
                running_loss = 0.0

        train_acc = correct / total

        # Validation
        model.eval()
        val_error = 0
        correct = 0
        with torch.no_grad():
            for images, labels in valloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_error += criterion(outputs, labels) * images.size(0)
                correct += (torch.argmax(outputs, 1) == labels).float().sum().item()

        val_error = val_error / len(valloader.dataset)
        acc_val = correct / len(valloader.dataset)

        # Save epoch summary
        log_data["data"][str(epoch + 1)]["summary"] = {
            "train_acc": train_acc,
            "val_loss": val_error.item(),
            "val_acc": acc_val,
            "lr": optimizer.param_groups[0]["lr"],
        }

        # Save model for this epoch
        model_path = os.path.join(models_dir, f"model_epoch_{epoch + 1}.pt")
        torch.save(model.state_dict(), model_path)

        # Save updated log file
        with open(filepath, "w") as f:
            json.dump(log_data, f, indent=4)

        print(
            f"Epoch {epoch+1} | val loss: {val_error:.4f}, val acc: {acc_val:.4f} | model saved to {model_path}"
        )

        if ES.check(val_error, model):
            print("Early stopping triggered. Finished Training.")
            break

    print(f"Final training log written to: {filepath}")

In [7]:
train(15)

[1, 200], time 29.8s, loss: 2.238, acc: 16.08%
[1, 400], time 65.0s, loss: 1.973, acc: 21.34%
[1, 600], time 96.4s, loss: 1.833, acc: 24.43%
[1, 800], time 116.5s, loss: 1.757, acc: 26.93%
[1, 1000], time 136.1s, loss: 1.696, acc: 28.96%
[1, 1200], time 156.3s, loss: 1.645, acc: 30.65%
[1, 1400], time 176.8s, loss: 1.634, acc: 31.88%
[1, 1600], time 196.8s, loss: 1.602, acc: 33.01%
[1, 1800], time 217.2s, loss: 1.585, acc: 33.86%
[1, 2000], time 237.7s, loss: 1.542, acc: 34.79%
[1, 2200], time 259.1s, loss: 1.536, acc: 35.55%
[1, 2400], time 283.0s, loss: 1.479, acc: 36.40%
[1, 2600], time 305.3s, loss: 1.482, acc: 37.06%
[1, 2800], time 326.7s, loss: 1.461, acc: 37.71%
[1, 3000], time 346.4s, loss: 1.430, acc: 38.32%
[1, 3200], time 365.1s, loss: 1.419, acc: 38.95%
[1, 3400], time 383.8s, loss: 1.426, acc: 39.43%
[1, 3600], time 402.8s, loss: 1.393, acc: 39.98%
[1, 3800], time 422.8s, loss: 1.392, acc: 40.45%
[1, 4000], time 441.7s, loss: 1.369, acc: 40.95%
[1, 4200], time 461.2s, los

In [10]:
model.load_state_dict(torch.load("../saved_models/model_epoch_7.pt"))

  model.load_state_dict(torch.load("../saved_models/model_epoch_7.pt"))


<All keys matched successfully>

In [11]:
val_error = 0
correct = 0
with torch.no_grad():
    model.eval()
    for images, labels in valloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        val_error = val_error + criterion(outputs, labels) * images.size(0)
        correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
        val_error = val_error / len(valloader.dataset)
    print(f"epoch NONE val error: {val_error}, acc: {correct/len(valloader.dataset)}")

epoch NONE val error: 0.0002579218416940421, acc: 0.5961222222222222
