In [1]:
from torchvision.datasets import CIFAR10
from torchvision import transforms as T
from torch.utils.data import DataLoader
from finetune.models.resnet import resnet50
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet50().to(device)

train_transform = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.5,0.5,0.5],
                std=[0.5,0.5,0.5]),
])

test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5,0.5,0.5],
                std=[0.5,0.5,0.5]),
])

In [2]:
train_dataset = CIFAR10(root='./data', train=True,
                        download=True, transform=train_transform)
test_dataset = CIFAR10(root='./data', train=False,
                          download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=128,
                            shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128,
                            shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim import CosineAnnealingLR

optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
                            momentum=0.9, weight_decay=5e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
for _ in range(100):
    model.train()
    for i, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
    
    model.eval()
    test_preds = []
    test_labels = []
    for i, (x, y) in enumerate(test_loader):
        x, y = x.to(device), y.to(device)
        logits = model(x)
        test_preds.append(logits.argmax(dim=1).cpu().numpy())
        test_labels.append(y.cpu().numpy())
    test_preds = np.concatenate(test_preds)
    test_labels = np.concatenate(test_labels)
    acc = (test_preds == test_labels).mean()
    scheduler.step()
    print(acc)

0.2205


KeyboardInterrupt: 