In [1]:
import torch
from torch import nn, optim
from tqdm import tqdm
from torchvision import transforms
from transformers import BertTokenizer
from functools import partial
from main import config
from dataset import MultiModalDataset
from torch.utils.data import DataLoader
from torchmultimodal.models.flava.model import flava_model_for_classification

import json
%load_ext autoreload
%autoreload 2

In [2]:
model = flava_model_for_classification(config['num_labels'])
optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'])

In [3]:
def transform(tokenizer, input):
    """      return guid, img, text, label """
    batch = {}
    image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
    image = torch.stack([image_transform(d[1].convert("RGB")) for d in input])
    batch["image"] = image

    texts = [d[2] for d in input ]
    tokenized=tokenizer(texts,return_tensors='pt',padding="max_length",max_length=512)
    batch.update(tokenized)

    batch["answers"] = torch.tensor([d[3] for d in input], dtype=torch.long)

    return batch

tokenizer = BertTokenizer.from_pretrained(
    "/root/.cache/huggingface/hub/models--bert-base-uncased", padding="max_length", max_length=64)
transform=partial(transform, tokenizer)


In [4]:
def get_dataloader(config, path_key, transform):
    data = json.load(open(config[path_key], "r", encoding="utf-8"))
    guid = [d['guid'] for d in data]  # 3200
    labels = [d['label'] for d in data]
    multiModalDataset = MultiModalDataset(imgs_dir=config['data_path'],
                                          text_dir=config['data_path'],
                                          labels=labels, guids=guid, config=config)

    train_data_loader = DataLoader(multiModalDataset, batch_size=config['batch_size'],shuffle=True,
                                   collate_fn=transform)
    return train_data_loader

In [5]:
train_data_loader = get_dataloader(config, "train_path", transform=transform)

In [6]:
for epoch in tqdm(range(config['num_epochs'])):

    for idx, batch in enumerate(train_data_loader):
        optimizer.zero_grad()
        out = model(text=batch['input_ids'], image=batch['image'], labels=batch['answers'])
        loss = out.loss
        
        predictions = out.logits.argmax(dim=-1)
        print(predictions)
        print(batch['answers'])
        acc = torch.eq(predictions, batch['answers']).float().mean()
        print(f"acc: {acc:.2%}")
        
        loss.backward()
        optimizer.step()
        print(f"Loss at step {idx} = {loss}")

  0%|          | 0/8 [00:00<?, ?it/s]



tensor([0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 2, 2, 0, 0, 0, 2, 0, 0, 2,
        0, 0, 2, 0, 0, 0, 2, 0])
tensor([0, 2, 1, 0, 2, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 2, 2, 2, 2, 2, 1, 2,
        0, 2, 2, 2, 2, 0, 0, 2])
acc: 37.50%
Loss at step 0 = 1.0916640758514404
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
tensor([1, 2, 2, 1, 2, 0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 0, 2, 2, 0, 1, 1, 2, 1, 2,
        2, 0, 2, 2, 2, 2, 2, 2])
acc: 59.38%
Loss at step 1 = 0.9485380053520203
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
tensor([0, 1, 0, 0, 2, 2, 2, 1, 2, 2, 2, 2, 0, 2, 2, 1, 0, 2, 1, 0, 2, 0, 2, 2,
        2, 2, 1, 2, 2, 0, 1, 0])
acc: 53.12%
Loss at step 2 = 1.961264967918396
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
tensor([1, 2, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2

  0%|          | 0/8 [36:59<?, ?it/s]


KeyboardInterrupt: 