<a href="https://colab.research.google.com/github/Quillbolt/colabnotebook/blob/main/ptlt_trying.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from typing import Tuple

import torch
from torch import nn, Tensor


def convert_label_to_similarity(normed_feature: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
    similarity_matrix = normed_feature @ normed_feature.transpose(1, 0)
    label_matrix = label.unsqueeze(1) == label.unsqueeze(0)

    positive_matrix = label_matrix.triu(diagonal=1)
    negative_matrix = label_matrix.logical_not().triu(diagonal=1)

    similarity_matrix = similarity_matrix.view(-1)
    positive_matrix = positive_matrix.view(-1)
    negative_matrix = negative_matrix.view(-1)
    return similarity_matrix[positive_matrix], similarity_matrix[negative_matrix]


class CircleLoss(nn.Module):
    def __init__(self, m: float, gamma: float) -> None:
        super(CircleLoss, self).__init__()
        self.m = m
        self.gamma = gamma
        self.soft_plus = nn.Softplus()

    def forward(self, sp: Tensor, sn: Tensor) -> Tensor:
        ap = torch.clamp_min(- sp.detach() + 1 + self.m, min=0.)
        an = torch.clamp_min(sn.detach() + self.m, min=0.)

        delta_p = 1 - self.m
        delta_n = self.m

        logit_p = - ap * (sp - delta_p) * self.gamma
        logit_n = an * (sn - delta_n) * self.gamma

        loss = self.soft_plus(torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0))

        return loss

In [None]:
    feat = nn.functional.normalize(torch.rand(256, 64, requires_grad=True))
    lbl = torch.randint(high=10, size=(256,))

    inp_sp, inp_sn = convert_label_to_similarity(feat, lbl)

    criterion = CircleLoss(m=0.25, gamma=256)
    circle_loss = criterion(inp_sp, inp_sn)

    print(circle_loss)

tensor(209.7686)


In [None]:
import torch
from torch import nn, Tensor
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm
from torchvision import models

In [None]:
def get_loader(is_train: bool, batch_size: int) -> DataLoader:
    return DataLoader(
        dataset=MNIST(root="./data", train=is_train, transform=ToTensor(), download=True),
        batch_size=batch_size,
        shuffle=is_train,
    )


In [None]:
class Model(nn.Module):
    def __init__(self) -> None:
        super(Model, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
        )

    def forward(self, inp: Tensor) -> Tensor:
        feature = self.feature_extractor(inp).mean(dim=[2, 3])
        return nn.functional.normalize(feature)

In [None]:
model = Model()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
train_loader = get_loader(is_train=True, batch_size=64)
val_loader = get_loader(is_train=False, batch_size=2)
criterion = CircleLoss(m=0.25, gamma=80)

# if resume and os.path.exists("resume.state"):
#     model.load_state_dict(torch.load("resume.state"))
# else:
for epoch in range(20):
    for img, label in tqdm(train_loader):
        model.zero_grad()
        pred = model(img)
        loss = criterion(*convert_label_to_similarity(pred, label))
        loss.requires_grad=True
        loss.backward()
        optimizer.step()
# torch.save(model.state_dict(), "resume.state")

tp = 0
fn = 0
fp = 0
thresh = 0.75
for img, label in val_loader:
    pred = model(img)
    gt_label = label[0] == label[1]
    pred_label = torch.sum(pred[0] * pred[1]) > thresh
    if gt_label and pred_label:
        tp += 1
    elif gt_label and not pred_label:
        fn += 1
    elif not gt_label and pred_label:
        fp += 1

print("Recall: {:.4f}".format(tp / (tp + fn)))
print("Precision: {:.4f}".format(tp / (tp + fp)))


100%|██████████| 938/938 [00:11<00:00, 78.19it/s]
100%|██████████| 938/938 [00:11<00:00, 78.83it/s]
100%|██████████| 938/938 [00:11<00:00, 78.75it/s]
100%|██████████| 938/938 [00:11<00:00, 79.19it/s]
100%|██████████| 938/938 [00:12<00:00, 77.94it/s]
100%|██████████| 938/938 [00:12<00:00, 77.66it/s]
100%|██████████| 938/938 [00:11<00:00, 78.21it/s]
100%|██████████| 938/938 [00:11<00:00, 78.66it/s]
100%|██████████| 938/938 [00:12<00:00, 77.26it/s]
100%|██████████| 938/938 [00:11<00:00, 78.36it/s]
100%|██████████| 938/938 [00:12<00:00, 78.00it/s]
100%|██████████| 938/938 [00:12<00:00, 78.10it/s]
100%|██████████| 938/938 [00:12<00:00, 77.58it/s]
100%|██████████| 938/938 [00:12<00:00, 77.66it/s]
100%|██████████| 938/938 [00:12<00:00, 77.09it/s]
100%|██████████| 938/938 [00:12<00:00, 77.58it/s]
100%|██████████| 938/938 [00:12<00:00, 77.33it/s]
100%|██████████| 938/938 [00:11<00:00, 79.66it/s]
100%|██████████| 938/938 [00:12<00:00, 77.83it/s]
100%|██████████| 938/938 [00:12<00:00, 76.28it/s]


Recall: 1.0000
Precision: 0.0810


In [None]:
!pip install wandb
!pip install pytorch-lightning

Collecting wandb
[?25l  Downloading https://files.pythonhosted.org/packages/ca/5e/9df94df3bfee51b92b54a5e6fa277d6e1fcdf1f27b1872214b98f55ec0f7/wandb-0.10.12-py2.py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 8.5MB/s 
Collecting subprocess32>=3.5.3
[?25l  Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)
[K     |████████████████████████████████| 102kB 11.5MB/s 
[?25hCollecting docker-pycreds>=0.4.0
  Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl
Collecting shortuuid>=0.5.0
  Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl
Collecting watchdog>=0.8.3
[?25l  Downloading https://files.pythonhosted.org/packages/e6/76/39d123d37908a772b6a281d85fbb

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
wandb.login()
wandb_logger = WandbLogger(project="MNIST")


[34m[1mwandb[0m: Currently logged in as: [33mquillbolt[0m (use `wandb login --relogin` to force relogin)


In [None]:
class LitClassifier(pl.LightningModule):
    def __init__(self, backbone):
        super().__init__()
        self.save_hyperparameters()
        self.backbone = backbone

    def forward(self, x):
        # use forward for inference/predictions
        embedding = self.backbone(x)
        return embedding

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)
        # loss.requires_grad=True
        loss = criterion(*convert_label_to_similarity(y_hat, y))
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)
        # loss.requires_grad=True
        loss = criterion(*convert_label_to_similarity(y_hat, y))
        self.log('valid_loss', loss, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)
        # loss.requires_grad=True
        loss = criterion(*convert_label_to_similarity(y_hat, y))
        self.log('test_loss', loss)

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)

In [None]:
train_loader = get_loader(is_train=True, batch_size=64)
val_loader = get_loader(is_train=False, batch_size=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
 model = LitClassifier(Model())
 criterion = CircleLoss(m=0.25, gamma=80)

In [None]:
 trainer = pl.Trainer(logger=wandb_logger, auto_select_gpus=True,                # use all GPU's
    max_epochs=20)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores


In [None]:
 trainer.fit(model, train_loader, val_loader)


  | Name     | Type  | Params
-----------------------------------
0 | backbone | Model | 8.1 K 
-----------------------------------
8.1 K     Trainable params
0         Non-trainable params
8.1 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[1;30;43mStreaming output truncated to the last 5000 lines.[0m


Buffered data was truncated after reaching the output size limit.