In [1]:
from PIL import Image
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_files, text, processor, num_labels):
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['img']
        label = self.image_files[idx]['label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        targets = torch.zeros(self.num_labels)
        targets[label] = 1
        encoding["labels"] = targets

        return encoding

In [3]:
!pip install datasets
from datasets import load_dataset

cache_dir='./cache'
#datasets = load_dataset('Maysee/tiny-imagenet', cache_dir=cache_dir)
datasets = load_dataset('cifar100', cache_dir=cache_dir)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiohttp
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m32.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess
  Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash
  Downloading xxhash-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting hugg

Downloading builder script:   0%|          | 0.00/5.61k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/4.21k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.83k [00:00<?, ?B/s]

Downloading and preparing dataset cifar100/cifar100 to /content/cache/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142...


Downloading data:   0%|          | 0.00/169M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]



Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset cifar100 downloaded and prepared to /content/cache/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m76.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m102.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, transformers
Successfully installed tokenizers-0.13.3 transformers-4.28.1


In [10]:
datasets = datasets.rename_column("fine_label", "label")

In [11]:
label_list = datasets["train"].features["label"].names
print(label_list)

['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'cra', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 

In [26]:
prompt =  "This is a classification problem from the cifar10 dataset. Classifiy the images amongst the classes"
for i in label_list:
  prompt += " '" + i + "',"
prompt = prompt[:-1]
prompt += '?'

In [34]:
from transformers import ViltProcessor, BertTokenizer

label_list = datasets["train"].features["label"].names
num_labels = len(label_list)

config = BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=len(prompt))
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: str(i) for i, label in enumerate(label_list)}
config.num_labels = num_labels

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

train_dataset = ImageDataset(image_files=datasets["train"], text=prompt, processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(image_files=datasets["test"], text=prompt, processor=processor, num_labels=num_labels)

In [35]:
from torch.utils.data import DataLoader
import torchvision.transforms as T
transform = T.Resize((224,224))
def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = transform(encoding['pixel_values'])
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=64, shuffle=True)

In [36]:
!pip install torchmultimodal-nightly

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [37]:
from torchmultimodal.models.flava.model import flava_model_for_classification
model_1 = flava_model_for_classification(num_classes=100)

In [38]:
from tqdm.notebook import tqdm

@torch.no_grad()
def evaluate(model, device, test_dataloader):
    losses = []  # List of scalar tensors
    correct = 0
    total = 0
    for batch in tqdm(test_dataloader):
        # adapt batch to model
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        target = torch.argmax(batch['labels'], dim=1)
        correct += torch.sum(preds==target).item()
        total += target.size(0)
        losses.append(outputs.loss)
    stacked_losses = torch.stack(losses)  # (num_batches, ) 
    total_avg_loss = stacked_losses.mean()  # (num test examples, ) -> scalar
    total_avg_acc = (100 * correct) / total
    print("Correct: " + str(correct) + "/" + "Total: "  +str(total))
    print("Average val loss: " + str(total_avg_loss.item()))
    print("Average val acc: " + str(total_avg_acc))

    return total_avg_loss.item(), total_avg_acc

In [65]:
from torch import nn
BATCH_SIZE = 64
MAX_STEPS = 3
from torch.utils.data import DataLoader

model_1.to(device)

optimizer = torch.optim.AdamW(model_1.parameters())

for param in model_1.parameters():
    param.requires_grad = False

for name, param in model_1.named_parameters():
    if 'classifier' in name:
        param.requires_grad = True

epochs = 1
a = 0
b = 0
for _ in range(epochs):
  for idx, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    out = model_1(text = batch["input_ids"].to(device), image = batch["pixel_values"].to(device), labels = batch["labels"].to(device))
    loss = out.loss
    preds = torch.argmax(out['logits'], dim=1)
    target = torch.argmax(batch['labels'].to(device), dim=1)
    correct = torch.sum(preds==target).item()
    acc = (correct * 100) / target.size(0)
    loss.backward()
    optimizer.step()
    print(f"Loss at step {idx} = {loss}  Accuracy = {acc}")


