In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification, AutoConfig, T5ForConditionalGeneration
from transformers import AutoTokenizer
from sklearn.metrics import confusion_matrix

import math
import os
import json
import datasets
from matplotlib import pyplot as plt
from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

In [2]:
from transformers import AutoModelForSequenceClassification
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')
from modeling_rmt import RMTEncoderForSequenceClassification
from modeling_rmt_enc_dec import RMTEncoderDecoderForConditionalGeneration

## RMT

In [3]:
# set_params_kwargs['backbone_cls'] = transformers.T5ForConditionalGeneration
# set_params_kwargs['backbone_cls']

In [4]:
model_name = 't5-base'


experiment_path = "/home/bulatov/bulatov/runs/finetune/debug/qmsum/t5-base/lr2e-04_linear_adamw_wd1e-03_501-1024_mem0_bs8_iters3200_sl_regular/run_2"

cpt_path = os.path.join(experiment_path, "model_best.pth")
config_path = os.path.join(experiment_path, "config.json")
cpt = torch.load(cpt_path, map_location='cpu')

with open(config_path, 'r') as f:
    d = json.load(f)

base_model = transformers.T5ForConditionalGeneration.from_pretrained('t5-base')
rmt = RMTEncoderDecoderForConditionalGeneration(base_model=base_model)
tokenizer = AutoTokenizer.from_pretrained(model_name)

set_params_kwargs = {k:v for k,v in d.items() if k in rmt.set_params.__code__.co_varnames}
set_params_kwargs['backbone_cls'] = None#transformers.T5ForConditionalGeneration
set_params_kwargs['segment_ordering'] = 'regular'
set_params_kwargs['inter_layer_memory'] = False
set_params_kwargs['tokenizer'] = tokenizer        

rmt.set_params(**set_params_kwargs)
rmt.load_state_dict(cpt['model_state_dict'])



<All keys matched successfully>

### Baseline

In [5]:
# model_name = 't5-base'

# experiment_path = "/home/bulatov/bulatov/runs/finetune/debug/qmsum/t5-base/lr2e-04_linear_adamw_wd1e-03_512-1024_bs8_iters3200_bl/run_1/"
# cpt_path = os.path.join(experiment_path, "model_best.pth")
# config_path = os.path.join(experiment_path, "config.json")
# cpt = torch.load(cpt_path, map_location='cpu')

# with open(config_path, 'r') as f:
#     d = json.load(f)

# baseline = T5ForConditionalGeneration.from_pretrained(model_name, num_labels=3)
# tokenizer = AutoTokenizer.from_pretrained(model_name)

# baseline.load_state_dict(cpt['model_state_dict'])

In [6]:
class Holder:
    def __init__(self):
        pass

In [7]:
input_seq_len = d['input_seq_len']
target_seq_len = 1024
batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''


device = torch.device(2)

global_attention_first_token = False  # should be True for LED
encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}
# generate_kwargs = {'max_length': args.target_seq_len, 'min_length': args.target_seq_len}
generate_kwargs = {}

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    if 'outputs' in batch[0]:
        # if we have more than 1 label per example (only in valid) take only one of them
        # to compute loss on valid
        labels = [b['outputs'][0][:args.target_seq_len * 10] for b in batch]
    else:
        labels = [b['output'][:args.target_seq_len * 10] for b in batch]
    if args.input_prefix:
        inputs = [args.input_prefix + inp for inp in inputs]
    features = tokenizer.batch_encode_plus(list(inputs), max_length=args.input_seq_len, return_tensors='pt',
                                           **encode_plus_kwargs)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer.batch_encode_plus(list(labels), max_length=args.target_seq_len, return_tensors='pt',
                                             **encode_plus_kwargs).input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    features['labels'] = labels
    features['id'] = [b['id'] for b in batch]
    if 'outputs' in batch[0]:
        features['target_text'] = [b['outputs'] for b in batch]
    else:
        features['target_text'] = [b['output'] for b in batch]
    if 'global_attention_mask' in features:
        raise RuntimeError('What global attention mask for Longformer and LongformerEncoder-Decoder should be?')
    return features

In [8]:
task_name = 'qmsum'
dataset = datasets.load_dataset('tau/scrolls', task_name)
train_dataset = dataset['train']

