In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from cnn import CNN
from train import train
from datasets import load_dataset

# Configuration of model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.backends.mps.is_available():
    device = "mps"

DATASET = "CIFAR10" # MNIST, CIFAR10
PATH_TO_READY_MODEL = None
PATH_TO_SAVE_MODEL = "model.pth"
EPOCHS = 100

# Load dataset

In [None]:
train_loader, test_loader = load_dataset(DATASET)

# Setup model

In [None]:
model = CNN(in_channels=3, num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss().to(device)

# Train or load model

In [None]:
if PATH_TO_READY_MODEL is not None:
    model.load_state_dict(torch.load(PATH_TO_READY_MODEL, map_location=device))
else:
    train(model, optimizer, criterion, train_loader, test_loader, device=device, num_epochs=EPOCHS)
    torch.save(model.state_dict(), PATH_TO_SAVE_MODEL if PATH_TO_SAVE_MODEL is not None else "model.pth")

# Show result on validation data

In [None]:
model.eval() 

correct = 0
total = 0

with torch.no_grad():
    for images, real_labels in test_loader:
        images = images.to(device)
        real_labels = real_labels.to(device)
        
        outputs = model(images)
        _, predicted_labels = torch.max(outputs, 1)
        
        total += real_labels.size(0)
        correct += (predicted_labels == real_labels).sum().item()

print(f'Accuracy on test dataset: {100 * correct / total:.2f}%')