In [1]:
from convokit import Corpus, download
from tqdm import tqdm
import pandas as pd
import pickle
import re
from sklearn.model_selection import train_test_split

In [None]:
corpus = Corpus(filename=download("reddit-coarse-discourse-corpus"))

In [None]:
conversation_ids = corpus.get_conversation_ids()

In [None]:
corpus.get_conversations_dataframe().head()

In [None]:
len(conversation_ids)

In [None]:
utter_df = corpus.get_utterances_dataframe()

In [None]:
utter_df['text_len'] = utter_df.text.apply(len)

In [None]:
utter_df['text_len'].median()

In [11]:
utter_df.to_pickle('data/convokit/utter_df.pkl')

In [12]:
max_s, max_u = 0, 0
for conv_id in tqdm(conversation_ids):
    conv_data = utter_df[utter_df.conversation_id == conv_id]
    max_s = max(max_s, len(conv_data.speaker.unique()))
    max_u = max(max_u, conv_data.shape[0])

100%|██████████| 9483/9483 [02:22<00:00, 66.55it/s]


In [13]:
max_s, max_u

(40, 41)

In [2]:
with open('data/convokit/utter_df.pkl', 'rb') as f:
    utter_df = pickle.load(f)

In [3]:
def construct_raw_data(conv_df):
    speakers = {}
    utters = {}
    result = []
    for id, row in zip(conv_df.index, conv_df.values):
        utters[id] = 'u' + str(len(utters) + 1)
        speaker = row[0]
        if speaker not in speakers:
            speakers[speaker] = 's' + str(len(speakers) + 1)
        prev_id = row[1]
        item = [utters[id], speakers[speaker], 'unk', 'bos', row[3]]  # unk - none; bos - для начала
        if row[2] != None:
            item[2] = row[2]
        if prev_id != None:
            if prev_id not in utters:
                prev_id = conv_df.index[0]
            item[3] = utters[prev_id]
        for i in range(len(item) - 1):
            item[i] = '<' + item[i] + '>'
        result.append(item)
    return result


def construct_generation_examples(utter_df, conv_id):
    conv_data = utter_df[utter_df.conversation_id == conv_id][['speaker', 'reply_to', 'meta.majority_type', 'text']]
    train_examples = []
    preproc_utter = construct_raw_data(conv_data)
    for k in range(1, len(preproc_utter) - 1):
        cur_context = ' </s> '.join([' '.join(item) for item in preproc_utter[:k]])
        u_text = [el[-1] for el in preproc_utter[:k] if el[0] == preproc_utter[k][3]]
        u_text = u_text[0] if len(u_text) > 0 else ''
        cur_context += f' </s> {preproc_utter[k][1]} {preproc_utter[k][3]} {u_text}'
        cur_context = re.sub(r'\s+', ' ', cur_context)
        cur_utter = ' '.join([preproc_utter[k][2], preproc_utter[k][4]])
        train_examples.append([cur_context, cur_utter])
    return train_examples

In [4]:
conv_id = 't3_1yjwii' #'t3_1cduyx'
conv_data = utter_df[utter_df.conversation_id == conv_id][['speaker', 'reply_to', 'meta.majority_type', 'text']]

In [5]:
construct_raw_data(conv_data)

