In [1]:
from torchvision import datasets
from torchvision import transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.models import efficientnet_b0
from sklearn.metrics import accuracy_score
import torch
import os
from tqdm.notebook import tqdm
from torch import nn, optim 
import math
import imgaug.augmenters as iaa
from random import randint, sample

from PIL.Image import fromarray
import cv2
from scipy.spatial.distance import cosine
import pandas as pd
from sklearn.model_selection import train_test_split
from os.path import join
from torch import nn
import numpy as np
import json
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

device = xm.xla_device()



In [2]:
class AdaCos(nn.Module):
    def __init__(self, feat_dim, num_classes, fixed_scale=False):
        super(AdaCos, self).__init__()
        self.fixed_scale = fixed_scale
        self.scale = math.sqrt(2) * math.log(num_classes - 1)
        self.W = nn.Parameter(torch.FloatTensor(num_classes, feat_dim))
        nn.init.xavier_uniform_(self.W)
        
    def forward(self, feats, labels):
        W = F.normalize(self.W)

        logits = F.linear(feats, W)

        theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
        one_hot = torch.zeros_like(logits)
        one_hot.scatter_(1, labels.view(-1, 1).long(), 1)

        if self.fixed_scale:
            with torch.no_grad():
                B_avg = torch.where(one_hot < 1, torch.exp(self.scale * logits), torch.zeros_like(logits))
                B_avg = torch.sum(B_avg) / feats.size(0)
                
                theta_med = torch.median(theta[one_hot == 1])
                self.scale = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med))
            
        output = self.scale * logits
        return output
    
    def get_logits(self, feats):
        W = F.normalize(self.W)

        logits = F.linear(feats, W)
        return logits


In [3]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [11]:
class Net(nn.Module):
    def __init__(self, num_classes):
        super(Net, self).__init__()
                
        self.backbone = torch.nn.Sequential(*(list(efficientnet_b0(pretrained=True).children())[:-2]))
        self.gem_pool = GeM()
        self.bn1 = nn.BatchNorm1d(1280)
        self.fc1 = nn.Linear(1280, 512)
        self.dropout = nn.Dropout(0.2)

        self.arc_face = AdaCos(512, num_classes)
        
    def forward(self, x, targets = None):
        x = torch.squeeze(torch.squeeze(self.gem_pool(self.backbone(x)), -1), -1)

        x = F.relu(self.fc1(self.dropout(self.bn1(x))))
        x = F.normalize(x)
        
        if targets is not None:
            logits = self.arc_face(x, targets)
            return logits

        return x
    
    def get_logits(self, x):
        x = self.gem_pool(self.backbone(x))
        x = torch.unsqueeze(torch.squeeze(x), 0)
        x = F.relu(self.fc1(self.dropout(self.bn1(x))))
        x = F.normalize(x)

        logits = self.arc_face.get_logits(x)
        return logits
        
input_size = (224, 224)

In [None]:
csv_path = join('/content/happywhale/data/train.csv')
img_data = join('/content/train_images-256-256')
data_csv = pd.read_csv(csv_path)

In [None]:
class ImageDataset(Dataset):
  def __init__(self, csv, img_folder, transform=None):
    self.transform = transform
    self.img_folder = img_folder
     
    self.images = csv['image']
    self.targets = csv['Y']
   

  def __len__(self):
    return len(self.images)
 

  def __getitem__(self, index):

    image = cv2.cvtColor(cv2.imread(join(self.img_folder, self.images[index])), cv2.COLOR_BGR2RGB)
    target = self.targets[index]
     
    if self.transform is not None:
        image = self.transform(image)
    
    return image, target


transforms_list = T.Compose([             
    iaa.Sequential([
        iaa.Sequential([
        iaa.Sometimes(0.3, iaa.AverageBlur(k=(3,3))),
        iaa.Sometimes(0.3, iaa.MotionBlur(k=(3,5))),
        iaa.Add((-10, 10), per_channel=0.5),
        iaa.Multiply((0.9, 1.1), per_channel=0.5),
        iaa.Sometimes(0.3, iaa.Affine(
            scale={'x': (0.9,1.1), 'y': (0.9,1.1)},
            translate_percent={'x': (-0.05,0.05), 'y': (-0.05,0.05)},
            shear=(-10,10),
            rotate=(-10,10)
            )),
        iaa.Sometimes(0.3, iaa.Grayscale(alpha=(0.8,1.0))),
        ], random_order=True),
        iaa.size.Resize(input_size, interpolation='cubic')
    ]).augment_image,     
    T.ToTensor()
])

train_dataset = ImageDataset(data_csv,
                             img_data,
                             transform=transforms_list)

In [35]:
data_csv['individual_id'].nunique()