train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataset = dataset['validation']
valid_sampler = RandomSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/qmsum/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
len(train_dataset['pid']), len(valid_dataset['pid'])

(1257, 272)

### predictions from all segments

In [10]:
def generate(self, input_ids, return_all_outputs=False, **kwargs):
    memory = self.set_memory()
    mem_start_ind = 1 if self.bos_token is not None else 0
    min_length, max_length = None, None
    if 'min_length' in kwargs:
        min_length = kwargs.pop('min_length')
    if 'max_length' in kwargs:
        max_length = kwargs.pop('max_length')

    segmented = self.pad_and_segment(input_ids)
    segmented = list(zip(*segmented))

    if self.segment_ordering in {'regular', 'last_memory_only'}:
        pass
    elif self.segment_ordering == 'reversed':
        segmented = segmented[::-1]
    elif self.segment_ordering == 'bidirectional':
        segmented = segmented + segmented[::-1][1:]
    elif self.segment_ordering == 'repeat_first':
        segmented = segmented + segmented[:1]
    else:
        raise ValueError(f'Unknown segment ordering: {self.segment_ordering}')

    outputs = []
    for seg_num, segment_data in enumerate(segmented):
        input_ids, attention_mask, token_type_ids = segment_data
        if memory.ndim == 2:
            memory = memory.repeat(input_ids.shape[0], 1, 1)
        if (self.bptt_depth > -1) and (len(segmented) - seg_num > self.bptt_depth): 
            memory = memory.detach()

        seg_kwargs = dict(**kwargs)
        if self.drop_empty_segments:
            non_empty_mask = [not torch.equal(input_ids[i], self.empty) for i in range(len(input_ids))]
            if sum(non_empty_mask) == 0:
                continue
            input_ids = input_ids[non_empty_mask]
            attention_mask = attention_mask[non_empty_mask]
            token_type_ids = token_type_ids[non_empty_mask]

            inputs_embeds = self.embeddings(input_ids)
            inputs_embeds[:, mem_start_ind:mem_start_ind+self.num_mem_tokens] = memory[non_empty_mask]

        else:
            inputs_embeds = self.embeddings(input_ids)
            inputs_embeds[:, mem_start_ind:mem_start_ind+self.num_mem_tokens] = memory

        seg_kwargs['inputs_embeds'] = inputs_embeds
        seg_kwargs['attention_mask'] = attention_mask
        if seg_num < len(segmented)-1:
            labels = torch.zeros(inputs_embeds.shape[0], inputs_embeds.shape[1], device=inputs_embeds.device, dtype=input_ids.dtype)
            gen_out = self.model.generate(**seg_kwargs, output_hidden_states=True, min_length=min_length, max_length=max_length)
            outputs.append(gen_out)
            out = self.model.forward(**seg_kwargs, output_hidden_states=True, labels=labels)
            if self.drop_empty_segments:
                memory[non_empty_mask] = out.encoder_hidden_states[-1][:, mem_start_ind:mem_start_ind+self.num_mem_tokens]
            else:
                memory = out.encoder_hidden_states[-1][:, mem_start_ind:mem_start_ind+self.num_mem_tokens]
        else:
            out = self.model.generate(**seg_kwargs, output_hidden_states=True, min_length=min_length, max_length=max_length)
            outputs.append(out)

    if return_all_outputs:
        return out, outputs
    
    return out


### get predictions

In [11]:
import pandas as pd
def evaluate(output, sample):
    labels = sample['labels']
    logits = o['logits']
    preds = torch.argmax(logits, dim=1)
    correct_mask = preds == labels[output['non_empty_mask']]
    return correct_mask

In [12]:
import shutil
from huggingface_hub import hf_hub_download

def download_metric():
    scrolls_metric_path = hf_hub_download(repo_id="datasets/tau/scrolls", filename="metrics/scrolls.py")
    updated_scrolls_metric_path = (
        os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
    )
    shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
    return updated_scrolls_metric_path


scrolls_metric_path = download_metric()
scrolls_metric = datasets.load_metric(scrolls_metric_path, task_name, keep_in_memory=True)

