In [19]:
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification, AutoConfig
from transformers import AutoTokenizer
from sklearn.metrics import confusion_matrix

import math
import os
import pandas as pd
import json
import datasets

from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

from torchviz import make_dot

In [2]:
from transformers import AutoModelForSequenceClassification
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')
from modeling_rmt import RMTEncoderForSequenceClassification
from modeling_rmt_enc_dec import RMTEncoderDecoderForConditionalGeneration

## RMT

In [3]:
# !ls /home/bulatov/bulatov/rmt_internship/finetune/contract_nli/bert-base-cased/lr1e-05_constant_with_warmup_adamw_wd1e-03_1452_mem25_sum_loss/run_3/model_best.pth
# !ls /home/bulatov/bulatov/rmt_internship/finetune/contract_nli/bert-base-cased/lr1e-05_linear_adamw_wd1e-03_968_mem25_sum_loss/run_3/

In [4]:
# cpt_path = "/cephfs/home/bulatov/bulatov/rmt_internship/finetune/contract_nli/t5-base/lr1e-05_linear_adamw_wd1e-03_972_mem25_sum_loss/run_1/model_best.pth"

model_name = 'bert-base-cased'

experiment_path = "/home/bulatov/bulatov/rmt_internship/finetune/contract_nli/bert-base-cased/lr1e-05_linear_adamw_wd1e-03_968_mem25_sum_loss/run_3/"
cpt_path = os.path.join(experiment_path, "model_best.pth")
config_path = os.path.join(experiment_path, "config.json")
cpt = torch.load(cpt_path, map_location='cpu')

with open(config_path, 'r') as f:
    d = json.load(f)

rmt = RMTEncoderForSequenceClassification.from_pretrained(model_name, num_labels=3)
tokenizer = AutoTokenizer.from_pretrained(model_name)

set_params_kwargs = {k:v for k,v in d.items() if k in rmt.set_params.__code__.co_varnames}
set_params_kwargs['segment_ordering'] = 'regular'
set_params_kwargs['inter_layer_memory'] = False
set_params_kwargs['tokenizer'] = tokenizer        

rmt.set_params(**set_params_kwargs)
rmt.load_state_dict(cpt['model_state_dict'])

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

<All keys matched successfully>

### Baseline

In [5]:
# model_name = 'bert-base-cased'

# experiment_path = "/home/bulatov/bulatov/runs/finetune/debug/contract_nli/bert-base-cased/lr1e-05_linear_adamw_wd1e-03_512_mem/run_1/"
# cpt_path = os.path.join(experiment_path, "model_best.pth")
# config_path = os.path.join(experiment_path, "config.json")
# cpt = torch.load(cpt_path, map_location='cpu')

# with open(config_path, 'r') as f:
#     d = json.load(f)

# baseline = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
# tokenizer = AutoTokenizer.from_pretrained(model_name)

# baseline.load_state_dict(cpt['model_state_dict'])

### encoder. cnli

In [6]:
input_seq_len = d['input_seq_len']
target_seq_len = 2
batch_size = 2

device = torch.device(0)

encode_plus_kwargs = {'max_length': input_seq_len,
                              'truncation': True,
                              'padding': 'longest',
                              'pad_to_multiple_of': 1}
generate_kwargs = {}
labels_map = {'Contradiction': 0, 'Entailment': 1, 'Not mentioned': 2}
num_labels = len(labels_map)

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:input_seq_len * 10] for b in batch]
    labels = [b['output'][:target_seq_len * 10] for b in batch]
    features = tokenizer.batch_encode_plus(list(inputs), return_tensors='pt', **encode_plus_kwargs)
    labels = np.array([labels_map[t] for t in labels])
    features['labels'] = torch.from_numpy(labels)
    features['id'] = [b['id'] for b in batch]
    features['pid'] = [b['pid'] for b in batch]
    
    return features

