In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !tar -xvzf /content/drive/MyDrive/Sejong.tgz
! unzip /content/drive/MyDrive/airi_dataset.zip

In [4]:
!pip install wandb 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.12.21-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 28.9 MB/s 
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
[K     |████████████████████████████████| 181 kB 71.5 MB/s 
[?25hCollecting setproctitle
  Downloading setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.6.0-py2.py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 74.6 MB/s 
[?25hCollecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tensorflow as tf

from torch.autograd import Variable
from torchvision.models import resnet18

# from keras.preprocessing import image

import PIL
import torch
import torchvision

import matplotlib.pyplot as plt
from sklearn import metrics
import wandb
import tqdm
import json

RANDOM_SEED = 777

torch.manual_seed(RANDOM_SEED)

with open('/content/airi_dataset/metainfo.json') as fin:
    metainfo = json.load(fin)

num_to_class = metainfo['num_to_class']

train_size = metainfo['data']['train_size']
test_size = metainfo['data']['test_size']
val_size = metainfo['data']['val_size']

classes_count = len(num_to_class)

In [6]:
from PIL import Image
from sklearn import (manifold, datasets, decomposition, ensemble,
                     discriminant_analysis, random_projection)
import torchvision.transforms.functional as Function
from IPython.display import display
from time import time
from matplotlib import offsetbox
from sklearn.neighbors import DistanceMetric
%matplotlib inline

## NetVLAD

In [7]:
class NetVLAD(nn.Module):
    """NetVLAD layer implementation"""

    def __init__(self, num_clusters=6, dim=128, alpha=100.0,
                 normalize_input=True):
        """
        Args:
            num_clusters : int
                The number of clusters
            dim : int
                Dimension of descriptors
            alpha : float
                Parameter of initialization. Larger value is harder assignment.
            normalize_input : bool
                If true, descriptor-wise L2 normalization is applied to input.
        """
        super(NetVLAD, self).__init__()
        self.num_clusters = num_clusters
        self.dim = dim
        self.alpha = alpha
        self.normalize_input = normalize_input
        self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True)
        self.centroids = nn.Parameter(torch.rand(num_clusters, dim))
        self._init_params()

    def _init_params(self):
        self.conv.weight = nn.Parameter(
            (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1)
        )
        self.conv.bias = nn.Parameter(
            - self.alpha * self.centroids.norm(dim=1)
        )

    def forward(self, x):
        N, C = x.shape[:2]

        if self.normalize_input:
            x = F.normalize(x, p=2, dim=1)  # across descriptor dim

        # soft-assignment
        soft_assign = self.conv(x).view(N, self.num_clusters, -1)
        soft_assign = F.softmax(soft_assign, dim=1)

        x_flatten = x.view(N, C, -1)
        
        # calculate residuals to each clusters
        residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \
            self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
        residual *= soft_assign.unsqueeze(2)
        vlad = residual.sum(dim=-1)

        vlad = F.normalize(vlad, p=2, dim=2)  # intra-normalization
        vlad = vlad.view(x.size(0), -1)  # flatten
        vlad = F.normalize(vlad, p=2, dim=1)  # L2 normalize

        return vlad

In [8]:
class EmbedNet(nn.Module):
    def __init__(self, base_model, net_vlad):
        super(EmbedNet, self).__init__()
        self.base_model = base_model
        self.net_vlad = net_vlad

    def forward(self, x):
        x = self.base_model(x)
        embedded_x = self.net_vlad(x)
        return embedded_x
      
class TripletNet(nn.Module):
    def __init__(self, embed_net):
        super(TripletNet, self).__init__()
        self.embed_net = embed_net

    def forward(self, a, p, n):
        embedded_a = self.embed_net(a)
        embedded_p = self.embed_net(p)
        embedded_n = self.embed_net(n)
        return embedded_a, embedded_p, embedded_n

    def feature_extract(self, x):
        return self.embed_net(x)

In [9]:
class HardTripletLoss(nn.Module):
    """Hard/Hardest Triplet Loss
    (pytorch implementation of https://omoindrot.github.io/triplet-loss)
    For each anchor, we get the hardest positive and hardest negative to form a triplet.
    """
    def __init__(self, margin=0.1, hardest=False, squared=False):
        """
        Args:
            margin: margin for triplet loss
            hardest: If true, loss is considered only hardest triplets.
            squared: If true, output is the pairwise squared euclidean distance matrix.
                If false, output is the pairwise euclidean distance matrix.
        """
        super(HardTripletLoss, self).__init__()
        self.margin = margin
        self.hardest = hardest
        self.squared = squared

    def forward(self, embeddings, labels):
        """
        Args:
            labels: labels of the batch, of size (batch_size,)
            embeddings: tensor of shape (batch_size, embed_dim)
        Returns:
            triplet_loss: scalar tensor containing the triplet loss
        """
        pairwise_dist = _pairwise_distance(embeddings, squared=self.squared)

        if self.hardest:
            # Get the hardest positive pairs
            mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
            valid_positive_dist = pairwise_dist * mask_anchor_positive
            hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True)

            # Get the hardest negative pairs
            mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
            max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)
            anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (
                    1.0 - mask_anchor_negative)
            hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True)

            # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
            triplet_loss = F.relu(hardest_positive_dist - hardest_negative_dist + 0.1)
            triplet_loss = torch.mean(triplet_loss)
        else:
            anc_pos_dist = pairwise_dist.unsqueeze(dim=2)
            anc_neg_dist = pairwise_dist.unsqueeze(dim=1)

            # Compute a 3D tensor of size (batch_size, batch_size, batch_size)
            # triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k
            # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
            # and the 2nd (batch_size, 1, batch_size)
            loss = anc_pos_dist - anc_neg_dist + self.margin

            mask = _get_triplet_mask(labels).float()
            triplet_loss = loss * mask

            # Remove negative losses (i.e. the easy triplets)
            triplet_loss = F.relu(triplet_loss)

            # Count number of hard triplets (where triplet_loss > 0)
            hard_triplets = torch.gt(triplet_loss, 1e-16).float()
            num_hard_triplets = torch.sum(hard_triplets)

            triplet_loss = torch.sum(triplet_loss) / (num_hard_triplets + 1e-16)

        return triplet_loss