def metrics_fn(labels, generation_outputs, verbose=True):
    # compute metrics based on stored labels, predictions, ...
        # replace -100 with pad token in labels
    y = labels
    p = tokenizer.batch_decode(generation_outputs, skip_special_tokens=True)
    if verbose:
        for i in range(len(y)):
            print(f'y: {y[i]}')
            print(f'p: {p[i]}')
            # print(f'p ids: {generation_outputs[i]}')
            print('-' * 50)
        # todo: do we need to better clean P to remove tokens after eos? not remove special tokens only
    if y is not None and p is not None:
        if not isinstance(y[0], list):
            y = [[_y] for _y in y]
        result = scrolls_metric.compute(predictions=p, references=y)
    return result

In [13]:
train_dataset

Dataset({
    features: ['id', 'pid', 'input', 'output'],
    num_rows: 1257
})

In [14]:
id2label_train = dict(zip(train_dataset['id'], train_dataset['output']))
id2label_valid = dict(zip(valid_dataset['id'], valid_dataset['output']))

id2text_train = dict(zip(train_dataset['id'], train_dataset['input']))
id2text_valid = dict(zip(valid_dataset['id'], valid_dataset['input']))

### baseline


In [15]:
# %%time
# it = 0
# max_it =  3000

# baseline.to(device=device)
# sampler = RandomSampler(train_dataset)
# dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
#                                 collate_fn=collate_fn, **kwargs)
# # sampler = RandomSampler(valid_dataset)
# # dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=sampler,
# #                                 collate_fn=collate_fn, **kwargs)


# it = 0
    
# res_df = pd.DataFrame()
# gen = iter(dataloader)
# for sample in gen:
#     ids, target_text, labels = sample.pop('id'), sample.pop('target_text'), sample.pop('labels')
#     for key in sample:
#         sample[key] = sample[key].to(device)
    
#     # out, outputs = generate(rmt, return_all_outputs=True, **sample)  
#     out = baseline.generate(**sample)  

#     res_dict = {'ids': ids}

#     f1 = metrics_fn(target_text, generation_outputs=out, verbose=False)['f1']
#     res_dict[f'f1_seg_{i}'] = f1

#     preds = tokenizer.batch_decode(out, skip_special_tokens=True)
#     res_dict['preds'] = preds
#     res_dict['preds_tokens'] = str(out.cpu().numpy())

#     res_dict['target_text'] = target_text
#     res_dict['labels'] = str(labels.cpu().numpy())

#     res_df = res_df.append(pd.DataFrame(res_dict), ignore_index=True)


#     it += 1
#     if it > max_it:
#         break



In [16]:
# res_df.to_csv('tables/qasper-valid-t5-base.csv', index=False)
# res_df.to_csv('tables/qasper-train-t5-base.csv', index=False)


#### rmt

In [17]:
metric_name = 'rouge/geometric_mean'

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [18]:
%%time
it = 0
max_it =  2000

rmt.to(device=device)
rmt.drop_empty_segments = False
sampler = RandomSampler(train_dataset)
dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
                                collate_fn=collate_fn, **kwargs)
# sampler = RandomSampler(valid_dataset)
# dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=sampler,
#                                 collate_fn=collate_fn, **kwargs)


it = 0
    
res_df = pd.DataFrame()
gen = iter(dataloader)
for sample in gen:
    ids, target_text, labels = sample.pop('id'), sample.pop('target_text'), sample.pop('labels')
    for key in sample:
        sample[key] = sample[key].to(device)
    
    out, outputs = generate(rmt, return_all_outputs=True, **sample)  
    # print(out.keys())

    res_dict = {'ids': ids}
    for i, o in enumerate(outputs):
        # metric = metrics_fn(target_text, generation_outputs=o, verbose=False)[metric_name]
        metrics = [metrics_fn([t], generation_outputs=[row], verbose=False)[metric_name] for t,row in zip(target_text, o)]
        res_dict[f'{metric_name}_seg_{i}'] = metrics
        
        preds = tokenizer.batch_decode(o, skip_special_tokens=True)
        res_dict[f'preds_seg_{i}'] = preds
        res_dict[f'preds_seg_{i}_tokens'] = [row for row in o.cpu().numpy()]
        
    res_dict['target_text'] = target_text
    res_dict['labels'] = [row for row in labels.cpu().numpy()]
    # print(res_dict)

    res_df = res_df.append(pd.DataFrame(res_dict), ignore_index=True)


    it += 1
    if (it % 100 == 0) or (it < 10): 
        print(it)
    if it > max_it:
        break



