In [1]:
import timm
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.datasets import FashionMNIST
import os


import wandb

wandb.finish()
# Initialize a new wandb run
wandb.init(project='cifar10_classification', dir="./wandb")

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/model_mavericks/.netrc


In [2]:
config = wandb.config
config.learning_rate = 0.001
config.epochs = 10
config.batch_size = 64
config.model_name = 'resnet18'


In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])



trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:07<00:00, 23389208.33it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
model = timm.create_model('resnet18', pretrained=True, num_classes=10)
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')


model.safetensors: 100%|██████████| 46.8M/46.8M [00:00<00:00, 67.8MB/s]


In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [6]:
epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to('cuda' if torch.cuda.is_available() else 'cpu'), labels.to('cuda' if torch.cuda.is_available() else 'cpu')

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}")
    wandb.log({"epoch": epoch, "loss": running_loss / len(trainloader)})

Epoch 1, Loss: 1.1466198234302003
Epoch 2, Loss: 0.7471613933515671
Epoch 3, Loss: 0.6411269292273485
Epoch 4, Loss: 0.5813521062931442
Epoch 5, Loss: 0.5339565413915898
Epoch 6, Loss: 0.5009867040550008
Epoch 7, Loss: 0.47106429137994565
Epoch 8, Loss: 0.44476710275158554
Epoch 9, Loss: 0.4185693253717764
Epoch 10, Loss: 0.4008919577045209


In [7]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to('cuda' if torch.cuda.is_available() else 'cpu'), labels.to('cuda' if torch.cuda.is_available() else 'cpu')
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    wandb.log({"test_accuracy": 100 * correct // total})
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total}%')


Accuracy of the network on the 10000 test images: 83%


In [8]:
wandb.watch(model, log='all')


[]

In [9]:
torch.save(model.state_dict(), 'model.pth')
wandb.save('model.pth')


['./wandb/wandb/run-20231210_183241-tk3qbp86/files/model.pth']