In [1]:
# OPTIONAL: Load the "autoreload" extension so that code can change
%load_ext autoreload

# OPTIONAL: always reload modules so that as you change code in src, it gets loaded
%autoreload 2

In [4]:
from transformers import BertTokenizer, LxmertTokenizer
from data import ImageTextClassificationDataset, collate_fn_batch_visualbert, collate_fn_batch_lxmert, collate_fn_batch_visualbert_semi_supervised, collate_fn_batch_lxmert_semi_supervised
from matplotlib import pyplot as plt
import pandas as pd
from functools import partial
import torch

In [5]:
# import argparse

In [6]:
# parser = argparse.ArgumentParser(description='train')
parser = pd.Series()
parser = parser.append(pd.Series({'img_feature_path': "../data/features/visualgenome/"}))
parser = parser.append(pd.Series({'train_csv_path': "../data/splits/random/memotion_train.csv"}))
parser = parser.append(pd.Series({'val_csv_path': "../data/splits/random/memotion_val.csv"}))
parser = parser.append(pd.Series({'model_type': "visualbert"}))
parser = parser.append(pd.Series({'model_path': "uclanlp/visualbert-vqa-coco-pre"}))
parser = parser.append(pd.Series({'learning_rate': 2e-5}))
parser = parser.append(pd.Series({'epoch': 100}))
parser = parser.append(pd.Series({'eval_step': 100}))
parser = parser.append(pd.Series({'batch_size': 64}))
parser = parser.append(pd.Series({'amp':True}))
parser = parser.append(pd.Series({'output_dir': "./tmp"}))
parser = parser.append(pd.Series({'checkpoint_step': 1000}))
parser = parser.append(pd.Series({'random_seed': 42}))
parser = parser.append(pd.Series({'resume_training': False}))
parser = parser.append(pd.Series({'semi_supervised': False}))


# args = parser.parse_args()
args = parser
args

  


img_feature_path              ../data/features/visualgenome/
train_csv_path      ../data/splits/random/memotion_train.csv
val_csv_path          ../data/splits/random/memotion_val.csv
model_type                                        visualbert
model_path                   uclanlp/visualbert-vqa-coco-pre
learning_rate                                        0.00002
epoch                                                    100
eval_step                                                100
batch_size                                                64
amp                                                     True
output_dir                                             ./tmp
checkpoint_step                                         1000
random_seed                                               42
resume_training                                        False
semi_supervised                                        False
dtype: object

In [7]:
img_feature_path = args.img_feature_path
model_type = args.model_type
# dataset_train = ImageTextClassificationDataset(img_feature_path, args.train_csv_path, 
#             supervise = not args.semi_supervised,model_type=model_type, vilt_processor=processor,mode='train')
dataset_train = ImageTextClassificationDataset(img_feature_path, args.val_csv_path, model_type=model_type,mode='train', 
                                                debug=True, metadata_path='../data/features/visualgenome/train_images/metadata.json'
                                                )


In [8]:
# load model
if model_type == "visualbert":
    # config = VisualBertConfig.from_pretrained(args.model_path)
    # model = VisualBertModel.from_pretrained(args.model_path)
    # model = ModelForBinaryClassification(model,config)
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    # processor = None
elif model_type == "lxmert":
    # config = LxmertConfig.from_pretrained(args.model_path)
    # model = LxmertModel.from_pretrained(args.model_path)
    # model = ModelForBinaryClassification(model,config)
    tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased") 
    # processor = None
# elif model_type == "vilt":
#     from transformers import ViltProcessor, ViltModel, ViltForImagesAndTextClassification
#     config = AutoConfig.from_pretrained("dandelin/vilt-b32-mlm")
#     config.num_images = 1
#     model = ViltForImagesAndTextClassification(config)
#     model.vilt = ViltModel.from_pretrained(args.model_path)
#     processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
#     tokenizer = None

In [9]:
if args.semi_supervised:
    if model_type == "visualbert":
        collate_fn_batch = partial(collate_fn_batch_visualbert_semi_supervised,tokenizer=tokenizer)
    elif model_type == "lxmert":
        collate_fn_batch = partial(collate_fn_batch_lxmert_semi_supervised,tokenizer=tokenizer)
else:
    if model_type == "visualbert":
        collate_fn_batch = partial(collate_fn_batch_visualbert,tokenizer=tokenizer, debug=True)
    elif model_type == "lxmert":
        collate_fn_batch = partial(collate_fn_batch_lxmert,tokenizer=tokenizer)
    # elif model_type == "vilt":
    #     collate_fn_batch = partial(collate_fn_batch_vilt,processor=processor)

In [10]:
train_loader = torch.utils.data.DataLoader(
    dataset_train,
    collate_fn = collate_fn_batch,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=3,)

In [11]:
batch_toks, batch_img_features, batch_labels, batch_metadata = next(iter(train_loader))

In [12]:
for caption_ids in batch_toks['input_ids'].tolist():
    print(tokenizer.convert_ids_to_tokens(caption_ids, skip_special_tokens=True))

['me', 'at', '3a', '##m', 'watching', 'info', '##mer', '##cial', '##s', '1844', '##9', '##41', '##7', 'mozambique', 'pigeon', 'blood', 'red', 'hub', '##y', 'di', '##omo', '##nd', 'ring', 'gr', '##sg', '##iano', 'heat', '512', '##ct', 'center']
['can', 'u', 'imagine', 'how', 'hot', 'id', 'be', 'if', 'i', 'ate', 'right', 'and', 'took', 'care', 'of', 'my', 'body', 'im', 'not', 'gonna', 'do', 'it', 'but', 'can', 'u', 'imagine']
['when', 'its', '2', 'am', 'but', 'you', 'can', '##t', 'stop', 'thinking', 'about', 'that', 'time', 'you', 'tripped', 'during', 'the', '4th', 'grade', 'spelling', 'bee']
['sorting', 'by', 'new', 'and', 'only', 'seeing', 'rep', '##ost', '##s', 'anime', '##s', '95', 'puck', '##you', 'fuck', 'you', 'too']
['people', 'hearing', 'im', 'a', 'cop', 'people', 'realizing', 'im', 'just', 'a', 'mall', 'cop', 'mg', '##fi', '##pc', '##om']
['me', 'pulls', 'out', 'a', 'pack', 'of', 'gum', 'the', 'entire', 'fucking', 'class']
['me', 'reading', 'all', 'the', 'kind', 'comments', 're

In [13]:
batch_labels

tensor([[1, 0, 1, 0],
        [1, 0, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 0, 1],
        [1, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 0],
        [1, 0, 1, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 1, 0],
        [1, 1, 0, 0],
        [1, 1, 1, 0],
        [0, 1, 0, 0],
        [1, 1, 1, 0],
        [0, 0, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 1, 1],
        [1, 1, 0, 0],
        [1, 1, 1, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 1, 0, 1],
        [0, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
        [1

In [14]:
batch_img_features.shape

torch.Size([64, 64, 2048])