In [1]:
import os
import sys
import logging
import itertools
import pdb
import numpy as np
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from transformers import BertConfig, BertTokenizer, BertModel
from transformers import ViltProcessor, ViltModel, ViltConfig
from transformers import BertTokenizerFast
from transformers import logging as transformers_logging

logger = logging.getLogger(__name__)
logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
transformers_logging.set_verbosity_error()

ModuleNotFoundError: No module named 'torch'

In [None]:
from cl_evaluation.evaluate_cl_algorithm import forward_transfer_eval, catastrophic_forgetting_eval
from configs.model_configs import model_configs
from configs.task_configs import task_configs, SUPPORTED_VL_TASKS
from configs.adapter_configs import ADAPTER_MAP
from utils.seed_utils import set_seed

class Args:
    
    def __init__(self):
        self.num_workers = 0
        self.batch_size = 32
        self.mcl_data_dir = '/data/datasets/MCL/'
        self.pretrained_model_name = 'dandelin/vilt-b32-mlm'

args = Args()

In [None]:
vilt_processor = ViltProcessor.from_pretrained(args.pretrained_model_name)
vilt = ViltModel.from_pretrained(args.pretrained_model_name)

In [None]:
from modeling.vilt_modeling import ViltContinualLearner, ViltEncoderWrapper

ordered_cl_tasks = ['vqa', 'nlvr2', 'snli-ve']
device = torch.device("cuda")

model_config = model_configs['vilt']
encoder_dim = model_config['encoder_dim']
visual_mode = model_config['visual_mode']
batch2inputs_converter = model_config['batch2inputs_converter']

encoder = ViltEncoderWrapper(vilt_processor, vilt, device)
model = ViltContinualLearner(ordered_cl_tasks, encoder, encoder_dim, task_configs)
model.to(device)

In [None]:
ckpt_path = os.path.join('/data/experiments/MCL', 'vilt-sequential_ft-task0_vqa-task1_nlvr2-task2_snli-ve', \
                        'checkpoints', 'task1_nlvr2', 'model')
model.load_state_dict(torch.load(ckpt_path))

In [None]:
from data.visionlanguage_datasets.nlvr2_dataset import build_nlvr2_dataloader

nlvr_config = task_configs['nlvr2']
data_dir = os.path.join(args.mcl_data_dir, nlvr_config['data_dir'])
val_dataloader = build_nlvr2_dataloader(args=args,
                                          data_dir=data_dir,
                                          split='val',
                                          visual_mode=visual_mode)
model.eval()
eval_score = 0
t = tqdm(val_dataloader, desc='Evaluating on NLVR2 val set')
total = 0
for step, batch in enumerate(t):
    inputs = batch2inputs_converter(batch)
    with torch.no_grad():
        output = model(task_key='nlvr2', **inputs)
        logits = output[1]

    batch_scores = (logits.argmax(-1).cpu() == batch['labels'])
    eval_score += batch_scores.sum().item()
    total += batch_scores.shape[0]
    t.set_postfix({'score': eval_score/total})
eval_score = eval_score/len(val_dataloader.dataset)*100.0