In [1]:
import os
import cv2
import sys
import random

import json
import h5py
import itertools
import numpy as np
from PIL import Image
import argparse, pickle

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
import torchvision.models
from torch.utils.data import Dataset as torchDataset
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data.dataloader import default_collate

# caption libraries
import evaluation
import collections
from data.example import Example
from data.utils import nostdout
from data.field import ImageDetectionsField, TextField, RawField
from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory

# graph libraries
import utils.io as io
import matplotlib.pyplot as plt
from utils.g_vis_img import *
from models.graph_su import *
from evaluation.graph_eval import *

# feature extractor
from models.feature_extractor import *

os.environ["CUDA_VISIBLE_DEVICES"]="1"

# Random seeds
seed = 27
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# Dataset Constants

In [2]:
class SurgicalSceneConstants():
    '''
    Surgical Scene constants
    '''
    def __init__( self):
        self.instrument_classes = ('kidney', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
                                'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier',
                                'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery')
        self.action_classes = ( 'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation', 
                                'Tool_Manipulation', 'Cutting', 'Cauterization', 
                                'Suction', 'Looping', 'Suturing', 'Clipping', 'Staple', 
                                'Ultrasound_Sensing')

# Cross-entropy loss with label smoothing

In [3]:
class CELossWithLS(torch.nn.Module):
    '''
    label smoothing cross-entropy loss for captioning
    '''
    def __init__(self, classes=None, smoothing=0.1, gamma=3.0, isCos=True, ignore_index=-1):
        super(CELossWithLS, self).__init__()
        self.complement = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = torch.nn.LogSoftmax(dim=1)
        self.gamma = gamma
        self.ignore_index = ignore_index

    def forward(self, logits, target):
        with torch.no_grad():
            oh_labels = F.one_hot(target.to(torch.int64), num_classes = self.cls).permute(0,1,2).contiguous()
            smoothen_ohlabel = oh_labels * self.complement + self.smoothing / self.cls

        logs = self.log_softmax(logits[target!=self.ignore_index])
        pt = torch.exp(logs)
        return -torch.sum((1-pt).pow(self.gamma)*logs * smoothen_ohlabel[target!=self.ignore_index], dim=1).mean()


# Dataloader

In [4]:
class DataLoader(TorchDataLoader):
    def __init__(self, dataset, *args, **kwargs):
        super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)

class Dataset(object):
    def __init__(self, examples, fields, gsu_const):
        self.examples = examples
        self.fields = dict(fields)
        
        self.file_dir = gsu_const['file_dir']
        self.img_dir = gsu_const['img_dir']
        self.dataconst = gsu_const['dataconst']
        self.feature_extractor = gsu_const['feature_extractor']
        self.word2vec = h5py.File(gsu_const['w2v_loc'], 'r')
        
    # word2vec
    def _get_word2vec(self,node_ids):
        word2vec = np.empty((0,300))
        for node_id in node_ids:
            vec = self.word2vec[self.dataconst.instrument_classes[node_id]]
            word2vec = np.vstack((word2vec, vec))
        return word2vec

    def __getitem__(self, i):
        example = self.examples[i]
        frame_path = getattr(example, 'image')
        frame_path = frame_path.split("/")
        
        _img_loc = os.path.join(self.file_dir, frame_path[0],self.img_dir,frame_path[3].split("_")[0]+'.png')
        frame_data = h5py.File(os.path.join(self.file_dir, frame_path[0],'vsgat',self.feature_extractor, frame_path[3].split("_")[0]+'_features.hdf5'), 'r')    
        
        #print(_img_loc)
        
        # caption data
        cp_data = []
        for field_name, field in self.fields.items():
            if field_name == 'image' and field == None:
                cp_data.append(np.zeros((6,512), dtype = np.float32))
            else:
                cp_data.append(field.preprocess(getattr(example, field_name)))   
        if len(cp_data) == 1: cp_data = cp_data[0]
        
        # graph data
        gsu_data = {}
        gsu_data['img_name'] = frame_data['img_name'].value[:] + '.jpg'
        gsu_data['img_loc'] = _img_loc
        gsu_data['node_num'] = frame_data['node_num'].value
        gsu_data['roi_labels'] = frame_data['classes'][:]
        gsu_data['det_boxes'] = frame_data['boxes'][:]
        gsu_data['edge_labels'] = frame_data['edge_labels'][:]
        gsu_data['edge_num'] = gsu_data['edge_labels'].shape[0]
        if self.fields['image'] == None:
            gsu_data['features'] = np.zeros((gsu_data['node_num'],512), dtype = np.float32)
        else:
            gsu_data['features'] = frame_data['node_features'][:]
        gsu_data['spatial_feat'] = frame_data['spatial_features'][:]
        gsu_data['word2vec'] = self._get_word2vec(gsu_data['roi_labels'])
        
        data = {}
        data['cp_data'] = cp_data
        data['gsu_data'] = gsu_data
        return data

    def __len__(self):
        return len(self.examples)

    def __getattr__(self, attr):
        if attr in self.fields:
            for x in self.examples:
                yield getattr(x, attr)
                
    def collate_fn(self):
        def collate(batch):
            gsu_batch_data = {}
            gsu_batch_data['img_name'] = []
            gsu_batch_data['img_loc'] = []
            gsu_batch_data['node_num'] = []
            gsu_batch_data['roi_labels'] = []
            gsu_batch_data['det_boxes'] = []
            gsu_batch_data['edge_labels'] = []
            gsu_batch_data['edge_num'] = []
            gsu_batch_data['features'] = []
            gsu_batch_data['spatial_feat'] = []
            gsu_batch_data['word2vec'] = []

            for data in batch:
                gsu_batch_data['img_name'].append(data['gsu_data']['img_name'])
                gsu_batch_data['img_loc'].append(data['gsu_data']['img_loc'])
                gsu_batch_data['node_num'].append(data['gsu_data']['node_num'])
                gsu_batch_data['roi_labels'].append(data['gsu_data']['roi_labels'])
                gsu_batch_data['det_boxes'].append(data['gsu_data']['det_boxes'])
                gsu_batch_data['edge_labels'].append(data['gsu_data']['edge_labels'])
                gsu_batch_data['edge_num'].append(data['gsu_data']['edge_num'])
                gsu_batch_data['features'].append(data['gsu_data']['features'])
                gsu_batch_data['spatial_feat'].append(data['gsu_data']['spatial_feat'])
                gsu_batch_data['word2vec'].append(data['gsu_data']['word2vec'])

            gsu_batch_data['edge_labels'] = torch.FloatTensor(np.concatenate(gsu_batch_data['edge_labels'], axis=0))
            gsu_batch_data['features'] = torch.FloatTensor(np.concatenate(gsu_batch_data['features'], axis=0))
            gsu_batch_data['spatial_feat'] = torch.FloatTensor(np.concatenate(gsu_batch_data['spatial_feat'], axis=0))
            gsu_batch_data['word2vec'] = torch.FloatTensor(np.concatenate(gsu_batch_data['word2vec'], axis=0))
            
            cp_batch_data = []
            tensors = []
            
            for data in batch: cp_batch_data.append(data['cp_data'])
            if len(self.fields) == 1: cp_batch_data = [cp_batch_data, ]
            else: cp_batch_data = list(zip(*cp_batch_data))

            for field, data in zip(self.fields.values(), cp_batch_data):
                if field == None: tensor = default_collate(data)
                else: tensor = field.process(data)
                if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
                    tensors.extend(tensor)
                else: tensors.append(tensor)

            if len(tensors) > 1:cp_batch_data = tensors
            else: cp_batch_data = tensors[0]
            
            batch_data = {}
            batch_data['gsu'] = gsu_batch_data
            batch_data['cp'] = cp_batch_data
            
            return(batch_data)

        return collate
    
class PairedDataset(Dataset):
    def __init__(self, examples, fields, gsu_const):
        assert ('image' in fields)
        assert ('text' in fields)
        super(PairedDataset, self).__init__(examples, fields, gsu_const)
        self.image_field = self.fields['image']
        if self.image_field == None: print('no pre-extracted image featured')
        self.text_field = self.fields['text']
        
    def image_dictionary(self, fields=None):
        if not fields:
            fields = self.fields
        dataset = Dataset(self.examples, fields, gsu_const)
        #dataset = DictionaryDataset(self.examples, fields, gsu_const, 'image')
        return dataset
        
class COCO(PairedDataset):
    def __init__(self, image_field, text_field, gsu_const, img_root, ann_root, id_root=None):
        # setting training and val root
        roots = {}
        roots['train'] = { 'img': img_root, 'cap': os.path.join(ann_root, 'captions_train.json')}
        roots['val'] = {'img': img_root, 'cap': os.path.join(ann_root, 'captions_val.json')}

        # Getting the id: planning to remove this in future
        if id_root is not None:
            ids = {}
            ids['train'] = json.load(open(os.path.join(id_root, 'WithCaption_id_path_train.json'), 'r'))
            ids['val'] = json.load(open(os.path.join(id_root, 'WithCaption_id_path_val.json'), 'r'))   
        else: ids = None
        
        with nostdout():
            self.train_examples, self.val_examples = self.get_samples(roots, ids)
        examples = self.train_examples + self.val_examples
        super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field}, gsu_const)   

    @property
    def splits(self):
        train_split = PairedDataset(self.train_examples, self.fields, gsu_const) 
        val_split = PairedDataset(self.val_examples, self.fields, gsu_const)
        return train_split, val_split

    @classmethod
    def get_samples(cls, roots, ids_dataset=None):
        train_samples = []
        val_samples = []
   
        for split in ['train', 'val']:
            anns = json.load(open(roots[split]['cap'], 'r'))
            if ids_dataset is not None: ids = ids_dataset[split]
                
            for index in range(len(ids)):              
                id_path = ids[index]
                caption = anns[index]['caption']
                example = Example.fromdict({'image': os.path.join('', id_path), 'text': caption})
                if split == 'train': train_samples.append(example)
                elif split == 'val': val_samples.append(example)
                    
        return train_samples, val_samples

# Arguments, dataloader

In [5]:
# arguments
device = torch.device('cuda')
parser = argparse.ArgumentParser(description='Incremental domain adaptation for surgical report generation')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--workers', type=int, default=0)

# caption
parser.add_argument('--exp_name', type=str, default='m2_transformer')
parser.add_argument('--m', type=int, default=40)   
parser.add_argument('--cp_cbs', type=str, default='True')
parser.add_argument('--cp_cbs_filter', default='LOG', type=str) # Potential choice: 'gau' and 'LOG'
parser.add_argument('--cp_kernel_sizex', default=3, type=int)
parser.add_argument('--cp_kernel_sizey', default=1, type=int)
parser.add_argument('--cp_decay_epoch', default=2, type=int) 
parser.add_argument('--cp_std_factor', default=0.9, type=float)