1
2
3
4
5
6
7
8
9
100
200
300
400
500
600
CPU times: user 5min 36s, sys: 43min 48s, total: 49min 24s
Wall time: 55min 57s


In [19]:
out[0]

tensor([   0,  100,   19,    8, 4509,   11,  336, 1338,   30,    8,  516,    5,
          86,   48, 1338,    6,    8,  563, 5172,    8], device='cuda:2')

In [20]:
res_df

Unnamed: 0,ids,rouge/geometric_mean_seg_0,preds_seg_0,preds_seg_0_tokens,target_text,labels
0,tr-sq-73,83.0052,Conferees agreed that the remote control could...,"[0, 11847, 49, 15, 15, 7, 4686, 24, 8, 4322, 6...",Conferees agreed that the remote control could...,"[11847, 49, 15, 15, 7, 4686, 24, 8, 4322, 610,..."
1,tr-gq-731,8.3643,The group discussed the future directions of t...,"[0, 37, 563, 5172, 8, 647, 7943, 13, 8, 516, 6...",An idea for future work was suggested during t...,"[389, 800, 21, 647, 161, 47, 5259, 383, 8, 719..."
2,tr-sq-648,53.7580,Grad D was working on filing his taxes. He exp...,"[0, 10771, 309, 47, 464, 30, 9479, 112, 5161, ...",Grad D was focused on filing his taxes. He exp...,"[10771, 309, 47, 2937, 30, 9479, 112, 5161, 5,..."
3,tr-sq-341,11.4004,The group agreed that the remote control shoul...,"[0, 37, 563, 4686, 24, 8, 4322, 610, 225, 36, ...",Project Manager suggested a can opener beneath...,"[2786, 3440, 5259, 3, 9, 54, 19717, 12759, 8, ..."
4,tr-sq-198,20.6617,The team members got acquainted with the use o...,"[0, 37, 372, 724, 530, 29740, 28, 8, 169, 13, ...","The manager introduced a new technical device,...","[37, 2743, 3665, 3, 9, 126, 2268, 1407, 6, 8, ..."
...,...,...,...,...,...,...
1252,tr-sq-591,54.3323,Industrial Designer thought that the vegetable...,"[0, 9199, 11296, 816, 24, 8, 12065, 11, 2728, ...",Industrial Designer thought that the vegetable...,"[9199, 11296, 816, 24, 8, 12065, 11, 2728, 380..."
1253,tr-sq-499,20.3754,The voice recognition could have two advantage...,"[0, 37, 2249, 5786, 228, 43, 192, 7648, 10, 34...",The voice recognition could deliver advantages...,"[37, 2249, 5786, 228, 2156, 7648, 12, 1722, 10..."
1254,tr-sq-1251,8.5119,The committee members discussed the various me...,"[0, 37, 4492, 724, 5172, 8, 796, 3629, 12, 361...","The group mentioned that during the pandemic, ...","[37, 563, 2799, 24, 383, 8, 2131, 221, 3113, 6..."
1255,tr-sq-253,16.4258,User Interface proposed to make a remote contr...,"[0, 6674, 25064, 4382, 12, 143, 3, 9, 4322, 61...",The remote control should be made dependable a...,"[37, 4322, 610, 225, 36, 263, 3, 23094, 11, 13..."


In [21]:
res_df.to_csv('tables/qmsum-train-rm-t5-501-0.csv', index=False)
# res_df.to_csv('tables/qmsum-valid-rm-t5-501-0.csv', index=False)

In [22]:
%%time
it = 0
max_it =  1000

rmt.to(device=device)
rmt.drop_empty_segments = False
# sampler = RandomSampler(train_dataset)
# dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
#                                 collate_fn=collate_fn, **kwargs)
sampler = RandomSampler(valid_dataset)
dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=sampler,
                                collate_fn=collate_fn, **kwargs)


it = 0
    
