In [9]:
from torchvision import datasets
from torchvision import transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.models import efficientnet_b0
from sklearn.metrics import accuracy_score
import torch
import os
from tqdm.notebook import tqdm
from torch import nn, optim 
import math
import imgaug.augmenters as iaa
from random import randint, sample

from PIL.Image import fromarray
import cv2
from scipy.spatial.distance import cosine
import pandas as pd
from sklearn.model_selection import train_test_split
from os.path import join

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


pwd = os.getcwd()

In [10]:
class AdaCos(nn.Module):
    def __init__(self, feat_dim, num_classes, fixed_scale=False):
        super(AdaCos, self).__init__()
        self.fixed_scale = fixed_scale
        self.scale = math.sqrt(2) * math.log(num_classes - 1)
        self.W = nn.Parameter(torch.FloatTensor(num_classes, feat_dim))
        nn.init.xavier_uniform_(self.W)
        
    def forward(self, feats, labels):
        W = F.normalize(self.W)

        logits = F.linear(feats, W)

        theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
        one_hot = torch.zeros_like(logits)
        one_hot.scatter_(1, labels.view(-1, 1).long(), 1)

        if self.fixed_scale:
            with torch.no_grad():
                B_avg = torch.where(one_hot < 1, torch.exp(self.scale * logits), torch.zeros_like(logits))
                B_avg = torch.sum(B_avg) / feats.size(0)
                
                theta_med = torch.median(theta[one_hot == 1])
                self.scale = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med))
            
        output = self.scale * logits
        return output

In [11]:
class Net(nn.Module):
    def __init__(self, num_classes):
        super(Net, self).__init__()
                
        self.backbone = efficientnet_b0(pretrained=True)
        
        self.bn1 = nn.BatchNorm1d(1000)
        self.fc1 = nn.Linear(1000, 384)

        self.arc_face = AdaCos(384, num_classes, fixed_scale=True)
        
    def forward(self, x, targets = None):
        x = self.backbone(x)

        x = F.relu(self.fc1(self.bn1(x)))
        x = F.normalize(x)
        
        if targets is not None:
            logits = self.arc_face(x, targets)
            return logits

        return x

input_size = (224, 224)

