In [None]:
from tqdm import tqdm
import random
import torch
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import v2, ToTensor
from torchvision.models import resnet18

In [None]:
transforms = v2.Compose([
    v2.RandomHorizontalFlip(),
    v2.RandomRotation(10),
    v2.ToTensor()
])

In [None]:
train_dataset = datasets.CIFAR10(
    root="./nodes/data",
    download=False,
    train=True
)

train_dataset.transform = transforms

test_dataset = datasets.CIFAR10(
    root="./nodes/data",
    download=False,
    train=False, 
    transform=ToTensor()
)

In [None]:
batch_size = 32
epochs = 20
learning_rate = 1e-3
weight_decay = 1e-4

In [None]:
val_size = len(test_dataset) // 2
val_dataset = torch.utils.data.Subset(test_dataset, range(val_size))
test_dataset = torch.utils.data.Subset(test_dataset, range(val_size, len(test_dataset)))


train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=val_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
class Classifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        self.resnet = resnet18(pretrained=False)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = nn.Identity()
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Classifier(num_classes=10).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=3)

In [None]:
def train():
    model.train()
    losses = []
    model.train()
    for b, (X, y) in tqdm(enumerate(train_dataloader)):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
    return sum(losses)/len(losses)

In [None]:
def validate():
    model.eval()
    with torch.no_grad():
        X, y = next(iter(val_dataloader))
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
    return loss.item()

In [None]:
train_losses = []
val_losses = []
for e in range(epochs):
    loss = train()
    train_losses.append(loss)
    val_loss = validate()
    val_losses.append(val_loss)
    scheduler.step(val_loss)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses)
plt.plot(val_losses)
plt.show()