In [188]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import tqdm
from sklearn.metrics import accuracy_score
import numpy as np
from numpy import dot
from numpy.linalg import norm

from datasets import load_dataset, Image

from transformers import BlipForQuestionAnswering, BlipProcessor, BlipConfig, BlipModel

In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [191]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, config, image_files, text, processor, num_labels):
        self.config = config
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['img']
        label_id = self.image_files[idx]['fine_label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        # encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        encoding = self.processor(image, text, return_tensors="pt")
        label_encoding = self.processor(text=self.config.id2label[str(label_id)], padding="max_length", return_tensors="pt").input_ids
        label_encoding = label_encoding.squeeze()

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        #targets = torch.zeros(self.num_labels)
        #targets[label_id] = 1
        encoding["labels"] = label_encoding

        return encoding

In [6]:
dataset = load_dataset('cifar100', cache_dir='./cache')

Downloading builder script:   0%|          | 0.00/5.61k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/4.21k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.83k [00:00<?, ?B/s]

Downloading and preparing dataset cifar100/cifar100 to /content/cache/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142...


Downloading data:   0%|          | 0.00/169M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]



Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset cifar100 downloaded and prepared to /content/cache/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [193]:
label_list = dataset["train"].features["fine_label"].names

# replace label 'cra' to 'crab'
label_list.remove('cra')
label_list.append('crab')
num_labels = len(label_list)

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")

max_label_token_length = 0
for label in label_list:
    labels = processor(text=label, return_tensors="pt").input_ids
    max_label_token_length = max(max_label_token_length, len(labels[0]))
    print(label, labels)
print(max_label_token_length)

config = BlipConfig.from_pretrained("Salesforce/blip-vqa-base")
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: str(i) for i, label in enumerate(label_list)}
config.num_labels = num_labels
#config.max_length = max_label_token_length
config.text_config.max_length = max_label_token_length # for CLS and SEP tokens