# graph
parser.add_argument('--gsu_cbs',        type=bool, default=True)
parser.add_argument('--gsu_feat', type=str,  default='resnet18_09_SC_CBS')
parser.add_argument('--gsu_w2v_loc', type=str,  default='datasets/surgicalscene_word2vec.hdf5')

# feature_extractor
parser.add_argument('--fe_use_cbs',            type=bool,      default=True,        help='use CBS')
parser.add_argument('--fe_std',                type=float,     default=1.0,         help='')
parser.add_argument('--fe_std_factor',         type=float,     default=0.9,         help='')
parser.add_argument('--fe_cbs_epoch',          type=int,       default=5,           help='')
parser.add_argument('--fe_kernel_size',        type=int,       default=3,           help='')
parser.add_argument('--fe_fil1',               type=str,       default='LOG',       help='gau, LOG')
parser.add_argument('--fe_fil2',               type=str,       default='gau',       help='gau, LOG')
parser.add_argument('--fe_fil3',               type=str,       default='gau',       help='gau, LOG')
parser.add_argument('--fe_num_classes',        type=int,       default=11,           help='11')
parser.add_argument('--fe_use_SC',             type=bool,      default=True,       help='use SuperCon')

# checkpoints and file dirs
print('Training check for DA_ECBS_ResNet18_09_SC_ECBS')

parser.add_argument('--gsu_img_dir', type=str,  default='left_frames')
parser.add_argument('--gsu_file_dir', type=str,  default='datasets/instruments18/')

parser.add_argument('--cp_features_path', type=str, default='datasets/instruments18/') 
parser.add_argument('--cp_annotation_folder', type=str, default='datasets/annotations_new/annotations_SD_inc')

# checkpoints

args = parser.parse_args(args=[])
print(args)

# graph scene understanding constants
gsu_const = {}
gsu_const['file_dir'] = args.gsu_file_dir
gsu_const['img_dir'] = args.gsu_img_dir
gsu_const['dataconst'] = SurgicalSceneConstants()
gsu_const['feature_extractor'] = args.gsu_feat
gsu_const['w2v_loc'] =args.gsu_w2v_loc


# Pipeline for image regions and text
#image_field = ImageDetectionsField(detections_path=args.cp_features_path, max_detections=6, load_in_tmp=False)  
image_field = None
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy', remove_punctuation=True, nopoints=False)

# Create the dataset 
dataset = COCO(image_field, text_field, gsu_const, args.cp_features_path, args.cp_annotation_folder, args.cp_annotation_folder)
train_dataset, val_dataset = dataset.splits   
print('train:', len(train_dataset))
print('val:', len(val_dataset))
    
# caption data
if not os.path.isfile('datasets/vocab_%s.pkl' % args.exp_name):
    print("Building vocabulary")
    text_field.build_vocab(train_dataset, val_dataset, min_freq=2)  
    pickle.dump(text_field.vocab, open('datasets/vocab_%s.pkl' % args.exp_name, 'wb'))
else:
    text_field.vocab = pickle.load(open('datasets/vocab_%s.pkl' % args.exp_name, 'rb'))

print('vocabulary size is:', len(text_field.vocab))
print(text_field.vocab.stoi)

# dataset
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()})
dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size) # for caption with word GT class number

Training check for DA_ECBS_ResNet18_09_SC_ECBS
Namespace(batch_size=32, cp_annotation_folder='datasets/annotations_new/annotations_SD_inc', cp_cbs='True', cp_cbs_filter='LOG', cp_decay_epoch=2, cp_features_path='datasets/instruments18/', cp_kernel_sizex=3, cp_kernel_sizey=1, cp_std_factor=0.9, exp_name='m2_transformer', fe_cbs_epoch=5, fe_fil1='LOG', fe_fil2='gau', fe_fil3='gau', fe_kernel_size=3, fe_num_classes=11, fe_std=1.0, fe_std_factor=0.9, fe_use_SC=True, fe_use_cbs=True, gsu_cbs=True, gsu_feat='resnet18_09_SC_CBS', gsu_file_dir='datasets/instruments18/', gsu_img_dir='left_frames', gsu_w2v_loc='datasets/surgicalscene_word2vec.hdf5', m=40, workers=0)
no pre-extracted image featured
no pre-extracted image featured
no pre-extracted image featured
train: 1560
val: 447
vocabulary size is: 41
defaultdict(<function _default_unk_index at 0x7f68d8ba08c8>, {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, 'is': 4, 'tissue': 5, 'forceps': 6, 'monopolar': 7, 'curved': 8, 'scissors': 9, 'bipo

# Feature Extractor

In [6]:
# net model
if args.fe_use_SC: feature_network = SupConResNet(args=args)
else: feature_network = ResNet18(args)

# CBS
if args.fe_use_cbs:
    if args.fe_use_SC: feature_network.encoder.get_new_kernels(0)
    else: feature_network.get_new_kernels(0)

# gpu
num_gpu = torch.cuda.device_count()
if num_gpu > 0:
    device_ids = np.arange(num_gpu).tolist()    
    if args.fe_use_SC:
        feature_network.encoder = torch.nn.DataParallel(feature_network.encoder)
        feature_network = feature_network.cuda()
    else:
        feature_network = nn.DataParallel(feature_network, device_ids=device_ids).cuda()

# Caption model

In [7]:
if args.cp_cbs == 'True':
    from models.transformer import MemoryAugmentedEncoder_CBS
    print("MemoryAugmentedEncoder_CBS")
    encoder = MemoryAugmentedEncoder_CBS(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': args.m})
else:
    print("MemoryAugmentedEncoder")
    encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': args.m}) 

decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
caption_model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

if args.cp_cbs == 'True':
    caption_model.encoder.get_new_kernels(0, args.cp_kernel_sizex, args.cp_kernel_sizey, args.cp_decay_epoch, args.cp_std_factor, args.cp_cbs_filter) 

MemoryAugmentedEncoder_CBS


# Graph Model

In [8]:
graph_su_model = AGRNN(bias= True, bn=False, dropout=0.3, multi_attn=False, layer=1, diff_edge=False, use_cbs = args.gsu_cbs)
if args.gsu_cbs:
    graph_su_model.grnn1.gnn.apply_h_h_edge.get_new_kernels(0)

# Load Pre-trained_weights independent weights

In [9]:
# fe_modelpath ='feature_extractor/checkpoint/incremental/inc_ResNet18_SC_CBS_0_012345678.pkl'
# gsu_checkpoint = 'checkpoints/g_checkpoints/da_ecbs_resnet18_09_SC_eCBS/da_ecbs_resnet18_09_SC_eCBS/epoch_train/checkpoint_D1230_epoch.pth'
# cp_checkpoint = 'checkpoints/IDA_MICCAI2021_checkpoints/SD_base_LOG/'

# # feature network
# feature_network.load_state_dict(torch.load(args.fe_modelpath))

# # caption
# pretrained_model = torch.load(args.cp_checkpoint+('%s_best.pth' % args.exp_name))
# caption_model.load_state_dict(pretrained_model['state_dict']) 

# # graph
# pretrained_model = torch.load(args.gsu_checkpoint)
# graph_su_model.load_state_dict(pretrained_model['state_dict'])

# feature extraction layers

In [10]:
# extract the encoder layer
if args.fe_use_SC:
    feature_network = feature_network.encoder
else:
    if args.fe_use_cbs: feature_network = nn.Sequential(*list(feature_network.module.children())[:-2])
    else: feature_network = nn.Sequential(*list(feature_network.module.children())[:-1])

feature_network = feature_network.cuda()

# MTL Model (Graph Scene Understanding and Captioning)

In [11]:
class mtl_model(nn.Module):
    '''
    Multi-task model : Graph Scene Understanding and Captioning
    '''
    def __init__(self, feature_extractor, graph, caption):
        super(mtl_model, self).__init__()
        self.feature_extractor = feature_extractor
        self.graph_su = graph
        self.caption = caption
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    
    def forward(self, img_dir, det_boxes_all, caps_gt, node_num, features, spatial_feat, word2vec, roi_labels, val = False, text_field = None):               
        
        gsu_node_feat = None
        cp_node_feat = None
        for index, img_loc in  enumerate(img_dir):
            #print(img_loc)
            _img = Image.open(img_loc).convert('RGB')
            _img = np.array(_img)
            
            img_stack = None
            for idx, bndbox in enumerate(det_boxes_all[index]):        
                roi = np.array(bndbox).astype(int)
                roi_image = _img[roi[1]:roi[3] + 1, roi[0]:roi[2] + 1, :]
                roi_image = self.transform(cv2.resize(roi_image, (224, 224), interpolation=cv2.INTER_LINEAR))
                roi_image = torch.autograd.Variable(roi_image.unsqueeze(0))
                # stack nodes images per image
                if img_stack is None: img_stack = roi_image
                else: img_stack = torch.cat((img_stack, roi_image))
            
            # send the stack to feature extractor
            img_stack = img_stack.cuda()
            feature = feature_network(img_stack)
            feature = feature.view(feature.size(0), -1)        
            
            if gsu_node_feat == None: gsu_node_feat = feature
            else: gsu_node_feat = torch.cat((gsu_node_feat,feature))
            
            feature = torch.unsqueeze(torch.cat((feature,torch.zeros((6-len(feature)),512).cuda())),0)
            if cp_node_feat == None: cp_node_feat = feature
            else: cp_node_feat = torch.cat((cp_node_feat,feature))
                
        #print('feat_size',gsu_node_feat.shape)
        #print('cp_feat', cp_node_feat.shape)
    
        if val == True:
            caption_output, _ = self.caption.beam_search(cp_node_feat, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
        else:
            caption_output = self.caption(cp_node_feat, caps_gt)
        interaction = self.graph_su(node_num, gsu_node_feat, spatial_feat, word2vec, roi_labels)
        #print(interaction)
        #print(interaction.shape)
        return interaction, caption_output

# Combined_model

In [12]:
model = mtl_model(feature_network, graph_su_model, caption_model)
model = model.to(device)

# Evaluation

In [13]:
import itertools


def eval_mtl(epoch, model, dataloader, text_field):
    
    
    #model.eval()
    gen = {}
    gts = {}

    # graph
    # criterion and scheduler
    g_criterion = nn.MultiLabelSoftMarginLoss()                   
    g_edge_count = 0
    g_total_acc = 0.0
    g_total_loss = 0.0
    g_logits_list = []
    g_labels_list = []
    
    for it, data in tqdm(enumerate(iter(dataloader))):
            
        graph_data = data['gsu']
        cp_data = data['cp']
            
        # graph
        img_name = graph_data['img_name']
        img_loc = graph_data['img_loc']
        node_num = graph_data['node_num']
        roi_labels = graph_data['roi_labels']
        det_boxes = graph_data['det_boxes']
        edge_labels = graph_data['edge_labels']
        edge_num = graph_data['edge_num']
        features = graph_data['features']
        spatial_feat = graph_data['spatial_feat']
        word2vec = graph_data['word2vec']
        features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)         
        
        # caption
        _, caps_gt = cp_data
            
        with torch.no_grad():
    
            g_output, caption_out = model(img_loc, det_boxes, caps_gt, node_num, features, spatial_feat, word2vec, roi_labels, val = True, text_field = text_field)
        
            g_logits_list.append(g_output)
            g_labels_list.append(edge_labels)
            # loss and accuracy
            g_loss = g_criterion(g_output, edge_labels.float())
            g_acc = np.sum(np.equal(np.argmax(g_output.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
            
        # accumulate loss and accuracy of the batch
        g_total_loss += g_loss.item() * edge_labels.shape[0]
        g_total_acc  += g_acc
        g_edge_count += edge_labels.shape[0]
        
        caps_gen = text_field.decode(caption_out, join_words=False)
        
        for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
            gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
            gen['%d_%d' % (it, i)] = [gen_i, ]    
            gts['%d_%d' % (it, i)] = [gts_i,]
        
    #graph evaluation
    g_total_acc = g_total_acc / g_edge_count
    g_total_loss = g_total_loss / len(dataloader)

    g_logits_all = torch.cat(g_logits_list).cuda()
    g_labels_all = torch.cat(g_labels_list).cuda()
    g_logits_all = F.softmax(g_logits_all, dim=1)
    g_map_value, g_ece, g_sce, g_tace, g_brier, g_uce = calibration_metrics(g_logits_all, g_labels_all, 'test')
    
    # caption evaluation
    #if not os.path.exists('results/c_results/predict_caption'):
    #    os.makedirs('results/c_results/predict_caption')
    #json.dump(gen, open('results/c_results/predict_caption/predict_caption_val.json', 'w'))

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)

    scores, _ = evaluation.compute_scores(gts, gen)
    print('Graph : {acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f}' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce.item()) )
    print(print("Caption Scores :", scores))
    return #scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce.item()

# Load pretrained MTL model

In [17]:
mtl_checkpoint = 'checkpoints/mtl_train/UDA_Graph/checkpoint_'
    
for i in range(1,101):
    # MTL
    pretrained_model = torch.load(mtl_checkpoint+str(i)+'_epoch.pth')
    model.load_state_dict(pretrained_model['state_dict'])
    print('epoch:',i)
    eval_mtl(i, model, dict_dataloader_val, text_field)

0it [00:00, ?it/s]

epoch: 1


14it [00:54,  3.91s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.476314 map: 0.167071 loss: 28.179157, ece:0.137035, sce:0.069644, tace:0.067284, brier:0.774015, uce:0.078060}
Caption Scores : {'BLEU': array([0.5122, 0.4437, 0.3825, 0.3289]), 'METEOR': 0.3125, 'ROUGE': 0.4989, 'CIDEr': 1.9248}
None
epoch: 2


14it [00:59,  4.24s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.486649 map: 0.156777 loss: 27.757163, ece:0.152564, sce:0.068916, tace:0.065765, brier:0.769532, uce:0.094871}
Caption Scores : {'BLEU': array([0.5138, 0.4461, 0.3848, 0.3306]), 'METEOR': 0.3199, 'ROUGE': 0.5076, 'CIDEr': 1.9691}
None
epoch: 3


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.486649 map: 0.175080 loss: 26.501456, ece:0.146385, sce:0.065478, tace:0.064190, brier:0.763749, uce:0.095934}
Caption Scores : {'BLEU': array([0.5204, 0.4525, 0.3928, 0.3414]), 'METEOR': 0.3197, 'ROUGE': 0.5115, 'CIDEr': 2.0484}
None
epoch: 4


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.472007 map: 0.155323 loss: 27.610705, ece:0.136772, sce:0.067714, tace:0.064562, brier:0.772718, uce:0.074595}
Caption Scores : {'BLEU': array([0.5246, 0.4582, 0.399 , 0.3455]), 'METEOR': 0.3219, 'ROUGE': 0.5219, 'CIDEr': 2.1398}
None
epoch: 5


14it [00:59,  4.27s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.470284 map: 0.161713 loss: 27.404731, ece:0.133776, sce:0.067251, tace:0.064868, brier:0.770434, uce:0.068247}
Caption Scores : {'BLEU': array([0.5164, 0.4507, 0.3943, 0.3446]), 'METEOR': 0.3168, 'ROUGE': 0.5114, 'CIDEr': 2.0096}
None
epoch: 6


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.478898 map: 0.154205 loss: 27.654642, ece:0.142235, sce:0.068359, tace:0.066402, brier:0.777649, uce:0.102241}
Caption Scores : {'BLEU': array([0.5231, 0.4581, 0.4013, 0.35  ]), 'METEOR': 0.321, 'ROUGE': 0.5214, 'CIDEr': 2.1532}
None
epoch: 7


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.471146 map: 0.168287 loss: 28.281314, ece:0.133516, sce:0.067352, tace:0.064555, brier:0.782789, uce:0.078967}
Caption Scores : {'BLEU': array([0.5149, 0.4491, 0.3912, 0.3396]), 'METEOR': 0.3133, 'ROUGE': 0.5089, 'CIDEr': 2.0381}
None
epoch: 8


