In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torchview import draw_graph
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import time
import os
from tqdm import tqdm
import sys
import matplotlib.pyplot as plt
import numpy as np
from torchviz import make_dot
import timm
# Set the device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
# Load the cifar10 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(384),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(384),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Prepare dataset
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=50, shuffle=True)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=50, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
net = timm.create_model("vit_base_patch16_384", pretrained=True)
net.head = nn.Linear(net.head.in_features, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-4*6)
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} total trainable parameters.')
# Move everything to the GPU
net = net.to(device)
criterion = criterion.to(device)

86,098,186 total parameters.
86,098,186 total trainable parameters.


In [None]:
def train(epoch):
    pb = tqdm(total=1000, unit="batch", leave=False, desc=f"Epoch {epoch} Train", position=0, file=sys.stdout)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Train with amp
        with torch.amp.autocast('cuda', enabled=use_amp):
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pb.postfix = '; Loss: %.3f | Acc: %.3f%% (%d/%d)'%(train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        pb.update(1)

        #progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
           # % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    pb.close()
    return train_loss/(batch_idx+1), correct/total

def test(epoch):
    pb = tqdm(total=200, unit="batch", leave=False, desc=f"Epoch {epoch} Test ", position=0, file=sys.stdout)
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            pb.postfix = '; Loss: %.3f | Acc: %.3f%% (%d/%d)'%(test_loss/(batch_idx+1), correct/total, correct, total)
            pb.update(1)

            #progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                #% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    pb.close()
    # Save checkpoint.
    
    #os.makedirs("log", exist_ok=True)
    return test_loss/(batch_idx+1), correct/total

list_train_loss = []
list_train_acc = []
list_val_loss = []
list_val_acc = []
list_lr = []

net.cuda()
use_amp = True
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 50)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)

for epoch in range(0, 50):
    start = time.time()
    train_loss, train_acc = train(epoch)
    val_loss, val_acc = test(epoch)
    
    scheduler.step() # step cosine scheduling

    list_lr.append(optimizer.param_groups[0]["lr"])
    list_train_loss.append(train_loss)
    list_train_acc.append(train_acc)
    list_val_loss.append(val_loss)
    list_val_acc.append(val_acc)
    content = f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, acc: {100*train_acc:.3f}%, val loss: {val_loss:.5f}, acc: {100*val_acc:.2f}%, time: {time.time()-start:.5f}s'
    #with open(f'log/log_ViT_patch4.txt', 'a') as appender:
    #    appender.write(content + "\n")
    print(content)

Epoch 0 Train:   2%|▏         | 16/1000 [00:44<49:35,  3.02s/batch, ; Loss: 5.306 | Acc: 10.500% (84/800)]

In [None]:
test_loss, test_acc = test(epoch)

In [None]:
best_train_acc = max(list_train_acc)
best_val_acc = max(list_val_acc)
best_train_loss = min(list_train_loss)
best_val_loss = min(list_val_loss)
plt.plot(list_train_loss)
plt.plot(list_val_loss)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['train', 'val'])
plt.show()
plt.plot(list_train_acc)
plt.plot(list_val_acc)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['train', 'val'])
plt.show()
plt.plot(list_lr)
plt.xlabel('Epoch')
plt.ylabel('Learning rate')
plt.show()
print(f'Best train loss: {best_train_loss:.5f}, acc: {100*best_train_acc:.2f}%, Best val loss: {best_val_loss:.5f}, acc: {100*best_val_acc:.2f}%, test loss: {test_loss:.5f}, acc: {100*test_acc:.2f}%')