In [4]:
import pandas as pd
import pickle
import re
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [5]:
with open('../groundhog/data/threads_with_metas_3ut_aug_full.pkl', 'rb') as f:
    threads = pickle.load(f)

In [6]:
with open('../groundhog/data/threads_with_metas_3ut_aug_new.pkl', 'rb') as f:
    threads_new, _ = pickle.load(f)

In [7]:
len(threads), len(threads_new)

(39803, 22608)

In [8]:
threads += threads_new

In [9]:
sep_token = '</s>'

In [10]:
utt_set = set()
speakers_set = set()

In [11]:
def get_discourse_tokens(discourse_list):
    return [f'<u{discourse_list[1]+1}>', f'<to:u{discourse_list[0]+1}>', f'<{discourse_list[2]}>']


def get_aug_value(ut, speaker='<s1>'):
    return ' '.join([speaker] + get_discourse_tokens(ut['discourse']) + [ut['text']])


def get_aug_value_nodis(ut, speaker='<s1>'):
    return ' '.join([speaker] + get_discourse_tokens(ut['discourse'])[:-1] + [ut['text']])

def preproc_text(text, utt_set, speakers_set):
    utt_set |= set(re.findall(r'<u\d+>', text))
    speakers_set |= set(re.findall(r'<s\d+>', text))
    if type(text) == str:
        res = re.sub(r'\s+', ' ', str(text))
        if len(res.strip()) == 0:
            return 'unk'
        return res.strip()
    return 'unk'


def get_dialogue_instances(threads, utt_set, speakers_set):
    utter_covered = set() # кажду реплику генерим только один раз
    
    result = []
    for thr in tqdm(threads):
        try:
            speakers = {}

            for i, ut in enumerate(thr['dialogue']):
                speaker = ut['speaker']
                if speaker not in speakers:
                    speakers[speaker] = '<s' + str(len(speakers) + 1) + '>'

                if i >= 2:
                    if thr['id'] + '_' + ut['id'] not in utter_covered:
                        utter_covered.add(thr['id'] + '_' + ut['id'])

                        utter_dict = {
                            'thread_id': thr['id'],
                            'id': thr['id'] + '_' + ut['id'],
                            'history': f' {sep_token} '.join([get_aug_value_nodis(ut_his, speakers[ut_his['speaker']]) for
                                                             ut_his in thr['dialogue'][:i]] + [speakers[ut['speaker']]]),
                            'history_aug': f' {sep_token} '.join([get_aug_value(ut_his, speakers[ut_his['speaker']]) for
                                                                  ut_his in thr['dialogue'][:i]] +
                                                                 [' '.join(get_aug_value(ut, speakers[ut['speaker']]).split()[:3])]),
                            'response': ut['text'],
                            'response_aug': ' '.join(get_aug_value(ut, speakers[ut['speaker']]).split()[3:]),
                        }

                        for k in utter_dict:
                            try:
                                utter_dict[k] = preproc_text(utter_dict[k], utt_set, speakers_set)
                            except:
                                utter_dict[k] = 'unk'

                        if len(utter_dict['response']) > 3:
                            result.append(utter_dict)
        except:
            continue
                    
    return pd.DataFrame(result)

In [12]:
dialogue_df = get_dialogue_instances(threads, utt_set, speakers_set)

100%|██████████| 62411/62411 [00:22<00:00, 2715.81it/s]


In [13]:
dialogue_df.shape

(111202, 6)

In [14]:
# split by thread ids
train_threads, val_threads = train_test_split(list(dialogue_df['thread_id'].unique()), test_size=0.1, random_state=575)
train_df = dialogue_df[dialogue_df.thread_id.isin(train_threads)]
val_df = dialogue_df[dialogue_df.thread_id.isin(val_threads)]

In [15]:
train_df.shape, val_df.shape

((100621, 6), (10581, 6))

In [16]:
train_df.head()