14it [00:59,  4.27s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.462532 map: 0.146341 loss: 28.212362, ece:0.143192, sce:0.069602, tace:0.066611, brier:0.790933, uce:0.099511}
Caption Scores : {'BLEU': array([0.5251, 0.4609, 0.402 , 0.3477]), 'METEOR': 0.3234, 'ROUGE': 0.5193, 'CIDEr': 2.0083}
None
epoch: 9


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.472868 map: 0.161199 loss: 27.583495, ece:0.143687, sce:0.069365, tace:0.066874, brier:0.780167, uce:0.078997}
Caption Scores : {'BLEU': array([0.5206, 0.4534, 0.3979, 0.3498]), 'METEOR': 0.317, 'ROUGE': 0.5163, 'CIDEr': 2.0793}
None
epoch: 10


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.475452 map: 0.151724 loss: 28.013663, ece:0.142745, sce:0.070542, tace:0.067403, brier:0.778992, uce:0.097119}
Caption Scores : {'BLEU': array([0.5027, 0.4315, 0.3674, 0.311 ]), 'METEOR': 0.3103, 'ROUGE': 0.4933, 'CIDEr': 1.7115}
None
epoch: 11


14it [00:59,  4.24s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.466839 map: 0.156411 loss: 27.652854, ece:0.138508, sce:0.068178, tace:0.065408, brier:0.777167, uce:0.091012}
Caption Scores : {'BLEU': array([0.5204, 0.452 , 0.3906, 0.3377]), 'METEOR': 0.3187, 'ROUGE': 0.5075, 'CIDEr': 1.9962}
None
epoch: 12


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.491817 map: 0.166447 loss: 27.279981, ece:0.151121, sce:0.067494, tace:0.065772, brier:0.765175, uce:0.096554}
Caption Scores : {'BLEU': array([0.5121, 0.447 , 0.3892, 0.3381]), 'METEOR': 0.3143, 'ROUGE': 0.5039, 'CIDEr': 2.018}
None
epoch: 13


14it [00:59,  4.24s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.470284 map: 0.154466 loss: 28.512246, ece:0.148659, sce:0.070101, tace:0.067150, brier:0.786733, uce:0.093417}
Caption Scores : {'BLEU': array([0.5174, 0.4477, 0.3861, 0.3313]), 'METEOR': 0.3137, 'ROUGE': 0.5067, 'CIDEr': 1.9971}
None
epoch: 14


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.476314 map: 0.190765 loss: 27.256116, ece:0.135754, sce:0.068289, tace:0.065389, brier:0.768848, uce:0.081579}
Caption Scores : {'BLEU': array([0.506 , 0.4404, 0.3823, 0.3316]), 'METEOR': 0.3122, 'ROUGE': 0.4995, 'CIDEr': 1.8955}
None
epoch: 15


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.475452 map: 0.159101 loss: 27.588386, ece:0.141066, sce:0.066851, tace:0.064583, brier:0.776563, uce:0.076438}
Caption Scores : {'BLEU': array([0.5195, 0.4526, 0.3924, 0.3397]), 'METEOR': 0.3152, 'ROUGE': 0.5069, 'CIDEr': 2.0496}
None
epoch: 16


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.452196 map: 0.155264 loss: 27.601474, ece:0.116856, sce:0.068240, tace:0.065564, brier:0.779863, uce:0.063491}
Caption Scores : {'BLEU': array([0.5073, 0.4414, 0.3817, 0.3287]), 'METEOR': 0.315, 'ROUGE': 0.506, 'CIDEr': 1.9614}
None
epoch: 17


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.475452 map: 0.163259 loss: 27.529664, ece:0.145052, sce:0.067153, tace:0.065420, brier:0.776810, uce:0.089243}
Caption Scores : {'BLEU': array([0.5137, 0.4474, 0.3866, 0.3329]), 'METEOR': 0.3205, 'ROUGE': 0.5083, 'CIDEr': 1.8656}
None
epoch: 18


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.473730 map: 0.156907 loss: 28.231769, ece:0.142588, sce:0.070051, tace:0.066853, brier:0.787002, uce:0.082210}
Caption Scores : {'BLEU': array([0.5142, 0.4458, 0.3859, 0.3326]), 'METEOR': 0.3152, 'ROUGE': 0.5065, 'CIDEr': 2.0385}
None
epoch: 19


