In [1]:
import torch
import torch.nn.functional as F

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms as T, datasets

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

from tqdm.notebook import tqdm
from sklearn.manifold import TSNE

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


# Preparing dataset

In [2]:
train_dataset = datasets.MNIST(root='./sample_data', train=True, transform=T.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./sample_data', train=False, transform=T.ToTensor(), download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./sample_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]


Extracting ./sample_data/MNIST/raw/train-images-idx3-ubyte.gz to ./sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./sample_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 488kB/s]


Extracting ./sample_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./sample_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]


Extracting ./sample_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./sample_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.03MB/s]

Extracting ./sample_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./sample_data/MNIST/raw






# Additive Angular Margin Penalty

In [None]:
class AdditiveAngularMarginPenalty(nn.Module):
    """
        Insightface implementation : https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/losses.py
        ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
    """
    def __init__(self, s=64.0, margin=0.5):
        super(AdditiveAngularMarginPenalty, self).__init__()
        self.s = s
        self.margin = margin

        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.theta = math.cos(math.pi - margin)
        self.sinmm = math.sin(math.pi - margin) * margin
        self.easy_margin = False

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]

        with torch.no_grad():
            target_logit.arccos_()
            logits.arccos_()
            final_target_logit = target_logit + self.margin
            logits[index, labels[index].view(-1)] = final_target_logit
            logits.cos_()
        logits = logits * self.s
        return logits

In [17]:
import torch
import math
import torch.nn as nn

class AdditiveAngularMarginPenalty(nn.Module):
    def __init__(self, s=64.0, margin=0.5):
        super(AdditiveAngularMarginPenalty, self).__init__()
        self.s = s
        self.margin = margin

        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]

        target_logit = torch.clamp(target_logit, -1.0, 1.0)

        # Apply margin
        with torch.no_grad():
            target_logit = torch.arccos(target_logit)
            target_logit = target_logit + self.margin
            target_logit = torch.cos(target_logit)

        logits[index, labels[index].view(-1)] = target_logit

        # Scale logits
        logits = logits * self.s
        return logits


# Example CNN model

In [18]:
class ToyMNISTModel(nn.Module):
  def __init__(self):
    super(ToyMNISTModel, self).__init__()

    self.conv1 = nn.Conv2d(1, 32, 5)
    self.conv2 = nn.Conv2d(32, 32, 5)
    self.conv3 = nn.Conv2d(32, 64, 5)
    self.dropout = nn.Dropout(0.25)
    self.fc1 = nn.Linear(3*3*64, 256)
    self.fc2 = nn.Linear(256, 10)
    self.angular_margin_penalty = AdditiveAngularMarginPenalty(10, 10)
    self.relu = nn.ReLU(inplace=True)
    self.maxpooling = nn.MaxPool2d(2, 2)

  def forward(self, x, label=None):
    # CNN part
    x = self.relu(self.conv1(x))
    x = self.dropout(x)
    x = self.relu(self.maxpooling(self.conv2(x)))
    x = self.dropout(x)
    x = self.relu(self.maxpooling(self.conv3(x)))
    x = self.dropout(x)

    # fully connected part
    x = x.view(x.size(0), -1)    # (batch_size, 3*3*64)
    x = self.relu(self.fc1(x))
    x = self.fc2(x)

    if label is not None:
      # angular margin penalty part
      logits = self.angular_margin_penalty(x, label)
    else:
      logits = x

    return logits

In [19]:
model = ToyMNISTModel()
model.to(device)

