In [None]:
# Rerun of the cross-evaluation as one of the reviews pointed out that we have to remove duplicates 
# The code was adapted accordingly

In [None]:
import setGPU
from transformers import BertTokenizer, BertForPreTraining
from nimble_pytorch.transformers import NimbleBert
from utils.generators import TextDatasetFineTuning, simple_collate_fn
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import pickle
def lr_lambda(epoch): return 0.97

In [None]:
## Selected models
MODEL_PATHS = [
    '../../Training/Language Models/trained_models/gbert-base-double-head-unfrozen-RP-Crowd-2-folds-a0.9_best.pt',
    '../../Training/Language Models/trained_models/gbert-base-double-head-unfrozen-RP-Crowd-2-folds-a0.5_best.pt',

    '../../Training/Language Models/trained_models/gbert-base-double-head-unfrozen-RP-Crowd-3-folds-a0.9_best.pt',
    '../../Training/Language Models/trained_models/gbert-base-double-head-unfrozen-RP-Crowd-3-folds-a0.5_best.pt',

    '../../Training/Language Models/trained_models/gbert-base-double-head-unfrozen-RP-Mod-folds-a0.9_best.pt',
    '../../Training/Language Models/trained_models/gbert-base-double-head-unfrozen-RP-Mod-folds-a0.5_best.pt',

    '../../Training/Language Models/trained_models/gbert-base-double-head-unfrozen-derstandard-folds-a0.9_best.pt',
    '../../Training/Language Models/trained_models/gbert-base-double-head-unfrozen-derstandard-folds-a0.5_best.pt'
]
## Load tokenizer
tokenizer = BertTokenizer.from_pretrained('deepset/gbert-base')

In [None]:
## Dataset as tuple (Name, Dataset)
DATASETS = [
    ('RP-Crowd-2', TextDatasetFineTuning('../../Dataset/Text-Data/Cross Evaluation Without Duplicates/evaluate_on_RP-Crowd-2-folds.csv', 
                                         tokenizer, 'ten_folds', [8,9], include_mlm=True, mlm_probability=0., 
                                         padding='max_length', truncation=True, max_length=512)),
    ('RP-Crowd-3', TextDatasetFineTuning('../../Dataset/Text-Data/Cross Evaluation Without Duplicates/evaluate_on_RP-Crowd-3-folds.csv', 
                                         tokenizer, 'ten_folds', [8,9], include_mlm=True, mlm_probability=0., 
                                         padding='max_length', truncation=True, max_length=512)),
    ('RP-Mod', TextDatasetFineTuning('../../Dataset/Text-Data/Cross Evaluation Without Duplicates/evaluate_on_RP-Mod-folds.csv', 
                                         tokenizer, 'ten_folds', [8,9], include_mlm=True, mlm_probability=0., 
                                         padding='max_length', truncation=True, max_length=512)),
    ('DerStandard', TextDatasetFineTuning('../../Dataset/Text-Data/derstandard-folds.csv', tokenizer, 
                                     'ten_folds', [8,9], include_mlm=True, mlm_probability=0., 
                                     padding='max_length', truncation=True, max_length=512))
]

In [None]:
## Iterate over models
for path in MODEL_PATHS:
    model = torch.load(path, map_location='cuda')
    # Iterate over datasets
    for name, dataset in DATASETS:
        model.name = model.name.replace('unfreezed', 'unfrozen')
        print(f'Evaluate {model.name} on {name}!')
        probs = []
        decisions = []
        y_trues = []
        
        # Iterate over batches and store them
        for inp in iter(dataset):
            y_trues.append(inp['next_sentence_label'])
            x = {}
            x['input_ids'] = inp['input_ids'].cuda()
            x['token_type_ids'] = inp['token_type_ids'].cuda()
            x['attention_mask'] = inp['attention_mask'].cuda()

            prob = model.predict(x)['seq_relationship_logits'].cpu().exp()
            decision = prob.argmax(-1)

            probs.extend(prob.tolist())
            decisions.extend(decision.tolist())
            del prob, decision, x

        ## Save Pickle
        with open(f'../Cross-Evaluation-Results/renewed_cross_eval_{model.name}_on_{name}.pickle', 'wb') as f:
            pickle.dump({'probs': probs, 'decisions': decisions, 'y_trues': y_trues}, f)

    del model