14it [00:59,  4.28s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.458226 map: 0.148742 loss: 27.731014, ece:0.140376, sce:0.068049, tace:0.065674, brier:0.788592, uce:0.085365}
Caption Scores : {'BLEU': array([0.502 , 0.435 , 0.3734, 0.3184]), 'METEOR': 0.3094, 'ROUGE': 0.4943, 'CIDEr': 1.7966}
None
epoch: 20


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.462532 map: 0.152076 loss: 27.271131, ece:0.121447, sce:0.067797, tace:0.065598, brier:0.785781, uce:0.080130}
Caption Scores : {'BLEU': array([0.5044, 0.4396, 0.3809, 0.3282]), 'METEOR': 0.3122, 'ROUGE': 0.5004, 'CIDEr': 1.9385}
None
epoch: 21


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.465978 map: 0.160414 loss: 27.100308, ece:0.130701, sce:0.068751, tace:0.065981, brier:0.770628, uce:0.072993}
Caption Scores : {'BLEU': array([0.507 , 0.4394, 0.3784, 0.326 ]), 'METEOR': 0.3145, 'ROUGE': 0.4979, 'CIDEr': 1.8514}
None
epoch: 22


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.456503 map: 0.161552 loss: 27.921888, ece:0.117498, sce:0.069272, tace:0.065695, brier:0.786230, uce:0.061744}
Caption Scores : {'BLEU': array([0.5157, 0.4468, 0.386 , 0.3338]), 'METEOR': 0.3117, 'ROUGE': 0.5069, 'CIDEr': 2.0142}
None
epoch: 23


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.468562 map: 0.154403 loss: 28.079416, ece:0.144365, sce:0.068032, tace:0.065245, brier:0.785676, uce:0.088216}
Caption Scores : {'BLEU': array([0.5121, 0.4445, 0.3853, 0.3325]), 'METEOR': 0.3092, 'ROUGE': 0.4967, 'CIDEr': 1.9822}
None
epoch: 24


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.463394 map: 0.158080 loss: 27.417857, ece:0.148228, sce:0.069805, tace:0.066549, brier:0.783306, uce:0.086567}
Caption Scores : {'BLEU': array([0.5213, 0.4515, 0.3914, 0.3389]), 'METEOR': 0.3187, 'ROUGE': 0.516, 'CIDEr': 2.0527}
None
epoch: 25


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.467700 map: 0.163787 loss: 27.825143, ece:0.139112, sce:0.068592, tace:0.066253, brier:0.777871, uce:0.083195}
Caption Scores : {'BLEU': array([0.5239, 0.4577, 0.3993, 0.3457]), 'METEOR': 0.3206, 'ROUGE': 0.5207, 'CIDEr': 2.1045}
None
epoch: 26


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.478898 map: 0.165075 loss: 27.858911, ece:0.143017, sce:0.068454, tace:0.066404, brier:0.775295, uce:0.082953}
Caption Scores : {'BLEU': array([0.5144, 0.4451, 0.3844, 0.3295]), 'METEOR': 0.3164, 'ROUGE': 0.5094, 'CIDEr': 1.9331}
None
epoch: 27


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.489233 map: 0.172448 loss: 27.565765, ece:0.150574, sce:0.067835, tace:0.065705, brier:0.763062, uce:0.090956}
Caption Scores : {'BLEU': array([0.5209, 0.4555, 0.3965, 0.3454]), 'METEOR': 0.3231, 'ROUGE': 0.517, 'CIDEr': 2.1206}
None
epoch: 28


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.469423 map: 0.167733 loss: 27.774573, ece:0.127136, sce:0.068645, tace:0.066072, brier:0.776900, uce:0.084814}
Caption Scores : {'BLEU': array([0.4966, 0.4283, 0.3659, 0.3111]), 'METEOR': 0.3017, 'ROUGE': 0.4842, 'CIDEr': 1.8447}
None
epoch: 29


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.465978 map: 0.151997 loss: 27.821686, ece:0.141203, sce:0.068608, tace:0.067125, brier:0.786397, uce:0.100181}
Caption Scores : {'BLEU': array([0.5043, 0.4351, 0.3727, 0.3164]), 'METEOR': 0.3048, 'ROUGE': 0.4889, 'CIDEr': 1.8536}
None
epoch: 30


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.478898 map: 0.154363 loss: 27.904565, ece:0.145873, sce:0.069146, tace:0.065938, brier:0.784287, uce:0.085041}
Caption Scores : {'BLEU': array([0.5092, 0.4429, 0.3848, 0.3323]), 'METEOR': 0.311, 'ROUGE': 0.5043, 'CIDEr': 1.9587}
None
epoch: 31


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.472868 map: 0.158060 loss: 27.399487, ece:0.149619, sce:0.067596, tace:0.064839, brier:0.779904, uce:0.084856}
Caption Scores : {'BLEU': array([0.503 , 0.437 , 0.3763, 0.3215]), 'METEOR': 0.3126, 'ROUGE': 0.4979, 'CIDEr': 1.8754}
None
epoch: 32


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.471146 map: 0.159586 loss: 28.325932, ece:0.139327, sce:0.068602, tace:0.065977, brier:0.779724, uce:0.087802}
Caption Scores : {'BLEU': array([0.5228, 0.46  , 0.4053, 0.3569]), 'METEOR': 0.322, 'ROUGE': 0.5215, 'CIDEr': 2.1782}
None
epoch: 33


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.469423 map: 0.159959 loss: 27.832656, ece:0.127030, sce:0.068205, tace:0.065411, brier:0.776953, uce:0.068805}
Caption Scores : {'BLEU': array([0.5235, 0.4532, 0.3916, 0.3376]), 'METEOR': 0.3204, 'ROUGE': 0.5176, 'CIDEr': 2.0291}
None
epoch: 34


14it [00:59,  4.28s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.469423 map: 0.157360 loss: 28.403966, ece:0.129780, sce:0.067985, tace:0.065535, brier:0.780942, uce:0.092050}
Caption Scores : {'BLEU': array([0.5171, 0.4538, 0.397 , 0.3456]), 'METEOR': 0.3206, 'ROUGE': 0.5117, 'CIDEr': 2.1585}
None
epoch: 35


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.459948 map: 0.151637 loss: 28.267342, ece:0.126404, sce:0.067942, tace:0.065214, brier:0.786757, uce:0.069952}
Caption Scores : {'BLEU': array([0.5081, 0.4409, 0.3786, 0.3232]), 'METEOR': 0.3177, 'ROUGE': 0.5045, 'CIDEr': 1.9324}
None
epoch: 36


14it [00:59,  4.27s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.471146 map: 0.159477 loss: 28.211149, ece:0.134148, sce:0.069506, tace:0.067243, brier:0.777640, uce:0.072440}
Caption Scores : {'BLEU': array([0.5226, 0.4538, 0.3919, 0.3374]), 'METEOR': 0.32, 'ROUGE': 0.5157, 'CIDEr': 2.0981}
None
epoch: 37


14it [00:59,  4.27s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.496124 map: 0.157695 loss: 27.847798, ece:0.166336, sce:0.067771, tace:0.064106, brier:0.772193, uce:0.102657}
Caption Scores : {'BLEU': array([0.5126, 0.4466, 0.387 , 0.3349]), 'METEOR': 0.3157, 'ROUGE': 0.5056, 'CIDEr': 1.9631}
None
epoch: 38


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.482343 map: 0.155999 loss: 28.121025, ece:0.141898, sce:0.067984, tace:0.065608, brier:0.780430, uce:0.105720}
Caption Scores : {'BLEU': array([0.5093, 0.4431, 0.3849, 0.3325]), 'METEOR': 0.3094, 'ROUGE': 0.5032, 'CIDEr': 1.9892}
None
epoch: 39


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.471146 map: 0.155736 loss: 27.733265, ece:0.134225, sce:0.067081, tace:0.064141, brier:0.779660, uce:0.089893}
Caption Scores : {'BLEU': array([0.5184, 0.4539, 0.3964, 0.3441]), 'METEOR': 0.3186, 'ROUGE': 0.5135, 'CIDEr': 2.0823}
None
epoch: 40


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.477175 map: 0.185033 loss: 27.558920, ece:0.129589, sce:0.065555, tace:0.063226, brier:0.756354, uce:0.078920}
Caption Scores : {'BLEU': array([0.5109, 0.445 , 0.3872, 0.3374]), 'METEOR': 0.3125, 'ROUGE': 0.5036, 'CIDEr': 1.9937}
None
epoch: 41


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.473730 map: 0.159502 loss: 27.765431, ece:0.125425, sce:0.066035, tace:0.063809, brier:0.768042, uce:0.066754}
Caption Scores : {'BLEU': array([0.5159, 0.4473, 0.384 , 0.3276]), 'METEOR': 0.3175, 'ROUGE': 0.5061, 'CIDEr': 1.9976}
None
epoch: 42


14it [00:59,  4.27s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.481481 map: 0.163676 loss: 27.018346, ece:0.141194, sce:0.065378, tace:0.064190, brier:0.762452, uce:0.087307}
Caption Scores : {'BLEU': array([0.5132, 0.4476, 0.3889, 0.3367]), 'METEOR': 0.3169, 'ROUGE': 0.5089, 'CIDEr': 2.0382}
None
epoch: 43


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.483204 map: 0.151908 loss: 27.663118, ece:0.151118, sce:0.067672, tace:0.065094, brier:0.771428, uce:0.105148}
Caption Scores : {'BLEU': array([0.5273, 0.4652, 0.4076, 0.3569]), 'METEOR': 0.3335, 'ROUGE': 0.5237, 'CIDEr': 2.2141}
None
epoch: 44


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.469423 map: 0.158975 loss: 28.134241, ece:0.145878, sce:0.070368, tace:0.067750, brier:0.788067, uce:0.091256}
Caption Scores : {'BLEU': array([0.5177, 0.4538, 0.3975, 0.3486]), 'METEOR': 0.3226, 'ROUGE': 0.5177, 'CIDEr': 2.1107}
None
epoch: 45


14it [00:59,  4.26s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.459087 map: 0.154592 loss: 27.876214, ece:0.131478, sce:0.068101, tace:0.066099, brier:0.786158, uce:0.071481}
Caption Scores : {'BLEU': array([0.4941, 0.4272, 0.3649, 0.3086]), 'METEOR': 0.3065, 'ROUGE': 0.4835, 'CIDEr': 1.8695}
None
epoch: 46


