In [1]:
import augly.image as imaugs
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoModel, AutoFeatureExtractor
from utils.disc21 import DISC21Definition, DISC21

In [2]:
model_ckpt = "google/vit-large-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

Some weights of the model checkpoint at google/vit-large-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
transformation_chain = transforms.Compose(
    [
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(int((256 / 224) * 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)

augmentation_chain = transforms.Compose(
    [
        imaugs.Brightness(factor=2.0),
        imaugs.RandomRotation(),
        imaugs.OneOf([
            imaugs.RandomAspectRatio(),
            imaugs.RandomBlur(),
            imaugs.RandomBrightness(),
            imaugs.RandomNoise(),
            imaugs.RandomPixelization(),
        ]),
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(int((256 / 224) * 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)

In [4]:
train_df = DISC21Definition('/scratch/lustre/home/auma4493/images/DISC21')
train_ds = DISC21(train_df, subset='train', transform=transformation_chain, augmentations=augmentation_chain)

DISC21Definition dataset loaded
  subset   | # ids | # images
  ---------------------------
  train    | 100000 |   100000
  gallery  | 100000 |   100000
  query    |  10000 |    10000


In [5]:
embedding_dims = 2
batch_size = 16
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

In [6]:
print(torch.cuda.is_available())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

True
cuda:0


In [7]:
model.to(device)

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-23): 24 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=1024, out_features=4096, bias=True)
          (intermediate_act_fn): GELUA

In [8]:
epoch_count = 10  # for now
lr = 1e-5  # could use a scheduler
optimizer = optim.Adam(model.parameters(), lr=lr, )
loss_func = torch.nn.TripletMarginLoss()

In [9]:
model.train()
for epoch in tqdm(range(epoch_count), desc="Epochs"):
    running_loss = []
    for step, (anchor_img, positive_img, negative_img, anchor_label) in enumerate(
            tqdm(train_loader, desc="Training", leave=False)):
        
        anchor_img = anchor_img.to(device)
        positive_img = positive_img.to(device)
        negative_img = negative_img.to(device)

        anchor_out = model(anchor_img).last_hidden_state        
        positive_out = model(positive_img).last_hidden_state
        negative_out = model(negative_img).last_hidden_state
        
        print(anchor_out)

        loss = loss_func(anchor_out, positive_out, negative_out)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss.append(loss.cpu().detach().numpy())
        break
    print("Epoch: {}/{} - Loss: {:.4f}".format(epoch + 1, epoch_count, np.mean(running_loss)))
    torch.save({"model_state_dict": model.state_dict(),
                "optimzier_state_dict": optimizer.state_dict()
                }, f"vit_checkpoints/trained_model_{epoch + 1}_{epoch_count}.pth")

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/6250 [00:00<?, ?it/s]

tensor([[[ 1.1643,  0.2702,  0.3025,  ..., -0.0909,  0.6494,  1.3829],
         [ 0.9080,  0.4086,  0.6673,  ..., -1.0627,  0.5136, -0.0792],
         [ 1.2989, -0.0561,  0.9638,  ..., -0.5098,  0.4882, -0.4838],
         ...,
         [ 0.5335,  0.4966, -0.3425,  ...,  0.2432, -0.8150, -0.4187],
         [ 0.5823,  0.4559, -0.8236,  ...,  0.1264, -1.0074, -0.0441],
         [ 0.6399,  0.7584, -1.0133,  ...,  0.2388, -0.8157, -0.3999]],

        [[-0.5691,  1.4298,  0.6242,  ..., -0.2342,  0.1633,  0.7874],
         [ 0.3702, -0.0602,  0.7627,  ...,  0.1139, -0.6421,  0.2637],
         [ 0.4417, -0.1337,  1.1589,  ..., -0.1497, -0.6205,  0.4328],
         ...,
         [ 1.5070,  0.2695,  0.6197,  ..., -0.1580,  0.3269,  0.6150],
         [ 1.6072, -0.0199, -0.1424,  ...,  0.2055, -0.1020,  0.4769],
         [-0.2224,  0.5415, -0.6262,  ...,  0.2407, -1.3835, -1.3767]],

        [[-0.5470, -1.1812, -0.1814,  ...,  1.6395,  1.4345,  1.4156],
         [-1.3079, -0.1415,  0.9537,  ..., -0

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fbacb6eee50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


Training:   0%|          | 0/6250 [00:00<?, ?it/s]


KeyboardInterrupt