Unnamed: 0,thread_id,id,history,history_aug,response,response_aug
0,t3_zm90dt,t3_zm90dt_j0agl5x,<s1> <u1> <to:u1> I don’t think Henry is going...,<s1> <u1> <to:u1> <init> I don’t think Henry i...,Yup. They burnt that bridge into ashes at this...,<answer> Yup. They burnt that bridge into ashe...
1,t3_zm90dt,t3_zm90dt_j0bhhx4,<s1> <u1> <to:u1> I don’t think Henry is going...,<s1> <u1> <to:u1> <init> I don’t think Henry i...,"What's james gonna say, that he completely fuc...","<question> What's james gonna say, that he com..."
2,t3_zm90dt,t3_zm90dt_j0c0y3j,<s1> <u1> <to:u1> I don’t think Henry is going...,<s1> <u1> <to:u1> <init> I don’t think Henry i...,The Rock did that. It was up to Gunn to make t...,<answer> The Rock did that. It was up to Gunn ...
3,t3_zm90dt,t3_zm90dt_j0zisqb,<s1> <u1> <to:u1> I don’t think Henry is going...,<s1> <u1> <to:u1> <init> I don’t think Henry i...,So many people putting so much power into The ...,<elaboration> So many people putting so much p...
4,t3_zm90dt,t3_zm90dt_j0zj4qg,<s1> <u1> <to:u1> I don’t think Henry is going...,<s1> <u1> <to:u1> <init> I don’t think Henry i...,And he didn’t even turn in a movie that made t...,<elaboration> And he didn’t even turn in a mov...


In [17]:
train_df['history_aug'].values[0]

"<s1> <u1> <to:u1> <init> I don’t think Henry is going to play a different DC character. His instagram post reads like he’s done with DC films, not that he’s going to be in something else just not as Superman. </s> <s2> <u2> <to:u1> <question> At this point hasn't WB burnt the bridge? Toying with the character for years he finally gets welcomed back and now he gets kicked out again?? </s> <s1> <u3> <to:u2>"

In [18]:
train_df.to_csv('data/train_structure_reddit.csv', sep='\t', index=False)
val_df.to_csv('data/val_structure_reddit.csv', sep='\t', index=False)

In [30]:
additional_special_tokens = ['<negativereaction>',
     '<other>',
     '<appreciation>',
     '<unk>',
     '<elaboration>',
     '<answer>',
     '<question>',
     '<humor>',
     '<announcement>',
     '<agreement>',
     '<disagreement>']

In [31]:
max_s, max_u = (40, 41)
for s in range(1, max_s+1):
    additional_special_tokens.append('<s' + str(s) + '>')
for u in range(1, max_u+1):
    additional_special_tokens.append('<u' + str(u) + '>')
for u in range(1, max_u+1):
    additional_special_tokens.append('<to:u' + str(u) + '>')

In [32]:
special_tokens_dict = {'additional_special_tokens': additional_special_tokens,
                         'bos_token': '<s>',
                         'eos_token': '</s>',
                         'unk_token': '<unk>',
                         'sep_token': '</s>',
                         'pad_token': '<pad>',
                         'cls_token': '<s>',
                         'mask_token': '<mask>'}

with open('data/special_tokens_map_reddit.pkl', 'wb') as f:
    pickle.dump(special_tokens_dict, f)

In [38]:
import numpy as np
import torch
from transformers import BartForConditionalGeneration, BartTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name_or_path = "facebook/bart-base"

tokenizer = BartTokenizer.from_pretrained(model_name_or_path)
model =  BartForConditionalGeneration.from_pretrained(model_name_or_path).to(device) # to check load

In [39]:
with open('data/special_tokens_map_reddit.pkl', 'rb') as f:
    special_tokens_dict = pickle.load(f)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [40]:
num_added_toks + tokenizer.vocab_size # COPY TO WEIGHTS IN MODELING.PY

50397

In [41]:
tokenizer.decode([50265, 50266, 50267, 50268, 50269, 50270, 50271, 50272, 50273, 50274])

'<negativereaction> <other> <appreciation> <elaboration> <answer> <question> <humor> <announcement> <agreement> <disagreement>'

