In [1]:
from PIL import Image
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
device

device(type='cuda', index=0)

In [3]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_files, text, processor, num_labels):
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['img']
        label = self.image_files[idx]['label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        targets = torch.zeros(self.num_labels)
        targets[label] = 1
        encoding["labels"] = targets

        return encoding

In [4]:
!pip install datasets
from datasets import load_dataset

cache_dir='./cache'
#datasets = load_dataset('Maysee/tiny-imagenet', cache_dir=cache_dir)
datasets = load_dataset('cifar10', cache_dir=cache_dir)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 KB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess
  Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 KB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dill<0.3.7,>=0.3.0
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 KB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xxhash
  Downloading xxhash-3.2.0-cp39-cp39-ma

Downloading builder script:   0%|          | 0.00/3.61k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.66k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.00k [00:00<?, ?B/s]

Downloading and preparing dataset cifar10/plain_text to /content/cache/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4...


Downloading data:   0%|          | 0.00/170M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]



Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset cifar10 downloaded and prepared to /content/cache/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m79.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m75.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, transformers
Successfully installed tokenizers-0.13.2 transformers-4.27.4


In [6]:
from transformers import ViltProcessor, BertTokenizer

label_list = datasets["train"].features["label"].names
num_labels = len(label_list)

config = BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=4)
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: str(i) for i, label in enumerate(label_list)}
config.num_labels = num_labels

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

train_dataset = ImageDataset(image_files=datasets["train"], text="What is this image?", processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(image_files=datasets["test"], text="What is this image?", processor=processor, num_labels=num_labels)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/251 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/320 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [7]:
from torch.utils.data import DataLoader
import torchvision.transforms as T
transform = T.Resize((224,224))
def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = transform(encoding['pixel_values'])
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=64, shuffle=True)

In [None]:
type(datasets)

datasets.dataset_dict.DatasetDict

In [8]:
!pip install torchmultimodal-nightly

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmultimodal-nightly
  Downloading torchmultimodal_nightly-2023.3.29-py39-none-any.whl (128 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m128.6/128.6 KB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting iopath
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 KB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting DALL-E==0.1
  Downloading DALL_E-0.1-py3-none-any.whl (6.0 kB)
Collecting mypy
  Downloading mypy-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.1/12.1 MB[0m [31m52.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting blobfile
  Downloading blobfile-2.0.1-py3-none-any.whl (73 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [9]:
from torchmultimodal.models.flava.model import flava_model_for_classification
# vocab = datasets["train"].features["label"].names
model_1 = flava_model_for_classification(num_classes=10)

flava_for_pretraining_unified_text_encoder.pt: 1.43GB [00:35, 40.7MB/s]                            


In [10]:
model_1

FLAVAForClassification(
  (model): FLAVAModel(
    (image_encoder): ImageTransformer(
      (embeddings): ImageEmbeddings(
        (patch_embeddings): PatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): TransformerEncoder(
        (layer): ModuleList(
          (0): TransformerEncoderLayer(
            (attention): MultiHeadAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (output): Linear(in_features=768, out_features=768, bias=True)
              (attn): SelfAttention()
            )
            (attention_dropout): Dropout(p=0.0, inplace=False)
            (feedforward): MLP(
              (model): Sequential(
                (0): Linear(in_features=768, out_fe

In [11]:
from torch import nn
from torch.utils.data import DataLoader

model_1.to(device)
optimizer = torch.optim.AdamW(model_1.parameters())

for param in model_1.parameters():
    param.requires_grad = False

for name, param in model_1.named_parameters():
    if 'classifier' in name:
        param.requires_grad = True

epochs = 1
a = 0
b = 0
for _ in range(epochs):
  for idx, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    out = model_1(text = batch["input_ids"].to(device), image = batch["pixel_values"].to(device), labels = batch["labels"].to(device))
    loss = out.loss
    preds = torch.argmax(out['logits'], dim=1)
    target = torch.argmax(batch['labels'].to(device), dim=1)
    correct = torch.sum(preds==target).item()
    acc = (correct * 100) / target.size(0)
    loss.backward()
    optimizer.step()
    print(f"Loss at step {idx} = {loss}  Accuracy = {acc}")




Loss at step 0 = 2.3039920330047607  Accuracy = 7.8125
Loss at step 1 = 2.2975077629089355  Accuracy = 18.75
Loss at step 2 = 2.290396213531494  Accuracy = 26.5625
Loss at step 3 = 2.272592544555664  Accuracy = 23.4375
Loss at step 4 = 2.276644229888916  Accuracy = 23.4375
Loss at step 5 = 2.258100748062134  Accuracy = 25.0
Loss at step 6 = 2.286461114883423  Accuracy = 15.625
Loss at step 7 = 2.27683162689209  Accuracy = 10.9375
Loss at step 8 = 2.249321937561035  Accuracy = 17.1875
Loss at step 9 = 2.2512266635894775  Accuracy = 9.375
Loss at step 10 = 2.207352876663208  Accuracy = 29.6875
Loss at step 11 = 2.2554972171783447  Accuracy = 18.75
Loss at step 12 = 2.2115511894226074  Accuracy = 25.0
Loss at step 13 = 2.2036685943603516  Accuracy = 28.125
Loss at step 14 = 2.2005858421325684  Accuracy = 25.0
Loss at step 15 = 2.199254035949707  Accuracy = 29.6875
Loss at step 16 = 2.1873340606689453  Accuracy = 32.8125
Loss at step 17 = 2.179293632507324  Accuracy = 37.5
Loss at step 18 

In [12]:
correct_test = 0
total_test = 0
for idx, batch in enumerate(test_dataloader):
    out = model_1(text = batch["input_ids"].to(device), image = batch["pixel_values"].to(device), labels = batch["labels"].to(device))
    loss = out.loss
    preds = torch.argmax(out['logits'], dim=1)
    target = torch.argmax(batch['labels'].to(device), dim=1)
    batch_acc = torch.sum(preds==target).item() * 100 / target.size(0)
    print(f"Test Accuracy for batch = {batch_acc}")
    correct_test += torch.sum(preds==target).item()
    total_test += target.size(0)



Test Accuracy for batch = 93.75
Test Accuracy for batch = 92.1875
Test Accuracy for batch = 93.75
Test Accuracy for batch = 93.75
Test Accuracy for batch = 93.75
Test Accuracy for batch = 93.75
Test Accuracy for batch = 90.625
Test Accuracy for batch = 93.75
Test Accuracy for batch = 93.75
Test Accuracy for batch = 93.75
Test Accuracy for batch = 95.3125
Test Accuracy for batch = 89.0625
Test Accuracy for batch = 95.3125
Test Accuracy for batch = 95.3125
Test Accuracy for batch = 93.75
Test Accuracy for batch = 95.3125
Test Accuracy for batch = 96.875
Test Accuracy for batch = 95.3125
Test Accuracy for batch = 93.75
Test Accuracy for batch = 90.625
Test Accuracy for batch = 93.75
Test Accuracy for batch = 85.9375
Test Accuracy for batch = 93.75
Test Accuracy for batch = 96.875
Test Accuracy for batch = 90.625
Test Accuracy for batch = 95.3125
Test Accuracy for batch = 90.625
Test Accuracy for batch = 92.1875
Test Accuracy for batch = 90.625
Test Accuracy for batch = 92.1875
Test Accura

In [13]:
print('Training accuracy: ' + str(correct_test * 100/total_test))

Training accuracy: 92.74
