In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from tensorboard_pytorch import TensorboardPyTorch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
class Trainer(object):
    def __init__(self, model, loaders, criterion, optim, scheduler=None, device='cuda'):
        self.model = model
        self.loaders = loaders
        self.criterion = criterion # MLM
        self.optim = optim
        self.scheduler = scheduler
        self.n_logger = None # neptune logger
        self.t_logger = None # tensorflow logger
        self.device = device
        
    def run_exp(self, epoch_start, epoch_end, exp_name, save_interval, random_seed=42):
        save_path = self.at_exp_start(exp_name, random_seed)
        for epoch in tqdm(range(epoch_start, epoch_end)):
            self.model.train()
            run_epoch(epoch, save_path, save_interval, phase='train')
            self.model.eval()
            with torch.no_grad():
                run_epoch(epoch, save_path, save_interval, phase='test')
            if self.scheduler is not None: 
                self.scheduler.step()
                
    def at_exp_start(self, exp_name, random_seed):
        self.manual_seed(random_seed)
        date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        base_path = os.path.join(os.getcwd(), f'exps/{exp_name}/{date}')
        save_path = f'{base_path}/checkpoints'
        os.makedirs(save_path)
        self.t_logger = TensorboardPyTorch(f'{base_path}/tensorboard', self.device)
        return save_path

    def run_epoch(self, epoch, save_path, save_interval, phase):
        running_acc = 0.0
        running_loss = 0.0
        loader_size = len(self.loaders[phase])
        for i, (x_true, y_true) in enumerate(self.loaders[phase]):
            x_true, y_true = x_true.to(self.device), y_true.to(self.device)
            y_pred = self.model(x_true)
            loss = self.criterion(y_pred, y_true)
            if 'train' in phase:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                
            batch_acc = (torch.argmax(y_pred.detach.data, dim=1) == y_true).mean().item()
            batch_loss = loss.item()
            self.t_logger.log_scalar(f'StepAcc/{phase}', round(batch_acc, 4), i + 1 + epoch * loader_size)
            self.t_logger.log_scalar(f'StepLoss/{phase}', round(batch_loss, 4), i + 1 + epoch * loader_size)
            running_acc += batch_acc
            running_loss += batch_loss
            # loggers
            if (i + 1) % 10 == 0:
                tmp_acc = running_acc / 10
                tmp_loss = running_loss / 10
                self.t_logger.log_scalar(f'RunningAcc/{phase}', round(tmp_acc, 4), i + 1 + epoch * loader_size)
                self.t_logger.log_scalar(f'RunningLoss/{phase}', round(tmp_loss, 4), i + 1 + epoch * loader_size)
                running_acc = 0.0
                running_loss = 0.0
                
            if (i + 1) % save_interval == 0:
                self.save_net(save_path)  
                
    def save_net(self, path):
        torch.save(self.model.state_dict(), path)
        
    def manual_seed(self, random_seed):
        import numpy as np
        torch.manual_seed(random_seed)
        torch.cuda.manual_seed(random_seed)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False
        np.random.seed(random_seed)

In [None]:
class Resnet50_ft(nn.Module):
    def __init__(self, ptr_model):
        super().__init__()
        self.ptr_model = ptr_model
        self.ptr_model.fc = Linear(in_features=2048, out_features=28, bias=True)
        
    def forward(self, x):
        x = self.ptr_model(x)
        return x
        

In [26]:
# create model
from torchvision import models
ptr_model = models.resnet50(pretrained=False)
ptr_model.load_state_dict(torch.load('resnet50-0676ba61.pth'))
for name, param in ptr_model.named_parameters():
    if 'fc' not in name:
        param.requires_grad = False

In [29]:
# create loaders
from torchvision.datasets import ImageFolder
train_dataset = ImageFolder(train_root, train_transform)
test_dataset = ImageFolder(test_root, test_transform)

from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

loaders = {
    'train': train_loader,
    'test': test_loader
}

In [None]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, ptr_model.parameters()), lr=1e-2)
trainer = Trainer(ptr_model, loaders, criterion, optim, device=device)

In [None]:
%tensorboard --logdir=data

In [None]:
# run experiment
trainer.run_exp(epoch_start=0, epoch_end=5, exp_name='resnet50', save_interval=5)