res_df = pd.DataFrame()
gen = iter(dataloader)
for sample in gen:
    ids, target_text, labels = sample.pop('id'), sample.pop('target_text'), sample.pop('labels')
    for key in sample:
        sample[key] = sample[key].to(device)
    
    out, outputs = generate(rmt, return_all_outputs=True, **sample)  

    res_dict = {'ids': ids}
    for i, o in enumerate(outputs):
        # metric = metrics_fn(target_text, generation_outputs=o, verbose=False)[metric_name]
        metrics = [metrics_fn([t], generation_outputs=[row], verbose=False)[metric_name] for t,row in zip(target_text, o)]
        res_dict[f'{metric_name}_seg_{i}'] = metrics
        
        preds = tokenizer.batch_decode(o, skip_special_tokens=True)
        res_dict[f'preds_seg_{i}'] = preds
        res_dict[f'preds_seg_{i}_tokens'] = [row for row in o.cpu().numpy()]
        
    res_dict['target_text'] = target_text
    res_dict['labels'] = [row for row in labels.cpu().numpy()]

    res_df = res_df.append(pd.DataFrame(res_dict), ignore_index=True)


    it += 1
    if it // 20 == 0:
        print(it)
    if it > max_it:
        break



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
CPU times: user 1min 10s, sys: 9min 22s, total: 10min 33s
Wall time: 11min 57s


In [23]:
res_df

Unnamed: 0,ids,rouge/geometric_mean_seg_0,preds_seg_0,preds_seg_0_tokens,target_text,labels
0,va-sq-68,19.4028,The User Interface thought that the remote con...,"[0, 37, 6674, 25064, 816, 24, 8, 4322, 610, 13...","While discussing the financial issue, the User...","[818, 12104, 8, 981, 962, 6, 8, 6674, 25064, 4..."
1,va-sq-15,9.2796,The Project Manager thought that the buttons o...,"[0, 37, 2786, 3440, 816, 24, 8, 10634, 30, 8, ...",In terms of the great idea of making a multifu...,"[86, 1353, 13, 8, 248, 800, 13, 492, 3, 9, 124..."
2,va-gq-49,0.0000,The team decided to stick to the so-called con...,"[0, 37, 372, 1500, 12, 4372, 12, 8, 78, 18, 93...",The meeting discussed various topics such as d...,"[37, 1338, 5172, 796, 4064, 224, 38, 331, 1232..."
3,va-sq-50,7.6497,The group discussed which components to choose...,"[0, 37, 563, 5172, 84, 3379, 12, 854, 45, 11, ...",The industrial designer provided several optio...,"[37, 2913, 4378, 937, 633, 931, 6898, 21, 827,..."
4,va-sq-110,14.1300,The group was concerned that the University of...,"[0, 37, 563, 47, 4376, 24, 8, 636, 13, 1013, 1...",Kirsty Williams firstly made it clear that it ...,"[10976, 7, 17, 63, 6060, 166, 120, 263, 34, 96..."
...,...,...,...,...,...,...
267,va-sq-78,9.3486,B thought that the current computer navigation...,"[0, 272, 816, 24, 8, 750, 1218, 8789, 358, 47,...",B explained that the subject of the experiment...,"[272, 5243, 24, 8, 1426, 13, 8, 5016, 56, 166,..."
268,va-sq-64,22.3798,When Marketing showed a picture of the spin wh...,"[0, 366, 4329, 3217, 3, 9, 1554, 13, 8, 5404, ...",User Interface asked whether they could put an...,"[6674, 25064, 1380, 823, 79, 228, 474, 46, 115..."
269,va-sq-232,0.0000,"The first domain is a set of digits, the secon...","[0, 37, 166, 3303, 19, 3, 9, 356, 13, 3, 9206,...","Given the domains currently used (tourist, chi...","[9246, 8, 3303, 7, 1083, 261, 41, 17, 1211, 34..."
270,va-sq-267,16.5432,Project Manager thought that the manual design...,"[0, 2786, 3440, 816, 24, 8, 3354, 408, 47, 514...",Project Manager thought the beeping device was...,"[2786, 3440, 816, 8, 36, 15, 2462, 1407, 47, 5..."


In [24]:
# res_df.to_csv('tables/qmsum-train-rm-t5-501-0.csv', index=False)
res_df.to_csv('tables/qmsum-valid-rm-t5-501-0.csv', index=False)

In [25]:
# sample['input_ids'].shape