train_dataset = ImageDataset(config, image_files=dataset["train"], text="", processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(config, image_files=dataset["test"], text="", processor=processor, num_labels=num_labels)

apple tensor([[ 101, 6207,  102]])
aquarium_fish tensor([[  101, 18257,  1035,  3869,   102]])
baby tensor([[ 101, 3336,  102]])
bear tensor([[ 101, 4562,  102]])
beaver tensor([[  101, 13570,   102]])
bed tensor([[ 101, 2793,  102]])
bee tensor([[  101, 10506,   102]])
beetle tensor([[ 101, 7813,  102]])
bicycle tensor([[  101, 10165,   102]])
bottle tensor([[ 101, 5835,  102]])
bowl tensor([[ 101, 4605,  102]])
boy tensor([[ 101, 2879,  102]])
bridge tensor([[ 101, 2958,  102]])
bus tensor([[ 101, 3902,  102]])
butterfly tensor([[ 101, 9112,  102]])
camel tensor([[  101, 19130,   102]])
can tensor([[ 101, 2064,  102]])
castle tensor([[ 101, 3317,  102]])
caterpillar tensor([[  101, 23488,  8197, 17305,   102]])
cattle tensor([[ 101, 7125,  102]])
chair tensor([[ 101, 3242,  102]])
chimpanzee tensor([[  101,  9610,  8737,  2319, 23940,   102]])
clock tensor([[ 101, 5119,  102]])
cloud tensor([[ 101, 6112,  102]])
cockroach tensor([[  101, 10338,  3217,  6776,   102]])
couch tensor([[ 

In [194]:
total_labels = ['beaver', 'dolphin', 'otter', 'seal', 'whale', 'aquarium fish', 'flatfish', 'ray', 'shark', 'trout', 'orchids', 'poppies', 'roses', 'sunflowers', 'tulips', 'bottles', 'bowls', 'cans', 'cups', 'plates', 'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers', 'clock', 'computer keyboard', 'lamp', 'telephone', 'television', 'bed', 'chair', 'couch', 'table', 'wardrobe', 'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach', 'bear', 'leopard', 'lion', 'tiger', 'wolf', 'bridge', 'castle', 'house', 'road', 'skyscraper', 'cloud', 'forest', 'mountain', 'plain', 'sea', 'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo', 'fox', 'porcupine', 'possum', 'raccoon', 'skunk', 'crab', 'lobster', 'snail', 'spider', 'worm', 'baby', 'boy', 'girl', 'man', 'woman', 'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle', 'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel', 'maple', 'oak', 'palm', 'pine', 'willow', 'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train', 'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor']
real_labels = []
for label in dataset["train"].features["fine_label"].names:
    real_labels.append(label)

In [195]:
embeddings_dict = {}
with open("/content/glove.6B.txt", 'r',  encoding='utf-8') as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = np.asarray(values[1:], "float32")
        embeddings_dict[word] = vector

In [196]:
label_embeddings = []
for i in real_labels:
  if '_' not in i:
    label_embeddings.append(embeddings_dict[i])
  else:
    B = np.zeros((100,))
    for j in i.split('_'):
      B += embeddings_dict[j]
    label_embeddings.append(B)
label_embeddings = np.array(label_embeddings)

In [197]:
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", config=config)
model = model.to(device)

In [198]:
def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  # token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  # encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  # batch['token_type_ids'] = torch.stack(token_type_ids)
  # batch['pixel_values'] = encoding['pixel_values']
  # batch['pixel_mask'] = encoding['pixel_mask']
  batch['pixel_values'] = torch.stack(pixel_values)
  batch['labels'] = torch.stack(labels)
  # batch['labels'] = torch.Tensor(labels).type(torch.LongTensor).unsqueeze(1)

  return batch

In [199]:
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=32, shuffle=True)
val_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=32, shuffle=True)

In [200]:
for name, param in model.named_parameters():
    if 'text_decoder.cls' in name:
        print(name)

text_decoder.cls.predictions.bias
text_decoder.cls.predictions.transform.dense.weight
text_decoder.cls.predictions.transform.dense.bias
text_decoder.cls.predictions.transform.LayerNorm.weight
text_decoder.cls.predictions.transform.LayerNorm.bias
text_decoder.cls.predictions.decoder.weight


In [201]:
for name, param in model.named_parameters():
    if 'text_decoder.cls' in name:
        print(name)
        param.requires_grad = True
    else:
        param.requires_grad = False

text_decoder.cls.predictions.bias
text_decoder.cls.predictions.transform.dense.weight
text_decoder.cls.predictions.transform.dense.bias
text_decoder.cls.predictions.transform.LayerNorm.weight
text_decoder.cls.predictions.transform.LayerNorm.bias
text_decoder.cls.predictions.decoder.weight


In [202]:
model = model.to(device)

In [203]:
optimizer = optim.Adam(model.parameters(), lr=6e-4)
num_epochs = 30

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.25)

best_params = None
best_val_accuracy = -1

for epoch in range(1):
    
    model.train()
    total_loss = 0.0

    step = 0
    for batch in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        batch = {k:v.to(device) for k,v in batch.items()}

        optimizer.zero_grad()
        
        # import ipdb; ipdb.set_trace()
        outputs = model(**batch) 
        loss = outputs.loss
        loss.backward()
        optimizer.step()


        total_loss += loss
        step += 1
            
    scheduler.step()

    

Epoch 1/30: 100%|██████████| 1563/1563 [1:24:18<00:00,  3.24s/it]


In [1]:
model.eval()

val_predictions = []
val_labels_eval = []
total_acc = 0.0
with torch.no_grad():

    for batch in val_dataloader:
        batch = {k:v.to(device) for k,v in batch.items()}
      
        outputs = model.generate(**batch)
        outputs = outputs[:,0:max_label_token_length]
        #print('outputs:',processor.decode(outputs[0], skip_special_tokens=True), outputs[0])
        labels = batch['labels'][:,0:max_label_token_length]
        #print('labels:',processor.decode(labels[0], skip_special_tokens=True), labels[0])
        outputs = processor.batch_decode(outputs, skip_special_tokens=True)
        labels = processor.batch_decode(labels, skip_special_tokens=True)

        # count how many are correct
        correct = 0
        for i in range(len(labels)):
            output = outputs[i]
            label = labels[i]
            output = output.replace(" ", "")
            label = label.replace(" ", "")

            if('_' not in output):
              if output in embeddings_dict.keys():
                A = embeddings_dict[output]
            else:
              A = np.zeros((100,))
              for j in output.split('_'):
                if j in embeddings_dict.keys():
                  A += embeddings_dict[j]
            
            v = dot(label_embeddings, A)/(norm(A)*norm(label_embeddings, axis = 1))
            if(real_labels[np.argmax(v)] == label):
                print(output[:len(label)], label)
                correct += 1
        acc = (correct * 100) / len(labels)
        print(acc)

        total_acc += acc
    print('total acc:',total_acc / len(val_dataloader))

NameError: ignored