14it [00:59,  4.27s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.461671 map: 0.155010 loss: 27.869080, ece:0.137422, sce:0.070225, tace:0.066929, brier:0.784054, uce:0.083245}
Caption Scores : {'BLEU': array([0.5191, 0.4559, 0.3996, 0.3485]), 'METEOR': 0.3196, 'ROUGE': 0.507, 'CIDEr': 2.0924}
None
epoch: 47


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.485788 map: 0.162728 loss: 27.963636, ece:0.164957, sce:0.068862, tace:0.066281, brier:0.780103, uce:0.107844}
Caption Scores : {'BLEU': array([0.4982, 0.4332, 0.3753, 0.3236]), 'METEOR': 0.3096, 'ROUGE': 0.4902, 'CIDEr': 1.8558}
None
epoch: 48


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.461671 map: 0.175541 loss: 27.639905, ece:0.127912, sce:0.067416, tace:0.065905, brier:0.777575, uce:0.077298}
Caption Scores : {'BLEU': array([0.513 , 0.4471, 0.3873, 0.3338]), 'METEOR': 0.3121, 'ROUGE': 0.5029, 'CIDEr': 1.9809}
None
epoch: 49


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.470284 map: 0.158603 loss: 27.268298, ece:0.127343, sce:0.065265, tace:0.062856, brier:0.773065, uce:0.062987}
Caption Scores : {'BLEU': array([0.5182, 0.4533, 0.3961, 0.3458]), 'METEOR': 0.3169, 'ROUGE': 0.5052, 'CIDEr': 2.132}
None
epoch: 50


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.489233 map: 0.166341 loss: 27.996416, ece:0.143653, sce:0.066586, tace:0.063726, brier:0.760362, uce:0.093684}
Caption Scores : {'BLEU': array([0.5205, 0.4572, 0.4003, 0.3498]), 'METEOR': 0.3215, 'ROUGE': 0.5133, 'CIDEr': 2.0858}
None
epoch: 51


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.470284 map: 0.162520 loss: 27.844714, ece:0.136403, sce:0.069170, tace:0.066066, brier:0.779631, uce:0.079813}
Caption Scores : {'BLEU': array([0.5172, 0.4529, 0.3953, 0.3439]), 'METEOR': 0.3159, 'ROUGE': 0.513, 'CIDEr': 2.1308}
None
epoch: 52


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.459087 map: 0.153322 loss: 27.697726, ece:0.123057, sce:0.066723, tace:0.064771, brier:0.783026, uce:0.079929}
Caption Scores : {'BLEU': array([0.5082, 0.4445, 0.387 , 0.3341]), 'METEOR': 0.3176, 'ROUGE': 0.5106, 'CIDEr': 2.0184}
None
epoch: 53


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.471146 map: 0.156090 loss: 27.563665, ece:0.141593, sce:0.067924, tace:0.065911, brier:0.785800, uce:0.094589}
Caption Scores : {'BLEU': array([0.5128, 0.4469, 0.386 , 0.3315]), 'METEOR': 0.3163, 'ROUGE': 0.5023, 'CIDEr': 2.0492}
None
epoch: 54


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.481481 map: 0.161987 loss: 27.323276, ece:0.154918, sce:0.069441, tace:0.067123, brier:0.778433, uce:0.094178}
Caption Scores : {'BLEU': array([0.5211, 0.453 , 0.393 , 0.3392]), 'METEOR': 0.3126, 'ROUGE': 0.5115, 'CIDEr': 2.0094}
None
epoch: 55


14it [01:00,  4.33s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.477175 map: 0.160058 loss: 28.361354, ece:0.134537, sce:0.068351, tace:0.065454, brier:0.776106, uce:0.079050}
Caption Scores : {'BLEU': array([0.5155, 0.4482, 0.3885, 0.3358]), 'METEOR': 0.3084, 'ROUGE': 0.5049, 'CIDEr': 2.0288}
None
epoch: 56


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.495263 map: 0.156385 loss: 28.014997, ece:0.155513, sce:0.067208, tace:0.064177, brier:0.768530, uce:0.100979}
Caption Scores : {'BLEU': array([0.5221, 0.4573, 0.3979, 0.3444]), 'METEOR': 0.3206, 'ROUGE': 0.5131, 'CIDEr': 2.0169}
None
epoch: 57


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.474591 map: 0.151325 loss: 28.176413, ece:0.163698, sce:0.069989, tace:0.067121, brier:0.792102, uce:0.108638}
Caption Scores : {'BLEU': array([0.5221, 0.4558, 0.3967, 0.3436]), 'METEOR': 0.3148, 'ROUGE': 0.5086, 'CIDEr': 2.0649}
None
epoch: 58


14it [01:00,  4.32s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.489233 map: 0.170092 loss: 27.936874, ece:0.171471, sce:0.067677, tace:0.063895, brier:0.779477, uce:0.108870}
Caption Scores : {'BLEU': array([0.503 , 0.4365, 0.3746, 0.319 ]), 'METEOR': 0.3091, 'ROUGE': 0.4905, 'CIDEr': 1.9325}
None
epoch: 59


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.488372 map: 0.169853 loss: 27.338727, ece:0.155803, sce:0.067563, tace:0.065148, brier:0.774240, uce:0.099865}
Caption Scores : {'BLEU': array([0.5058, 0.4388, 0.3781, 0.3242]), 'METEOR': 0.3125, 'ROUGE': 0.4964, 'CIDEr': 1.9763}
None
epoch: 60


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.487511 map: 0.158295 loss: 27.899434, ece:0.151296, sce:0.067423, tace:0.063979, brier:0.769494, uce:0.092537}
Caption Scores : {'BLEU': array([0.5081, 0.4407, 0.3787, 0.3234]), 'METEOR': 0.3141, 'ROUGE': 0.499, 'CIDEr': 1.9467}
None
epoch: 61


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.477175 map: 0.153995 loss: 28.097582, ece:0.139884, sce:0.067413, tace:0.065432, brier:0.777727, uce:0.096670}
Caption Scores : {'BLEU': array([0.5124, 0.4455, 0.3864, 0.3349]), 'METEOR': 0.3082, 'ROUGE': 0.4972, 'CIDEr': 2.0002}
None
epoch: 62


14it [01:00,  4.32s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.472007 map: 0.147833 loss: 28.484780, ece:0.150564, sce:0.067518, tace:0.064649, brier:0.789408, uce:0.084972}
Caption Scores : {'BLEU': array([0.5021, 0.4361, 0.3744, 0.3188]), 'METEOR': 0.3133, 'ROUGE': 0.4982, 'CIDEr': 1.9799}
None
epoch: 63


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.478898 map: 0.166362 loss: 28.490440, ece:0.155638, sce:0.068196, tace:0.064820, brier:0.777653, uce:0.088261}
Caption Scores : {'BLEU': array([0.5118, 0.4465, 0.3865, 0.3329]), 'METEOR': 0.3142, 'ROUGE': 0.5009, 'CIDEr': 1.9432}
None
epoch: 64


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.474591 map: 0.150165 loss: 28.178695, ece:0.145962, sce:0.068837, tace:0.065052, brier:0.788675, uce:0.110311}
Caption Scores : {'BLEU': array([0.521 , 0.4565, 0.3991, 0.3492]), 'METEOR': 0.3218, 'ROUGE': 0.5194, 'CIDEr': 2.1829}
None
epoch: 65


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.475452 map: 0.168358 loss: 27.785778, ece:0.131210, sce:0.067812, tace:0.065200, brier:0.780784, uce:0.090936}
Caption Scores : {'BLEU': array([0.5118, 0.4454, 0.3848, 0.3306]), 'METEOR': 0.3102, 'ROUGE': 0.4994, 'CIDEr': 1.967}
None
epoch: 66


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.482343 map: 0.154493 loss: 27.826824, ece:0.138064, sce:0.065585, tace:0.063046, brier:0.772112, uce:0.096503}
Caption Scores : {'BLEU': array([0.5138, 0.4447, 0.3849, 0.3327]), 'METEOR': 0.3142, 'ROUGE': 0.5065, 'CIDEr': 2.0797}
None
epoch: 67


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.463394 map: 0.157089 loss: 27.637297, ece:0.111330, sce:0.063544, tace:0.062200, brier:0.766050, uce:0.057147}
Caption Scores : {'BLEU': array([0.5031, 0.4381, 0.3768, 0.3211]), 'METEOR': 0.311, 'ROUGE': 0.4942, 'CIDEr': 1.9468}
None
epoch: 68


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.475452 map: 0.158506 loss: 28.136286, ece:0.127625, sce:0.066466, tace:0.063386, brier:0.763274, uce:0.074516}
Caption Scores : {'BLEU': array([0.5057, 0.4395, 0.3801, 0.3272]), 'METEOR': 0.308, 'ROUGE': 0.4927, 'CIDEr': 2.0027}
None
epoch: 69


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.465116 map: 0.155868 loss: 27.868735, ece:0.132254, sce:0.068173, tace:0.065006, brier:0.778994, uce:0.069285}
Caption Scores : {'BLEU': array([0.5111, 0.4472, 0.3885, 0.3354]), 'METEOR': 0.3154, 'ROUGE': 0.5034, 'CIDEr': 2.0122}
None
epoch: 70


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.478898 map: 0.159523 loss: 28.077375, ece:0.156408, sce:0.069133, tace:0.065817, brier:0.781803, uce:0.099647}
Caption Scores : {'BLEU': array([0.5033, 0.4397, 0.3807, 0.3272]), 'METEOR': 0.3138, 'ROUGE': 0.4968, 'CIDEr': 1.9627}
None
epoch: 71


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.458226 map: 0.153567 loss: 28.805834, ece:0.126042, sce:0.068318, tace:0.064661, brier:0.785719, uce:0.074273}
Caption Scores : {'BLEU': array([0.507 , 0.4425, 0.384 , 0.3318]), 'METEOR': 0.3091, 'ROUGE': 0.4962, 'CIDEr': 2.006}
None
epoch: 72


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.466839 map: 0.161121 loss: 28.510169, ece:0.129507, sce:0.067461, tace:0.064490, brier:0.777572, uce:0.069155}
Caption Scores : {'BLEU': array([0.4948, 0.4294, 0.3695, 0.3155]), 'METEOR': 0.3103, 'ROUGE': 0.4932, 'CIDEr': 1.8861}
None
epoch: 73


14it [01:00,  4.31s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.474591 map: 0.155047 loss: 28.451917, ece:0.150238, sce:0.068242, tace:0.064616, brier:0.782539, uce:0.093429}
Caption Scores : {'BLEU': array([0.5082, 0.4437, 0.3844, 0.332 ]), 'METEOR': 0.3152, 'ROUGE': 0.4969, 'CIDEr': 2.023}
None
epoch: 74


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.477175 map: 0.160296 loss: 27.936406, ece:0.140626, sce:0.065677, tace:0.064117, brier:0.775043, uce:0.094794}
Caption Scores : {'BLEU': array([0.4966, 0.4319, 0.3694, 0.3116]), 'METEOR': 0.3104, 'ROUGE': 0.489, 'CIDEr': 1.8135}
None
epoch: 75


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.491817 map: 0.151867 loss: 27.988405, ece:0.155822, sce:0.067728, tace:0.065053, brier:0.776583, uce:0.102775}
Caption Scores : {'BLEU': array([0.5142, 0.4487, 0.3894, 0.3357]), 'METEOR': 0.3163, 'ROUGE': 0.5095, 'CIDEr': 2.0952}
None
epoch: 76


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.485788 map: 0.154021 loss: 27.797798, ece:0.155834, sce:0.066735, tace:0.063444, brier:0.771924, uce:0.088335}
Caption Scores : {'BLEU': array([0.4981, 0.4298, 0.3672, 0.311 ]), 'METEOR': 0.3126, 'ROUGE': 0.4858, 'CIDEr': 1.8311}
None
epoch: 77


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.470284 map: 0.162861 loss: 28.145551, ece:0.120136, sce:0.067383, tace:0.064658, brier:0.770171, uce:0.063663}
Caption Scores : {'BLEU': array([0.5007, 0.4378, 0.379 , 0.3257]), 'METEOR': 0.3139, 'ROUGE': 0.4981, 'CIDEr': 1.9862}
None
epoch: 78