## Train models

In [None]:
!CUDA_VISIBLE_DEVICES=0 python custom_bart_scripts_weights/run_summarization.py \
    --model_name_or_path="facebook/bart-base" \
    --train_file="data/train_structure_reddit.csv" \
    --validation_file="data/val_structure_reddit.csv" \
    --text_column="history_aug" \
    --summary_column="response_aug" \
    --max_source_length=1024 \
    --max_target_length=64 \
    --do_train \
    --do_eval \
    --per_device_train_batch_size=1 \
    --per_device_eval_batch_size=1 \
    --gradient_accumulation_steps=2 \
    --learning_rate=2e-5 \
    --class_weights=0. \
    --save_steps=80000 \
    --num_train_epochs=5 \
    --output_dir="checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_0_cp" \
    --overwrite_output_dir

Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
05/17/2023 11:18:50 - INFO - __main__ - Training/evaluation parameters Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
bf16=False,
bf16_full_eval=False,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_steps=None,
evaluation_strategy=IntervalStrategy.NO,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
generation_max_length=None,
generation_num_beams=None,
gradient_accumulation_steps=2,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_mo

[INFO|modeling_utils.py:1427] 2023-05-17 11:20:14,458 >> loading weights file https://huggingface.co/facebook/bart-base/resolve/main/pytorch_model.bin from cache at /home/aschernyavskiy/.cache/huggingface/transformers/486355ec722ef05fd480e999d4c763be56549ae930f6a3742ee721a5d2a05647.f2f355ad2775769afc60592b43a46d72ca548375e3a1d65f381a751e711cbadd
CUSTOM BART with class_weight=0.0
[INFO|modeling_utils.py:1694] 2023-05-17 11:20:22,223 >> All model checkpoint weights were used when initializing BartForConditionalGeneration.