In [26]:
pd.DataFrame(res_dict)

Unnamed: 0,ids,rouge/geometric_mean_seg_0,preds_seg_0,preds_seg_0_tokens,target_text,labels
0,va-sq-267,16.5432,Project Manager thought that the manual design...,"[0, 2786, 3440, 816, 24, 8, 3354, 408, 47, 514...",Project Manager thought the beeping device was...,"[2786, 3440, 816, 8, 36, 15, 2462, 1407, 47, 5..."
1,va-sq-135,11.7844,The group thought that Dave had left IBM and t...,"[0, 37, 563, 816, 24, 8545, 141, 646, 11045, 1...",The group needed Dave Johnson to install the s...,"[37, 563, 906, 8545, 5891, 12, 2438, 8, 128, 8..."


In [27]:
res_df

Unnamed: 0,ids,rouge/geometric_mean_seg_0,preds_seg_0,preds_seg_0_tokens,target_text,labels
0,va-sq-68,19.4028,The User Interface thought that the remote con...,"[0, 37, 6674, 25064, 816, 24, 8, 4322, 610, 13...","While discussing the financial issue, the User...","[818, 12104, 8, 981, 962, 6, 8, 6674, 25064, 4..."
1,va-sq-15,9.2796,The Project Manager thought that the buttons o...,"[0, 37, 2786, 3440, 816, 24, 8, 10634, 30, 8, ...",In terms of the great idea of making a multifu...,"[86, 1353, 13, 8, 248, 800, 13, 492, 3, 9, 124..."
2,va-gq-49,0.0000,The team decided to stick to the so-called con...,"[0, 37, 372, 1500, 12, 4372, 12, 8, 78, 18, 93...",The meeting discussed various topics such as d...,"[37, 1338, 5172, 796, 4064, 224, 38, 331, 1232..."
3,va-sq-50,7.6497,The group discussed which components to choose...,"[0, 37, 563, 5172, 84, 3379, 12, 854, 45, 11, ...",The industrial designer provided several optio...,"[37, 2913, 4378, 937, 633, 931, 6898, 21, 827,..."
4,va-sq-110,14.1300,The group was concerned that the University of...,"[0, 37, 563, 47, 4376, 24, 8, 636, 13, 1013, 1...",Kirsty Williams firstly made it clear that it ...,"[10976, 7, 17, 63, 6060, 166, 120, 263, 34, 96..."
...,...,...,...,...,...,...
267,va-sq-78,9.3486,B thought that the current computer navigation...,"[0, 272, 816, 24, 8, 750, 1218, 8789, 358, 47,...",B explained that the subject of the experiment...,"[272, 5243, 24, 8, 1426, 13, 8, 5016, 56, 166,..."
268,va-sq-64,22.3798,When Marketing showed a picture of the spin wh...,"[0, 366, 4329, 3217, 3, 9, 1554, 13, 8, 5404, ...",User Interface asked whether they could put an...,"[6674, 25064, 1380, 823, 79, 228, 474, 46, 115..."
269,va-sq-232,0.0000,"The first domain is a set of digits, the secon...","[0, 37, 166, 3303, 19, 3, 9, 356, 13, 3, 9206,...","Given the domains currently used (tourist, chi...","[9246, 8, 3303, 7, 1083, 261, 41, 17, 1211, 34..."
270,va-sq-267,16.5432,Project Manager thought that the manual design...,"[0, 2786, 3440, 816, 24, 8, 3354, 408, 47, 514...",Project Manager thought the beeping device was...,"[2786, 3440, 816, 8, 36, 15, 2462, 1407, 47, 5..."


## Interpret

#### Train

In [28]:
rmt_df = pd.read_csv('tables/qasper-train-rm-t5-1002-10.csv').sort_values(['ids']).reset_index()
rmt_df['input'] = rmt_df.ids.apply(lambda x: id2text_train[x])

baseline_df = pd.read_csv('tables/qasper-train-t5-base.csv').sort_values(['ids']).reset_index()
baseline_df['input'] = baseline_df.ids.apply(lambda x: id2text_train[x])
baseline_df['f1_seg_1'] = baseline_df.f1_seg_2

KeyError: '0038b073b7cca847033177024f9719c971692042'

