In [1]:
from PIL import Image
import torch

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [2]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_files, text, processor, num_labels):
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['img']
        label = self.image_files[idx]['label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        targets = torch.zeros(self.num_labels)
        targets[label] = 1
        encoding["labels"] = targets

        return encoding

In [4]:
from datasets import load_dataset

cache_dir='./cache'
#datasets = load_dataset('Maysee/tiny-imagenet', cache_dir=cache_dir)
datasets = load_dataset('cifar10', cache_dir=cache_dir)

Downloading builder script:   0%|          | 0.00/3.61k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.66k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.00k [00:00<?, ?B/s]

Downloading and preparing dataset cifar10/plain_text to C:/Users/Maheshvar/Deep Learning/Project/cache/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4...


Downloading data:   0%|          | 0.00/170M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]



Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset cifar10 downloaded and prepared to C:/Users/Maheshvar/Deep Learning/Project/cache/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [27]:
from transformers import FlavaProcessor, FlavaConfig, FlavaProcessor, FlavaMultimodalModel

label_list = datasets["train"].features["label"].names
num_labels = len(label_list)

config = FlavaConfig.from_pretrained("facebook/flava-full", cache_dir=cache_dir)
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: str(i) for i, label in enumerate(label_list)}
config.num_labels = num_labels

processor = FlavaProcessor.from_pretrained("facebook/flava-full")

train_dataset = ImageDataset(image_files=datasets["train"], text="What is this image?", processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(image_files=datasets["test"], text="What is this image?", processor=processor, num_labels=num_labels)

In [28]:
processor.decode(train_dataset[0]['input_ids'])

'[CLS] what is this image? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [29]:
label = torch.nonzero(train_dataset[0]['labels']).squeeze().tolist()
config.id2label[str(label)]

'airplane'

In [30]:
from transformers import FlavaMultimodalModel
model = FlavaMultimodalModel.from_pretrained("facebook/flava-full", config=config)
model.to(device)

AttributeError: 'FlavaConfig' object has no attribute 'use_cls_token'

In [23]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True)

In [24]:
from tqdm.notebook import tqdm

@torch.no_grad()
def evaluate(model, device, test_dataloader):
    losses = []  # List of scalar tensors
    correct = 0
    total = 0
    for batch in tqdm(test_dataloader):
        # adapt batch to model
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        target = torch.argmax(batch['labels'], dim=1)
        correct += torch.sum(preds==target).item()
        total += target.size(0)
        losses.append(outputs.loss)
    stacked_losses = torch.stack(losses)  # (num_batches, ) 
    total_avg_loss = stacked_losses.mean()  # (num test examples, ) -> scalar
    total_avg_acc = (100 * correct) / total
    print("Correct: " + str(correct) + "/" + "Total: "  +str(total))
    print("Average val loss: " + str(total_avg_loss.item()))
    print("Average val acc: " + str(total_avg_acc))

    return total_avg_loss.item(), total_avg_acc

In [25]:
import os

for param in model.parameters():
    param.requires_grad = False

for name, param in model.named_parameters():
    if 'classifier' in name or 'pooler' in name:
        param.requires_grad = True

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

checkpoint_dir = './checkpoints'
best_test_acc = -1
step = -1
patience = 0
model.train()
for epoch in range(1):  # loop over the dataset multiple times
    print(f"Train Epoch: {epoch}")
    for batch in tqdm(train_dataloader):
        step += 1
        # get the inputs; 
        batch = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        target = torch.argmax(batch['labels'], dim=1)
        correct = torch.sum(preds==target).item()
        acc = (correct * 100) / target.size(0)
        print(f"Step {step} - loss: {loss.item()} , train acc: {acc}")
        loss.backward()
        optimizer.step()
    print(f"Evaluate Epoch: {epoch}")
    # Evaluate
    model.eval()
    new_test_loss, new_test_acc = evaluate(model, device, test_dataloader)
    # save checkpoint with best test loss
    if new_test_acc > best_test_acc or best_test_acc < 0:
        patience = 0
        if best_test_acc > 0:
            os.remove(checkpoint_dir + '/'+ best_checkpoint_filename)
        best_checkpoint_filename = "best_model" + str(epoch) +".pt"
        torch.save(model.state_dict(), checkpoint_dir + '/' + best_checkpoint_filename)
        best_test_acc = new_test_acc
    else:
        patience += 1
        if patience > 3:
            print("Early stopping at epoch "+str(epoch))
            break

    model.train()

Train Epoch: 0


  0%|          | 0/98 [00:00<?, ?it/s]

TypeError: forward() got an unexpected keyword argument 'input_ids'