[INFO|modeling_utils.py:1703] 2023-05-17 11:20:22,223 >> All the weights of BartForConditionalGeneration were initialized from the model checkpoint at facebook/bart-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditionalGeneration for predictions without further training.
[INFO|tokenization_utils_base.py:888] 2023-05-17 11:20:22,264 >> Assigning ['<negativereaction>', '<other>', '<appreciation>', '<un

Running tokenizer on validation dataset:   0%|           | 0/11 [00:00<?, ?ba/s]05/17/2023 11:20:28 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/aschernyavskiy/.cache/huggingface/datasets/csv/default-1543fcdae0d91a45/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e/cache-af6bc30ff5c33de1.arrow
Running tokenizer on validation dataset: 100%|██| 11/11 [00:21<00:00,  1.93s/ba]
[INFO|trainer.py:1244] 2023-05-17 11:20:56,931 >> ***** Running training *****
[INFO|trainer.py:1245] 2023-05-17 11:20:56,932 >>   Num examples = 100621
[INFO|trainer.py:1246] 2023-05-17 11:20:56,932 >>   Num Epochs = 5
[INFO|trainer.py:1247] 2023-05-17 11:20:56,932 >>   Instantaneous batch size per device = 1
[INFO|trainer.py:1248] 2023-05-17 11:20:56,932 >>   Total train batch size (w. parallel, distributed & accumulation) = 2
[INFO|trainer.py:1249] 2023-05-17 11:20:56,932 >>   Gradient Accumulation steps = 2
[INFO|trainer.py:1250] 2023-05-17 11:20:56,932 >>   Total opt

{'loss': 3.4611, 'learning_rate': 1.3321407274895648e-05, 'epoch': 1.67}        
{'loss': 3.4284, 'learning_rate': 1.3281653746770027e-05, 'epoch': 1.68}        
{'loss': 3.4582, 'learning_rate': 1.3241900218644406e-05, 'epoch': 1.69}        
{'loss': 3.4344, 'learning_rate': 1.3202146690518785e-05, 'epoch': 1.7}         
{'loss': 3.4175, 'learning_rate': 1.3162393162393164e-05, 'epoch': 1.71}        
{'loss': 3.3978, 'learning_rate': 1.3122639634267543e-05, 'epoch': 1.72}        
{'loss': 3.4509, 'learning_rate': 1.308288610614192e-05, 'epoch': 1.73}         
{'loss': 3.4422, 'learning_rate': 1.30431325780163e-05, 'epoch': 1.74}          
{'loss': 3.4426, 'learning_rate': 1.3003379049890678e-05, 'epoch': 1.75}        
{'loss': 3.4426, 'learning_rate': 1.2963625521765058e-05, 'epoch': 1.76}        
{'loss': 3.4182, 'learning_rate': 1.2923871993639437e-05, 'epoch': 1.77}        
{'loss': 3.4613, 'learning_rate': 1.2884118465513816e-05, 'epoch': 1.78}        
{'loss': 3.4216, 'learning_r

In [None]:
!CUDA_VISIBLE_DEVICES=0 python custom_bart_scripts_weights/run_summarization.py \
    --model_name_or_path="facebook/bart-base" \
    --train_file="data/train_structure_reddit.csv" \
    --validation_file="data/val_structure_reddit.csv" \
    --text_column="history_aug" \
    --summary_column="response_aug" \
    --max_source_length=1024 \
    --max_target_length=64 \
    --do_train \
    --do_eval \
    --per_device_train_batch_size=1 \
    --per_device_eval_batch_size=1 \
    --gradient_accumulation_steps=2 \
    --learning_rate=2e-5 \
    --class_weights=100. \
    --save_steps=80000 \
    --num_train_epochs=5 \
    --output_dir="checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_100" \
    --overwrite_output_dir

Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
05/15/2023 11:55:32 - INFO - __main__ - Training/evaluation parameters Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
bf16=False,
bf16_full_eval=False,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_steps=None,
evaluation_strategy=IntervalStrategy.NO,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
generation_max_length=None,
generation_num_beams=None,
gradient_accumulation_steps=2,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_mo

[INFO|modeling_utils.py:1427] 2023-05-15 11:56:56,370 >> loading weights file https://huggingface.co/facebook/bart-base/resolve/main/pytorch_model.bin from cache at /home/aschernyavskiy/.cache/huggingface/transformers/486355ec722ef05fd480e999d4c763be56549ae930f6a3742ee721a5d2a05647.f2f355ad2775769afc60592b43a46d72ca548375e3a1d65f381a751e711cbadd
CUSTOM BART with class_weight=100.0
[INFO|modeling_utils.py:1694] 2023-05-15 11:57:02,066 >> All model checkpoint weights were used when initializing BartForConditionalGeneration.

[INFO|modeling_utils.py:1703] 2023-05-15 11:57:02,066 >> All the weights of BartForConditionalGeneration were initialized from the model checkpoint at facebook/bart-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditionalGeneration for predictions without further training.
[INFO|tokenization_utils_base.py:888] 2023-05-15 11:57:02,069 >> Assigning ['<negativereaction>', '<other>', '<appreciation>', '<

Running tokenizer on train dataset:   0%|               | 0/101 [00:00<?, ?ba/s]05/15/2023 11:57:06 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/aschernyavskiy/.cache/huggingface/datasets/csv/default-1543fcdae0d91a45/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e/cache-7a30ee22fc6c6edf.arrow
Running tokenizer on train dataset: 100%|█████| 101/101 [03:05<00:00,  1.83s/ba]
Running tokenizer on validation dataset:   0%|           | 0/11 [00:00<?, ?ba/s]05/15/2023 12:00:12 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/aschernyavskiy/.cache/huggingface/datasets/csv/default-1543fcdae0d91a45/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e/cache-7933951e04031a88.arrow
Running tokenizer on validation dataset: 100%|██| 11/11 [00:19<00:00,  1.77s/ba]
[INFO|trainer.py:1244] 2023-05-15 12:00:30,878 >> ***** Running training *****
[INFO|trainer.py:1245] 2023-05-15 12:00:30,881 >>   Num examples = 100621
[IN

{'loss': 2.1015, 'learning_rate': 1.3400914331146891e-05, 'epoch': 1.65}        
{'loss': 2.0688, 'learning_rate': 1.3361160803021268e-05, 'epoch': 1.66}        
{'loss': 2.1091, 'learning_rate': 1.3321407274895648e-05, 'epoch': 1.67}        
{'loss': 2.0806, 'learning_rate': 1.3281653746770027e-05, 'epoch': 1.68}        
{'loss': 2.1158, 'learning_rate': 1.3241900218644406e-05, 'epoch': 1.69}        
{'loss': 2.123, 'learning_rate': 1.3202146690518785e-05, 'epoch': 1.7}          
{'loss': 2.1195, 'learning_rate': 1.3162393162393164e-05, 'epoch': 1.71}        
{'loss': 2.064, 'learning_rate': 1.3122639634267543e-05, 'epoch': 1.72}         
{'loss': 2.1337, 'learning_rate': 1.308288610614192e-05, 'epoch': 1.73}         
{'loss': 2.1004, 'learning_rate': 1.30431325780163e-05, 'epoch': 1.74}          
{'loss': 2.0676, 'learning_rate': 1.3003379049890678e-05, 'epoch': 1.75}        
{'loss': 2.0996, 'learning_rate': 1.2963625521765058e-05, 'epoch': 1.76}        
{'loss': 2.1454, 'learning_r

In [None]:
!CUDA_VISIBLE_DEVICES=0 python custom_bart_scripts_weights/run_summarization.py \
    --model_name_or_path="facebook/bart-base" \
    --train_file="data/train_structure_reddit.csv" \
    --validation_file="data/val_structure_reddit.csv" \
    --text_column="history_aug" \
    --summary_column="response" \
    --max_source_length=1024 \
    --max_target_length=64 \
    --do_train \
    --do_eval \
    --per_device_train_batch_size=1 \
    --per_device_eval_batch_size=1 \
    --gradient_accumulation_steps=2 \
    --learning_rate=2e-5 \
    --class_weights=0. \
    --save_steps=80000 \
    --num_train_epochs=5 \
    --output_dir="checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norelut" \
    --overwrite_output_dir

In [None]:
!CUDA_VISIBLE_DEVICES=0 python custom_bart_scripts_weights/run_summarization.py \
    --model_name_or_path="facebook/bart-base" \
    --train_file="data/train_structure_reddit.csv" \
    --validation_file="data/val_structure_reddit.csv" \
    --text_column="history" \
    --summary_column="response" \
    --max_source_length=1024 \
    --max_target_length=64 \
    --do_train \
    --do_eval \
    --per_device_train_batch_size=1 \
    --per_device_eval_batch_size=1 \
    --gradient_accumulation_steps=2 \
    --learning_rate=2e-5 \
    --class_weights=0. \
    --save_steps=80000 \
    --num_train_epochs=5 \
    --output_dir="checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norels" \
    --overwrite_output_dir

## Test model

In [8]:
import pandas as pd
import pickle
import re
import string
from tqdm import tqdm

In [9]:
import warnings
warnings.filterwarnings("ignore")

import torch
import numpy as np
from transformers import BartForConditionalGeneration, BartTokenizer
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [10]:
#model_name_or_path = 'checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_100'
model_name_or_path = 'checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_0_cp'

In [11]:
tokenizer = BartTokenizer.from_pretrained(model_name_or_path)
model = BartForConditionalGeneration.from_pretrained(model_name_or_path).train(False).to(device)

In [12]:
def generate_top(text, model, tokenizer, num_beams=4,  max_source_len=1024, max_target_length=64, top_k=50, top_p=1):
    inputs = tokenizer([text], max_length=max_source_len, return_tensors="pt", truncation=True, padding = False).to(device)
    summary_ids = model.generate(inputs["input_ids"], do_sample=True,num_beams=num_beams,
                                 max_length=max_target_length, top_k=top_k, top_p=top_p)
    pred = tokenizer.batch_decode(summary_ids, clean_up_tokenization_spaces=False)[0]
    pred = re.sub(r'\s+', ' ', pred).replace('</s>', '').replace('<s>', '').strip()
    return pred

In [13]:
test_data = pd.read_csv("data/val_structure_reddit.csv", sep='\t')

In [14]:
X_test = test_data['history_aug'].values
y_test = test_data['response_aug'].values

In [15]:
k = 200
X_test[k], y_test[k]

("<s1> <u1> <to:u1> <init> People act like DC doesn’t make good animated movies, I’d love it if they at least gave us a conclusion, heck it might be even work better animated. </s> <s2> <u2> <to:u1> <unk> I'd watch the hell out of an animated wrap up of the Snyderverse. </s> <s3> <u3> <to:u1>",
 "<elaboration> TBH Superman/Batman: Apocalypse is a better film than most of the live action stuff DC has put out. Animated directed by Snyder would be an interesting way to go hell I'm still hoping the do Batman 89 animated with Michael Keaton and Michelle Pfieffer reprising their roles.")

In [16]:
generate_top(X_test[k], model, tokenizer, num_beams=1)

'<unk>We know they are done. But it will have been a very confusing 2 seasons.'

## Predict

In [None]:
import pandas as pd
import pickle
import re
import string
from tqdm import tqdm

In [None]:
import warnings
warnings.filterwarnings("ignore")

import torch
import numpy as np
from transformers import BartForConditionalGeneration, BartTokenizer
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#model_name_or_path = 'checkpoints/structure_custom_bart_convokit_bs_1_2_lr_2e5_ep_5_noisy_0.5'

In [None]:
test_data = pd.read_csv("data/val_structure_reddit.csv", sep='\t')

In [None]:
def generate_top(text, model, tokenizer, num_beams=4,  max_source_len=1024, max_target_length=64, top_k=50, top_p=1):
    inputs = tokenizer([text], max_length=max_source_len, return_tensors="pt", truncation=True, padding = False).to(device)
    summary_ids = model.generate(inputs["input_ids"], do_sample=True,num_beams=num_beams,
                                 max_length=max_target_length, top_k=top_k, top_p=top_p)
    pred = tokenizer.batch_decode(summary_ids, clean_up_tokenization_spaces=False)[0]
    pred = re.sub(r'\s+', ' ', pred).replace('</s>', '').replace('<s>', '').strip()
    return pred

In [34]:
name2col = {
    "checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_0_cp/checkpoint-160000": "history_aug",
    "checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_100": "history_aug",
    "checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norelut": "history_aug",
    "checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norels": "history",
}

In [None]:
for model_name_or_path in [
    "checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_0_cp",
    #"checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_100",
    #"checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norelut",
    #"checkpoint/structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norels",
]:
    tokenizer = BartTokenizer.from_pretrained(model_name_or_path)
    model = BartForConditionalGeneration.from_pretrained(model_name_or_path).train(False).to(device)
    
    X_test = test_data[name2col[model_name_or_path]].values
    #y_test = test_data['structure'].values
    
    preds = []
    for i, text in tqdm(enumerate(X_test), total=len(X_test)):
        #try:
        preds.append([text, generate_top(text, model, tokenizer, top_k=50, num_beams=1)])
#         except:
#             print(i)
#             preds.append([text, 'err'])
#             continue
            
    with open('predictions/{}.pkl'.format(model_name_or_path.replace('checkpoint/', '').replace('/checkpoint-', '-')), 'wb') as f:
        pickle.dump([X_test, preds], f)

## Calculate metrics

In [1]:
import os
import numpy as np
import pandas as pd
import pickle
from rouge import Rouge
import string

In [2]:
from nltk.translate.bleu_score import sentence_bleu
from nltk import word_tokenize

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
def calc_accuracy(preds):
    y_true = np.array([p[0] for p in preds])
    y_pred = np.array([p[1] for p in preds])
    return np.mean(y_true == y_pred)

In [36]:
results_paths = [
    "structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_100.pkl",
    "structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norelut.pkl",
    "structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norels.pkl",
]

In [37]:
name2col = {
    "structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_100.pkl": "history_aug",
    "structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norelut.pkl": "history_aug",
    "structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norels.pkl": "history",
}

In [38]:
test_data = pd.read_csv("data/val_structure_reddit.csv", sep='\t')
y_test = test_data['response_aug'].values

for res_path in results_paths:
    print(res_path)
    X_test = test_data[name2col[res_path]].values
    
    with open('predictions/' + res_path, 'rb') as f:
        _, preds = pickle.load(f)
    
    for i in range(len(preds)):
        if preds[i][1].startswith('<unk>') and preds[i][1][4] != ' ':
            preds[i][1] = '<unk> ' + preds[i][1][5:]
    
    if res_path.endswith('norelut.pkl') or res_path.endswith('norels.pkl'):
        for i in range(len(preds)):
            preds[i][1] = '<unk> ' + preds[i][1]
            
    print('No errors:', len([p for p in preds if p[1] != 'err']))
    
    relations = []
    cnt_err = 0
    for i in range(len(preds)):
        if preds[i][1] != 'err':
            pred_rel = preds[i][1].split(' ', 1)[0]
            relation = y_test[i].split(' ', 1)[0]
            relations.append([relation, pred_rel])
            
    print('Accuracy:', round(calc_accuracy(relations), 3))
    
    rouge = Rouge()
    hyps, refs = [], []
    for i in range(len(preds)):
        #try:
        if len(preds[i][1].split(' ', 1)) > 1:
            hyps.append(preds[i][1].split(' ', 1)[1])
        else:
            hyps.append('')
            
        if len(y_test[i].split(' ', 1)) > 1:
            refs.append(y_test[i].split(' ', 1)[1])
        else:
            refs.append('')
        #except:
        #    continue
    
    gen_ref = zip(hyps, refs)
    gen_ref = [_ for _ in gen_ref if not all(j in string.punctuation for j in _[1]) and not all(j in string.punctuation for j in _[0])]
    gens, refs  = zip(*gen_ref)
    
    rouge_res = rouge.get_scores(gens, refs, avg=True, ignore_empty=False)
    print()
    print('ROUGE-1:', round(100 * rouge_res['rouge-1']['f'], 2))
    print('ROUGE-2:', round(100 * rouge_res['rouge-2']['f'], 2))
    print('ROUGE-L:', round(100 * rouge_res['rouge-l']['f'], 2))
    
    mean_bleu = 0
    for gen, ref in zip(gens, refs):
        mean_bleu += sentence_bleu([word_tokenize(ref)], word_tokenize(gen), weights=[1,0,0,0])
    mean_bleu /= len(gens)
    print()
    print('BLEU-1:', round(100 * mean_bleu, 2))
    
    mean_bleu = 0
    for gen, ref in zip(gens, refs):
        mean_bleu += sentence_bleu([word_tokenize(ref)], word_tokenize(gen), weights=[1,1,0,0])
    mean_bleu /= len(gens)
    print('BLEU-2:', round(100 * mean_bleu, 2))
    
    print('\n' + '-'*50 + '\n')

structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_100.pkl
No errors: 10581
Accuracy: 0.441

ROUGE-1: 8.89
ROUGE-2: 0.58
ROUGE-L: 7.96

BLEU-1: 8.11
BLEU-2: 0.17

--------------------------------------------------

structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norelut.pkl
No errors: 10581
Accuracy: 0.155

ROUGE-1: 7.76
ROUGE-2: 0.54
ROUGE-L: 7.07

BLEU-1: 6.8
BLEU-2: 0.2

--------------------------------------------------

structure_custom_bart_reddit_bs_1_2_lr_2e5_ep_5_w_norels.pkl
No errors: 10581
Accuracy: 0.155

ROUGE-1: 7.45
ROUGE-2: 0.51
ROUGE-L: 6.77

BLEU-1: 6.47
BLEU-2: 0.17

--------------------------------------------------