def _pairwise_distance(x, squared=False, eps=1e-16):
    # Compute the 2D matrix of distances between all the embeddings.

    cor_mat = torch.matmul(x, x.t())
    norm_mat = cor_mat.diag()
    distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)
    distances = F.relu(distances)

    if not squared:
        mask = torch.eq(distances, 0.0).float()
        distances = distances + mask * eps
        distances = torch.sqrt(distances)
        distances = distances * (1.0 - mask)

    return distances


def _get_anchor_positive_triplet_mask(labels):
    # Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1

    # Check if labels[i] == labels[j]
    labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)

    mask = indices_not_equal * labels_equal

    return mask


def _get_anchor_negative_triplet_mask(labels):
    # Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.

    # Check if labels[i] != labels[k]
    labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
    mask = labels_equal ^ 1

    return mask


def _get_triplet_mask(labels):
    """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
    A triplet (i, j, k) is valid if:
        - i, j, k are distinct
        - labels[i] == labels[j] and labels[i] != labels[k]
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Check that i, j and k are distinct
    indices_not_same = torch.eye(labels.shape[0]).to(device).byte() ^ 1
    i_not_equal_j = torch.unsqueeze(indices_not_same, 2)
    i_not_equal_k = torch.unsqueeze(indices_not_same, 1)
    j_not_equal_k = torch.unsqueeze(indices_not_same, 0)
    distinct_indices = i_not_equal_j * i_not_equal_k * j_not_equal_k

    # Check if labels[i] == labels[j] and labels[i] != labels[k]
    label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1))
    i_equal_j = torch.unsqueeze(label_equal, 2)
    i_equal_k = torch.unsqueeze(label_equal, 1)
    valid_labels = i_equal_j * (~i_equal_k)

    mask = distinct_indices * valid_labels   # Combine the two masks

    return mask

## Построение модели

In [10]:
# Discard layers at the end of base network
encoder = resnet18(pretrained=True)
base_model = nn.Sequential(
    encoder.conv1,
    encoder.bn1,
    encoder.relu,
    encoder.maxpool,
    encoder.layer1,
    encoder.layer2,
    encoder.layer3,
    encoder.layer4,
)
dim = list(base_model.parameters())[-1].shape[0]  # last channels (512)

# Define model for embedding
net_vlad = NetVLAD(num_clusters=classes_count, dim=dim, alpha=1.0)
model = EmbedNet(base_model, net_vlad).cuda()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [11]:
# Define loss

wandb_config = {"epochs": 200, 
                "batch_size": 32, 
                "learning_rate": 5e-3,
                "momentum": 0.9, 
                "margin": 0.1,
                "classes_count":classes_count}

criterion = HardTripletLoss(margin=wandb_config["margin"]).cuda()
optimizer = torch.optim.SGD(model.parameters(), 
                            lr=wandb_config["learning_rate"], 
                            momentum=wandb_config["momentum"])

## Загрузка данных

In [12]:
transforms_bef = torchvision.transforms.Compose([                       
    torchvision.transforms.Resize((128,128)),               
    torchvision.transforms.ToTensor(),
])

bef_train_imagenet_data = torchvision.datasets.ImageFolder('/content/airi_dataset/train', transform=transforms_bef)
bef_train_data_loader = torch.utils.data.DataLoader(bef_train_imagenet_data,
                                          train_size, #размер всего трейна
                                          shuffle=False,
                                          num_workers=0)

In [13]:
for bef_train_image,bef_train_label in bef_train_data_loader :
  bef_train_image = bef_train_image
  bef_train_label = bef_train_label

Dataloaders for test and train

In [14]:
transforms_train = torchvision.transforms.Compose([
    torchvision.transforms.ColorJitter(brightness=.5, hue=.3),
    torchvision.transforms.RandomGrayscale(),
    torchvision.transforms.RandomCrop((384, 384)),    
    torchvision.transforms.RandomRotation(degrees=(-20, 20), expand=True),                                     
    torchvision.transforms.Resize((128,128)),               
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transforms_test = torchvision.transforms.Compose([                               
    torchvision.transforms.Resize((128,128)),               
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

train_imagenet_data = torchvision.datasets.ImageFolder('/content/airi_dataset/train', transform=transforms_train)
train_data_loader = torch.utils.data.DataLoader(train_imagenet_data,
                                          batch_size=wandb_config['batch_size'],
                                          shuffle=True,
                                          num_workers=0,
                                          drop_last=True)
test_imagenet_data = torchvision.datasets.ImageFolder('/content/airi_dataset/test', transform=transforms_test)
test_data_loader = torch.utils.data.DataLoader(test_imagenet_data,
                                          shuffle=False,
                                          batch_size = train_size,#размер теста
                                          num_workers=0)

Dataloader for validation

In [15]:
valid_imagenet_data = torchvision.datasets.ImageFolder('/content/airi_dataset/val', transform=transforms_test)
valid_data_loader = torch.utils.data.DataLoader(valid_imagenet_data,
                                          shuffle=False,
                                          batch_size = val_size,#размер валидационного
                                          num_workers=0)

## Обучение

In [26]:
wandb.init(project="NetVLAD-Model",config = wandb_config)
wandb.watch(model)

[]

In [27]:
def get_response(X_total, top_n=1):
    pairwise_dist_t = _pairwise_distance(X_total)
    pairwise_dist_n = pairwise_dist_t.cpu().detach().numpy()

    pairwise_dist_sort = np.sort(pairwise_dist_n[-1][:-1])
  
    response = []
    for ii in range(top_n):
        idx = np.where( pairwise_dist_n[-1] == pairwise_dist_sort[ii])
        index_in_base = idx[0][0]
        response.append(index_in_base)
    if top_n == 1:
        return response[0]
    return response # массив индексов из базы

In [28]:
def find_top_n_nearest(X_pred, Y_base, X_val, Y_val, top_n=1):
    Y_pred = []
    for q_i in range(val_size):
        # print(X_val.shape)
        image_query = X_val[q_i].view(1,-1)
        X_total = torch.cat([X_pred, image_query], dim=0)
        response = get_response(X_total, top_n=top_n)
        pred = Y_base[response]
        Y_pred.append(pred)
    return Y_pred

In [29]:
def validate(model, X_base, Y_base, valid_image, valid_label, metric_funcs, top_n=1):
    metrics = {}
    X_val = model(valid_image.cuda()).cpu().detach()
    Y_val = valid_label
    Y_pred = find_top_n_nearest(X_base, Y_base, X_val, Y_val, top_n=top_n)
    for (name, mf) in metric_funcs:
        metrics[name] = mf(Y_val, Y_pred)
    return metrics

In [30]:
import os

def save_model(model, epoch):
    if not os.path.exists('/content/checkpoints/'):
        os.mkdir('/content/checkpoints/')
    model_save_name = 'model_{:02d}_epoch.pt'.format(epoch)
    path = f"/content/checkpoints/{model_save_name}" 
    torch.save(model.state_dict(), path)

In [31]:
metric_funcs = [('accuracy', metrics.accuracy_score), 
                ('recall', (lambda y_true, y_pred: 
                            metrics.recall_score(y_true, y_pred,average='weighted'))),
                ('precision', (lambda y_true, y_pred: 
                            metrics.precision_score(y_true, y_pred,average='weighted')))]

In [None]:
pbar = tqdm.tqdm(total=wandb_config['epochs'])
for epoch in range(wandb_config['epochs']):
    # обучение
    loss = 0
    for batch_idx, (train_image,train_label) in enumerate(train_data_loader) :
        output_train = model(train_image.cuda())
        triplet_loss = criterion(output_train, train_label.cuda())
        optimizer.zero_grad()
        triplet_loss.backward(retain_graph=True)
        optimizer.step()
        loss += train_image.size(0) * triplet_loss
    wandb.log({"triplet_loss": loss/len(train_data_loader)})

    model.eval() 
    # валидация
    out_train_image = model(bef_train_image.cuda()).cpu().detach()
    X_base = out_train_image
    Y_base = bef_train_label
    for valid_image, valid_label in valid_data_loader: # тут батч вмещает всю валидационную выборку
        metrics_dict = validate(model, X_base, Y_base, valid_image, valid_label, metric_funcs)
        wandb.log(metrics_dict)
    model.train()

    if not (epoch+1)%10:
        save_model(model=model, epoch=epoch)
    pbar.update(1)
    pbar.set_description(f"Training {epoch+1} eposch...")
pbar.close()


Training 1 eposch...:   0%|          | 1/200 [05:37<18:39:49, 337.63s/it]
  _warn_prf(average, modifier, msg_start, len(result))

  0%|          | 1/200 [00:29<1:39:03, 29.87s/it][A
  _warn_prf(average, modifier, msg_start, len(result))

Training 1 eposch...:   1%|          | 2/200 [00:57<1:34:27, 28.62s/it][A
  _warn_prf(average, modifier, msg_start, len(result))

Training 2 eposch...:   2%|▏         | 3/200 [01:25<1:32:28, 28.17s/it][A
  _warn_prf(average, modifier, msg_start, len(result))

Training 3 eposch...:   2%|▏         | 4/200 [01:53<1:32:34, 28.34s/it][A
  _warn_prf(average, modifier, msg_start, len(result))

Training 4 eposch...:   2%|▎         | 5/200 [02:21<1:31:22, 28.11s/it][A
  _warn_prf(average, modifier, msg_start, len(result))

Training 5 eposch...:   3%|▎         | 6/200 [02:49<1:30:20, 27.94s/it][A
  _warn_prf(average, modifier, msg_start, len(result))

Training 6 eposch...:   4%|▎         | 7/200 [03:16<1:29:36, 27.86s/it][A
  _warn_prf(average, modifier,

## Визуализация Эмбедингов

In [23]:
!pip install umap-learn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting umap-learn
  Downloading umap-learn-0.5.3.tar.gz (88 kB)
[K     |████████████████████████████████| 88 kB 7.3 MB/s 
Collecting pynndescent>=0.5
  Downloading pynndescent-0.5.7.tar.gz (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 58.8 MB/s 
Building wheels for collected packages: umap-learn, pynndescent
  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Created wheel for umap-learn: filename=umap_learn-0.5.3-py3-none-any.whl size=82829 sha256=2bdc6a07aa8fed67be230dfd2b553ed3d0e7b16c8a57fff78f5d634974d7b4dc
  Stored in directory: /root/.cache/pip/wheels/b3/52/a5/1fd9e3e76a7ab34f134c07469cd6f16e27ef3a37aeff1fe821
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Created wheel for pynndescent: filename=pynndescent-0.5.7-py3-none-any.whl size=54286 sha256=48a7d3f33fc46cd9f33d0737d92b18c454af03e60f7d6733c3e04c3c2386bd09
  Stored in directo

In [24]:
import seaborn as sns
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap

def plot_embeddings(X, y, clustering_alg, alg_name):
    classes = len(num_to_class)
    est = clustering_alg(n_components=2)

    X_transform = est.fit_transform(X)
    # print(X_transform.shape)
    # print(y.shape)
    sns.set(rc={'figure.figsize':(11.7,8.27)})
    df = pd.DataFrame(zip(X_transform[:, 0], X_transform[:, 1], 
                          list(map(lambda t: num_to_class[str(t)], y))), 
                      columns=['X_0', 'X_1', 'y'])
    df['ind'] = df.index
    # print(df)
    # print(df.dtypes)
    sc = sns.scatterplot(data=df,
                    x='X_0',
                    y='X_1',
                    hue='y',
                    palette=sns.color_palette("husl", classes),
                    legend="brief")
    sc.set(title=alg_name)


In [None]:
model.eval()

for X_train, y_train in bef_train_data_loader:
    X_pred = model(X_train.cuda()).cpu().detach().numpy()
    plot_embeddings(X_pred, y_train.numpy(), TSNE, 'TSNE')
    plt.show()
    plot_embeddings(X_pred, y_train.numpy(), PCA, 'PCA')
    plt.show()
    plot_embeddings(X_pred, y_train.numpy(), umap.UMAP, 'UMAP')