In [3]:
import logging
import os
from abc import abstractmethod

import cv2
import torch
from tqdm.auto import tqdm

In [2]:
class BaseTester(object):
    def __init__(self, model, criterion, metric_ftns, args):
        self.args = args

        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(args.n_gpu)
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns

        self.epochs = self.args.epochs
        self.save_dir = self.args.save_dir

        self._load_checkpoint(args.load)

    @abstractmethod
    def test(self):
        raise NotImplementedError

    @abstractmethod
    def plot(self):
        raise NotImplementedError

    def _prepare_device(self, n_gpu_use):
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
                    n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _load_checkpoint(self, load_path):
        load_path = str(load_path)
        self.logger.info("Loading checkpoint: {} ...".format(load_path))
        checkpoint = torch.load(load_path)
        self.model.load_state_dict(checkpoint['state_dict'])

In [4]:
class Tester(BaseTester):
    def __init__(self, model, criterion, metric_ftns, args, test_dataloader):
        super(Tester, self).__init__(model, criterion, metric_ftns, args)
        self.test_dataloader = test_dataloader

    def test(self):
        self.logger.info('Start to evaluate in the test set.')
        log = dict()
        self.model.eval()
        with torch.no_grad():
            test_gts, test_res = [], []
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output = self.model(images, mode='sample')
                reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
                ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
                test_res.extend(reports)
                test_gts.extend(ground_truths)
            test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
                                        {i: [re] for i, re in enumerate(test_res)})
            log.update(**{'test_' + k: v for k, v in test_met.items()})
            print(log)
        return test_res, test_gts

In [5]:
args = {
'image_dir': "../../datasets/IUX_DATA/",
"ann_path" : "../multiImageData.json",
"dataset_name": "multi-image",
"num_workers": 10,
"batch_size": 32,
"max_seq_length":60,
"threshold":3,
"visual_extractor":"resnet101",
"visual_extractor_pretrained":True,
"d_model":512,
"--d_ff":512,
"--d_vf":2048,
"--num_heads":8,
"--num_layers":3,
"--dropout":0.1,
"--logit_layers":1,
"--bos_idx":0,
"--eos_idx":0,
"--pad_idx":0,
"--use_bn":0, 
"--drop_prob_lm":0.5,
"rm_num_slots":3,
"rm_num_heads":8,
"rm_d_model":512,
"sample_method":"beam_search",
"beam_size":3,
"temperature":1.0,
"sample_n":1,
"group_size":1,
"output_logsoftmax":1,
"decoding_constraint":0,
"block_trigrams":1,

"n_gpu":1,
"epochs":100,
"save_dir":"results/iux-ray",
"record_dir":"records/",
"save_period":1,
"monitor_mode":"max",
"monitor_metric":"BLEU_4",

"optim":"Adam",
"lr_ve":5e-5,
"lr_ed":1e-4,
"weight_decay":5e-5,
"amsgrad":True,
"lr_scheduler":"StepLR",
"step_size":50,
"gamma":0.1,
"seed":2022,
"resume":"",
"load":"/home/sweta/scratch/828-Project/R2Gen/results/augmented/current_checkpoint.pth"


}

In [11]:
from dotmap import DotMap
args = DotMap(args)

In [14]:
import numpy as np

In [15]:
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)

In [19]:
from modules.tokenizers import Tokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.loss import compute_loss
from models.r2gen import R2GenModel

In [21]:
tokenizer = Tokenizer(args)

# create data loader
test_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)

# build model architecture
model = R2GenModel(args, tokenizer)

# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores

# build trainer and start to train
tester = Tester(model, criterion, metrics, args, test_dataloader)
tester.test()

TypeError: new(): argument 'size' must be tuple of ints, but found element of type DotMap at pos 2