14it [00:59,  4.28s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.479759 map: 0.158037 loss: 28.206415, ece:0.132474, sce:0.066679, tace:0.063631, brier:0.771283, uce:0.083468}
Caption Scores : {'BLEU': array([0.5121, 0.4458, 0.3871, 0.3344]), 'METEOR': 0.3136, 'ROUGE': 0.5029, 'CIDEr': 1.9662}
None
epoch: 79


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.484927 map: 0.157325 loss: 27.562903, ece:0.140461, sce:0.066069, tace:0.063089, brier:0.770654, uce:0.087034}
Caption Scores : {'BLEU': array([0.511, 0.446, 0.386, 0.331]), 'METEOR': 0.3163, 'ROUGE': 0.5074, 'CIDEr': 1.9955}
None
epoch: 80


14it [00:59,  4.28s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.475452 map: 0.154522 loss: 28.339488, ece:0.143809, sce:0.066618, tace:0.064491, brier:0.783127, uce:0.083567}
Caption Scores : {'BLEU': array([0.5033, 0.4404, 0.3816, 0.3279]), 'METEOR': 0.3189, 'ROUGE': 0.4978, 'CIDEr': 1.9528}
None
epoch: 81


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.496985 map: 0.157761 loss: 28.901511, ece:0.156375, sce:0.068272, tace:0.064809, brier:0.780714, uce:0.106154}
Caption Scores : {'BLEU': array([0.5021, 0.4343, 0.3735, 0.3179]), 'METEOR': 0.3036, 'ROUGE': 0.4851, 'CIDEr': 1.8532}
None
epoch: 82


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.461671 map: 0.151970 loss: 28.784708, ece:0.125442, sce:0.067355, tace:0.064184, brier:0.782482, uce:0.071811}
Caption Scores : {'BLEU': array([0.5105, 0.4436, 0.3828, 0.3278]), 'METEOR': 0.3145, 'ROUGE': 0.5006, 'CIDEr': 1.9624}
None
epoch: 83


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.472868 map: 0.191945 loss: 28.016662, ece:0.118019, sce:0.066828, tace:0.064953, brier:0.765007, uce:0.064627}
Caption Scores : {'BLEU': array([0.5063, 0.4433, 0.3836, 0.3294]), 'METEOR': 0.3161, 'ROUGE': 0.4983, 'CIDEr': 1.9628}
None
epoch: 84


14it [01:00,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.484927 map: 0.170650 loss: 28.170049, ece:0.158450, sce:0.068224, tace:0.065901, brier:0.781936, uce:0.113052}
Caption Scores : {'BLEU': array([0.5068, 0.4434, 0.3839, 0.3312]), 'METEOR': 0.3187, 'ROUGE': 0.5051, 'CIDEr': 2.02}
None
epoch: 85


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.457364 map: 0.151540 loss: 28.379122, ece:0.126135, sce:0.066900, tace:0.064400, brier:0.788025, uce:0.085710}
Caption Scores : {'BLEU': array([0.5139, 0.449 , 0.3885, 0.3338]), 'METEOR': 0.3201, 'ROUGE': 0.5083, 'CIDEr': 2.1431}
None
epoch: 86


14it [00:59,  4.29s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.459948 map: 0.157518 loss: 28.700226, ece:0.119106, sce:0.069019, tace:0.065574, brier:0.778824, uce:0.069224}
Caption Scores : {'BLEU': array([0.5119, 0.4464, 0.3888, 0.3367]), 'METEOR': 0.3105, 'ROUGE': 0.5064, 'CIDEr': 2.0978}
None
epoch: 87


14it [01:00,  4.30s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.466839 map: 0.159043 loss: 27.808630, ece:0.122892, sce:0.066923, tace:0.064663, brier:0.767688, uce:0.064948}
Caption Scores : {'BLEU': array([0.5062, 0.4376, 0.3752, 0.3183]), 'METEOR': 0.3108, 'ROUGE': 0.4966, 'CIDEr': 1.9315}
None
epoch: 88


14it [00:59,  4.28s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.475452 map: 0.153139 loss: 28.245610, ece:0.144356, sce:0.069103, tace:0.066553, brier:0.785388, uce:0.088291}
Caption Scores : {'BLEU': array([0.5009, 0.4357, 0.3736, 0.3173]), 'METEOR': 0.3102, 'ROUGE': 0.4924, 'CIDEr': 1.9764}
None
epoch: 89


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.478898 map: 0.158707 loss: 28.157845, ece:0.153627, sce:0.069240, tace:0.065947, brier:0.782651, uce:0.094725}
Caption Scores : {'BLEU': array([0.5078, 0.4384, 0.3734, 0.3156]), 'METEOR': 0.3065, 'ROUGE': 0.491, 'CIDEr': 1.9692}
None
epoch: 90


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.484927 map: 0.156078 loss: 28.347409, ece:0.139187, sce:0.066211, tace:0.063879, brier:0.775394, uce:0.078035}
Caption Scores : {'BLEU': array([0.5049, 0.439 , 0.3782, 0.3239]), 'METEOR': 0.3152, 'ROUGE': 0.5003, 'CIDEr': 1.9649}
None
epoch: 91


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.476314 map: 0.159315 loss: 28.265827, ece:0.134997, sce:0.069043, tace:0.066700, brier:0.781043, uce:0.067522}
Caption Scores : {'BLEU': array([0.5076, 0.4453, 0.3852, 0.3301]), 'METEOR': 0.324, 'ROUGE': 0.5018, 'CIDEr': 2.04}
None
epoch: 92


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.472007 map: 0.156505 loss: 27.877077, ece:0.128464, sce:0.067540, tace:0.064923, brier:0.780857, uce:0.078057}
Caption Scores : {'BLEU': array([0.5015, 0.4347, 0.3721, 0.3157]), 'METEOR': 0.3135, 'ROUGE': 0.4939, 'CIDEr': 1.8581}
None
epoch: 93


14it [00:59,  4.24s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.465116 map: 0.152796 loss: 28.248096, ece:0.138018, sce:0.067335, tace:0.064401, brier:0.785753, uce:0.076357}
Caption Scores : {'BLEU': array([0.511 , 0.4469, 0.3864, 0.3308]), 'METEOR': 0.3125, 'ROUGE': 0.5024, 'CIDEr': 2.0567}
None
epoch: 94


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.462532 map: 0.154121 loss: 28.057335, ece:0.125934, sce:0.068354, tace:0.066356, brier:0.786799, uce:0.067982}
Caption Scores : {'BLEU': array([0.496 , 0.432 , 0.3724, 0.3168]), 'METEOR': 0.3114, 'ROUGE': 0.4929, 'CIDEr': 1.9651}
None
epoch: 95


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.467700 map: 0.158923 loss: 28.100692, ece:0.128543, sce:0.065672, tace:0.063161, brier:0.776586, uce:0.068390}
Caption Scores : {'BLEU': array([0.4905, 0.4275, 0.3688, 0.316 ]), 'METEOR': 0.3097, 'ROUGE': 0.4839, 'CIDEr': 1.8464}
None
epoch: 96


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.477175 map: 0.150910 loss: 28.118444, ece:0.152182, sce:0.067190, tace:0.064514, brier:0.777853, uce:0.101698}
Caption Scores : {'BLEU': array([0.5108, 0.4454, 0.3845, 0.3297]), 'METEOR': 0.3155, 'ROUGE': 0.501, 'CIDEr': 2.0599}
None
epoch: 97


14it [00:59,  4.25s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.471146 map: 0.148322 loss: 29.549693, ece:0.140275, sce:0.069306, tace:0.066734, brier:0.795721, uce:0.097232}
Caption Scores : {'BLEU': array([0.5075, 0.44  , 0.3782, 0.3229]), 'METEOR': 0.3133, 'ROUGE': 0.4969, 'CIDEr': 1.9659}
None
epoch: 98


14it [00:59,  4.27s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.475452 map: 0.155219 loss: 28.790698, ece:0.131780, sce:0.067217, tace:0.064464, brier:0.781218, uce:0.080023}
Caption Scores : {'BLEU': array([0.506 , 0.4399, 0.3798, 0.3262]), 'METEOR': 0.3094, 'ROUGE': 0.497, 'CIDEr': 2.0055}
None
epoch: 99


14it [00:59,  4.24s/it]
0it [00:00, ?it/s]

Graph : {acc: 0.488372 map: 0.158030 loss: 28.781845, ece:0.145541, sce:0.066064, tace:0.063622, brier:0.773853, uce:0.094215}
Caption Scores : {'BLEU': array([0.5081, 0.4429, 0.3835, 0.3296]), 'METEOR': 0.3122, 'ROUGE': 0.4957, 'CIDEr': 2.088}
None
epoch: 100


14it [00:59,  4.26s/it]


Graph : {acc: 0.470284 map: 0.154727 loss: 28.153866, ece:0.127032, sce:0.066442, tace:0.064375, brier:0.776224, uce:0.072268}
Caption Scores : {'BLEU': array([0.5002, 0.4343, 0.3742, 0.3201]), 'METEOR': 0.3073, 'ROUGE': 0.4888, 'CIDEr': 1.8817}
None


In [17]:
mtl_checkpoint = 'checkpoints/mtl_train/UDA/checkpoint_'

pretrained_model = torch.load(mtl_checkpoint+str(1)+'_epoch.pth')
model.load_state_dict(pretrained_model['state_dict'])
print('epoch:',1)
eval_mtl(1, model, dict_dataloader_val, text_field)
#eval_mtl(1, model, train_dataloader, text_field)

0it [00:00, ?it/s]

epoch: 1


1it [00:00,  2.08it/s]

torch.Size([5, 13])


2it [00:00,  2.12it/s]

torch.Size([6, 13])


3it [00:01,  2.15it/s]

torch.Size([6, 13])


4it [00:01,  2.19it/s]

torch.Size([6, 13])


5it [00:02,  2.16it/s]

torch.Size([5, 13])


6it [00:02,  2.20it/s]

torch.Size([6, 13])


7it [00:03,  2.22it/s]

torch.Size([7, 13])


8it [00:03,  2.24it/s]

torch.Size([6, 13])


8it [00:03,  2.02it/s]


KeyboardInterrupt: 

# Need to work on evaluation and TS/CDA-TS

