In [None]:
import import_ipynb
import argparse, os
from glob import glob
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from model_load import *
from dataloader import *
from torchsummary import summary as summary_
import math

In [None]:
%config Completer.use_jedi = False

In [None]:
#cifar100_train, cifar100_valid = cifar100(224, 32, 4)
cifar10_train, cifar10_valid = cifar10(224, 128, workers = 4)

In [None]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [None]:
class ElasticLoss(nn.Module):
    def __init__(self, n_class=10, s=30.0, m=0.2,std=0.1,plus=False):
        super(ElasticLoss, self).__init__()
        #self.in_features = in_features
        self.n_class = n_class
        self.s = s
        self.m = m
        self.weight = torch.nn.Parameter(torch.FloatTensor(n_class, 192), requires_grad=True)
        nn.init.xavier_normal_(self.weight, gain=1)
        self.std=std
        self.plus=plus
        self.ce = nn.CrossEntropyLoss()
        
    def forward(self, embeddings, label):
        cos_theta = F.linear(F.normalize(embeddings), F.normalize(self.weight))
        cos_theta = cos_theta.clamp(-1, 1)  # for numerical stability
        index = torch.where(label != -1)[0]
        m_hot = torch.zeros(index.size()[0], cos_theta.size()[1], device=cos_theta.device)
        margin = torch.normal(mean=self.m, std=self.std, size=label[index, None].size(), device=cos_theta.device).clamp(self.m-self.std, self.m+self.std) # Fast converge 
        if self.plus:
            with torch.no_grad():
                distmat = cos_theta[index, label.view(-1)].detach().clone()
                _, idicate_cosie = torch.sort(distmat, dim=0, descending=True)
                margin, _ = torch.sort(margin, dim=0)
            m_hot.scatter_(1, label[index, None], margin[idicate_cosie])
        else:
            m_hot.scatter_(1, label[index, None], margin)
        cos_theta.acos_()
        cos_theta[index] += m_hot
        cos_theta.cos_().mul_(self.s)
        
        loss = self.ce(cos_theta, label)

        return loss

In [None]:
def train(model_name, save_name, train_loader, valid_loader, epochs, loss_n, lr):
    device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
    
    if loss_n == 'ce':
        criterion = nn.CrossEntropyLoss().to(device)
    elif loss_n == 'ls':
        criterion = LabelSmoothingLoss(10, smoothing=0.1).to(device)
    else:
        criterion = ElasticLoss(10).to(device)
        
    model.to(device)   
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=1)
    best_loss = 10.0
    best_acc = 0.0
    v_loss = 0
    tt = []
    vv = []
    for epoch in range(epochs):
        model.train()
        scheduler.step(v_loss)
        running_loss = 0.0
        for img, label in tqdm(train_loader):
            img, label = img.to(device), label.to(device)
            optimizer.zero_grad()
            
            y_pred = model(img)
            loss = criterion(y_pred, label)
            
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            p_lr = optimizer.param_groups[0]['lr']
        print(f'epoch:{epoch+1}, lr:{p_lr:.6f},Train loss:{running_loss/len(train_loader): .4f}')
                
        model.eval()
        valid_loss = 0.0
        correct = 0
        
        with torch.no_grad():
            for img, label in valid_loader:
                img, label = img.to(device), label.to(device)
                y_pred = model(img)
                valid_loss += criterion(y_pred, label)
                pred = y_pred.argmax(dim = 1, keepdim = True)
                correct += pred.eq(label.view_as(pred)).sum().item()
            valid_acc = 100*correct/len(valid_loader.dataset)
            v_loss = valid_loss/len(valid_loader)
            print(f'valid_loss : {v_loss:.4f}, ACC : {valid_acc:.4f}')
            
            if loss_n in ['ce', 'ls']:
                if best_acc < valid_acc:
                    best_acc = valid_acc
                    torch.save(model.state_dict(), os.path.join('.', save_name+'.pth'))
                    print('Model saved')
            else:
                if best_loss > v_loss:
                    best_loss = v_loss
                    torch.save(model.state_dict(), os.path.join('.', save_name+'.pth'))
                    print('Model saved')
                    
        tt.append(running_loss/len(train_loader))
        vv.append(v_loss)
    return tt, vv

In [None]:
model = model_load('efficientnet_b0', False, 10)

In [None]:
summary_(model.cuda(), (3,224,224))

In [None]:
train_loss_ce, valid_loss_ce = train(model, 'efficientnet_b0', cifar10_train, cifar10_valid, 50, 'ce', 1e-3)