In [None]:
f1_cols = {'rm-t5-seg-2': rmt_df.f1_seg_1, 'rmt-t5-seg-1': rmt_df.f1_seg_0, 't5': baseline_df.f1_seg_1}
comp_df = pd.DataFrame(index = f1_cols.keys(), columns = f1_cols.keys())

for better_name, better_f1 in f1_cols.items():
    for worse_name, worse_f1 in f1_cols.items():
        num_occ = (better_f1 > worse_f1).sum()
        comp_df.loc[worse_name, better_name] = num_occ
        
comp_df = (comp_df / rmt_df.shape[0]).astype(float).round(2)
comp_df

In [None]:
# rmt_df.pred_seg_0.value_counts(), rmt_df.pred_seg_1.value_counts()

In [None]:
baseline_df.f1_seg_2.hist()
rmt_df.f1_seg_0.hist()
rmt_df.f1_seg_1.hist()
plt.legend(['baseline', 'seg 0', 'seg 1'])

In [None]:
(rmt_df.f1_seg_0 == 0).mean()

In [None]:
(rmt_df.f1_seg_1 - rmt_df.f1_seg_0).mean()

In [None]:
rmt_df[rmt_df.f1_seg_1 >  0].target_text.value_counts()

In [None]:
rmt_df[rmt_df.f1_seg_0 >  0].target_text.value_counts()

#### Valid

In [None]:
rmt_df = pd.read_csv('tables/qasper-valid-rm-t5-1002-10.csv').sort_values(['ids']).reset_index()
rmt_df['input'] = rmt_df.ids.apply(lambda x: id2text_valid[x])

baseline_df = pd.read_csv('tables/qasper-valid-t5-base.csv').sort_values(['ids']).reset_index()
baseline_df['input'] = baseline_df.ids.apply(lambda x: id2text_valid[x])
baseline_df['f1_seg_1'] = baseline_df['f1_seg_2']

In [None]:
baseline_df.f1_seg_2.hist()
rmt_df.f1_seg_0.hist()
rmt_df.f1_seg_1.hist()
plt.legend(['baseline', 'seg 0', 'seg 1'])

In [None]:
mask = (rmt_df.f1_seg_1 < rmt_df.f1_seg_0) & (rmt_df.f1_seg_1 < baseline_df.f1_seg_1)
mask.sum()

In [None]:
(rmt_df.f1_seg_1 - rmt_df.f1_seg_0).mean(), (baseline_df.f1_seg_1 - rmt_df.f1_seg_0).mean()

In [None]:
(rmt_df.f1_seg_0).mean(), (rmt_df.f1_seg_1).mean(), (baseline_df.f1_seg_1).mean()

In [None]:
# rmt_df[(rmt_df.f1_seg_1 - rmt_df.f1_seg_0) < 0]

In [None]:
rmt_df[(rmt_df.f1_seg_1 - rmt_df.f1_seg_0) > 0].preds_seg_1.value_counts()

In [None]:
n_segments = 2
def split(text, n_segments=n_segments):
    premise = text.split('?')[0]
    encoded = tokenizer.encode(text, **encode_plus_kwargs, add_special_tokens=False)
    segments = np.split(np.array(encoded), n_segments)    
    texts = [tokenizer.decode(s) for s in segments]
    
    
    return [premise] + texts

In [None]:
max_i = 1

### analysis table

In [None]:
f1_cols = {'rm-t5-seg-2': rmt_df.f1_seg_1, 'rmt-t5-seg-1': rmt_df.f1_seg_0, 't5': baseline_df.f1_seg_1}
comp_df = pd.DataFrame(index = f1_cols.keys(), columns = f1_cols.keys())

for better_name, better_f1 in f1_cols.items():
    for worse_name, worse_f1 in f1_cols.items():
        num_occ = (better_f1 > worse_f1).sum()
        comp_df.loc[worse_name, better_name] = num_occ
        
comp_df = (comp_df / rmt_df.shape[0]).astype(float).round(2)
comp_df

In [None]:
comp_df

### look at texts

In [None]:
rmt_df[rmt_df.f1_seg_1 > rmt_df.f1_seg_0].preds_seg_1.value_counts() / sum(rmt_df.f1_seg_1 > rmt_df.f1_seg_0)

In [None]:
300/17