In [29]:
class Trainer():
    
    def __init__(self, criterion = None, optimizer = None, device = None, start_epoch=0):
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.start_epoch = start_epoch
        
        
    def accuracy(self, logits, labels):
        ps = torch.argmax(logits,dim = 1).detach().cpu().numpy()
        acc = accuracy_score(ps,labels.detach().cpu().numpy())
        return acc

        
    def train_batch_loop(self, model, train_loader, i, save_path=None):
        
        epoch_loss = 0.0
        epoch_acc = 0.0
        pbar_train = tqdm(train_loader, desc="Epoch" + " [TRAIN] " + str(i+1))
        batch_num = len(pbar_train)
        for it, data in enumerate(pbar_train):
            
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images, labels)
            loss = self.criterion(logits,labels)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            epoch_loss += loss.item()
            epoch_acc += self.accuracy(logits, labels)
            
            postfix = {'loss' : round(float(epoch_loss/(it+1)), 4), 'acc' : float(epoch_acc/(it+1))}
            pbar_train.set_postfix(postfix)
            
            if save_path is not None:
                if it % 200 == 199:
                    with open(save_path + 'train_log.txt', 'a') as f:
                        f.write(f'B# {it+1}/{batch_num}, Loss: {round(float(epoch_loss/(it+1)), 4)}, Acc: {round(float(epoch_acc/(it+1)), 4)} \n')
                
                if it % 2000 == 999:
                    torch.save(model, save_path + 'model_.pth')
                
            
        return epoch_loss / len(train_loader), epoch_acc / len(train_loader)
            
    
    def valid_batch_loop(self, model, valid_loader, i, save_path=None):
        
        epoch_loss = 0.0
        epoch_acc = 0.0
        pbar_valid = tqdm(valid_loader, desc = "Epoch" + " [VALID] " + str(i+1))
        batch_num = len(pbar_valid)
        
        for it, data in enumerate(pbar_valid):
            
            images,labels = data
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images, labels)
            loss = self.criterion(logits, labels)
            
            epoch_loss += loss.item()
            epoch_acc += self.accuracy(logits, labels)
            
            postfix = {'loss' : round(float(epoch_loss/(it+1)), 4), 'acc' : float(epoch_acc/(it+1))}
            pbar_valid.set_postfix(postfix)
            
            
            if save_path is not None:
                if it % 200 == 199:
                    with open(save_path + 'valid_log.txt', 'a') as f:
                        f.write(f'B# {it+1}/{batch_num}, Loss: {round(float(epoch_loss/(it+1)), 4)}, Acc: {round(float(epoch_acc/(it+1)), 4)} \n')
            
        return epoch_loss / len(valid_loader), epoch_acc / len(valid_loader)
            
    
    def run(self, model, train_loader, valid_loader=None, schedule=None, epochs=1, save_path=None):
        if not os.path.exists(save_path) and save_path is not None:
            os.mkdir(save_path)
        
        if schedule is not None:
            if len(schedule) != epochs:
                raise Exception('Scedule lenght must be equal epoch num')
        
        
        for i in range(self.start_epoch, self.start_epoch + 1, 1):
            if save_path is not None:
                epoch_save_path = join(save_path, f'epoch_{i}/')
                if not os.path.exists(epoch_save_path):
                    os.mkdir(epoch_save_path)
            else:
                epoch_save_path = None
            
            if schedule is not None:
                for g in self.optimizer.param_groups:
                    g['lr'] = schedule[i]
            
            model.train()
            avg_train_loss, avg_train_acc = self.train_batch_loop(model, train_loader, i, save_path=epoch_save_path)
            
            if save_path is not None:
                torch.save(model, epoch_save_path + 'model.pth')
            
            if valid_loader is not None:
                model.eval()
                avg_valid_loss, avg_valid_acc = self.valid_batch_loop(model, valid_loader, i, save_path=epoch_save_path)
            
        return model
    
    def run_eval(self, model, data_lodaer):
        model.eval()
        avg_valid_loss, avg_valid_acc = self.valid_batch_loop(model, data_lodaer, 0)
        return avg_valid_loss, avg_valid_acc

In [30]:
class ImageDataset(Dataset):
  def __init__(self, csv, img_folder, transform=None):
    self.transform = transform
    self.img_folder = img_folder
     
    self.images = csv['image']
    self.targets = csv['Y']
   

  def __len__(self):
    return len(self.images)
 

  def __getitem__(self, index):

    image = cv2.cvtColor(cv2.imread(join(self.img_folder, self.images[index])), cv2.COLOR_BGR2RGB)
    target = self.targets[index]
     
    if self.transform is not None:
        image = self.transform(image)
    
    return image, target

In [31]:
csv_path = join(pwd, 'csv/train.csv')
img_data = join(pwd, '../train_images-256-256')

In [33]:
data_csv = pd.read_csv(csv_path)

transforms_list = T.Compose([             
    iaa.Sequential([
        iaa.Sometimes(0.15, iaa.AddToSaturation((-10, 10))),
        iaa.Sometimes(0.15, iaa.Crop(percent=(0.02, 0.05), keep_size=True)),
        iaa.size.Resize(input_size, interpolation='cubic')
    ]).augment_image,     
    T.ToTensor()
])

train_dataset = ImageDataset(data_csv,
                             img_data,
                             transform=transforms_list)

In [34]:
batch_size = 64
start_epoch = 3
num_epochs = 3
lr = 0.0001
schedule = [0.001, 0.00075, 0.0005]
num_classes = data_csv['individual_id'].nunique()
save_path = join(pwd, '../models/arcface_fixed_64_bn_384')

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)

In [35]:
model = Net(num_classes=num_classes).to(device)
model = torch.load('/content/models/arcface_fixed_48_bn_384_epoch_2/model.pth')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

trainer = Trainer(criterion=criterion,
                  optimizer=optimizer,
                  device=device,
                  start_epoch=start_epoch)

In [36]:
trainer.run(model, train_loader, epochs=num_epochs, save_path=save_path)

Epoch [TRAIN] 4:   0%|          | 0/798 [00:00<?, ?it/s]