In [58]:
import pickle
from typing_model.models.BERT_models import ConcatenatedContextBERTTyper
from torch.nn import Sigmoid
import torch

In [16]:
path_to_base_model = 'checkpoints/BBS_figer-v0.ckpt'
base_model_aux_path = 'dataloaders/Bert_Baseline_figer_train_auxiliary_variables.pkl'
base_model_dev_loader_path = 'dataloaders/Bert_Baseline_figer_dev.pkl'

In [17]:
path_to_ft_model = 'checkpoints/FT_BBN_into_FIGER-v0.ckpt'
ft_model_aux_path = 'dataloaders/Bert_Baseline_figer_train_auxiliary_variables.pkl'
ft_model_dev_loader_path = 'dataloaders/Bert_Baseline_figer_dev.pkl'

In [24]:
path_to_frozen_model = 'checkpoints/BBS_bbn-v1.ckpt'
fro_model_aux_path = 'dataloaders/Bert_Baseline_bbn_train_auxiliary_variables.pkl'
fro_model_dev_loader_path = 'dataloaders/Bert_Baseline_bbn_dev.pkl'

### Load aux variables

In [18]:
with open(base_model_aux_path, 'rb') as inp:
    bm_id2label, bm_label2id, bm_class_number = pickle.load(inp)

In [19]:
with open(ft_model_aux_path, 'rb') as inp:
    ft_id2label, ft_label2id, ft_class_number = pickle.load(inp)

In [25]:
with open(fro_model_aux_path, 'rb') as inp:
    fro_id2label, fro_label2id, fro_class_number = pickle.load(inp)

### Load models & dataloaders

In [28]:
base_model = ConcatenatedContextBERTTyper.load_from_checkpoint(path_to_base_model,
                                                               id2label = bm_id2label, 
                                                               label2id = bm_label2id, 
                                                               classes = bm_class_number,
                                                              weights = None).cuda()

In [29]:
ft_model = ConcatenatedContextBERTTyper.load_from_checkpoint(path_to_ft_model,
                                                               id2label = ft_id2label, 
                                                               label2id = ft_label2id, 
                                                               classes = ft_class_number,
                                                              weights = None).cuda()

In [30]:
fro_model = ConcatenatedContextBERTTyper.load_from_checkpoint(path_to_frozen_model,
                                                               id2label = fro_id2label, 
                                                               label2id = fro_label2id, 
                                                               classes = fro_class_number,
                                                              weights = None).cuda()

In [32]:
with open(base_model_dev_loader_path, 'rb') as inp:
    dataloader = pickle.load(inp)

## Generate Predictions

In [33]:
sig = Sigmoid().cuda() 

In [52]:
mini_batch_size = 10
admit_void_prediction = False

In [56]:
def get_model_pred(model, id2label, mentions, contexts):
    all_preds = []
    all_labels = []
    avg_predictions = 0
    
    mentions = mentions.cuda()
    contexts = contexts.cuda()
    pred = sig(model(mentions, contexts))

    pred = pred.detach().cpu()

    mask = pred > .5
    batch_preds = []
    for i, m in enumerate(mask):
        ex_preds = []   
        pred_ids =  m.nonzero()

        if len(pred_ids) == 0:
            if admit_void_prediction:
                void_counter += 1
            else:
                pred_ids = [torch.argmax(pred[i])]
        avg_predictions += len(pred_ids)

        for p in pred_ids:
            ex_preds.append(id2label[p.item()])
        batch_preds.append(ex_preds)
    all_preds.extend(batch_preds)
    
    avg_predictions = avg_predictions/mini_batch_size
    return all_preds, avg_predictions

In [41]:
dt = iter(dataloader)

In [43]:
mention, context, label = next(dt)

In [117]:
i = 0

In [118]:
i += mini_batch_size

In [119]:
base_preds, base_avg_preds = get_model_pred(base_model, bm_id2label, 
                                             mention[i: i + mini_batch_size], 
                                             context[i: i + mini_batch_size])
ft_preds, ft_avg_preds = get_model_pred(ft_model, ft_id2label, 
                                       mention[i: i + mini_batch_size],
                                       context[i: i + mini_batch_size])
fro_preds, fro_avg_preds = get_model_pred(fro_model, fro_id2label, 
                                       mention[i: i + mini_batch_size],
                                       context[i: i + mini_batch_size])

In [120]:
mask = label == 1
batch_labels = []
for m in mask:
    ex_labels = []
    labels_ids = m.nonzero()
    for l in labels_ids:
        ex_labels.append(bm_id2label[l.item()])
    batch_labels.append(ex_labels)
batch_labels

[['/person/athlete', '/person'],
 ['/person/author', '/person/artist', '/person/musician', '/person'],
 ['/person/artist',
  '/location/cemetery',
  '/location/country',
  '/location',
  '/government_agency',
  '/language',
  '/title',
  '/person'],
 ['/art/film', '/art'],
 ['/location'],
 ['/person/actor', '/person'],
 ['/language'],
 ['/location/city', '/location'],
 ['/location/city', '/location'],
 ['/person/artist', '/person/musician', '/person'],
 ['/person/athlete', '/person'],
 ['/location/city', '/location'],
 ['/location'],
 ['/person/author',
  '/person/athlete',
  '/person/actor',
  '/person/artist',
  '/person'],
 ['/organization'],
 ['/location/country', '/location'],
 ['/building', '/location'],
 ['/location'],
 ['/location/body_of_water', '/location'],
 ['/people/ethnicity',
  '/people',
  '/person/artist',
  '/location/country',
  '/location',
  '/person'],
 ['/location'],
 ['/building', '/location', '/building/sports_facility'],
 ['/location/city', '/location'],
 ['/o

In [123]:
def print_result(p1, p2, p3, label):
    for j in range(mini_batch_size):
        print('|{:#^121}|'.format(''))
        base_sentence = '|{:^25}|{:^95}|'
        print(base_sentence.format('sentence_idx', i + j + 1))        
        print(base_sentence.format('true_labels', ','.join(label[j])))
        print(base_sentence.format('base_model', ','.join(p1[j])))
        print(base_sentence.format('ft_model', ','.join(p2[j])))
        print(base_sentence.format('frozen_model', ','.join(p3[j])))
    print('|{:#^121}|'.format(''))

In [124]:
print_result(base_preds, ft_preds, fro_preds, batch_labels)

|#########################################################################################################################|
|      sentence_idx       |                                              11                                               |
|       true_labels       |                                    /person/athlete,/person                                    |
|       base_model        |                                            /person                                            |
|        ft_model         |                                    /person/athlete,/person                                    |
|      frozen_model       |                                            /PERSON                                            |
|#########################################################################################################################|
|      sentence_idx       |                                              12                                               |
|       