In [None]:
# Import necessary libraries
import torch
import pytorch_lightning as L
from torchvision import transforms
from torchvision.datasets import Imagenette
from torch.utils.data import DataLoader
from all_conv_model import AllConvNet

# Define transformations
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
])

# Load datasets
train_dataset = Imagenette("data/imagenette/train/", split="train", size="160px", download=True, transform=train_transforms)
val_set_size = int(len(train_dataset) * 0.1)
train_set_size = len(train_dataset) - val_set_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_set_size, val_set_size])
val_dataset.dataset.transform = test_transforms

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4)

test_dataset = Imagenette("data/imagenette/test/", split="val", size="160px", download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)


# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

# Initialize the model
num_classes = 10
model = AllConvNet(num_classes=num_classes)

# Set up early stopping and checkpoint callbacks
early_stop_callback = L.callbacks.EarlyStopping(monitor="val_loss", patience=5, verbose=True, mode="min")
checkpoint_callback = L.callbacks.ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)

# Initialize the Trainer
trainer = L.Trainer(callbacks=[early_stop_callback, checkpoint_callback], max_epochs=50, accelerator="gpu")

# Train the model
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Evaluate the model on the test set
test_results = trainer.test(model=model, dataloaders=val_loader)

# Print the metrics
print("Training and Validation Metrics:")
print(f"Training Loss: {trainer.callback_metrics.get('train_loss', 'Not Available')}")
print(f"Validation Loss: {trainer.callback_metrics.get('val_loss', 'Not Available')}")
print(f"Validation Accuracy: {trainer.callback_metrics.get('val_accuracy', 'Not Available')}")

# Print test results
print(f"Test Loss: {test_results[0]['test_loss'] if 'test_loss' in test_results[0] else 'Not Available'}")
print(f"Test Accuracy: {test_results[0]['test_accuracy'] if 'test_accuracy' in test_results[0] else 'Not Available'}")
