In [19]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('..')

from tqdm import tqdm
from network import ResNet
from dataset import InputPipeLineBuilder

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
num_epochs = 100
batch_size = 256

device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 0.05
weight_decay = 0.001

model = ResNet(head_input_dim=512).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.head.parameters(), lr=lr, weight_decay=weight_decay)

def lr_lambda(epoch):
    if epoch < 30:
        return 1.0
    elif epoch < 60:
        return 0.1
    elif epoch < 90:
        return 0.01
    else:
        return 0.001
    
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda)

In [21]:
for layer in model.feature_extractor.modules():
    layer.requires_grad_ = False
    
input_pipeline_builder = InputPipeLineBuilder(batch_size=batch_size, dataset='cifar100')

train_dataloader = input_pipeline_builder.get_dataloader(subset='train')
test_dataloader = input_pipeline_builder.get_dataloader(subset='test')

In [None]:
model.feature_extractor.eval()
for epoch in range(num_epochs):
    losses = []
    model.head.train()
    for batch in tqdm(train_dataloader):
        train_x, train_y = batch
        logits = model(train_x.to(device))
        
        loss = loss_fn(logits, train_y.to(device))
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss.cpu().item())
        
    scheduler.step()
    print(f"avg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")

In [None]:
torch.save(model.head.state_dict(), './resnet_18_head.pth')

In [None]:
rcorrect, total = 0, 0

model.head.eval()
for batch in tqdm(test_dataloader):
    batch_x, batch_y = batch
    with torch.no_grad():
      logits = model(batch_x.to(device))
    preds = torch.argmax(logits, dim=1)
    num_preds = torch.sum((preds == batch_y.to(device))).item()
    
    rcorrect += num_preds
    total += batch_x.shape[0]

print(f"\n\tavg accuracy at epoch: {epoch+1}/{num_epochs}: {rcorrect/total:.4f}")