ToyMNISTModel(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (dropout): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=576, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
  (angular_margin_penalty): AdditiveAngularMarginPenalty()
  (relu): ReLU(inplace=True)
  (maxpooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

# Training

In [20]:
epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [22]:
model.train()

for e in range(epochs):
  print('epochs: ', e)
  for idx, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = model(images, labels)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

  with torch.no_grad():

    losses = []
    accs = 0.0

    for j, (val_images, val_labels) in enumerate(test_loader):
      val_images = val_images.to(device)
      val_labels = val_labels.to(device)

      outputs = model(val_images, val_labels)
      loss = criterion(outputs, val_labels)

      pred = outputs.argmax(dim=1, keepdim=True)
      acc = pred.eq(val_labels.view_as(pred)).sum().item()
      accs += acc

      losses.append(loss.item())

    loss = np.mean(losses)
    acc = accs / len(test_dataset)
    print(f'{len(test_dataset), len(losses), losses[0: 5], accs}\n')
    print(f'validation loss: {loss}, validation acc: {acc}')

epochs:  0
(10000, 157, [1.4901159417490817e-08, 2.2351738238057806e-08, 2.2351738238057806e-08, 4.284081711602994e-08, 1.303851426825986e-08], 10000.0)

validation loss: 1.913660006679059e-08, validation acc: 1.0
epochs:  1
(10000, 157, [0.0, 0.0, 0.0, 1.8626450382086546e-09, 1.8626450382086546e-09], 10000.0)

validation loss: 1.6965492280668524e-09, validation acc: 1.0
epochs:  2
(10000, 157, [0.0, 0.0, 0.0, 3.7252898543727042e-09, 0.0], 10000.0)

validation loss: 3.2032748569730926e-10, validation acc: 1.0
epochs:  3
(10000, 157, [0.0, 0.0, 0.0, 0.0, 0.0], 10000.0)

validation loss: 1.1863980852227603e-10, validation acc: 1.0
epochs:  4
(10000, 157, [0.0, 0.0, 0.0, 0.0, 0.0], 10000.0)

validation loss: 1.1863981135086972e-11, validation acc: 1.0
epochs:  5
(10000, 157, [0.0, 0.0, 0.0, 0.0, 0.0], 10000.0)

validation loss: 0.0, validation acc: 1.0
epochs:  6
(10000, 157, [0.0, 0.0, 0.0, 0.0, 0.0], 10000.0)

validation loss: 1.1863981135086972e-11, validation acc: 1.0
epochs:  7
(1000

In [23]:
torch.save(model.state_dict(), 'model.pt')

In [24]:
activations = {}

def get_activation(name):
  def hook(model, input, output):
    activations[name] = output.detach()
  return hook

In [25]:
h1 = model.fc2.register_forward_hook(get_activation('fc2'))

# Image feature visualization with t-SNE

In [26]:
with torch.no_grad():

    image_features = []
    labels = []

    for j, (val_images, val_labels) in enumerate(test_loader):
      val_images = val_images.to(device)
      val_labels = val_labels.to(device)

      outputs = model(val_images, val_labels)

      image_features.append(activations['fc2'].cpu().numpy())
      labels.append(val_labels.cpu().numpy())

    image_features = np.concatenate(image_features, axis=0)
    labels = np.concatenate(labels, axis=0)

tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(image_features)



[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 10000 samples in 0.013s...
[t-SNE] Computed neighbors for 10000 samples in 0.957s...
[t-SNE] Computed conditional probabilities for sample 1000 / 10000
[t-SNE] Computed conditional probabilities for sample 2000 / 10000
[t-SNE] Computed conditional probabilities for sample 3000 / 10000
[t-SNE] Computed conditional probabilities for sample 4000 / 10000
[t-SNE] Computed conditional probabilities for sample 5000 / 10000
[t-SNE] Computed conditional probabilities for sample 6000 / 10000
[t-SNE] Computed conditional probabilities for sample 7000 / 10000
[t-SNE] Computed conditional probabilities for sample 8000 / 10000
[t-SNE] Computed conditional probabilities for sample 9000 / 10000
[t-SNE] Computed conditional probabilities for sample 10000 / 10000
[t-SNE] Mean sigma: 0.126344
[t-SNE] KL divergence after 250 iterations with early exaggeration: 61.245537
[t-SNE] KL divergence after 300 iterations: 1.956563


In [27]:
fig = px.scatter(x=tsne_results[:, 0], y=tsne_results[:, 1], color=labels)
fig.show()

# Training model without an additive angular margin penalty (normal softmax)

In [None]:
model_softmax = ToyMNISTModel()
model_softmax.to(device)

ToyMNISTModel(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (dropout): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=576, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
  (angular_margin_penalty): AdditiveAngularMarginPenalty()
  (relu): ReLU(inplace=True)
  (maxpooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [None]:
epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_softmax.parameters(), lr=1e-4)

In [None]:
model_softmax.train()

for e in range(epochs):
  print('epochs: ', e)
  for idx, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    outputs = model_softmax(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

  with torch.no_grad():

    losses = []
    accs = 0.0

    for j, (val_images, val_labels) in enumerate(test_loader):
      val_images = val_images.to(device)
      val_labels = val_labels.to(device)

      outputs = model_softmax(val_images)
      loss = criterion(outputs, val_labels)

      pred = outputs.argmax(dim=1, keepdim=True)
      acc = pred.eq(val_labels.view_as(pred)).sum().item()
      accs += acc

      losses.append(loss.item())

    loss = np.mean(losses)
    acc = accs / len(test_dataset)

    print(f'validation loss: {loss}, validation acc: {acc}')

epochs:  0
validation loss: 0.1745920093243669, validation acc: 0.9467
epochs:  1
validation loss: 0.1051772616886694, validation acc: 0.9667
epochs:  2
validation loss: 0.07700401577493472, validation acc: 0.9748
epochs:  3
validation loss: 0.0649395207177468, validation acc: 0.98
epochs:  4
validation loss: 0.05069209292685245, validation acc: 0.9827
epochs:  5
validation loss: 0.048470399874873504, validation acc: 0.9843
epochs:  6
validation loss: 0.04504495392446733, validation acc: 0.9851
epochs:  7
validation loss: 0.04120145650924581, validation acc: 0.9859
epochs:  8
validation loss: 0.0402954799745226, validation acc: 0.9878
epochs:  9
validation loss: 0.040852650085013285, validation acc: 0.9879


In [None]:
activations_softmax = {}

def get_activation_softmax(name):
  def hook(model, input, output):
    activations_softmax[name] = output.detach()
  return hook

In [None]:
h2 = model_softmax.fc2.register_forward_hook(get_activation_softmax('fc2'))

In [None]:
with torch.no_grad():

    image_features = []
    labels = []

    for j, (val_images, val_labels) in enumerate(test_loader):
      val_images = val_images.to(device)
      val_labels = val_labels.to(device)

      outputs = model_softmax(val_images, val_labels)

      image_features.append(activations_softmax['fc2'].cpu().numpy())
      labels.append(val_labels.cpu().numpy())

    image_features = np.concatenate(image_features, axis=0)
    labels = np.concatenate(labels, axis=0)

tsne_ = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results_ = tsne_.fit_transform(image_features)

[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 10000 samples in 0.016s...
[t-SNE] Computed neighbors for 10000 samples in 1.636s...
[t-SNE] Computed conditional probabilities for sample 1000 / 10000
[t-SNE] Computed conditional probabilities for sample 2000 / 10000
[t-SNE] Computed conditional probabilities for sample 3000 / 10000
[t-SNE] Computed conditional probabilities for sample 4000 / 10000
[t-SNE] Computed conditional probabilities for sample 5000 / 10000
[t-SNE] Computed conditional probabilities for sample 6000 / 10000
[t-SNE] Computed conditional probabilities for sample 7000 / 10000
[t-SNE] Computed conditional probabilities for sample 8000 / 10000
[t-SNE] Computed conditional probabilities for sample 9000 / 10000
[t-SNE] Computed conditional probabilities for sample 10000 / 10000
[t-SNE] Mean sigma: 2.889194
[t-SNE] KL divergence after 250 iterations with early exaggeration: 70.615456
[t-SNE] KL divergence after 300 iterations: 2.554014


# Image feature visualization without an additive angular margin penalty

In [None]:
fig = px.scatter(x=tsne_results_[:, 0], y=tsne_results_[:, 1], color=labels)
fig.show()