In [1]:
%pylab inline
import sys
sys.dont_write_bytecode = True
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

Populating the interactive namespace from numpy and matplotlib


In [2]:
from torchvision import models

In [3]:
import torch
from tqdm import tqdm
from torch import nn, optim
from torch.optim import lr_scheduler
model_ft = models.densenet121(pretrained=True)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, 2)

In [4]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} (avg: {avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k * (100.0 / batch_size))
        return res

In [5]:
class BaseTrainer(object):
    """
    """
    def __init__(self, cfg):
        self.train_loader, self.valid_loader = cfg["loader"]
        self.batch_size = cfg["batch_size"]
        self.model = cfg["model"]
        self.device = cfg["device"]
        self.optimizer = cfg["optimizer"]
        self.criterion = cfg["criterion"]
        
        self.acc_met = AverageMeter("acc")
        self.valid_acc_met = AverageMeter("valid_acc")
        
    def train_epoch(self, epoch):
        
        self.model.train()
        
        for self.idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.train_step(data, target)
        
        self.valid()
    
    def train_step(self, data, target):
        self.optimizer.zero_grad()
        output = self.model(data)
        loss = self.criterion(output, target)
        loss.backward()
        
        self.optimizer.step()
        self.acc_met.update(accuracy(output, target)[0].item())
        
        if (self.idx + 1) % 200 == 0:
            print(loss.item())
            print(self.acc_met)
            
            
    def valid(self):
        self.model.eval()
        self.valid_acc_met.reset()
        with torch.no_grad():
            for data, target in tqdm(self.valid_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                
                self.valid_acc_met.update(accuracy(output, target)[0].item())
        
        print(self.valid_acc_met)

In [6]:
device = torch.device("cuda:0")
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.00001, momentum=0.99)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [7]:
config = {}
from data_io import get_dataloader
train_lst, valid_lst, train_dataset, valid_dataset, train_loader, valid_loader = get_dataloader(config)

Got 20000 items in train mode.
Got 5000 items in valid mode.


In [8]:

config["loader"] = (train_loader, valid_loader)
config["model"] = model_ft
config["device"] = device
config["criterion"] = criterion
config["optimizer"] = optimizer_ft
config["batch_size"] = 32

In [9]:
trainer = BaseTrainer(config)

In [12]:
trainer.train_epoch(2)

0.31749749183654785
acc 84.375000 (avg: 75.429459)


100%|██████████| 79/79 [00:21<00:00,  3.71it/s]

valid_acc 100.000000 (avg: 97.982595)





In [11]:
trainer.valid()

100%|██████████| 79/79 [00:07<00:00, 11.09it/s]

valid_acc 100.000000 (avg: 96.736551)