[['<u1>',
  '<s1>',
  '<announcement>',
  '<bos>',
  '[Universal Rules](http://www.reddit.com/r/uhccourtroom/wiki/banguidelines)\n\n[Universal Ban List](https://docs.google.com/spreadsheet/ccc?key=0AjACyg1Jc3_GdEhqWU5PTEVHZDVLYWphd2JfaEZXd2c#gid=0)\n\n**Notes:**\n\n---\n\n**IP:** 31.3.251.181\n\n**Version:** 1.7.4\n\n**Whitelist off:** 15 minutes before start\n\n**Game start:** 14:00 UTC\n\n**Game length:** 1.5 hours\n\n**Endgame:** Meetup at 0,0\n\n**Player Slots:** 30\n\n**Gamemode/Scenario:** FFA\n\n**PvP/iPvP:** 2nd day\n\n**Stealing:** Yes\n\n**Stalking:** No\n\n**Towering:** No\n\n**GHeads:** Yes\n\n**Absorption:** Yes\n\n**Nether:** On\n\n**Allies:** No\n\n**Additional Info:**\n\n---\n\n**Server Info:**\n\n**Ram:** 2.5GB\n\n**Located in:** London\n\n**Slots:** 30'],
 ['<u2>', '<s2>', '<unk>', '<u1>', "I'll be there^h^y^p^e"],
 ['<u3>', '<s3>', '<unk>', '<u1>', 'finally a game with nether!'],
 ['<u4>',
  '<s4>',
  '<elaboration>',
  '<u1>',
  'I dont like doing solos but ill give

In [6]:
conv_data

Unnamed: 0_level_0,speaker,reply_to,meta.majority_type,text
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
t3_1yjwii,TheQyet,,announcement,[Universal Rules](http://www.reddit.com/r/uhcc...
t1_cflroww,Azye,t3_1yjwii,,I'll be there^h^y^p^e
t1_cfls9du,Andibadia,t3_1yjwii,,finally a game with nether!
t1_cflrv4g,Djydoesmc,t3_1yjwii,elaboration,I dont like doing solos but ill give it a try ...
t1_cfls3gt,DarkAngelKing2,t3_1yjwii,question,Is pigmen spawning enabled?
t1_cfls487,TheQyet,t1_cfls3gt,answer,Pigmen will spawn in the nether :)
t1_cfls3j7,Enderdoood,t3_1yjwii,question,When whitelist is off?
t1_cfls6j9,Djydoesmc,t3_1yjwii,question,allies allowed?
t1_cfls937,TheQyet,t1_cfls6j9,answer,No allies
t1_cflsa6t,JoeStar1000,t3_1yjwii,appreciation,I'll be there too :D Thanks in advance for hos...


In [7]:
train_examples = construct_generation_examples(utter_df, conv_id)

In [8]:
train_examples

[['<u1> <s1> <announcement> <bos> [Universal Rules](http://www.reddit.com/r/uhccourtroom/wiki/banguidelines) [Universal Ban List](https://docs.google.com/spreadsheet/ccc?key=0AjACyg1Jc3_GdEhqWU5PTEVHZDVLYWphd2JfaEZXd2c#gid=0) **Notes:** --- **IP:** 31.3.251.181 **Version:** 1.7.4 **Whitelist off:** 15 minutes before start **Game start:** 14:00 UTC **Game length:** 1.5 hours **Endgame:** Meetup at 0,0 **Player Slots:** 30 **Gamemode/Scenario:** FFA **PvP/iPvP:** 2nd day **Stealing:** Yes **Stalking:** No **Towering:** No **GHeads:** Yes **Absorption:** Yes **Nether:** On **Allies:** No **Additional Info:** --- **Server Info:** **Ram:** 2.5GB **Located in:** London **Slots:** 30 </s> <s2> <u1> [Universal Rules](http://www.reddit.com/r/uhccourtroom/wiki/banguidelines) [Universal Ban List](https://docs.google.com/spreadsheet/ccc?key=0AjACyg1Jc3_GdEhqWU5PTEVHZDVLYWphd2JfaEZXd2c#gid=0) **Notes:** --- **IP:** 31.3.251.181 **Version:** 1.7.4 **Whitelist off:** 15 minutes before start **Game st

In [5]:
conversation_ids = list(set(utter_df.conversation_id.values))

In [6]:
train_ids, test_ids = train_test_split(conversation_ids, test_size=0.15, random_state=5757)

In [8]:
def remove_discourse_tokens(text):
    for item in ['<negativereaction>',
     '<other>',
     '<appreciation>',
     '<unk>',
     '<elaboration>',
     '<answer>',
     '<question>',
     '<humor>',
     '<announcement>',
     '<agreement>',
     '<disagreement>']:
        text = text.replace(item, '')
    return text


def construct_bart_input(data, save_path, col1='document', col2='summary', drop_rels_context=False, drop_rels_utter=False,
                         repl_random=True, repl_prob=0.15):
    df = pd.DataFrame()
    
    if drop_rels_context:
        df[col1] = [re.sub(r'\s+', ' ', remove_discourse_tokens(el[0])) for el in data]
    else:
        df[col1] = [re.sub(r'\s+', ' ', el[0]) for el in data]
    if drop_rels_utter:
        df[col2] = [re.sub(r'\s+', ' ', remove_discourse_tokens(el[1])) for el in data]
    else:
        df[col2] = [re.sub(r'\s+', ' ', el[1]) for el in data]
    df.to_csv(save_path, sep='\t', index=False)

In [44]:
X_train = []
for conv_id in tqdm(train_ids):
    X_train.extend(construct_generation_examples(utter_df, conv_id))

100%|██████████| 8060/8060 [01:15<00:00, 107.42it/s]


In [45]:
X_test = []
for conv_id in tqdm(test_ids):
    X_test.extend(construct_generation_examples(utter_df, conv_id))

100%|██████████| 1423/1423 [00:14<00:00, 96.35it/s] 


In [18]:
construct_bart_input(X_train, 'data/train_structure_convokit.csv', col1='context', col2='structure',
                     drop_rels_context=False, drop_rels_utter=False)

In [None]:
construct_bart_input(X_test, 'data/val_structure_convokit.csv', col1='context', col2='structure',
                     drop_rels_context=False, drop_rels_utter=False)

In [21]:
construct_bart_input(X_train, 'data/train_structure_convokit_norelut.csv', col1='context', col2='structure',
                     drop_rels_context=False, drop_rels_utter=True)

In [22]:
construct_bart_input(X_test, 'data/val_structure_convokit_norelut.csv', col1='context', col2='structure',
                    drop_rels_context=False, drop_rels_utter=True)

In [23]:
construct_bart_input(X_train, 'data/train_structure_convokit_norels.csv', col1='context', col2='structure',
                     drop_rels_context=True, drop_rels_utter=True)

In [24]:
construct_bart_input(X_test, 'data/val_structure_convokit_norels.csv', col1='context', col2='structure',
                    drop_rels_context=True, drop_rels_utter=True)

## Calculate source & target lengths

In [None]:
import numpy as np
import torch
from transformers import BartForConditionalGeneration, BartTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name_or_path = "facebook/bart-base"

tokenizer = BartTokenizer.from_pretrained(model_name_or_path)
model =  BartForConditionalGeneration.from_pretrained(model_name_or_path).to(device) # to check load

In [None]:
additional_special_tokens = ['<negativereaction>',
     '<other>',
     '<appreciation>',
     '<unk>',
     '<elaboration>',
     '<answer>',
     '<question>',
     '<humor>',
     '<announcement>',
     '<agreement>',
     '<disagreement>']

In [None]:
max_s, max_u = (40, 41)
for s in range(1, max_s+1):
    additional_special_tokens.append('<s' + str(s) + '>')
for u in range(1, max_u+1):
    additional_special_tokens.append('<u' + str(u) + '>')

In [89]:
special_tokens_dict = {'additional_special_tokens': additional_special_tokens,
                         'bos_token': '<s>',
                         'eos_token': '</s>',
                         'unk_token': '<unk>',
                         'sep_token': '</s>',
                         'pad_token': '<pad>',
                         'cls_token': '<s>',
                         'mask_token': '<mask>'}

with open('data/special_tokens_map_convokit.pkl', 'wb') as f:
    pickle.dump(special_tokens_dict, f)

In [90]:
with open('data/special_tokens_map_convokit.pkl', 'rb') as f:
    special_tokens_dict = pickle.load(f)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [91]:
tokenizer.vocab_size

50265

In [95]:
num_added_toks + tokenizer.vocab_size # COPY TO WEIGHTS IN MODELING.PY

50356

In [96]:
tokenizer.decode([50265, 50266, 50267, 50268, 50269, 50270, 50271, 50272, 50273, 50274])

'<negativereaction> <other> <appreciation> <elaboration> <answer> <question> <humor> <announcement> <agreement> <disagreement>'

In [None]:
num_tokens_text = []
num_tokens_summ = []
for record in tqdm(X_train):
    num_tokens_text.append(len(tokenizer.encode(record[0])))
    num_tokens_summ.append(len(tokenizer.encode(record[1])))

In [85]:
np.mean(num_tokens_text), np.median(num_tokens_text), np.quantile(num_tokens_text, 0.75)

(887.4763411564917, 575.0, 1166.0)

In [86]:
np.mean(num_tokens_summ), np.median(num_tokens_summ), np.quantile(num_tokens_summ, 0.75)

(57.10955272385188, 32.0, 64.0)

## Train Structure Generator

In [None]:
# change special tokens map path in run_summarization.py
!CUDA_VISIBLE_DEVICES=0 python custom_bart_scripts_weights/run_summarization.py \
    --model_name_or_path="facebook/bart-base" \
    --train_file="data/train_structure_convokit.csv" \
    --validation_file="data/val_structure_convokit.csv" \
    --text_column="context" \
    --summary_column="structure" \
    --max_source_length=1024 \
    --max_target_length=64 \
    --do_train \
    --do_eval \
    --per_device_train_batch_size=1 \
    --per_device_eval_batch_size=1 \
    --gradient_accumulation_steps=2 \
    --learning_rate=2e-5 \
    --class_weights=30 \
    --save_steps=80000 \
    --num_train_epochs=5 \
    --output_dir="checkpoints/structure_custom_bart_convokit_bs_1_2_lr_2e5_ep_5_w_30" \
    --overwrite_output_dir

Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
06/18/2022 15:56:24 - INFO - __main__ - Training/evaluation parameters Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
bf16=False,
bf16_full_eval=False,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_steps=None,
evaluation_strategy=IntervalStrategy.NO,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
generation_max_length=None,
generation_num_beams=None,
gradient_accumulation_steps=2,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_mo

[INFO|modeling_utils.py:1427] 2022-06-18 15:56:31,863 >> loading weights file https://huggingface.co/facebook/bart-base/resolve/main/pytorch_model.bin from cache at /home/aschernyavskiy/.cache/huggingface/transformers/486355ec722ef05fd480e999d4c763be56549ae930f6a3742ee721a5d2a05647.f2f355ad2775769afc60592b43a46d72ca548375e3a1d65f381a751e711cbadd
CUSTOM BART with class_weight=30.0
[INFO|modeling_utils.py:1694] 2022-06-18 15:56:51,171 >> All model checkpoint weights were used when initializing BartForConditionalGeneration.

[INFO|modeling_utils.py:1703] 2022-06-18 15:56:51,171 >> All the weights of BartForConditionalGeneration were initialized from the model checkpoint at facebook/bart-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditionalGeneration for predictions without further training.
[INFO|tokenization_utils_base.py:888] 2022-06-18 15:56:51,191 >> Assigning ['<negativereaction>', '<other>', '<appreciation>', '<u

[INFO|trainer.py:1244] 2022-06-18 15:56:57,589 >> ***** Running training *****
[INFO|trainer.py:1245] 2022-06-18 15:56:57,589 >>   Num examples = 81984
[INFO|trainer.py:1246] 2022-06-18 15:56:57,589 >>   Num Epochs = 5
[INFO|trainer.py:1247] 2022-06-18 15:56:57,589 >>   Instantaneous batch size per device = 1
[INFO|trainer.py:1248] 2022-06-18 15:56:57,589 >>   Total train batch size (w. parallel, distributed & accumulation) = 2
[INFO|trainer.py:1249] 2022-06-18 15:56:57,590 >>   Gradient Accumulation steps = 2
[INFO|trainer.py:1250] 2022-06-18 15:56:57,590 >>   Total optimization steps = 204960
{'loss': 3.5173, 'learning_rate': 1.9951209992193602e-05, 'epoch': 0.01}        
{'loss': 3.1564, 'learning_rate': 1.99024199843872e-05, 'epoch': 0.02}          
{'loss': 3.1401, 'learning_rate': 1.9853629976580797e-05, 'epoch': 0.04}        
{'loss': 3.0509, 'learning_rate': 1.9804839968774398e-05, 'epoch': 0.05}        
{'loss': 3.1103, 'learning_rate': 1.9756049960967995e-05, 'epoch': 0.06}  

{'loss': 2.4582, 'learning_rate': 1.175448868071819e-05, 'epoch': 2.06}         
{'loss': 2.4348, 'learning_rate': 1.170569867291179e-05, 'epoch': 2.07}         
{'loss': 2.3748, 'learning_rate': 1.1656908665105387e-05, 'epoch': 2.09}        
{'loss': 2.4667, 'learning_rate': 1.1608118657298986e-05, 'epoch': 2.1}         
{'loss': 2.457, 'learning_rate': 1.1559328649492585e-05, 'epoch': 2.11}         
{'loss': 2.4596, 'learning_rate': 1.1510538641686184e-05, 'epoch': 2.12}        
{'loss': 2.4409, 'learning_rate': 1.1461748633879783e-05, 'epoch': 2.13}        
{'loss': 2.4717, 'learning_rate': 1.141295862607338e-05, 'epoch': 2.15}         
{'loss': 2.3646, 'learning_rate': 1.136416861826698e-05, 'epoch': 2.16}         
{'loss': 2.5379, 'learning_rate': 1.1315378610460579e-05, 'epoch': 2.17}        
{'loss': 2.4362, 'learning_rate': 1.1266588602654178e-05, 'epoch': 2.18}        
{'loss': 2.415, 'learning_rate': 1.1217798594847775e-05, 'epoch': 2.2}          
{'loss': 2.4194, 'learning_r

{'loss': 2.2312, 'learning_rate': 2.6307572209211556e-06, 'epoch': 4.34}        
{'loss': 2.2005, 'learning_rate': 2.5819672131147543e-06, 'epoch': 4.35}        
{'loss': 2.1928, 'learning_rate': 2.533177205308353e-06, 'epoch': 4.37}         
{'loss': 2.2766, 'learning_rate': 2.4843871975019516e-06, 'epoch': 4.38}        
{'loss': 2.2493, 'learning_rate': 2.4355971896955503e-06, 'epoch': 4.39}        
{'loss': 2.2079, 'learning_rate': 2.386807181889149e-06, 'epoch': 4.4}          
{'loss': 2.2284, 'learning_rate': 2.338017174082748e-06, 'epoch': 4.42}         
{'loss': 2.2553, 'learning_rate': 2.2892271662763467e-06, 'epoch': 4.43}        
{'loss': 2.2131, 'learning_rate': 2.2404371584699454e-06, 'epoch': 4.44}        
{'loss': 2.2212, 'learning_rate': 2.191647150663544e-06, 'epoch': 4.45}         
{'loss': 2.2558, 'learning_rate': 2.1428571428571427e-06, 'epoch': 4.46}        
{'loss': 2.2216, 'learning_rate': 2.094067135050742e-06, 'epoch': 4.48}         
{'loss': 2.2236, 'learning_r

## Test model

In [1]:
import pandas as pd
import pickle
import re
import string
from tqdm import tqdm

In [2]:
import warnings
warnings.filterwarnings("ignore")

import torch
import numpy as np
from transformers import BartForConditionalGeneration, BartTokenizer
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#model_name_or_path = 'checkpoints/structure_custom_bart_convokit_bs_1_2_lr_2e5_ep_5_noisy_0.5'

In [3]:
model_name_or_path = 'checkpoints/structure_custom_bart_convokit_bs_1_2_lr_2e5_ep_5_w_100_v2'

In [None]:
tokenizer = BartTokenizer.from_pretrained(model_name_or_path)
model = BartForConditionalGeneration.from_pretrained(model_name_or_path).train(False).to(device)

In [None]:
def generate_top(text, num_beams=4,  max_source_len=1024, max_target_length=64, top_k=50, top_p=1):
    inputs = tokenizer([text], max_length=max_source_len, return_tensors="pt", truncation=True, padding = False).to(device)
    summary_ids = model.generate(inputs["input_ids"], do_sample=True,num_beams=num_beams,
                                 max_length=max_target_length, top_k=top_k, top_p=top_p)
    pred = tokenizer.batch_decode(summary_ids, clean_up_tokenization_spaces=False)[0]
    pred = re.sub(r'\s+', ' ', pred).replace('</s>', '').replace('<s>', '').strip()
    return pred

In [None]:
test_data = pd.read_csv("data/val_structure_convokit.csv", sep='\t')

In [None]:
X_test = test_data['context'].values
y_test = test_data['structure'].values

In [None]:
k = 13
X_test[k], y_test[k]

In [None]:
generate_top(X_test[19], top_k=50, num_beams=1)

In [None]:
preds = []
for i, text in tqdm(enumerate(X_test), total=len(X_test)):
    try:
        preds.append([text, generate_top(text, top_k=50, num_beams=1)])
    except:
        print(i)
        preds.append([text, 'err'])
        continue

In [None]:
with open('predictions/{}.pkl'.format(model_name_or_path.replace('checkpoints/', '')), 'wb') as f:
    pickle.dump([X_test, preds], f)