In [1]:
import random
import torchvision
import torch
import matplotlib.pyplot as plt
import numpy as np
import time
from torch.nn import ReLU, Conv2d, BatchNorm2d, Sequential, AdaptiveAvgPool2d, Linear, MaxPool2d, Flatten, CrossEntropyLoss
try:
    import pytorch_lightning as pl
except:
    !pip install pytorch-lightning
    import pytorch_lightning as pl
try:
    import adabelief_pytorch
except:
    !pip install adabelief_pytorch==0.2.0
    time.sleep(1)
    import adabelief_pytorch
try:
    from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
except:
    !pip install pytorch_metric_learning
    time.sleep(1)
    from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

In [2]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
!nvidia-smi

Tue Feb 21 16:58:27 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 517.48       Driver Version: 517.48       CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
|  0%   34C    P2    25W / 200W |    622MiB /  3072MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
class Embedder(torch.nn.Module):
    def __init__(self, embedding_dim = 128):
        super(Embedder, self).__init__()
        self.__embedding_dim = embedding_dim
        self.__build_resnet()
        
    def __build_resnet(self):
        self.__inner = torchvision.models.resnet18(weights = torchvision.models.ResNet18_Weights.DEFAULT)
        self.__inner.fc = Linear(512, self.__embedding_dim)
        
    @property
    def embedding_dim(self):
        return self.__embedding_dim
    
    def forward(self, x):
        return self.__inner(x)