In [7]:
dataset = datasets.load_dataset('tau/scrolls', 'contract_nli')
train_dataset = dataset['train']

train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataset = dataset['validation']
valid_sampler = RandomSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/contract_nli/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
len(train_dataset['pid']), len(valid_dataset['pid'])

(7191, 1037)

### predictions from all segments

In [9]:
def __call__(self, input_ids, return_all_segments=False, **kwargs):
    memory = self.set_memory()
    segmented = self.pad_and_segment(input_ids)
    segmented = list(zip(*segmented))

    if self.segment_ordering in {'regular', 'last_memory_only'}:
        pass
    elif self.segment_ordering == 'reversed':
        segmented = segmented[::-1]
    elif self.segment_ordering == 'bidirectional':
        segmented = segmented + segmented[::-1][1:]
    elif self.segment_ordering == 'repeat_first':
        segmented = segmented + segmented[:1]
    else:
        raise ValueError(f'Unknown segment ordering: {self.segment_ordering}')

    self.memory_storage = {'num_mem_tokens': self.num_mem_tokens}
    outputs = []
    for seg_num, segment_data in enumerate(segmented):
        input_ids, attention_mask, token_type_ids = segment_data
        if memory.ndim == 2:
            memory = memory.repeat(input_ids.shape[0], 1, 1)
        if (self.bptt_depth > -1) and (len(segmented) - seg_num > self.bptt_depth): 
            memory = memory.detach()

        seg_kwargs = dict(**kwargs)
        if self.drop_empty_segments:
            non_empty_mask = [not torch.equal(input_ids[i], self.empty) for i in range(len(input_ids))]
            if sum(non_empty_mask) == 0:
                continue
            input_ids = input_ids[non_empty_mask]
            attention_mask = attention_mask[non_empty_mask]
            token_type_ids = token_type_ids[non_empty_mask]
            seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]

            inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids)
            inputs_embeds[:, 1:1+self.num_mem_tokens] = memory[non_empty_mask]
        else:
            inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids)
            inputs_embeds[:, 1:1+self.num_mem_tokens] = memory

        seg_kwargs['inputs_embeds'] = inputs_embeds
        seg_kwargs['attention_mask'] = attention_mask
        seg_kwargs['token_type_ids'] = token_type_ids

        out = self.model.forward(**seg_kwargs, output_hidden_states=True)
        outputs.append(out)

        if self.drop_empty_segments:
            memory[non_empty_mask] = out.hidden_states[-1][:, :self.num_mem_tokens]
            out['non_empty_mask'] = non_empty_mask
        else:
            memory = out.hidden_states[-1][:, :self.num_mem_tokens]

    for i, o in enumerate(outputs):
        out[f'loss_{i}'] = o['loss'].mean()

    if self.sum_loss:
        out['loss'] = torch.stack([o['loss'] for o in outputs]).sum(dim=-1)

    if return_all_segments:
        return out, outputs
    
    return out

### get predictions

In [10]:
import pandas as pd
def evaluate(output, sample):
    labels = sample['labels']
    logits = o['logits']
    preds = torch.argmax(logits, dim=1)
    correct_mask = preds == labels[output['non_empty_mask']]
    return correct_mask

In [11]:
train_dataset

Dataset({
    features: ['id', 'pid', 'input', 'output'],
    num_rows: 7191
})

In [12]:
id2label_train = dict(zip(train_dataset['id'], train_dataset['output']))
id2label_valid = dict(zip(valid_dataset['id'], valid_dataset['output']))

id2text_train = dict(zip(train_dataset['id'], train_dataset['input']))
id2text_valid = dict(zip(valid_dataset['id'], valid_dataset['input']))

In [20]:
sampler = RandomSampler(valid_dataset)
dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=sampler,
                                collate_fn=collate_fn, **kwargs)
gen = iter(dataloader)
sample = next(gen)

In [23]:
dataset['train']

Dataset({
    features: ['id', 'pid', 'input', 'output'],
    num_rows: 7191
})

