In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import tqdm
from sklearn.metrics import accuracy_score
import numpy as np

from datasets import load_dataset, Image

from transformers import BlipForQuestionAnswering, BlipProcessor, BlipConfig, BlipModel

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, config, image_files, text, processor, num_labels):
        self.config = config
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['img']
        label_id = self.image_files[idx]['fine_label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        # encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        encoding = self.processor(image, text, return_tensors="pt")
        label_encoding = self.processor(text=self.config.id2label[str(label_id)], padding="max_length", return_tensors="pt").input_ids
        label_encoding = label_encoding.squeeze()

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        #targets = torch.zeros(self.num_labels)
        #targets[label_id] = 1
        encoding["labels"] = label_encoding

        return encoding

In [4]:
dataset = load_dataset('cifar100', cache_dir='./cache')

Found cached dataset cifar100 (/home/dongwon/adapt-VL-models-to-vision-only-tasks/Nish/BLIP Experiments/./cache/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142)


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
label_list = dataset["train"].features["fine_label"].names

# replace label 'cra' to 'crab'
label_list.remove('cra')
label_list.append('crab')
num_labels = len(label_list)

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")

max_label_token_length = 0
for label in label_list:
    labels = processor(text=label, return_tensors="pt").input_ids
    max_label_token_length = max(max_label_token_length, len(labels[0]))
    print(label, labels)
print(max_label_token_length)

config = BlipConfig.from_pretrained("Salesforce/blip-vqa-base")
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: str(i) for i, label in enumerate(label_list)}
config.num_labels = num_labels
#config.max_length = max_label_token_length
config.text_config.max_length = max_label_token_length # for CLS and SEP tokens


