In [33]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('..')

from tqdm import tqdm
from network import ViT
from dataset import InputPipeLineBuilder

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [34]:
num_epochs = 50
batch_size = 64

device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 0.05
weight_decay = 0.001

model = ViT(head_input_dim=768).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.head.parameters(), lr=lr, weight_decay=weight_decay)

def lr_lambda(epoch):
    if epoch < 20:
        return 1.0
    elif epoch < 30:
        return 0.1
    elif epoch < 40:
        return 0.01
    else:
        return 0.001
    
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda)

In [35]:
for layer in model.feature_extractor.modules():
    layer.requires_grad_ = False
    
input_pipeline_builder = InputPipeLineBuilder(batch_size=batch_size, dataset='tiny_imagenet')

train_dataloader = input_pipeline_builder.get_dataloader(subset='train')
test_dataloader = input_pipeline_builder.get_dataloader(subset='test')

In [None]:
for epoch in range(num_epochs):
    losses = []
    model.train()
    for batch in tqdm(train_dataloader):
        train_x, train_y = batch
        logits = model(train_x.to(device))
        
        loss = loss_fn(logits, train_y.to(device))
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss.cpu().item())
        
    scheduler.step()
    print(f"avg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")

In [None]:
torch.save(model.head.state_dict(), 'vit_head.pth')