In [None]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import random_split
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from time import time
from train_and_test_classification import seed_all, train_test_classifier
from models.LeNet import LeNet

In [None]:
seed_all()

In [None]:
project_root = Path("/project_root")
run_folder = project_root/"runs"/"lenet_mnist"
run_folder.mkdir(exist_ok=True, parents=True)
current_run_folder = run_folder/f"{int(time())}"
current_run_folder.mkdir(exist_ok=False)

In [None]:
# testing with random input
test_network = LeNet()
print(test_network)

with torch.no_grad():
  X = torch.ones((1, 1, 28, 28))
  y = test_network(X)
  assert y.shape == (1, 10)
  print(y)

In [None]:
train_data = MNIST(root = ".", train = True, download = True, transform=ToTensor())
train_data, val_data = random_split(train_data, (int(0.8*len(train_data)), int(0.2*len(train_data))))
test_data = MNIST(root = ".", train = False, download = True, transform=ToTensor())

In [None]:
logger = SummaryWriter(current_run_folder/"logs")
model = LeNet()
device = "cuda" if torch.cuda.device_count() > 0 else "cpu"
model = model.to(device)
logger.add_graph(model, torch.ones(1, 1, 28, 28).to(device))
checkpoint_folder = current_run_folder/"checkpoints"
checkpoint_folder.mkdir(exist_ok=False)
train_test_classifier(model=model,
                      train_data=train_data,
                      val_data=val_data,
                      test_data=test_data,
                      batch_size=16,
                      num_epochs=5,
                      loss_function=CrossEntropyLoss(),
                      optimizer=Adam(model.parameters(), lr = 0.001),
                      logger=logger,
                      device=device, 
                      checkpoint_folder=checkpoint_folder,
                      early_stopping_epochs=None)