train_dataset = ImageDataset(config, image_files=dataset["train"], text="", processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(config, image_files=dataset["test"], text="", processor=processor, num_labels=num_labels)

2023-04-25 17:25:44.081324: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-25 17:25:44.770062: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib:/lib:/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/

apple tensor([[ 101, 6207,  102]])
aquarium_fish tensor([[  101, 18257,  1035,  3869,   102]])
baby tensor([[ 101, 3336,  102]])
bear tensor([[ 101, 4562,  102]])
beaver tensor([[  101, 13570,   102]])
bed tensor([[ 101, 2793,  102]])
bee tensor([[  101, 10506,   102]])
beetle tensor([[ 101, 7813,  102]])
bicycle tensor([[  101, 10165,   102]])
bottle tensor([[ 101, 5835,  102]])
bowl tensor([[ 101, 4605,  102]])
boy tensor([[ 101, 2879,  102]])
bridge tensor([[ 101, 2958,  102]])
bus tensor([[ 101, 3902,  102]])
butterfly tensor([[ 101, 9112,  102]])
camel tensor([[  101, 19130,   102]])
can tensor([[ 101, 2064,  102]])
castle tensor([[ 101, 3317,  102]])
caterpillar tensor([[  101, 23488,  8197, 17305,   102]])
cattle tensor([[ 101, 7125,  102]])
chair tensor([[ 101, 3242,  102]])
chimpanzee tensor([[  101,  9610,  8737,  2319, 23940,   102]])
clock tensor([[ 101, 5119,  102]])
cloud tensor([[ 101, 6112,  102]])
cockroach tensor([[  101, 10338,  3217,  6776,   102]])
couch tensor([[ 

In [6]:
total_labels = ['beaver', 'dolphin', 'otter', 'seal', 'whale', 'aquarium fish', 'flatfish', 'ray', 'shark', 'trout', 'orchids', 'poppies', 'roses', 'sunflowers', 'tulips', 'bottles', 'bowls', 'cans', 'cups', 'plates', 'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers', 'clock', 'computer keyboard', 'lamp', 'telephone', 'television', 'bed', 'chair', 'couch', 'table', 'wardrobe', 'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach', 'bear', 'leopard', 'lion', 'tiger', 'wolf', 'bridge', 'castle', 'house', 'road', 'skyscraper', 'cloud', 'forest', 'mountain', 'plain', 'sea', 'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo', 'fox', 'porcupine', 'possum', 'raccoon', 'skunk', 'crab', 'lobster', 'snail', 'spider', 'worm', 'baby', 'boy', 'girl', 'man', 'woman', 'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle', 'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel', 'maple', 'oak', 'palm', 'pine', 'willow', 'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train', 'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor']
real_labels = set()
for label in dataset["train"].features["fine_label"].names:
    real_labels.add(label)


In [7]:
print(real_labels)

{'tulip', 'willow_tree', 'man', 'tiger', 'trout', 'road', 'shrew', 'beaver', 'skyscraper', 'otter', 'chimpanzee', 'possum', 'bridge', 'apple', 'baby', 'mountain', 'bottle', 'turtle', 'porcupine', 'lamp', 'sea', 'clock', 'bee', 'poppy', 'maple_tree', 'plate', 'leopard', 'squirrel', 'bed', 'fox', 'bus', 'couch', 'oak_tree', 'kangaroo', 'flatfish', 'train', 'cattle', 'elephant', 'tractor', 'crocodile', 'tank', 'beetle', 'aquarium_fish', 'house', 'skunk', 'lawn_mower', 'motorcycle', 'hamster', 'mushroom', 'television', 'plain', 'camel', 'pine_tree', 'raccoon', 'cloud', 'lion', 'seal', 'dolphin', 'orchid', 'cup', 'chair', 'sweet_pepper', 'orange', 'snake', 'woman', 'bowl', 'shark', 'boy', 'sunflower', 'bicycle', 'cockroach', 'palm_tree', 'lobster', 'worm', 'table', 'streetcar', 'pickup_truck', 'spider', 'bear', 'whale', 'telephone', 'rocket', 'castle', 'can', 'rose', 'lizard', 'ray', 'caterpillar', 'keyboard', 'forest', 'dinosaur', 'pear', 'rabbit', 'wardrobe', 'butterfly', 'mouse', 'snail'

In [8]:
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", config=config)
model = model.to(device)

In [9]:
def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  # token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  # encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  # batch['token_type_ids'] = torch.stack(token_type_ids)
  # batch['pixel_values'] = encoding['pixel_values']
  # batch['pixel_mask'] = encoding['pixel_mask']
  batch['pixel_values'] = torch.stack(pixel_values)
  batch['labels'] = torch.stack(labels)
  # batch['labels'] = torch.Tensor(labels).type(torch.LongTensor).unsqueeze(1)

  return batch

In [10]:
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=32, shuffle=True)
val_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=32, shuffle=True)

In [11]:
for name, param in model.named_parameters():
    if 'text_decoder.cls' in name:
        print(name)

text_decoder.cls.predictions.bias
text_decoder.cls.predictions.transform.dense.weight
text_decoder.cls.predictions.transform.dense.bias
text_decoder.cls.predictions.transform.LayerNorm.weight
text_decoder.cls.predictions.transform.LayerNorm.bias
text_decoder.cls.predictions.decoder.weight


In [12]:
for name, param in model.named_parameters():
    if 'text_decoder.cls' in name:
        print(name)
        param.requires_grad = True
    else:
        param.requires_grad = False

text_decoder.cls.predictions.bias
text_decoder.cls.predictions.transform.dense.weight
text_decoder.cls.predictions.transform.dense.bias
text_decoder.cls.predictions.transform.LayerNorm.weight
text_decoder.cls.predictions.transform.LayerNorm.bias
text_decoder.cls.predictions.decoder.weight


In [13]:
model = model.to(device)

In [25]:
#criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=6e-4)
num_epochs = 30

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.25)

best_params = None
best_val_accuracy = -1

for epoch in range(1):
    
    model.train()
    total_loss = 0.0

    step = 0
    """
    for batch in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        batch = {k:v.to(device) for k,v in batch.items()}

        optimizer.zero_grad()
        
        # import ipdb; ipdb.set_trace()
        outputs = model(**batch) 
        loss = outputs.loss
        loss.backward()
        optimizer.step()


        total_loss += loss
        step += 1
            
    scheduler.step()
    """
    
  

    # Evaluate the model on the validation set
    model.eval()
    val_predictions = []
    val_labels_eval = []
    total_acc = 0.0
    with torch.no_grad():

        for batch in val_dataloader:
            batch = {k:v.to(device) for k,v in batch.items()}
            
            outputs = model.generate(**batch)
            outputs = outputs[:,0:max_label_token_length]
            #print('outputs:',processor.decode(outputs[0], skip_special_tokens=True), outputs[0])
            labels = batch['labels'][:,0:max_label_token_length]
            #print('labels:',processor.decode(labels[0], skip_special_tokens=True), labels[0])
            outputs = processor.batch_decode(outputs, skip_special_tokens=True)
            labels = processor.batch_decode(labels, skip_special_tokens=True)

            # count how many are correct
            correct = 0
            for i in range(len(labels)):
                output = outputs[i]
                label = labels[i]
                output = output.replace(" ", "")
                label = label.replace(" ", "")
                output = output.replace("_", "")
                label = label.replace("_", "")
                if output[:len(label)] == label:
                    print(output[:len(label)], label)
                    correct += 1
            acc = (correct * 100) / len(labels)
            print(acc)

            total_acc += acc
        print('total acc:',total_acc / len(val_dataloader))
            
            #val_predictions.extend(preds.cpu().numpy())
            #val_labels_eval.extend(labels.cpu().numpy())



        #val_labels_idx = [np.argmax(tensor) for tensor in val_labels_eval]
        #val_accuracy = accuracy_score(val_labels_idx, val_predictions)
        
        
        #print(f"Epoch {epoch + 1}/{num_epochs},  Validation Accuracy: {val_accuracy:.4f}")
        # print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

   

keyboard keyboard
spider spider
poppy poppy
chair chair
camel camel
telephone telephone
dinosaur dinosaur
crab crab
bear bear
dinosaur dinosaur
chair chair
woman woman
poppy poppy
telephone telephone
43.75
mapletree mapletree
television television
bowl bowl
pickuptruck pickuptruck
seal seal
cloud cloud
plate plate
pear pear
skyscraper skyscraper
plate plate
turtle turtle
pickuptruck pickuptruck
bus bus
cattle cattle
sweetpepper sweetpepper
orchid orchid
rabbit rabbit
apple apple
table table
59.375
can can
lobster lobster
can can
seal seal
tank tank
clock clock
skyscraper skyscraper
mouse mouse
snail snail
kangaroo kangaroo
oaktree oaktree
pinetree pinetree
trout trout
cloud cloud
43.75
leopard leopard
baby baby
apple apple
clock clock
whale whale
mapletree mapletree
bowl bowl
woman woman
rabbit rabbit
wardrobe wardrobe
31.25
castle castle
turtle turtle
pear pear
train train
sea sea
pickuptruck pickuptruck
table table
crab crab
sweetpepper sweetpepper
tank tank
clock clock
bicycle bicyc

KeyboardInterrupt: 

: 