Loss at step 0 = 4.55089807510376  Accuracy = 3.125
Loss at step 1 = 4.572542190551758  Accuracy = 3.125
Loss at step 2 = 4.5596699714660645  Accuracy = 1.5625
Loss at step 3 = 4.560544013977051  Accuracy = 4.6875
Loss at step 4 = 4.599034309387207  Accuracy = 6.25
Loss at step 5 = 4.564797401428223  Accuracy = 1.5625
Loss at step 6 = 4.553944110870361  Accuracy = 6.25
Loss at step 7 = 4.559261322021484  Accuracy = 1.5625
Loss at step 8 = 4.532289505004883  Accuracy = 6.25
Loss at step 9 = 4.540643215179443  Accuracy = 3.125
Loss at step 10 = 4.477451801300049  Accuracy = 7.8125
Loss at step 11 = 4.543468475341797  Accuracy = 3.125
Loss at step 12 = 4.561739921569824  Accuracy = 0.0
Loss at step 13 = 4.517580986022949  Accuracy = 7.8125
Loss at step 14 = 4.528321266174316  Accuracy = 1.5625
Loss at step 15 = 4.519814491271973  Accuracy = 4.6875
Loss at step 16 = 4.488577842712402  Accuracy = 6.25
Loss at step 17 = 4.510899066925049  Accuracy = 6.25
Loss at step 18 = 4.500510215759277  

KeyboardInterrupt: ignored

In [62]:
out['logits'][0][50]

tensor(-0.0786, device='cuda:0', grad_fn=<SelectBackward0>)

In [None]:
out

In [52]:
target

tensor([50, 27, 81, 97,  5, 85,  6, 86, 75, 21, 96, 90,  9, 78, 80, 91, 27, 30,
        53, 83,  5, 82,  3, 45, 36, 94, 25, 95, 37, 29, 73, 88, 89, 74, 73, 95,
         9,  7, 42, 83, 80, 22, 33, 11, 95,  2, 67, 99, 83, 10, 14, 99, 69, 42,
         2, 29, 61, 36, 93, 38, 93, 98, 99, 31], device='cuda:0')

In [None]:
 target.size(0)

64

In [None]:
torch.argmax(out['logits'], dim=1)

tensor([4, 0, 4, 6, 4, 8, 5, 1, 5, 3, 0, 3, 5, 7, 5, 2], device='cuda:0')

In [None]:
len(test_dataloader)

157

In [None]:
correct = 0
total = 0
for idx, batch in enumerate(test_dataloader):
    out = model_1(text = batch["input_ids"].to(device), image = batch["pixel_values"].to(device), labels = batch["labels"].to(device))
    loss = out.loss
    preds = torch.argmax(out['logits'], dim=1)
    target = torch.argmax(batch['labels'].to(device), dim=1)
    correct += torch.sum(preds==target).item()
    total += target.size(0)

In [None]:
correct/total

0.9273

In [None]:
target

tensor([7, 1, 0, 8, 4, 5, 9, 8, 6, 0, 3, 4, 0, 6, 3, 1, 8, 1, 9, 1, 9, 4, 7, 8,
        0, 5, 7, 4, 4, 7, 1, 7, 4, 9, 0, 6, 4, 0, 5, 3, 0, 5, 7, 1, 3, 1, 6, 4,
        8, 1, 8, 5, 5, 5, 4, 6, 1, 1, 9, 6, 1, 9, 7, 5], device='cuda:0')

In [None]:
acc

12.5

In [None]:
# test_dataloader = DataLoader(datasets['test'], batch_size= 1, shuffle = True)
# model.eval()
# dev_predictions = []
# dev_true_labels = []
# total_count = 0
# correct_count = 0
# with torch.no_grad():
#     for batch in test_dataloader:
#       total_count +=1
#       out = model(text = batch["input_ids"], image = batch["img"], labels = batch["label"])
#       pred = torch.argmax(out.logits)
#       dev_predictions.append(pred)
#       dev_true_labels.append(batch['label'])
#       if(pred == batch['label']):
#         correct_count +=1
#       if(total_count == 20):
#         break
#       if(total_count % 100 == 0):
#         print('Accuracy = ' + str(correct_count / total_count))