In [None]:
''' 1.0 mutlitask model base evaluation ==========================================================================='''
scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model, device, dict_dataloader_val, text_field)
print('Initial Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("Initial Caption scores :", scores)

# Evaluation Matrix : Graph (Loss, Acc, ECE), Caption (Brier, Cider)

In [6]:
def evaluate_metrics(model, device, dataloader, text_field, g_temp = 1.5, c_temp = None):
    import itertools
    
    model.caption.decoder.caption_ts = c_temp
    
    model.eval()
    gen = {}
    gts = {}

    # graph
    # criterion and scheduler
    g_criterion = nn.MultiLabelSoftMarginLoss()                   
    g_edge_count = 0
    g_total_acc = 0.0
    g_total_loss = 0.0
    g_logits_list = []
    g_labels_list = []

    #print(model.caption.beam_search)
    
    with tqdm(desc='evaluation', unit='it', total=len(dataloader)) as pbar:
        for it, (images, caps_gt, graph_data) in enumerate(iter(dataloader)):
            
            # graph
            img_name = graph_data['img_name']
            img_loc = graph_data['img_loc']
            node_num = graph_data['node_num']
            roi_labels = graph_data['roi_labels']
            det_boxes = graph_data['det_boxes']
            edge_labels = graph_data['edge_labels']
            edge_num = graph_data['edge_num']
            features = graph_data['features']
            spatial_feat = graph_data['spatial_feat']
            word2vec = graph_data['word2vec']
            features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
        
            # caption
            images = images.to(device)

            with torch.no_grad():
                # caption
                caption_out, _ = model.caption.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
                #print(caption_out)
                # graph
                g_output = model.graph_su(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
                g_output = g_output/g_temp
                g_logits_list.append(g_output)
                g_labels_list.append(edge_labels)
                # loss and accuracy
                g_loss = g_criterion(g_output, edge_labels.float())
                g_acc = np.sum(np.equal(np.argmax(g_output.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
            
            # accumulate loss and accuracy of the batch
            g_total_loss += g_loss.item() * edge_labels.shape[0]
            g_total_acc  += g_acc
            g_edge_count += edge_labels.shape[0]
            
            caps_gen = text_field.decode(caption_out, join_words=False)
            for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
                gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen['%d_%d' % (it, i)] = [gen_i, ]    
                gts['%d_%d' % (it, i)] = gts_i
            pbar.update()
    
    #graph loss
    g_logits_all = torch.cat(g_logits_list).cuda()
    g_labels_all = torch.cat(g_labels_list).cuda()
    g_total_acc = g_total_acc / g_edge_count
    g_total_loss = g_total_loss / len(dataloader)

    g_logits_all = F.softmax(g_logits_all, dim=1)
    g_map_value, g_ece, g_sce, g_tace, g_brier, g_uce = calibration_metrics(g_logits_all, g_labels_all, 'test')
    print('acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce.item()) )


    if not os.path.exists('results/c_results/predict_caption'):
        os.makedirs('results/c_results/predict_caption')
    json.dump(gen, open('results/c_results/predict_caption/predict_caption_val.json', 'w'))

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)
    scores, _ = evaluation.compute_scores(gts, gen)
    return scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce.item()

# Temperature Scaling : Model Object

In [8]:
class ModelWithTemperature(nn.Module):
    '''
    Temperature scaling model for model
    '''
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.use_ts = False
        self.graph_su_temperature = nn.Parameter(torch.ones(1) * 1.5)
        self.caption_temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, detections, captions, batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec):
        
        # if self.use_ts:
            # g_su_temp = self.graph_su_temperature.unsqueeze(1).expand(interaction.size(0), interaction.size(1))
            # self.model.caption.decoder.caption_ts = self.caption_temperature
        
        caption_output, interaction = self.model(detections, captions, batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec)
        
        # if self.use_ts:
            # interaction = interaction / g_su_temp

        return caption_output, interaction

    def graph_su_set_temperature(self, valid_loader):
        '''
        Finding optimal temperature scale for graph scene understanding task
        '''
        self.cuda()
        g_logits_list = []
        g_labels_list = []
        g_criterion = nn.MultiLabelSoftMarginLoss()
        
        with torch.no_grad():
            for it, (images, caps_gt, graph_data) in enumerate(iter(valid_loader)):
                # graph
                img_name = graph_data['img_name']
                img_loc = graph_data['img_loc']
                node_num = graph_data['node_num']
                roi_labels = graph_data['roi_labels']
                det_boxes = graph_data['det_boxes']
                edge_labels = graph_data['edge_labels']
                edge_num = graph_data['edge_num']
                features = graph_data['features']
                spatial_feat = graph_data['spatial_feat']
                word2vec = graph_data['word2vec']
                features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
                
                #caption_out = self.model.caption(images, caps_gt)
                g_output = self.model.graph_su(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
                    
                g_logits_list.append(g_output)
                g_labels_list.append(edge_labels)
                
            g_logits = torch.cat(g_logits_list).cuda()
            g_labels = torch.cat(g_labels_list).cuda()

        #init_temp = self.graph_su_temperature.clone()
        optimizer = optim.LBFGS([self.graph_su_temperature], lr=0.01, max_iter=50)

        def eval():
            g_su_temp = self.graph_su_temperature.unsqueeze(1).expand(g_logits.size(0), g_logits.size(1))
            g_logit_out = F.softmax(g_logits/g_su_temp, dim=1)
            g_loss = g_criterion(g_logit_out, g_labels.float())
            g_loss.backward()
            return g_loss
        
        optimizer.step(eval)
        return

    def caption_set_temperature(self, valid_loader):
        '''
        Finding optimal temperature scale for caption task
        '''
        self.cuda()
        c_logits_list = None
        c_labels_list = None
        
        with torch.no_grad():
            for it, (images, caps_gt) in enumerate(iter(valid_loader)):    
                images, caps_gt = images.to(device), caps_gt.to(device)
                caption_out = self.model.caption(images, caps_gt)

                if c_logits_list is not None:
                    c_logits_list = torch.cat([c_logits_list, caption_out], 1)
                    c_labels_list = torch.cat([c_labels_list, caps_gt],1)
                else:
                    c_logits_list = caption_out
                    c_labels_list = caps_gt

            c_logits = c_logits_list.cuda()
            c_labels = c_labels_list.cuda()
        
        init_temp = self.caption_temperature.clone()
        optimizer = optim.LBFGS([self.caption_temperature], lr=0.01, max_iter=50)

        def eval():
            caption_temp = self.caption_temperature.unsqueeze(1).expand(c_logits.size(1), c_logits.size(0))
            c_criterion = CELossWithLS(classes=len(text_field.vocab), smoothing=0.1, gamma=0.0, isCos=False, ignore_index=text_field.vocab.stoi['<pad>'])
            c_base = c_logits/caption_temp
            c_loss = c_criterion(c_base[:, :-1].contiguous(), c_labels[:, 1:].contiguous())
            c_loss.backward()
            return c_loss
        
        optimizer.step(eval)
        return

# Temperature Scaling : ECE Loss

In [9]:
class _ECELoss(nn.Module):
    '''
    Expected Calibration Error for Calibration
    '''
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        if labels.size(1) == 13:
            gt_labels = torch.argmax(labels, dim=1)
        else: 
            logits = logits.squeeze()
            gt_labels = labels.squeeze()

        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(gt_labels)
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        return ece

# Temperature Scaling : Evaluation

In [10]:
def ts_evaluation(model_ts, g_c_val_dataloader, c_val_dataloader, g_temp=1.5, c_temp=1):
    '''
    '''
    model.eval()

    g_acc = 0.0
    c_acc = 0.0
    g_loss = 0.0
    c_loss = 0.0
    g_logits = []
    g_labels = []
    c_logits = []
    c_labels = []
    
    g_logits_list = []
    g_labels_list = []
    c_logits_list = None
    c_labels_list = None

    with torch.no_grad():
        for it, (images, caps_gt, graph_data) in enumerate(iter(g_c_val_dataloader)):
            # graph
            img_name = graph_data['img_name']
            img_loc = graph_data['img_loc']
            node_num = graph_data['node_num']
            roi_labels = graph_data['roi_labels']
            det_boxes = graph_data['det_boxes']
            edge_labels = graph_data['edge_labels']
            edge_num = graph_data['edge_num']
            features = graph_data['features']
            spatial_feat = graph_data['spatial_feat']
            word2vec = graph_data['word2vec']
            features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
        
            g_output = model_ts.model.graph_su(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
            
            g_output = g_output / g_temp
            g_logits_list.append(g_output)
            g_labels_list.append(edge_labels)
                
    g_logits = torch.cat(g_logits_list).cuda()
    g_labels = torch.cat(g_labels_list).cuda()

    g_logits = F.softmax(g_logits, dim=1)
    g_criterion = nn.MultiLabelSoftMarginLoss()
    g_loss = g_criterion(g_logits, g_labels.float())
    g_acc = np.sum(np.equal(np.argmax(g_logits.cpu().data.numpy(), axis=-1), np.argmax(g_labels.cpu().data.numpy(), axis=-1))) / g_labels.size(0)

    with torch.no_grad():
        for it, (images, caps_gt) in enumerate(iter(c_val_dataloader)):    
            images, caps_gt = images.to(device), caps_gt.to(device)
            caption_out = model_ts.model.caption(images, caps_gt)
            caption_out = caption_out/c_temp
            if c_logits_list is not None:
                c_logits_list = torch.cat([c_logits_list, caption_out], 1)
                c_labels_list = torch.cat([c_labels_list, caps_gt],1)
            else:
                c_logits_list = caption_out
                c_labels_list = caps_gt

    c_logits = c_logits_list.cuda()
    c_labels = c_labels_list.cuda()
    
    c_criterion = CELossWithLS(classes=len(text_field.vocab), smoothing=0.1, gamma=0.0, isCos=False, ignore_index=text_field.vocab.stoi['<pad>'])
    c_loss = c_criterion(c_logits[:, :-1].contiguous(), c_labels[:, 1:].contiguous())
    c_acc = np.sum(np.equal(np.argmax(c_logits.cpu().data.numpy(), axis=-1), c_labels.cpu().data.numpy())) / c_labels.size(1)
    
    return (g_loss.item(), c_loss.item(), g_acc, c_acc, g_logits, g_labels, c_logits, c_labels)

# 2.0: Temperature Scaling

In [11]:
model_ts = ModelWithTemperature(model)
ece_criterion = _ECELoss().to(device)
print('Initial Graph SU Temperature:%.4f'%model_ts.graph_su_temperature.item())
print('Initial Caption Temperature:%0.4f'%model_ts.caption_temperature.item())

Initial Graph SU Temperature:1.5000
Initial Caption Temperature:1.5000


# 2.1: Temperature Scaling : Find Optimal value

In [12]:
model_ts.graph_su_set_temperature(dict_dataloader_val)
model_ts.caption_set_temperature(dataloader_val)
print('-----------------------------------------------------------------------')
print('Optimal Graph SU Temperature:%.4f'%model_ts.graph_su_temperature.item())
print('Optimal Caption Temperature:%0.4f'%model_ts.caption_temperature.item())

-----------------------------------------------------------------------
Optimal Graph SU Temperature:1.3954
Optimal Caption Temperature:4.8063


# 2.2 Temperature Scaling : Without TS

In [13]:
g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val)
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('Before TS: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('Before TS: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field)
print('Before TS: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("Before TS: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
Before TS: G_loss:0.710, G_acc:0.571, G_ece:0.46277
Before TS: C_loss:2.724, C_acc:0.022, C_ece:0.84285


evaluation: 100%|██████████| 447/447 [01:56<00:00,  3.85it/s]


acc: 0.571059 map: 0.321686 loss: 0.509024, ece:0.219497, sce:0.047773, tace:0.049699, brier:0.649657, uce:0.283210
Before TS: Graph SU: acc: 0.571059 map: 0.321686 loss: 0.509024, ece:0.219497, sce:0.047773, tace:0.049699, brier:0.649657, uce:0.283210
Before TS: Caption scores : {'BLEU': array([0.5498, 0.4714, 0.4238, 0.3801]), 'METEOR': 0.2861, 'ROUGE': 0.57, 'CIDEr': 2.7487}


# 2.3 Temperature Scaling : With TS

In [14]:
g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val, g_temp=model_ts.graph_su_temperature.item(), c_temp=model_ts.caption_temperature.item())
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('Optimal TS: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('Optimal TS: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field, g_temp = model_ts.graph_su_temperature.item(), c_temp = model_ts.caption_temperature.item())
print('Optimal TS: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("Optimal TS: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
Optimal TS: G_loss:0.709, G_acc:0.571, G_ece:0.46025
Optimal TS: C_loss:1.894, C_acc:0.022, C_ece:0.29004


evaluation: 100%|██████████| 447/447 [01:57<00:00,  3.80it/s]


acc: 0.571059 map: 0.315094 loss: 0.487805, ece:0.218120, sce:0.045566, tace:0.046718, brier:0.644169, uce:0.260813
Optimal TS: Graph SU: acc: 0.571059 map: 0.315094 loss: 0.487805, ece:0.218120, sce:0.045566, tace:0.046718, brier:0.644169, uce:0.260813
Optimal TS: Caption scores : {'BLEU': array([0.3623, 0.3231, 0.3032, 0.2864]), 'METEOR': 0.2151, 'ROUGE': 0.4628, 'CIDEr': 2.3378}


# 2.4 Temperature Scaling : With CDA-TS

In [15]:
g_cls_freq = torch.zeros(13)
for i in torch.argmax(g_labels, dim=1): g_cls_freq[i] += 1
g_cls_freq_norm = g_cls_freq/torch.max(g_cls_freq)
g_temp = model_ts.graph_su_temperature.item() + g_cls_freq_norm*0.1
#g_temp = g_temp.to(device)
model_ts.graph_su_temperature = nn.Parameter(g_temp.to(device))

c_cls_freq = torch.zeros(41)
for i in c_labels.squeeze(): c_cls_freq[i] += 1
c_cls_freq_norm = c_cls_freq/torch.max(c_cls_freq)
c_temp = model_ts.caption_temperature.item() + c_cls_freq_norm*0.1
#c_temp = c_temp.to(device)
model_ts.caption_temperature = nn.Parameter(c_temp.to(device))

g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('CDA-TS: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('CDA-TS: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
print('CDA-TS: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("CDA-TS: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
CDA-TS: G_loss:0.709, G_acc:0.574, G_ece:0.46348
CDA-TS: C_loss:3.262, C_acc:0.022, C_ece:0.02099


evaluation: 100%|██████████| 447/447 [01:56<00:00,  3.85it/s]


acc: 0.573643 map: 0.319211 loss: 0.489725, ece:0.217880, sce:0.045660, tace:0.046676, brier:0.643576, uce:0.267471
CDA-TS: Graph SU: acc: 0.573643 map: 0.319211 loss: 0.489725, ece:0.217880, sce:0.045660, tace:0.046676, brier:0.643576, uce:0.267471
CDA-TS: Caption scores : {'BLEU': array([0.3623, 0.3231, 0.3032, 0.2864]), 'METEOR': 0.2151, 'ROUGE': 0.4628, 'CIDEr': 2.3378}


# Confidence Aware distribution

In [16]:
def conf_dist(model_ts, g_c_val_dataloader, c_val_dataloader):

    g_conf_list = np.zeros(13)
    c_conf_list = np.zeros(41)

    with torch.no_grad():
        for it, (images, caps_gt, graph_data) in enumerate(iter(g_c_val_dataloader)):
            # graph
            img_name = graph_data['img_name']
            img_loc = graph_data['img_loc']
            node_num = graph_data['node_num']
            roi_labels = graph_data['roi_labels']
            det_boxes = graph_data['det_boxes']
            edge_labels = graph_data['edge_labels']
            edge_num = graph_data['edge_num']
            features = graph_data['features']
            spatial_feat = graph_data['spatial_feat']
            word2vec = graph_data['word2vec']
            features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
        
            g_output = model_ts.model.graph_su(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
            
            g_output = F.softmax(g_output,1)
            g_confidences, g_predictions = torch.max(g_output, 1)
            g_accuracies = g_predictions.eq(torch.argmax(edge_labels, dim=1))
            for i in torch.argmax(edge_labels, dim=1).unique():
                g_conf_list[i] += g_confidences[torch.argmax(edge_labels, dim=1)==i].sum()

    
    with torch.no_grad():
        for it, (images, caps_gt) in enumerate(iter(c_val_dataloader)):    
            images, caps_gt = images.to(device), caps_gt.to(device)
            c_output = model_ts.model.caption(images, caps_gt)
            c_output = c_output.squeeze()
            c_output = F.softmax(c_output,1)
            c_confidences, c_predictions = torch.max(c_output, 1)
            c_accuracies = c_predictions.eq(caps_gt.squeeze())
            for i in caps_gt.squeeze().unique():
                c_conf_list[i] += c_confidences[caps_gt.squeeze()==i].sum()

    return g_conf_list, c_conf_list

# 2.5 Temperature Scaling : With CCA-TS

In [17]:
g_conf_list, c_conf_list = conf_dist(model_ts, dict_dataloader_val, dataloader_val)

plt.figure(1)
plt.title('Graph SU class distribution')
plt.bar(np.arange(len(g_cls_freq)),g_cls_freq)
plt.savefig('graph_class_dist.png')

plt.figure(2)
plt.title('Graph SU confidence score distribution')
plt.bar(np.arange(len(g_cls_freq)),g_conf_list/12)
plt.savefig('graph_conf_dist.png')
    
plt.figure(3)
plt.title('Caption Class Distribution')
plt.bar(np.arange(len(c_cls_freq)),c_cls_freq)
plt.savefig('caption_cls_dist.png')
    
plt.figure(4)
plt.title('Caption Confidence distribution')
plt.bar(np.arange(len(c_cls_freq)),c_conf_list/41)
plt.savefig('caption_conf_dist.png')

# CCA-TS calculation V2

In [18]:
g_cls_freq = g_conf_list/13
g_cls_freq_norm = g_cls_freq/np.max(g_cls_freq)
g_cls_freq_norm = torch.tensor(g_cls_freq_norm).float()
g_temp = model_ts.graph_su_temperature.cpu() + g_cls_freq_norm*0.1
# g_temp = g_temp.to(device)
model_ts.graph_su_temperature = nn.Parameter(g_temp.to(device))
    

c_cls_freq = c_conf_list/41#train_dataset.get_cls_num_list()
c_cls_freq_norm = c_cls_freq/np.max(c_cls_freq)
c_cls_freq_norm = torch.tensor(c_cls_freq_norm).float()
c_temp =  model_ts.caption_temperature.cpu() + c_cls_freq_norm*0.1
# c_temp = c_temp.to(device)
model_ts.caption_temperature = nn.Parameter(c_temp.to(device))
    
g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('CCA-TS-V1: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('CCA-TS-V1: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
print('CCA-TS-V1: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("CCA-TS-V1: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
CCA-TS-V1: G_loss:0.709, G_acc:0.575, G_ece:0.46582
CCA-TS-V1: C_loss:3.272, C_acc:0.022, C_ece:0.02043


evaluation: 100%|██████████| 447/447 [01:56<00:00,  3.85it/s]


acc: 0.575366 map: 0.320200 loss: 0.491720, ece:0.216305, sce:0.045280, tace:0.046805, brier:0.643167, uce:0.266481
CCA-TS-V1: Graph SU: acc: 0.575366 map: 0.320200 loss: 0.491720, ece:0.216305, sce:0.045280, tace:0.046805, brier:0.643167, uce:0.266481
CCA-TS-V1: Caption scores : {'BLEU': array([0.3623, 0.3231, 0.3032, 0.2864]), 'METEOR': 0.2151, 'ROUGE': 0.4628, 'CIDEr': 2.3378}


# CCA-TS calculation V2

In [19]:
g_cls_freq = g_conf_list/13
g_cls_freq_norm = g_cls_freq/np.max(g_cls_freq)
g_cls_freq_norm = torch.tensor(g_cls_freq_norm).float()
g_temp = model_ts.graph_su_temperature.cpu() + (g_cls_freq_norm.exp()-1.0)*.1
# g_temp = g_temp.to(device)
model_ts.graph_su_temperature.i = nn.Parameter(g_temp.to(device))
    
c_cls_freq = c_conf_list/41#train_dataset.get_cls_num_list()
c_cls_freq_norm = c_cls_freq/np.max(c_cls_freq)
c_cls_freq_norm = torch.tensor(c_cls_freq_norm).float()
c_temp =  model_ts.caption_temperature.cpu() + (c_cls_freq_norm.exp()-1.0)*.1
# c_temp = c_temp.to(device)
model_ts.caption_temperature = nn.Parameter(c_temp.to(device))
    
g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('CCA-TS-V2: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('CCA-TS-V2: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
print('CCA-TS-V2: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("CCA-TS-V2: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
CCA-TS-V2: G_loss:0.709, G_acc:0.575, G_ece:0.46582
CCA-TS-V2: C_loss:3.283, C_acc:0.022, C_ece:0.01962


evaluation: 100%|██████████| 447/447 [01:55<00:00,  3.87it/s]


acc: 0.575366 map: 0.320200 loss: 0.491720, ece:0.216305, sce:0.045280, tace:0.046805, brier:0.643167, uce:0.266481
CCA-TS-V2: Graph SU: acc: 0.575366 map: 0.320200 loss: 0.491720, ece:0.216305, sce:0.045280, tace:0.046805, brier:0.643167, uce:0.266481
CCA-TS-V2: Caption scores : {'BLEU': array([0.3623, 0.3231, 0.3032, 0.2864]), 'METEOR': 0.2151, 'ROUGE': 0.4628, 'CIDEr': 2.3378}
