In [1]:
import augly.image as imaugs
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoModel, AutoFeatureExtractor
from utils.disc21 import DISC21Definition

In [2]:
model_ckpt = "google/vit-large-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

Some weights of the model checkpoint at google/vit-large-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
transformation_chain = transforms.Compose(
    [
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(int((256 / 224) * 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)

augmentation_chain = transforms.Compose(
    [
        imaugs.Brightness(factor=2.0),
        imaugs.RandomRotation(),
        imaugs.OneOf([
            imaugs.RandomAspectRatio(),
            imaugs.RandomBlur(),
            imaugs.RandomBrightness(),
            imaugs.RandomNoise(),
            imaugs.RandomPixelization(),
        ]),
        imaugs.OneOf([
            imaugs.OverlayEmoji(),
            imaugs.OverlayStripes(),
            imaugs.OverlayText(),
        ], p=0.5),
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(int((256 / 224) * 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)

In [4]:
from torch.utils.data import Dataset
import os.path as osp
from glob import glob
from PIL import Image
import random

In [9]:
class DISC21(Dataset):
    def __init__(self, df, subset='train', transform=None, augmentations=None):
        self.is_train = subset == 'train'
        self.is_gallery = subset == 'gallery'
        self.transform = transform
        self.augmentations = transform if augmentations is None else augmentations

        if self.is_train:
            self.images = df.train
        elif self.is_gallery:
            self.images = df.gallery
        else:
            self.images = df.query

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        full_name, name = self.images[index]
        anchor_img = Image.open(full_name)

        if self.is_train:
            positive_img = anchor_img

            negative_indexes = [i for i in range(len(self.images)) if i != index]
            negative_indexes = random.sample(negative_indexes, 3)
            negative_imgs = []
            for i in negative_indexes:
                negative_full_name, negative_name = self.images[i]
                negative_img = Image.open(negative_full_name)
                negative_imgs.append(negative_img)

            if self.transform:
                anchor_img = self.transform(anchor_img)
                positive_img = self.augmentations(positive_img)
                negative_imgs = [self.augmentations(img) for img in negative_imgs]

            return anchor_img, positive_img, negative_imgs, name
        else:
            if self.transform:
                anchor_img = self.transform(anchor_img)
            return anchor_img, name

In [10]:
train_df = DISC21Definition('/media/augustinas/T7/DISC2021/SmallData/images/')
train_ds = DISC21(train_df, subset='train', transform=transformation_chain, augmentations=augmentation_chain)

DISC21Definition dataset loaded
  subset   | # ids | # images
  ---------------------------
  train    | 100000 |   100000
  gallery  | 100000 |   100000
  query    |  10000 |    10000


In [11]:
batch_size = 1
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

In [12]:
anchor_img, positive_img, negative_imgs, name = next(iter(train_loader))

In [57]:
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

In [83]:
mining_func = miners.TripletMarginMiner(
    margin=0.2, type_of_triplets="semihard"
)

In [84]:
anchor_out = model(anchor_img).last_hidden_state
positive_out = model(positive_img).last_hidden_state

In [61]:
negative_out = [model(negative_img).last_hidden_state for  negative_img in negative_imgs][0]

In [62]:
print(anchor_out.shape)

torch.Size([1, 197, 1024])


In [63]:
af = torch.flatten(anchor_out, start_dim=1)
print(af.shape)

torch.Size([1, 201728])


In [64]:
print(positive_out.shape)

torch.Size([1, 197, 1024])


In [65]:
pf = torch.flatten(positive_out, start_dim=1)
print(pf.shape)

torch.Size([1, 201728])


In [66]:
print(negative_out.shape)

torch.Size([3, 197, 1024])


In [67]:
nf = torch.flatten(negative_out, start_dim=1)
print(nf.shape)

torch.Size([3, 201728])


In [68]:
embeddings = torch.cat([af, pf, nf])

In [69]:
print(embeddings.shape)

torch.Size([5, 201728])


In [79]:
labels = torch.cat([torch.tensor([0, 0]), torch.tensor([i + 1 for i in range(negative_out.shape[0])])])

In [80]:
print(labels)

tensor([0, 0, 1, 2, 3])


In [81]:
indices_tuple = mining_func(embeddings, labels)

In [82]:
print(indices_tuple)

(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
