In [1]:
from transformers import ViltForQuestionAnswering, ViltConfig
from PIL import Image
import torch

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


In [2]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_files, text, processor, num_labels):
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['image']
        label = self.image_files[idx]['label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        targets = torch.zeros(self.num_labels)
        targets[label] = 1
        encoding["labels"] = targets

        return encoding

In [3]:
from datasets import load_dataset

cache_dir='./cache'
datasets = load_dataset('Maysee/tiny-imagenet', cache_dir=cache_dir)

Found cached dataset parquet (/home/dongwon/dl-project/./cache/Maysee___parquet/Maysee--tiny-imagenet-35af7c46a941f08e/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
from transformers import ViltProcessor

datasets = load_dataset('Maysee/tiny-imagenet', cache_dir=cache_dir)
label_list = datasets["train"].features["label"].names
num_labels = len(label_list)

config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa", cache_dir=cache_dir)
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: i for i, label in enumerate(label_list)}
config.num_labels = num_labels

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

train_dataset = ImageDataset(image_files=datasets["train"], text="", processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(image_files=datasets["valid"], text="", processor=processor, num_labels=num_labels)

Found cached dataset parquet (/home/dongwon/dl-project/./cache/Maysee___parquet/Maysee--tiny-imagenet-35af7c46a941f08e/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

2023-03-27 03:24:56.674358: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-27 03:24:57.433970: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib:/lib:/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/

In [5]:
train_dataset[0].keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])

In [6]:
processor.decode(train_dataset[0]['input_ids'])

'[CLS] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [7]:
label = torch.nonzero(train_dataset[0]['labels']).squeeze().tolist()
config.id2label[label]

'n01443537'

In [8]:
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-mlm", config=config)
model.to(device)

Some weights of the model checkpoint at dandelin/vilt-b32-mlm were not used when initializing ViltForQuestionAnswering: ['mlm_score.transform.LayerNorm.bias', 'mlm_score.transform.dense.bias', 'mlm_score.bias', 'mlm_score.transform.dense.weight', 'mlm_score.decoder.weight', 'mlm_score.transform.LayerNorm.weight']
- This IS expected if you are initializing ViltForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViltForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViltForQuestionAnswering were not initialized from the model checkpoint at dandelin/vilt-b32-mlm and are newly initialized: ['classifier.3.weight', 'classifier.0.bias', 'classifier.0.wei

ViltForQuestionAnswering(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0): ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_f

In [9]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True)

In [10]:
from tqdm.notebook import tqdm

@torch.no_grad()
def evaluate(model, device, test_dataloader):
    losses = []  # List of scalar tensors
    correct = 0
    total = 0
    for batch in tqdm(test_dataloader):
        # adapt batch to model
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        target = torch.argmax(batch['labels'], dim=1)
        correct += torch.sum(preds==target).item()
        total += target.size(0)
        losses.append(outputs.loss)
    stacked_losses = torch.stack(losses)  # (num_batches, ) 
    total_avg_loss = stacked_losses.mean()  # (num test examples, ) -> scalar
    total_avg_acc = (100 * correct / total) / len(test_dataloader)

    print("Correct: " + str(correct), "Total: " + str(total))
    print("Average test loss: " + str(total_avg_loss.item()))
    print("Average val acc: " + str(total_avg_acc))

    return total_avg_loss.item(), total_avg_acc

In [11]:
import os

for param in model.parameters():
    param.requires_grad = False

for name, param in model.named_parameters():
    if 'classifier' in name or 'pooler' in name:
        param.requires_grad = True

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

checkpoint_dir = './checkpoints'
best_test_acc = -1
step = -1
model.train()
for epoch in range(1):  # loop over the dataset multiple times
   print(f"Epoch: {epoch}")
   for batch in tqdm(train_dataloader):
        step += 1
        # get the inputs; 
        batch = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        print("Loss:", loss.item())
        loss.backward()
        optimizer.step()

        # Evaluate
        if step != 0 and step % 20 == 0:
            model.eval()
            new_test_loss, new_test_acc = evaluate(model, device, test_dataloader)

            # save checkpoint with best test loss
            if new_test_acc > best_test_acc or best_test_acc < 0:
                if best_test_acc > 0:
                    os.remove(checkpoint_dir + '/'+ best_checkpoint_filename)
                best_checkpoint_filename = "best_model" + str(step) +".pt"
                torch.save(model.state_dict(), checkpoint_dir + '/' + best_checkpoint_filename)
                best_test_acc = new_test_acc

            model.train()
        

Epoch: 0


  0%|          | 0/196 [00:00<?, ?it/s]



Loss: 147.2014923095703
Loss: 139.98004150390625
Loss: 133.02212524414062
Loss: 126.48741149902344
Loss: 120.21019744873047
Loss: 114.32279205322266
Loss: 108.85628509521484
Loss: 103.52584075927734
Loss: 98.59742736816406
Loss: 93.79312133789062
Loss: 89.43826293945312
Loss: 84.97572326660156
Loss: 80.98971557617188
Loss: 77.19768524169922
Loss: 73.51335906982422
Loss: 69.97216033935547
Loss: 66.87934112548828
Loss: 63.72492980957031
Loss: 60.83019256591797
Loss: 58.08111572265625
Loss: 55.35917282104492


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 59 Total: 10000
Average test loss: 52.93903732299805
Average val acc: 0.0295
Loss: 53.00183868408203
Loss: 50.65126419067383
Loss: 48.451698303222656
Loss: 46.34100341796875
Loss: 44.51832580566406
Loss: 42.55536651611328
Loss: 40.86403274536133
Loss: 39.21149826049805
Loss: 37.70921325683594
Loss: 36.152984619140625
Loss: 34.865142822265625
Loss: 33.618507385253906
Loss: 32.32726287841797
Loss: 31.28366470336914
Loss: 30.14713478088379
Loss: 29.18904685974121
Loss: 28.186481475830078
Loss: 27.207284927368164
Loss: 26.405488967895508
Loss: 25.61708641052246


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 76 Total: 10000
Average test loss: 24.795358657836914
Average val acc: 0.038
Loss: 24.788639068603516
Loss: 24.038055419921875
Loss: 23.361705780029297
Loss: 22.713186264038086
Loss: 22.100696563720703
Loss: 21.47056770324707
Loss: 21.00933074951172
Loss: 20.425806045532227
Loss: 19.95071029663086
Loss: 19.37655258178711
Loss: 18.98367691040039
Loss: 18.53857421875
Loss: 18.109874725341797
Loss: 17.779251098632812
Loss: 17.323556900024414
Loss: 16.935157775878906
Loss: 16.63144874572754
Loss: 16.31053352355957
Loss: 15.967021942138672
Loss: 15.695990562438965


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 56 Total: 10000
Average test loss: 15.385418891906738
Average val acc: 0.028000000000000004
Loss: 15.422800064086914
Loss: 15.146252632141113
Loss: 14.873952865600586
Loss: 14.59813117980957
Loss: 14.312643051147461
Loss: 14.126028060913086
Loss: 13.892168998718262
Loss: 13.663530349731445
Loss: 13.448955535888672
Loss: 13.231642723083496
Loss: 13.101935386657715
Loss: 12.88783073425293
Loss: 12.722521781921387
Loss: 12.516234397888184
Loss: 12.371763229370117
Loss: 12.191215515136719
Loss: 12.065913200378418
Loss: 11.898890495300293
Loss: 11.78935432434082
Loss: 11.622722625732422


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 51 Total: 10000
Average test loss: 11.476115226745605
Average val acc: 0.025500000000000002
Loss: 11.49130916595459
Loss: 11.375856399536133
Loss: 11.230683326721191
Loss: 11.127729415893555
Loss: 11.028316497802734
Loss: 10.860284805297852
Loss: 10.759739875793457
Loss: 10.693625450134277
Loss: 10.605430603027344
Loss: 10.440827369689941
Loss: 10.361474990844727
Loss: 10.255295753479004
Loss: 10.174308776855469
Loss: 10.100034713745117
Loss: 9.993803024291992
Loss: 9.904325485229492
Loss: 9.841365814208984
Loss: 9.758225440979004
Loss: 9.662202835083008
Loss: 9.680790901184082


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 44 Total: 10000
Average test loss: 9.526552200317383
Average val acc: 0.022
Loss: 9.521499633789062
Loss: 9.466958045959473
Loss: 9.413127899169922
Loss: 9.317559242248535
Loss: 9.253666877746582
Loss: 9.178423881530762
Loss: 9.114645004272461
Loss: 9.057684898376465
Loss: 9.027060508728027
Loss: 8.955873489379883
Loss: 8.936420440673828
Loss: 8.818607330322266
Loss: 8.810651779174805
Loss: 8.73857593536377
Loss: 8.696014404296875
Loss: 8.637091636657715
Loss: 8.612071990966797
Loss: 8.563379287719727
Loss: 8.526222229003906
Loss: 8.460796356201172


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 51 Total: 10000
Average test loss: 8.422137260437012
Average val acc: 0.025500000000000002
Loss: 8.43097972869873
Loss: 8.400518417358398
Loss: 8.315138816833496
Loss: 8.313461303710938
Loss: 8.262186050415039
Loss: 8.221821784973145
Loss: 8.196314811706543
Loss: 8.177116394042969
Loss: 8.080995559692383
Loss: 8.094027519226074
Loss: 8.047892570495605
Loss: 8.009960174560547
Loss: 7.98164701461792
Loss: 7.962636470794678
Loss: 7.938411712646484
Loss: 7.872116565704346
Loss: 7.864894866943359
Loss: 7.823638439178467
Loss: 7.811345100402832
Loss: 7.7826409339904785


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 59 Total: 10000
Average test loss: 7.744096279144287
Average val acc: 0.0295
Loss: 7.754395961761475
Loss: 7.720660209655762
Loss: 7.685206890106201
Loss: 7.6867499351501465
Loss: 7.642859935760498
Loss: 7.605317115783691
Loss: 7.605081558227539
Loss: 7.564018726348877
Loss: 7.52272891998291
Loss: 7.532293319702148
Loss: 7.496265888214111
Loss: 7.47697639465332
Loss: 7.467223167419434
Loss: 7.415633678436279
Loss: 7.4150800704956055
Loss: 7.414430618286133
Loss: 7.390464782714844
Loss: 7.355417251586914
Loss: 7.3402099609375
Loss: 7.329250335693359


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 57 Total: 10000
Average test loss: 7.302845001220703
Average val acc: 0.028499999999999998
Loss: 7.319746017456055
Loss: 7.266488552093506
Loss: 7.253535747528076
Loss: 7.247821807861328
Loss: 7.216864109039307
Loss: 7.228365421295166
Loss: 7.235970497131348
Loss: 7.181740760803223
Loss: 7.174409866333008
Loss: 7.179937362670898
Loss: 7.147481918334961
Loss: 7.116377353668213
Loss: 7.102487087249756
Loss: 7.09318208694458
Loss: 7.069350719451904
Loss: 7.076513767242432
Loss: 7.044258117675781
Loss: 7.054025650024414
Loss: 7.074855804443359
Loss: 7.025371551513672


  0%|          | 0/20 [00:00<?, ?it/s]

Correct: 63 Total: 10000
Average test loss: 7.005491733551025
Average val acc: 0.0315
Loss: 7.029352188110352
Loss: 6.995093822479248
Loss: 6.969025135040283
Loss: 6.968323707580566
Loss: 6.967830657958984
Loss: 6.9497294425964355
Loss: 6.921489715576172
Loss: 6.920053005218506
Loss: 6.914517402648926
Loss: 6.902112007141113
Loss: 6.896595001220703
Loss: 6.854294300079346
Loss: 6.876458168029785
Loss: 6.846006870269775
Loss: 6.8603386878967285
