In [1]:
import augly.image as imaugs
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoModel, AutoFeatureExtractor
from utils.disc21 import DISC21Definition

In [2]:
model_ckpt = "google/vit-large-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

Some weights of the model checkpoint at google/vit-large-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
transformation_chain = transforms.Compose(
    [
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(int((256 / 224) * 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)

augmentation_chain = transforms.Compose(
    [
        imaugs.Brightness(factor=2.0),
        imaugs.RandomRotation(),
        imaugs.OneOf([
            imaugs.RandomAspectRatio(),
            imaugs.RandomBlur(),
            imaugs.RandomBrightness(),
            imaugs.RandomNoise(),
            imaugs.RandomPixelization(),
        ]),
        imaugs.OneOf([
            imaugs.OverlayEmoji(),
            imaugs.OverlayStripes(),
            imaugs.OverlayText(),
        ], p=0.5),
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(int((256 / 224) * 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)

In [4]:
from torch.utils.data import Dataset
import os.path as osp
from glob import glob
from PIL import Image
import random


class DISC21(Dataset):
    def __init__(self, df, subset='train', transform=None, augmentations=None):
        self.is_train = subset == 'train'
        self.is_gallery = subset == 'gallery'
        self.transform = transform
        self.augmentations = transform if augmentations is None else augmentations

        if self.is_train:
            self.images = df.train
        elif self.is_gallery:
            self.images = df.gallery
        else:
            self.images = df.query

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index: int):
        full_name, name = self.images[index]
        anchor_img = Image.open(full_name)

        if self.is_train:
            positive_img = anchor_img

            if self.transform:
                anchor_img = self.transform(anchor_img)
                positive_img = self.augmentations(positive_img)

            return anchor_img, positive_img, index, name
        else:
            if self.transform:
                anchor_img = self.transform(anchor_img)
            return anchor_img, name

    def get_negatives(self, positive_indexes: list, num_negatives: int = 2):
        pos_negative_indexes = []
        for i in range(len(self)):
            if i not in positive_indexes:
                pos_negative_indexes.append(i)

        for i in pos_negative_indexes:
            if i in positive_indexes:
                raise Exception('Negative index is in positive indexes')

        negative_indexes = random.sample(pos_negative_indexes, num_negatives)
        negative_imgs = []
        for i in negative_indexes:
            full_name, name = self.images[i]
            negative_img = Image.open(full_name)
            if self.transform:
                negative_img = self.augmentations(negative_img)
            negative_imgs.append(negative_img)

        return torch.stack(negative_imgs)

In [5]:
train_df = DISC21Definition('/media/augustinas/T7/DISC2021/SmallData/images/')
train_ds = DISC21(train_df, subset='train', transform=transformation_chain, augmentations=augmentation_chain)

DISC21Definition dataset loaded
  subset   | # ids | # images
  ---------------------------
  train    | 100000 |   100000
  gallery  | 100000 |   100000
  query    |  10000 |    10000


In [6]:
batch_size = 4
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False, num_workers=4)

In [7]:
anchor_img, positive_img, index, name = next(iter(train_loader))

In [8]:
anchor_img.shape, positive_img.shape, index, name

(torch.Size([4, 3, 224, 224]),
 torch.Size([4, 3, 224, 224]),
 tensor([0, 1, 2, 3]),
 ('T000001.jpg', 'T000004.jpg', 'T000007.jpg', 'T000013.jpg'))

In [9]:
print(index.numpy())

[0 1 2 3]


In [10]:
negatives = train_ds.get_negatives(index.numpy())

In [11]:
negatives.shape

torch.Size([2, 3, 224, 224])

In [12]:
anchor_out = model(anchor_img).last_hidden_state
positive_out = model(positive_img).last_hidden_state
negative_out = model(negatives).last_hidden_state

In [13]:
anchor_out.shape, positive_out.shape, negative_out.shape

(torch.Size([4, 197, 1024]),
 torch.Size([4, 197, 1024]),
 torch.Size([2, 197, 1024]))

In [19]:
with torch.no_grad():
    pdist = torch.nn.PairwiseDistance(p=2)
    pos_matrix = pdist(torch.flatten(anchor_out, start_dim=1), torch.flatten(positive_out, start_dim=1))
    pos_matrix

In [16]:
pos_matrix.shape

torch.Size([4])

In [20]:
print(pos_matrix)

tensor([457.5093, 470.1355, 453.8908, 478.1834])


In [21]:
with torch.no_grad():
    neg_matrix = torch.cdist(torch.flatten(anchor_out, start_dim=1), torch.flatten(negative_out, start_dim=1))

In [22]:
neg_matrix.shape

torch.Size([4, 2])

In [23]:
print(neg_matrix)

tensor([[477.7953, 478.3789],
        [485.2651, 480.8302],
        [483.1764, 478.4398],
        [480.6655, 482.2917]])


In [31]:
loss = -1 * neg_matrix + pos_matrix[:, None] + 0.3

In [32]:
loss

tensor([[-19.9860, -20.5696],
        [-14.8297, -10.3948],
        [-28.9855, -24.2490],
        [ -2.1821,  -3.8082]])

In [35]:
torch.max(loss, dim=1)

torch.return_types.max(
values=tensor([-19.9860, -10.3948, -24.2490,  -2.1821]),
indices=tensor([0, 1, 1, 0]))

In [37]:
torch.min(neg_matrix, dim=1)

torch.return_types.min(
values=tensor([477.7953, 480.8302, 478.4398, 480.6655]),
indices=tensor([0, 1, 1, 0]))

In [46]:
bluh = torch.min(neg_matrix, dim=1)
bluh

torch.return_types.min(
values=tensor([477.7953, 480.8302, 478.4398, 480.6655]),
indices=tensor([0, 1, 1, 0]))

In [43]:
xx = negative_out[bluh]

In [44]:
xx.shape

torch.Size([4, 197, 1024])

In [45]:
xx

tensor([[[ 0.4705,  0.4278, -1.5170,  ..., -0.5304,  1.4812,  0.6127],
         [ 0.5671,  0.1269,  0.6412,  ..., -0.5248, -0.1459,  0.3352],
         [ 0.6669,  0.9921,  0.2923,  ..., -0.1060, -0.0841,  0.1201],
         ...,
         [ 0.6194,  0.9290, -1.1486,  ..., -0.1816,  0.8863, -0.2848],
         [ 1.4237,  0.7164,  1.3291,  ...,  0.1410,  1.0244, -0.2277],
         [ 0.2695,  0.8971, -1.1627,  ...,  0.8704,  1.1855,  0.1758]],

        [[ 0.1253, -0.1360,  0.1263,  ..., -0.0542,  0.2329,  0.1086],
         [-0.1860, -0.3081,  1.1543,  ...,  0.1272,  0.6510,  1.6323],
         [ 0.4400,  0.8774,  1.0815,  ...,  0.3712,  0.2753,  0.4012],
         ...,
         [ 0.9806,  0.8069, -0.9632,  ..., -1.1620, -0.5529,  0.6621],
         [-0.5336,  0.5281, -1.9177,  ..., -0.5603,  1.0208,  0.8507],
         [ 0.6068,  0.2072,  0.6100,  ..., -1.4635,  0.1104,  0.6014]],

        [[ 0.1253, -0.1360,  0.1263,  ..., -0.0542,  0.2329,  0.1086],
         [-0.1860, -0.3081,  1.1543,  ...,  0

In [52]:
with torch.no_grad():
    pdist = torch.nn.PairwiseDistance(p=2)
    xxf_matrix = pdist(torch.flatten(anchor_out, start_dim=1), torch.flatten(xx, start_dim=1))
    xxf_matrix

In [53]:
xxf_matrix.shape

torch.Size([4])

In [54]:
print(xxf_matrix)

tensor([477.8062, 480.8405, 478.4513, 480.6765])
