In [13]:
from PIL import Image
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_files, text, processor, num_labels):
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['img']
        label = self.image_files[idx]['label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        targets = torch.zeros(self.num_labels)
        targets[label] = 1
        encoding["labels"] = targets

        return encoding

In [3]:
!pip install datasets
from datasets import load_dataset

cache_dir='./cache'
#datasets = load_dataset('Maysee/tiny-imagenet', cache_dir=cache_dir)
datasets = load_dataset('cifar10', cache_dir=cache_dir)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 KB[0m [31m35.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess
  Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 KB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub<1.0.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
Collecting aiohttp
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3

Downloading builder script:   0%|          | 0.00/3.61k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.66k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.00k [00:00<?, ?B/s]

Downloading and preparing dataset cifar10/plain_text to /content/cache/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4...


Downloading data:   0%|          | 0.00/170M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]



Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset cifar10 downloaded and prepared to /content/cache/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m92.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m81.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, transformers
Successfully installed tokenizers-0.13.2 transformers-4.27.4


In [8]:
from transformers import ViltProcessor, BertTokenizer

label_list = datasets["train"].features["label"].names
num_labels = len(label_list)

config = BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=4)
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: str(i) for i, label in enumerate(label_list)}
config.num_labels = num_labels

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

train_dataset = ImageDataset(image_files=datasets["train"], text="What is this image?", processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(image_files=datasets["test"], text="What is this image?", processor=processor, num_labels=num_labels)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/251 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/320 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [9]:
from torch.utils.data import DataLoader
import torchvision.transforms as T
transform = T.Resize((224,224))
def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = transform(encoding['pixel_values'])
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True)

In [None]:
type(datasets)

datasets.dataset_dict.DatasetDict

In [10]:
!pip install torchmultimodal-nightly

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmultimodal-nightly
  Downloading torchmultimodal_nightly-2023.3.29-py39-none-any.whl (128 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m128.6/128.6 KB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting iopath
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 KB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting DALL-E==0.1
  Downloading DALL_E-0.1-py3-none-any.whl (6.0 kB)
Collecting mypy
  Downloading mypy-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.1/12.1 MB[0m [31m45.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting blobfile
  Downloading blobfile-2.0.1-py3-none-any.whl (73 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [12]:
from torchmultimodal.models.flava.model import flava_model_for_classification
vocab = datasets["train"].features["label"].names
model = flava_model_for_classification(num_classes=len(vocab))


flava_for_pretraining_unified_text_encoder.pt: 1.43GB [00:12, 110MB/s]                            


RuntimeError: ignored

In [15]:
model.to(device)

FLAVAForClassification(
  (model): FLAVAModel(
    (image_encoder): ImageTransformer(
      (embeddings): ImageEmbeddings(
        (patch_embeddings): PatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): TransformerEncoder(
        (layer): ModuleList(
          (0): TransformerEncoderLayer(
            (attention): MultiHeadAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (output): Linear(in_features=768, out_features=768, bias=True)
              (attn): SelfAttention()
            )
            (attention_dropout): Dropout(p=0.0, inplace=False)
            (feedforward): MLP(
              (model): Sequential(
                (0): Linear(in_features=768, out_fe

In [17]:
from torch import nn
BATCH_SIZE = 64
MAX_STEPS = 3
from torch.utils.data import DataLoader

# train_dataloader = DataLoader(datasets["train"], batch_size= BATCH_SIZE, shuffle  = True)
optimizer = torch.optim.AdamW(model.parameters())


epochs = 1
a = 0
b = 0
for _ in range(epochs):
  for idx, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    out = model(text = batch["input_ids"].to(device), image = batch["pixel_values"].to(device), labels = batch["labels"].to(device))
    loss = out.loss
    loss.backward()
    optimizer.step()
    print(f"Loss at step {idx} = {loss}")
    if idx > MAX_STEPS-1:
      break

OutOfMemoryError: ignored

In [None]:
# test_dataloader = DataLoader(datasets['test'], batch_size= 1, shuffle = True)
# model.eval()
# dev_predictions = []
# dev_true_labels = []
# total_count = 0
# correct_count = 0
# with torch.no_grad():
#     for batch in test_dataloader:
#       total_count +=1
#       out = model(text = batch["input_ids"], image = batch["img"], labels = batch["label"])
#       pred = torch.argmax(out.logits)
#       dev_predictions.append(pred)
#       dev_true_labels.append(batch['label'])
#       if(pred == batch['label']):
#         correct_count +=1
#       if(total_count == 20):
#         break
#       if(total_count % 100 == 0):
#         print('Accuracy = ' + str(correct_count / total_count))