15587

In [6]:
batch_size = 32
start_epoch = 0
num_epochs = 10
lr = 0.0001
# schedule = [0.001, 0.00075, 0.0005]
num_classes = 15587
# save_path = join(pwd, '../models/renet_50')
lr_start   = 0.000001
lr_max     = 0.000005 * batch_size
lr_min     = 0.000001
lr_ramp_ep = 4
lr_sus_ep  = 0
lr_decay   = 0.9

num_workers = 1
num_cores = 8

In [7]:
# Define Parameters
FLAGS = {}
FLAGS['batch_size'] = 32
FLAGS['num_workers'] = 1
FLAGS['learning_rate'] = 0.0001
FLAGS['num_epochs'] = 10
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 10
FLAGS['data_csv_path'] = '/content/happywhale/data/train.csv'
FLAGS['images_path'] = '/content/train_images-256-256'
FLAGS['input_size'] = input_size

In [8]:
import time

In [12]:
SERIAL_EXEC = xmp.MpSerialExecutor()



# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(Net(num_classes=15587))

def train_mnist():


  torch.manual_seed(1)
  
  def get_dataset():
    data_csv = pd.read_csv(FLAGS['data_csv_path'])

    class ImageDataset(Dataset):
      def __init__(self, csv, img_folder, transform=None):
        self.transform = transform
        self.img_folder = img_folder
        
        self.images = csv['image']
        self.targets = csv['Y']
      

      def __len__(self):
        return len(self.images)
    

      def __getitem__(self, index):

        image = cv2.cvtColor(cv2.imread(join(self.img_folder, self.images[index])), cv2.COLOR_BGR2RGB)
        target = self.targets[index]
        
        if self.transform is not None:
            image = self.transform(image)
        
        return image, target


    transforms_list = T.Compose([             
        iaa.Sequential([
            iaa.Sequential([
            iaa.Sometimes(0.3, iaa.AverageBlur(k=(3,3))),
            iaa.Sometimes(0.3, iaa.MotionBlur(k=(3,5))),
            iaa.Add((-10, 10), per_channel=0.5),
            iaa.Multiply((0.9, 1.1), per_channel=0.5),
            iaa.Sometimes(0.3, iaa.Affine(
                scale={'x': (0.9,1.1), 'y': (0.9,1.1)},
                translate_percent={'x': (-0.05,0.05), 'y': (-0.05,0.05)},
                shear=(-10,10),
                rotate=(-10,10)
                )),
            iaa.Sometimes(0.3, iaa.Grayscale(alpha=(0.8,1.0))),
            ], random_order=True),
            iaa.size.Resize(FLAGS['input_size'], interpolation='cubic')
        ]).augment_image,     
        T.ToTensor()
    ])

    train_dataset = ImageDataset(data_csv,
                             FLAGS['images_path'],
                             transform=transforms_list)
    
    return train_dataset
  
  # Using the serial executor avoids multiple processes to
  # download the same data.
  train_dataset = SERIAL_EXEC.run(get_dataset)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)

  # Scale learning rate to world size
  lr = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.CrossEntropyLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    epoch_loss = 0.0
    epoch_acc = 0.0
    pbar_train = tqdm(train_loader, desc="Epoch" + " [TRAIN] " + str(1))
    batch_num = len(pbar_train)
    for it, (images, labels) in enumerate(pbar_train):
      images, labels = images.to(device), labels.to(device)
      optimizer.zero_grad()
      logits = model(images, labels)
      loss = criterion(logits, labels)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(batch_size)
      if it % 10 == 9:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), it, loss.item(), tracker.rate(),
            tracker.global_rate(), time.asctime()), flush=True)

  # def test_loop_fn(loader):
  #   total_samples = 0
  #   correct = 0
  #   model.eval()
  #   data, pred, target = None, None, None
  #   for data, target in loader:
  #     output = model(data)
  #     pred = output.max(1, keepdim=True)[1]
  #     correct += pred.eq(target.view_as(pred)).sum().item()
  #     total_samples += data.size()[0]

  #   accuracy = 100.0 * correct / total_samples
  #   print('[xla:{}] Accuracy={:.2f}%'.format(
  #       xm.get_ordinal(), accuracy), flush=True)
  #   return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    # para_loader = pl.ParallelLoader(test_loader, [device])
    # accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))
    # if FLAGS['metrics_debug']:
      # xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

In [13]:
# Start training processes
def _mp_fn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target = train_mnist()
  # if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    # plot_results(data.cpu(), pred.cpu(), target.cpu())

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1,
          start_method='fork')

Epoch [TRAIN] 1:   0%|          | 0/1594 [00:00<?, ?it/s]