In [3]:
import os
import torch,torchvision
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from tqdm.notebook import tqdm

In [2]:
def top_n_accuracy(outputs,labels,n=1):
    num_obs,num_labels=outputs.shape
    cnt=0
    idx=num_labels-n
    argsorted=torch.argsort(outputs,dim=1)
    for i in range(num_obs):
        if labels[i] in argsorted[i,idx:]:
            cnt+=1
    return cnt/num_obs

## Training

In [2]:
def train(model,device,num_epochs,train_loader,valid_loader,criterion,optimizer,scheduler=None,save=0,path=None):
    model.train()
    best_acc=0
    for epoch in range(num_epochs):
        #train
        tcnt=0
        train_loss=0
        for inputs,labels in tqdm(train_loader,leave=False):
            inputs,labels=inputs.to(device),labels.to(device)
            outputs=model(inputs)
            loss=criterion(outputs,labels)
            train_loss+=loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            tcnt+=1
        train_loss/=tcnt
        #validation
        vcnt=0
        acc_top1=acc_top5=0
        with torch.no_grad():
            for inputs,labels in valid_loader:
                inputs,labels=inputs.to(device),labels.to(device)
                outputs=model(inputs)
                loss=criterion(outputs,labels)
                acc_top1+=top_n_accuracy(outputs,labels)
                acc_top5+=top_n_accuracy(outputs,labels,n=5)
                vcnt+=1
            acc_top1/=vcnt
            acc_top5/=vcnt
        if scheduler is not None:
            scheduler.step()
        tqdm.write(f'Epoch {epoch+1:>2} \t train_loss: {train_loss:>7.5f} \t top1_acc: {acc_top1*100:>5.2f}% \t top5_acc: {acc_top5*100:>5.2f}%')
        if save>0:
            if save>1 and epoch%10==9:
                torch.save(model.state_dict(),os.path.join(path,'e'+str(epoch+1)))
            if best_acc<acc_top1:
                best_acc=acc_top1
                torch.save(model.state_dict(),os.path.join(path,'best'))

## Test

In [4]:
def test(model,device,test_loader,criterion):
    model.eval()
    cnt=0
    test_loss=0
    acc_top1=acc_top5=0
    with torch.no_grad():
        for inputs,labels in test_loader:
            inputs,labels=inputs.to(device),labels.to(device)
            outputs=model(inputs)
            loss=criterion(outputs,labels)
            test_loss+=loss.item()
            acc_top1+=top_n_accuracy(outputs,labels)
            acc_top5+=top_n_accuracy(outputs,labels,n=5)
            cnt+=1
        test_loss/=cnt
        acc_top1/=cnt
        acc_top5/=cnt
    print(f'loss: {test_loss:>7.5f} \t top1_acc: {acc_top1*100:>5.2f}% \t top5_acc: {acc_top5*100:>5.2f}%')