In [14]:
id, pid = sample.pop('id'), sample.pop('pid')
out = rmt(**sample)

In [15]:
loss = out['loss'].mean()

In [16]:
from torchviz import make_dot

In [17]:
make_dot(loss, params = dict(rmt.named_parameters()))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


ExecutableNotFound: failed to execute PosixPath('dot'), make sure the Graphviz executables are on your systems' PATH

<graphviz.graphs.Digraph at 0x7f7dcc80aa00>

In [18]:
1/0

ZeroDivisionError: division by zero

In [None]:
# res_df

In [None]:
# res_df.to_csv('tables/cnli-train-bert.csv', index=False)
# res_df.to_csv('tables/cnli-valid-bert.csv', index=False)

## Interpret

#### Train

In [None]:
rmt_df = pd.read_csv('tables/cnli-train-rm-bert-968-25.csv')
baseline_df = pd.read_csv('tables/cnli-train-bert.csv')

In [None]:
# rmt_df.pred_seg_0.value_counts(), rmt_df.pred_seg_1.value_counts()

In [None]:
rmt_df['correct_seg_0'] = rmt_df.pred_seg_0 == rmt_df.labels_seg_0
rmt_df['correct_seg_1'] = rmt_df.pred_seg_1 == rmt_df.labels_seg_1

baseline_df['correct_seg'] = baseline_df.pred_seg == baseline_df.labels_seg

rmt_correct_ids_seg_0 = set(rmt_df[rmt_df.correct_seg_0].ids)
rmt_correct_ids_seg_1 = set(rmt_df[rmt_df.correct_seg_1].ids)
bl_correct_ids = set(baseline_df[baseline_df.correct_seg].ids)

In [None]:
rmt_lose_seg_0 = bl_correct_ids.difference(rmt_correct_ids_seg_0)
rmt_lose_seg_1 = bl_correct_ids.difference(rmt_correct_ids_seg_1)
baseline_lose = rmt_correct_ids_seg_1.difference(bl_correct_ids)

In [None]:
print(f'rmt loses: {len(rmt_lose_seg_1)}, baseline loses: {len(baseline_lose)}')

In [None]:
# baseline 

accuracy = (baseline_df.pred_seg == baseline_df.labels_seg).mean()
conf_mat = confusion_matrix(y_pred=baseline_df.pred_seg, y_true=baseline_df.labels_seg, )
conf_mat_normalized = (conf_mat / conf_mat.sum(axis=1).reshape(-1, 1)).round(2)
print(f'Accuracy: {accuracy}\n\nConfusion matrix: \n{conf_mat}\nNormalized: \n{conf_mat_normalized}')

In [None]:
# rmt

df = rmt_df 
for l in [0, 1]:
    accuracy = (df[f'pred_seg_{l}'] == df[f'labels_seg_{l}']).mean()
    conf_mat = confusion_matrix(y_pred=df[f'pred_seg_{l}'], y_true=df[f'labels_seg_{l}'])
    conf_mat_normalized = (conf_mat / conf_mat.sum(axis=1).reshape(-1, 1)).round(2)
    print(f'\nSegment {l}\nAccuracy: {accuracy}\n\nConfusion matrix: \n{conf_mat}\nNormalized: \n{conf_mat_normalized}')

#### Valid

In [None]:
rmt_df = pd.read_csv('tables/cnli-valid-rm-bert-968-25.csv')
baseline_df = pd.read_csv('tables/cnli-valid-bert.csv')

In [None]:
rmt_df['correct_seg_0'] = rmt_df.pred_seg_0 == rmt_df.labels_seg_0
rmt_df['correct_seg_1'] = rmt_df.pred_seg_1 == rmt_df.labels_seg_1

baseline_df['correct_seg'] = baseline_df.pred_seg == baseline_df.labels_seg

