In [None]:
import torch
from torchvision import transforms
from torchvision.datasets import Imagenette
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from basic_cnn import BasicCNN  # Make sure to import your model here

# Prepare the dataset
train_transforms = transforms.Compose([
    transforms.CenterCrop(160),
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    transforms.Grayscale()
])

test_transforms = transforms.Compose([
    transforms.CenterCrop(160),
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    transforms.Grayscale()
])

train_dataset = Imagenette("data/imagenette/train/", split="train", size="160px", download=True, transform=train_transforms)

# Use 10% of the training set for validation
train_set_size = int(len(train_dataset) * 0.9)
val_set_size = len(train_dataset) - train_set_size

seed = torch.Generator().manual_seed(42)
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_set_size, val_set_size], generator=seed)
val_dataset.dataset.transform = test_transforms

# Use DataLoader to load the dataset
train_loader = DataLoader(train_dataset, batch_size=128, num_workers=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, num_workers=8, shuffle=False)

# Configure the test dataset
test_dataset = Imagenette("data/imagenette/test/", split="val", size="160px", download=True, transform=test_transforms)

# Initialize the model
model = BasicCNN()

# Add EarlyStopping
early_stop_callback = EarlyStopping(monitor="val_loss", mode="min", patience=5)

# Configure Checkpoints
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")

# Initialize the Trainer
trainer = L.Trainer(callbacks=[early_stop_callback, checkpoint_callback], max_epochs=50, accelerator="gpu")

# Train the model
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Evaluate the model on the test set
test_loader = DataLoader(test_dataset, batch_size=256, num_workers=8, shuffle=False)
test_results = trainer.test(model=model, dataloaders=test_loader)

# Print the metrics
print("Training and Validation Metrics:")
print(f"Training Loss: {trainer.callback_metrics.get('train_loss', 'N/A')}")
print(f"Validation Loss: {trainer.callback_metrics.get('val_loss', 'N/A')}")
print(f"Validation Accuracy: {trainer.callback_metrics.get('val_accuracy', 'N/A')}")

print("Test Metrics:")
for result in test_results:
    print(f"Test Accuracy: {result.get('test_accuracy', 'N/A')}")
    print(f"Test Loss: {result.get('test_loss', 'N/A')}")