In [None]:
# max_i = 10
print(f'Segment 1 better than segment 0: {sum((rmt_df.f1_seg_1 - rmt_df.f1_seg_0) > 0)}\n\n')
slice = rmt_df[(rmt_df.f1_seg_1 - rmt_df.f1_seg_0) > 0].reset_index()
for i, row in slice.iterrows():
    inp = row['input']
    spl = '\n\n'.join(split(row['input']))
    print(f"{spl}\nseg_0: {row['preds_seg_0']},\nseg_1: {row['preds_seg_1']}\ntarget: {row['target_text']}\n\n")
          
    if i > max_i:
        break

In [None]:
rmt_df[rmt_df.f1_seg_1 < rmt_df.f1_seg_0].preds_seg_0.value_counts() / sum(rmt_df.f1_seg_1 < rmt_df.f1_seg_0)

In [None]:
# max_i = 10
print(f'Segment 0 better than segment 1: {sum((rmt_df.f1_seg_1 - rmt_df.f1_seg_0) < 0)}\n\n')
slice = rmt_df[(rmt_df.f1_seg_1 - rmt_df.f1_seg_0) < 0].reset_index()
for i, row in slice.iterrows():
    inp = row['input']
    spl = '\n\n'.join(split(row['input']))
    print(f"{spl}\nseg_0: {row['preds_seg_0']},\nseg_1: {row['preds_seg_1']}\ntarget: {row['target_text']}\n\n")
          
    if i > max_i:
        break

In [None]:
# rmt_df[baseline_df.f1_seg_1 > rmt_df.f1_seg_0].preds_seg_1.value_counts() / sum(baseline_df.f1_seg_1 > rmt_df.f1_seg_0)
baseline_df[baseline_df.f1_seg_1 > rmt_df.f1_seg_1].preds.value_counts() / sum(baseline_df.f1_seg_1 > rmt_df.f1_seg_1)

In [None]:
# max_i = 10
mask = (rmt_df.f1_seg_0 - baseline_df.f1_seg_1) > 0
print(f'Segment 0 better than baseline: {sum(mask)}\n\n')
slice = rmt_df[mask].reset_index()
slice_bl = baseline_df[mask].reset_index()
for i, row in slice.iterrows():
    inp = row['input']
    spl = '\n\n'.join(split(row['input']))
    # bl_pred = baseline_df[baseline_df.ids == row['ids']].preds.iloc[0]
    bl_pred = slice_bl.iloc[i].preds
    # print(slice.iloc[i])
    # print(slice_bl.iloc[i])
    print(f"{spl}\nseg_0: {row['preds_seg_0']},\nseg_1: {row['preds_seg_1']}\nbaseline: {bl_pred}\ntarget: {row['target_text']}\n\n")
          
    if i > max_i:
        break

In [None]:
max_i = 5
mask = (rmt_df.f1_seg_0 - baseline_df.f1_seg_1) < 0
print(f'Segment 0 worse than baseline: {sum(mask)}\n\n')
slice = rmt_df[mask].reset_index()
for i, row in slice.iterrows():
    inp = row['input']
    spl = '\n\n'.join(split(row['input']))
    bl_pred = baseline_df[baseline_df.ids == row['ids']].preds.iloc[0]
    # bl_pred = baseline_df[baseline_df.ids == row['ids']].preds
    print(f"{spl}\nseg_0: {row['preds_seg_0']},\nseg_1: {row['preds_seg_1']}\nbaseline: {bl_pred}\ntarget: {row['target_text']}\n\n")
          
    if i > max_i:
        break

In [None]:
max_i = 20
mask = (rmt_df.f1_seg_1 - baseline_df.f1_seg_1) > 0
print(f'Segment 1 better than baseline: {sum(mask)}\n\n')
slice = rmt_df[mask].reset_index()
for i, row in slice.iterrows():
    inp = row['input']
    spl = '\n\n'.join(split(row['input']))
    bl_pred = baseline_df[baseline_df.ids == row['ids']].preds.iloc[0]
    print(f"{spl}\nseg_0: {row['preds_seg_0']},\nseg_1: {row['preds_seg_1']}\nbaseline: {bl_pred}\ntarget: {row['target_text']}\n\n")
          
    if i > max_i:
        break