rmt_correct_ids_seg_0 = set(rmt_df[rmt_df.correct_seg_0].ids)
rmt_correct_ids_seg_1 = set(rmt_df[rmt_df.correct_seg_1].ids)
bl_correct_ids = set(baseline_df[baseline_df.correct_seg].ids)

rmt_lose_seg_0 = bl_correct_ids.difference(rmt_correct_ids_seg_0)
rmt_lose_seg_1 = bl_correct_ids.difference(rmt_correct_ids_seg_1)
baseline_lose = rmt_correct_ids_seg_1.difference(bl_correct_ids)
baseline_lose_seg_0 = rmt_correct_ids_seg_0.difference(bl_correct_ids)

print(f'rmt loses: seg_0 - {len(rmt_lose_seg_0)}, seg_1 - {len(rmt_lose_seg_1)}, baseline loses: {len(baseline_lose)} (to seg 0: {len(baseline_lose_seg_0)})')

In [None]:
# baseline 

accuracy = (baseline_df.pred_seg == baseline_df.labels_seg).mean()
conf_mat = confusion_matrix(y_pred=baseline_df.pred_seg, y_true=baseline_df.labels_seg, )
conf_mat_normalized = (conf_mat / conf_mat.sum(axis=1).reshape(-1, 1)).round(2)
print(f'Accuracy: {accuracy}\n\nConfusion matrix: \n{conf_mat}\nNormalized: \n{conf_mat_normalized}')

In [None]:
# rmt

df = rmt_df 
for l in [0, 1]:
    accuracy = (df[f'pred_seg_{l}'] == df[f'labels_seg_{l}']).mean()
    conf_mat = confusion_matrix(y_pred=df[f'pred_seg_{l}'], y_true=df[f'labels_seg_{l}'])
    conf_mat_normalized = (conf_mat / conf_mat.sum(axis=1).reshape(-1, 1)).round(2)
    print(f'\nSegment {l}\nAccuracy: {accuracy}\n\nConfusion matrix: \n{conf_mat}\nNormalized: \n{conf_mat_normalized}')

In [None]:
pd.DataFrame(conf_mat, index=labels_map.keys(), columns=labels_map.keys())

In [None]:
# rmt

df = rmt_df[rmt_df.ids.isin(rmt_lose_seg_1)]
for l in [0, ]:
    accuracy = (df[f'pred_seg_{l}'] == df[f'labels_seg_{l}']).mean()
    conf_mat = confusion_matrix(y_pred=df[f'pred_seg_{l}'], y_true=df[f'labels_seg_{l}'])
    conf_mat_normalized = (conf_mat / conf_mat.sum(axis=1).reshape(-1, 1)).round(2)
    print(f'\nSegment {l}\nAccuracy: {accuracy}\n\nConfusion matrix: \n{conf_mat}\nNormalized: \n{conf_mat_normalized}')

In [None]:
pd.DataFrame(conf_mat, index=labels_map.keys(), columns=labels_map.keys())

In [None]:
# rmt

df = rmt_df[rmt_df.ids.isin(rmt_lose_seg_0)]
for l in [0, 1]:
    accuracy = (df[f'pred_seg_{l}'] == df[f'labels_seg_{l}']).mean().round(2)
    conf_mat = confusion_matrix(y_pred=df[f'pred_seg_{l}'], y_true=df[f'labels_seg_{l}'])
    conf_mat_normalized = (conf_mat / conf_mat.sum(axis=1).reshape(-1, 1)).round(2)
    print(f'\nSegment {l}\nAccuracy: {accuracy}\n\nConfusion matrix: \n{conf_mat}\nNormalized: \n{conf_mat_normalized}')

In [None]:
# baseline 
df = baseline_df[baseline_df.ids.isin(baseline_lose)]
accuracy = (df.pred_seg == df.labels_seg).mean()
conf_mat = confusion_matrix(y_pred=df.pred_seg, y_true=df.labels_seg, )
conf_mat_normalized = (conf_mat / conf_mat.sum(axis=1).reshape(-1, 1)).round(2)
print(f'Accuracy: {accuracy}\n\nConfusion matrix: \n{conf_mat}\nNormalized: \n{conf_mat_normalized}')