In [5]:
class TripletLoss(torch.nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        
    def forward(self, batch):
        anchor, positive, negative = batch[0][0], batch[1][0], batch[2][0]
        dist_ap = torch.linalg.nor(anchor - positive, dim=1)
        dist_an = torch.linalg.nor(anchor - negative, dim=1)
        loss = torch.clamp(dist_ap - dist_an + self.margin, min=0)
        loss = torch.mean(loss)
        return loss

In [6]:
class SphericalClassifier(torch.nn.Linear):
    def __CosFace(scale, margin):
        def wrapped_loss(cosines, target):
            one_hot = torch.zeros_like(cosines)
            one_hot.scatter_(1, target.view(-1, 1).long(), 1)
            logits = scale * (cosines - margin * one_hot)
            return logist
        return wrapped_loss
    
    def __SphereFace(scale, margin):
        def wrapper_loss(cosines, target):
            one_hot = torch.zeros_like(cosines)
            one_hot.scatter_(1, target.view(-1, 1).long(), 1)
            
            req_cosines = one_hot * cosines
            req_cosines = torch.sum(req_cosines, dim=1, keepdim=True)
            req_cosines = torch.cos(margin * torch.arccos(req_cosines))
            logits = scale * cosines.scatter(dim=1,
                                             index=target.view(-1, 1).long(),
                                             src=req_cosines)
            return logits
        return wrapped_loss
    
    def __ArcFace(scale, margin):
        def wrapped_loss(cosines, target):
            one_hot = torch.zeros_like(cosines)
            one_hot.scatter_(1, target.view(-1, 1).long(), 1)
            
            req_cosines = one_hot * cosines
            req_cosines = torch.sum(req_cosines, dim=1, keepdim=True)
            req_cosines = torch.cos(margin + torch.arccos(req_cosines))
            logits = scale * cosines.scatter(dim=1,
                                             index=target.view(-1, 1).long(),
                                             src=req_cosines)
            return logits
        return wrapped_loss
    
    __margin_types = {'CosFace' : __CosFace,
                      'SphereFace' : __SphereFace,
                      'ArcFace' : __ArcFace}
    
    def __init__(self, in_features, out_features, scale = 64, margin=0.35, margin_type = 'CosFace'):
        super(SphericalClassifier, self).__init__(in_features, out_features, bias=False)
        self._scale = scale
        self._margin = margin
        self._margin_type = margin_type
        
        self._modified_softmax = self.__make_margin_loss()
        
    def __make_margin_loss(self):
        if not self._margin_type in SphericalClassifier.__margin_types:
            raise ValueError('There is no such type - {}'.format(self._margin_type))
        else:
            return SphericalClassifier.__margin_types[self._margin_type](self._scale, self._margin)
        
    def forward(self, data, target):
        cosines = torch.nn.functional.linear(torch.nn.functional.normalize(data),
                                  torch.nn.functional.normalize(self.weight))
        logits = self._modified_softmax(cosines, target)
        return logits

In [7]:
CRITERIONS = {'CE' : torch.nn.CrossEntropyLoss,
            'TripletLoss' : TripletLoss}

CLASSIFIERS = {'linear' : torch.nn.Linear,
              'SphericalClassifier' : SphericalClassifier}

config = {
    'embedding_dim' : 64,
    'scheduler' : {
        'type' : torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
        'step_every_batch' : True,
        'params' : {
            'eta_min' : 2e-4,
            'T_0' : 1
        }
    },
    'optimization' : {
        'optimizer' : adabelief_pytorch.AdaBelief,
        'params' : {
            'lr' : 1e-3,
            'betas' : (0.9, 0.999),
            'eps' : 1e-8,
            'weight_decay' : 5e-4,
            'weight_decouple' : False,
            'rectify' : False,
            'fixed_decay' : False,
            'amsgrad' : False
        }
    },
    'dataset_params' : {
        'batch_size' : 256
    },
    'criterion' : {
        'type' : 'CE',
        'params' : {}
    },
    'classifier': {
        'type' : 'SphericalClassifier',
        'params' : {
            'scale' : 64,
            'margin' : 0.35,
            'margin_type' : 'CosFace'
        }
    }
}

In [8]:
class CIFAR10Module(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self._config = config
        self.embedder = Embedder(config['embedding_dim'])
        self.classifier = self.get_classifier()
        self.criterion = self.get_criterion()
        self.recall_calculator = AccuracyCalculator(
            include=("precision_at_1", "mean_average_precision_at_r"), k="max_bin_count"
        )
  
    def train_dataset(self):
        transform = torchvision.transforms.Compose(
            [torchvision.transforms.AutoAugment(torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            torchvision.transforms.ToTensor()]
        )
        CIFAR_trainset = torchvision.datasets.CIFAR10(
            root='./data', train=True, download=True, transform=transform
        )
        return CIFAR_trainset

    def train_dataloader(self):
        dataset = self.train_dataset()
        params = self._config["dataset_params"]
        if "use_balanced_sampler" in params and params["use_balanced_sampler"]:
            sampler = ShuffledClassBalancedBatchSampler(
                dataset, params["batch_size"], params["samples_per_class"]
            )
            return torch.utils.data.DataLoader(
                dataset, batch_sampler=sampler
            )
        else:
            return torch.utils.data.DataLoader(
                dataset, params["batch_size"], shuffle=True, num_workers=4
            )
    
    def test_dataset(self):
        CIFAR_testset = torchvision.datasets.CIFAR10(
            root='./data', train=False, download=True,
            transform=torchvision.transforms.ToTensor()
        )
        return CIFAR_testset
  
    def test_dataloader(self):
        dataset = self.test_dataset()
        params = self._config["dataset_params"]
        return torch.utils.data.DataLoader(
            dataset, batch_size = params["batch_size"], num_workers=4
        )
  
    def val_dataloader(self):
        return self.test_dataloader()
  
    def configure_optimizers(self):
        optimizer_params = self._config["optimization"]['params']
        optimizer_class = self._config['optimization']['optimizer']
        optimizer = optimizer_class(self.parameters(), **optimizer_params)
        
        scheduler_class = self._config["scheduler"]['type']
        scheduler_params = self._config["scheduler"]['params']
        scheduler = {
            'scheduler' : scheduler_class(optimizer, **scheduler_params),
            'interval' : 'step' if self._config['scheduler']['step_every_batch'] else 'epoch'
        }
        return {'optimizer' : optimizer,
                'lr_scheduler' : scheduler}
    
    def get_criterion(self):
        criterion_type = self._config["criterion"]['type']
        params = self._config['criterion']["params"]
        return CRITERIONS[criterion_type](**params)
    
    def get_classifier(self):
        classifier_type = self._config['classifier']["type"]
        params = self._config['classifier']["params"]
        return CLASSIFIERS[classifier_type](self._config['embedding_dim'], 10, **params)
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self.embedder(images)
        if isinstance(self.criterion, torch.nn.CrossEntropyLoss):
            if isinstance(self.classifier, SphericalClassifier):
                logits = self.classifier(embeddings, labels)
            else:
                logits = self.classifier(embeddings)
            loss = self.criterion(logits, labels)
        elif isinstance(self.criterion, TriptetLoss):
            triplets = generate_triplets((embeddings, labels))
            loss = self.criterion(triplets)
        self.log("loss", loss)
        return {"loss": loss}
    
    def test_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self.embedder(images)
        return {"embeddings": embeddings.cpu(), "labels": labels.cpu()}
    
    def test_epoch_end(self, outputs) -> None:
        embeddings = np.vstack([b["embeddings"].numpy() for b in outputs])
        labels = np.hstack([b["labels"].numpy() for b in outputs])
        if embeddings.shape[1] == 3:
            embeddings = embeddings / np.sqrt((embeddings ** 2).sum(-1))[..., np.newaxis]
        metrics = self.recall_calculator.get_accuracy(
            embeddings, labels,
            embeddings, labels
        )
        self.log("r_at_one", metrics["precision_at_1"])
        self.log("map_at_r", metrics["mean_average_precision_at_r"])
    
    def validation_step(self, batch, batch_idx):
        return self.test_step(batch, batch_idx)
    
    def validation_epoch_end(self, outputs) -> None:
        embeddings = np.vstack([b["embeddings"].numpy() for b in outputs])
        labels = np.hstack([b["labels"].numpy() for b in outputs])
        print(embeddings.shape)
        print(labels.shape)
        metrics = self.recall_calculator.get_accuracy(
            embeddings, labels,
            embeddings, labels
        )
        self.log("val_r_at_one", metrics["precision_at_1"])
        self.log("val_map_at_r", metrics["mean_average_precision_at_r"])

In [9]:
module = CIFAR10Module(config)

In [None]:
logger = pl.loggers.TensorBoardLogger("./logs", name='ce')
trainer = pl.Trainer(
    accelerator="gpu",
    logger=logger,
    log_every_n_steps=10,
    max_epochs=20
)
trainer.fit(module)
trainer.test()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | embedder   | Embedder            | 11.2 M
1 | classifier | SphericalClassifier | 640   
2 | criterion  | CrossEntropyLoss    | 0     
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.840    Total estimated model params size (MB)


[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[0m


Sanity Checking: 0it [00:00, ?it/s]

Files already downloaded and verified
