In [2]:
from transformers import ViltForQuestionAnswering, ViltConfig
from PIL import Image
import torch

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


In [3]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_files, text, processor, num_labels):
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['img']
        label = self.image_files[idx]['label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        targets = torch.zeros(self.num_labels)
        targets[label] = 1
        encoding["labels"] = targets

        return encoding

In [4]:
from datasets import load_dataset

cache_dir='./cache'
#datasets = load_dataset('Maysee/tiny-imagenet', cache_dir=cache_dir)
datasets = load_dataset('cifar10', cache_dir=cache_dir)

Found cached dataset cifar10 (/home/dongwon/adapt-VL-models-to-vision-only-tasks/./cache/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
from transformers import ViltProcessor

label_list = datasets["train"].features["label"].names
num_labels = len(label_list)

config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa", cache_dir=cache_dir)
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: str(i) for i, label in enumerate(label_list)}
config.num_labels = num_labels

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

train_dataset = ImageDataset(image_files=datasets["train"], text="What is this image?", processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(image_files=datasets["test"], text="What is this image?", processor=processor, num_labels=num_labels)

2023-03-27 06:03:39.958038: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-27 06:03:40.719038: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib:/lib:/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/

In [5]:
train_dataset[0].keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])

In [6]:
processor.decode(train_dataset[0]['input_ids'])

'[CLS] what is this image? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [7]:
label = torch.nonzero(train_dataset[0]['labels']).squeeze().tolist()
config.id2label[str(label)]

'airplane'

In [8]:
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-mlm", config=config)
model.to(device)

Some weights of the model checkpoint at dandelin/vilt-b32-mlm were not used when initializing ViltForQuestionAnswering: ['mlm_score.decoder.weight', 'mlm_score.transform.LayerNorm.bias', 'mlm_score.bias', 'mlm_score.transform.dense.weight', 'mlm_score.transform.dense.bias', 'mlm_score.transform.LayerNorm.weight']
- This IS expected if you are initializing ViltForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViltForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViltForQuestionAnswering were not initialized from the model checkpoint at dandelin/vilt-b32-mlm and are newly initialized: ['classifier.1.weight', 'classifier.0.bias', 'classifier.1.bia

ViltForQuestionAnswering(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0): ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_f

In [9]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True)

In [10]:
from tqdm.notebook import tqdm

@torch.no_grad()
def evaluate(model, device, test_dataloader):
    losses = []  # List of scalar tensors
    correct = 0
    total = 0
    for batch in tqdm(test_dataloader):
        # adapt batch to model
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        target = torch.argmax(batch['labels'], dim=1)
        correct += torch.sum(preds==target).item()
        total += target.size(0)
        losses.append(outputs.loss)
    stacked_losses = torch.stack(losses)  # (num_batches, ) 
    total_avg_loss = stacked_losses.mean()  # (num test examples, ) -> scalar
    total_avg_acc = (100 * correct) / total
    print("Correct: " + str(correct) + "/" + "Total: "  +str(total))
    print("Average val loss: " + str(total_avg_loss.item()))
    print("Average val acc: " + str(total_avg_acc))

    return total_avg_loss.item(), total_avg_acc

In [11]:
import os

for param in model.parameters():
    param.requires_grad = False

for name, param in model.named_parameters():
    if 'classifier' in name or 'pooler' in name:
        param.requires_grad = True

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

checkpoint_dir = './checkpoints'
best_test_acc = -1
step = -1
patience = 0
model.train()
for epoch in range(1):  # loop over the dataset multiple times
    print(f"Train Epoch: {epoch}")
    for batch in tqdm(train_dataloader):
        step += 1
        # get the inputs; 
        batch = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        target = torch.argmax(batch['labels'], dim=1)
        correct = torch.sum(preds==target).item()
        acc = (correct * 100) / target.size(0)
        print(f"Step {step} - loss: {loss.item()} , train acc: {acc}")
        loss.backward()
        optimizer.step()
    print(f"Evaluate Epoch: {epoch}")
    # Evaluate
    model.eval()
    new_test_loss, new_test_acc = evaluate(model, device, test_dataloader)
    # save checkpoint with best test loss
    if new_test_acc > best_test_acc or best_test_acc < 0:
        patience = 0
        if best_test_acc > 0:
            os.remove(checkpoint_dir + '/'+ best_checkpoint_filename)
        best_checkpoint_filename = "best_model" + str(epoch) +".pt"
        torch.save(model.state_dict(), checkpoint_dir + '/' + best_checkpoint_filename)
        best_test_acc = new_test_acc
    else:
        patience += 1
        if patience > 3:
            print("Early stopping at epoch "+str(epoch))
            break

    model.train()
        

Epoch: 0


  0%|          | 0/98 [00:00<?, ?it/s]



Step 0 - loss: 7.722498893737793 , train acc: 10.3515625
Step 1 - loss: 6.568172454833984 , train acc: 11.9140625
Step 2 - loss: 5.658674240112305 , train acc: 10.15625
Step 3 - loss: 4.917052268981934 , train acc: 9.375
Step 4 - loss: 4.3404059410095215 , train acc: 12.109375
Step 5 - loss: 3.94334077835083 , train acc: 9.5703125
Step 6 - loss: 3.655184268951416 , train acc: 8.7890625
Step 7 - loss: 3.471733570098877 , train acc: 8.203125
Step 8 - loss: 3.333069324493408 , train acc: 10.7421875
Step 9 - loss: 3.2806954383850098 , train acc: 7.6171875
Step 10 - loss: 3.2515311241149902 , train acc: 7.03125
Step 11 - loss: 3.2474637031555176 , train acc: 10.9375
Step 12 - loss: 3.25655198097229 , train acc: 14.0625
Step 13 - loss: 3.2829089164733887 , train acc: 12.890625
Step 14 - loss: 3.2949612140655518 , train acc: 21.875
Step 15 - loss: 3.299503803253174 , train acc: 27.34375
Step 16 - loss: 3.3203606605529785 , train acc: 19.7265625
Step 17 - loss: 3.315800428390503 , train acc: 2

  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation
Correct: 2325/10000
Average test loss: 3.3026680946350098
Average val acc: 23.25
Step 21 - loss: 3.2923758029937744 , train acc: 25.1953125
Step 22 - loss: 3.2844159603118896 , train acc: 22.8515625
Step 23 - loss: 3.275674343109131 , train acc: 23.046875
Step 24 - loss: 3.2568325996398926 , train acc: 23.828125
Step 25 - loss: 3.222835063934326 , train acc: 25.78125
Step 26 - loss: 3.2088444232940674 , train acc: 23.046875
Step 27 - loss: 3.17765474319458 , train acc: 28.7109375
Step 28 - loss: 3.151585578918457 , train acc: 30.6640625
Step 29 - loss: 3.146512985229492 , train acc: 33.203125
Step 30 - loss: 3.1423802375793457 , train acc: 30.2734375
Step 31 - loss: 3.112980365753174 , train acc: 32.03125
Step 32 - loss: 3.1082727909088135 , train acc: 29.4921875
Step 33 - loss: 3.096137523651123 , train acc: 32.03125
Step 34 - loss: 3.069925546646118 , train acc: 33.0078125
Step 35 - loss: 3.0752644538879395 , train acc: 34.5703125
Step 36 - loss: 3.0733325481414795 , train

  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation
Correct: 4105/10000
Average test loss: 3.0527985095977783
Average val acc: 41.05
Step 41 - loss: 3.042055130004883 , train acc: 42.1875
Step 42 - loss: 3.0532565116882324 , train acc: 43.75
Step 43 - loss: 3.035207748413086 , train acc: 48.2421875
Step 44 - loss: 3.0223097801208496 , train acc: 50.1953125
Step 45 - loss: 3.0154385566711426 , train acc: 46.6796875
Step 46 - loss: 2.995030641555786 , train acc: 51.5625
Step 47 - loss: 2.986182928085327 , train acc: 50.1953125
Step 48 - loss: 2.9879543781280518 , train acc: 43.359375
Step 49 - loss: 2.9793858528137207 , train acc: 46.09375
Step 50 - loss: 2.964776039123535 , train acc: 43.9453125
Step 51 - loss: 2.9463791847229004 , train acc: 44.3359375
Step 52 - loss: 2.9468319416046143 , train acc: 43.5546875
Step 53 - loss: 2.915792942047119 , train acc: 47.265625
Step 54 - loss: 2.9348626136779785 , train acc: 42.3828125
Step 55 - loss: 2.921506881713867 , train acc: 42.96875
Step 56 - loss: 2.882458209991455 , train acc: 

  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation
Correct: 5028/10000
Average test loss: 2.8446121215820312
Average val acc: 50.28
Step 61 - loss: 2.8420562744140625 , train acc: 52.734375
Step 62 - loss: 2.8346524238586426 , train acc: 54.1015625
Step 63 - loss: 2.829590320587158 , train acc: 52.34375
Step 64 - loss: 2.800499439239502 , train acc: 57.421875
Step 65 - loss: 2.7799997329711914 , train acc: 57.03125
Step 66 - loss: 2.7839555740356445 , train acc: 53.125
Step 67 - loss: 2.7465381622314453 , train acc: 54.4921875
Step 68 - loss: 2.7432570457458496 , train acc: 53.3203125
Step 69 - loss: 2.7231647968292236 , train acc: 54.6875
Step 70 - loss: 2.7134037017822266 , train acc: 58.0078125
Step 71 - loss: 2.687244415283203 , train acc: 56.4453125
Step 72 - loss: 2.672985553741455 , train acc: 59.5703125
Step 73 - loss: 2.6662657260894775 , train acc: 57.421875
Step 74 - loss: 2.641737699508667 , train acc: 53.90625
Step 75 - loss: 2.6504006385803223 , train acc: 55.46875
Step 76 - loss: 2.6408982276916504 , train acc

  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation
Correct: 6342/10000
Average test loss: 2.52933669090271
Average val acc: 63.42
Step 81 - loss: 2.517951011657715 , train acc: 64.84375
Step 82 - loss: 2.533416509628296 , train acc: 61.328125
Step 83 - loss: 2.510601043701172 , train acc: 62.5
Step 84 - loss: 2.500447988510132 , train acc: 62.6953125
Step 85 - loss: 2.456552028656006 , train acc: 63.8671875
Step 86 - loss: 2.4415698051452637 , train acc: 64.0625
Step 87 - loss: 2.3730521202087402 , train acc: 67.1875
Step 88 - loss: 2.417891025543213 , train acc: 64.0625
Step 89 - loss: 2.4039628505706787 , train acc: 63.28125
Step 90 - loss: 2.3526782989501953 , train acc: 66.2109375
Step 91 - loss: 2.3465821743011475 , train acc: 62.5
Step 92 - loss: 2.374142646789551 , train acc: 60.15625
Step 93 - loss: 2.297637939453125 , train acc: 67.3828125
Step 94 - loss: 2.303889036178589 , train acc: 64.0625
Step 95 - loss: 2.262664794921875 , train acc: 70.703125
Step 96 - loss: 2.227717876434326 , train acc: 71.09375
Step 97 - l