In [None]:
pd.DataFrame(conf_mat, index=labels_map.keys(), columns=labels_map.keys())

In [None]:
# print('What label model distinguishes worse than its rival?\n RMT-seg-1, RMT-seg-0, baseline')
# rmt_df[rmt_df.ids.isin(rmt_lose_seg_1)].labels_seg_1.value_counts()/rmt_df.labels_seg_1.value_counts(), rmt_df[rmt_df.ids.isin(rmt_lose_seg_0)].labels_seg_0.value_counts()/rmt_df.labels_seg_0.value_counts(), baseline_df[baseline_df.ids.isin(baseline_lose)].labels_seg.value_counts()/baseline_df.labels_seg.value_counts()

In [None]:
# tokenizer.encode(text, **encode_plus_kwargs)
n_segments = 2
def split(text, n_segments=n_segments):
    premise = text.split('.')[0]
    encoded = tokenizer.encode(text, **encode_plus_kwargs, add_special_tokens=False)
    segments = np.split(np.array(encoded), n_segments)    
    texts = [tokenizer.decode(s) for s in segments]
    
    
    return [premise] + texts

In [None]:
i = 10

idx = list(rmt_lose_seg_1)[i]
label = id2label_valid[idx] 
rmt_preds = ', '.join(rmt_df[rmt_df.ids == idx][['pred_seg_0', 'pred_seg_1']].astype(str).values[0])
baseline_pred = baseline_df[baseline_df.ids == idx].pred_seg.values[0]
print(f'idx: {idx}\nlabel: {label, labels_map[label]}\nrmt predictions: {rmt_preds}\nbaseline_prediction: {baseline_pred}\n\n')

text = id2text_valid[idx]
texts = split(text)

print('\n\n\n'.join(texts))

In [None]:
i = 4

idx = list(baseline_lose)[i]
label = id2label_valid[idx] 
rmt_preds = ', '.join(rmt_df[rmt_df.ids == idx][['pred_seg_0', 'pred_seg_1']].astype(str).values[0])
baseline_pred = baseline_df[baseline_df.ids == idx].pred_seg.values[0]
print(f'idx: {idx}\nlabel: {label, labels_map[label]}\nrmt predictions: {rmt_preds}\nbaseline_prediction: {baseline_pred}\n\n')

text = id2text_valid[idx]
texts = split(text)

print('\n\n\n'.join(texts))

### selective

In [None]:


i = 2

ids = rmt_df[(rmt_df.pred_seg_0 == rmt_df.labels_seg_0) & ( rmt_df.pred_seg_1 != rmt_df.labels_seg_1)].ids.values
idx = ids[i]
label = id2label_valid[idx] 
rmt_preds = ', '.join(rmt_df[rmt_df.ids == idx][['pred_seg_0', 'pred_seg_1']].astype(str).values[0])
baseline_pred = baseline_df[baseline_df.ids == idx].pred_seg.values[0]
print(f'idx: {idx}\nlabel: {label, labels_map[label]}\nrmt predictions: {rmt_preds}\nbaseline_prediction: {baseline_pred}\n\n')

text = id2text_valid[idx]
texts = split(text)

print('\n\n\n'.join(texts))

In [None]:
rmt_df[(rmt_df.pred_seg_0 == rmt_df.labels_seg_0) & ( rmt_df.pred_seg_1 != rmt_df.labels_seg_1)].labels_seg_1.value_counts()

mmaybe finetune only memory weights
use mem outputs to decode previous segment with a separate decoder
orr pool answer to task not from cls but from mem tokens
check if  gradient flows through memory
decode using conccatenated memory, not last segment
train baseline without question on qasper

how do we improve remembering using memory