In [1]:
import augly.image as imaugs
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoModel, AutoFeatureExtractor
from utils.disc21 import DISC21Definition, DISC21

In [2]:
model_ckpt = "google/vit-large-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

Some weights of the model checkpoint at google/vit-large-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
transformation_chain = transforms.Compose(
    [
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(int((256 / 224) * 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)

augmentation_chain = transforms.Compose(
    [
        imaugs.Brightness(factor=2.0),
        imaugs.RandomRotation(),
        imaugs.OneOf([
            imaugs.RandomAspectRatio(),
            imaugs.RandomBlur(),
            imaugs.RandomBrightness(),
            imaugs.RandomNoise(),
            imaugs.RandomPixelization(),
        ]),
        imaugs.OneOf([
            imaugs.OverlayEmoji(),
            imaugs.OverlayStripes(),
            imaugs.OverlayText(),
        ], p=0.5),
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(int((256 / 224) * 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)

In [7]:
train_df = DISC21Definition('/media/augustinas/T7/DISC2021/SmallData/images/')
train_ds = DISC21(train_df, subset='train', transform=transformation_chain, augmentations=augmentation_chain)

DISC21Definition dataset loaded
  subset   | # ids | # images
  ---------------------------
  train    | 100000 |   100000
  gallery  | 100000 |   100000
  query    |  10000 |    10000


In [8]:
batch_size = 4
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

In [9]:
print(torch.cuda.is_available())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

True
cuda:0


In [10]:
model.to(device)

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=1024, out_features=4096, bias=True)
          (intermediate_act_fn): GELUActivatio

In [12]:
epoch_count = 1  #for now
lr = 1e-5  # could use a scheduler
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_func = torch.nn.TripletMarginLoss()

In [13]:
model.train()

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=1024, out_features=4096, bias=True)
          (intermediate_act_fn): GELUActivatio

In [14]:
anchor_img, positive_img, index, anchor_label = next(iter(train_loader))

In [24]:
pos_negatives = train_ds.get_negatives(index.numpy())

In [25]:
anchor_img = anchor_img.to(device)
positive_img = positive_img.to(device)
negative_img = pos_negatives.to(device)

In [26]:
anchor_out = model(anchor_img).last_hidden_state
positive_out = model(positive_img).last_hidden_state
negative_out = model(negative_img).last_hidden_state

In [27]:
anchor_out.shape, positive_out.shape, negative_out.shape

(torch.Size([4, 197, 1024]),
 torch.Size([4, 197, 1024]),
 torch.Size([2, 197, 1024]))

In [28]:
with torch.no_grad():
    pdist = torch.nn.PairwiseDistance(p=2)
    pos_matrix = pdist(torch.flatten(anchor_out, start_dim=1), torch.flatten(positive_out, start_dim=1))
    print(pos_matrix)

tensor([475.8610, 478.4731, 486.5186, 478.5147], device='cuda:0')


In [29]:
with torch.no_grad():
    neg_matrix = torch.cdist(torch.flatten(anchor_out, start_dim=1), torch.flatten(negative_out, start_dim=1))

In [42]:
negative_out = negative_out[torch.argmin(neg_matrix, dim=1)]

In [43]:
anchor_out.shape, positive_out.shape, negative_out.shape

(torch.Size([4, 197, 1024]),
 torch.Size([4, 197, 1024]),
 torch.Size([4, 197, 1024]))

In [44]:
loss = loss_func(anchor_out, positive_out, negative_out)

In [45]:
loss

tensor(4.3611, device='cuda:0', grad_fn=<MeanBackward0>)

In [46]:
loss.backward()
optimizer.step()
optimizer.zero_grad()

In [47]:
loss.cpu().detach().numpy()

array(4.3611217, dtype=float32)