In [1]:
!ls -l ../interactive_lm/wiki/models/ 

total 28
drwxrwxr-x 5 shapkin shapkin 4096 Feb  4 09:30 't5-base stage 2 p(comment, x_t+1 | x_t, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  4 09:22 't5-base stage 2 p(comment, x_t+1 | x_t, doc)_new'
drwxrwxr-x 5 shapkin shapkin 4096 Jan 30 07:51 't5-small p(comment | x_t, x_t+1, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  6 10:32 't5-small stage 2 _ 2 losses _ full p(comment, x_t+1 | x_t, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  5 21:28 't5-small stage 2 _ 2 losses p(comment, x_t+1 | x_t, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  7 10:04 't5-small stage 2 p(comment, x_t+1 | x_t, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  4 09:10 't5-small stage 2 p(comment, x_t+1 | x_t, doc)_new'


In [2]:
import os
os.environ['TOKENIZERS_PARALLELISM']='true'
os.environ['CUDA_VISIBLE_DEVICES']='2'

In [3]:
import re
import os
import torch
import json
import numpy as np
import pandas as pd
import seaborn as sns
import transformers
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

from transformers import T5Tokenizer, T5TokenizerFast, T5ForConditionalGeneration
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from typing import Callable, Union, Tuple
from tqdm.notebook import tqdm
from collections import Counter
from torch import nn
from catalyst import dl
from catalyst.callbacks.periodic_loader import PeriodicLoaderCallback
from langdetect import detect
from easse.sari import corpus_sari
from rouge import Rouge 

import sys
sys.path.append('../interactive_lm/wiki')
from utils.dataset_utils import extract_com8text_from_tgt, extract_text8docs_from_src
from utils.dataset_utils import ExplainDataset, get_tgt, get_src, COM_SEP, TEXT_SEP_SRC, TEXT_SEP_TGT, DOCS_SEP
from utils.metrics_utils import PeerEditMetricsCallback
from utils.config import Config


DOCS_DIR = 'data'
PAGES_DIR = 'data'

In [4]:
CONFIG = Config()
CONFIG.seed = 1337
CONFIG.beam_size = 1

In [5]:
import random

random.seed(CONFIG.seed)
os.environ['PYTHONHASHSEED'] = str(CONFIG.seed)
np.random.seed(CONFIG.seed)
torch.manual_seed(CONFIG.seed)
torch.cuda.manual_seed(CONFIG.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

## Data preparing

In [6]:
train = pd.read_csv(r'train_fruit.csv')

In [7]:
train.head()

Unnamed: 0.1,Unnamed: 0,old_texts,new_texts,doc_texts,titles
0,0,Daniel R. Lucey is a senior scholar with the O...,"Daniel R. Lucey is an American physician, rese...",DOC0: 2016 Angola and DR Congo yellow fever ou...,Public Health Emergency of International Concern
1,1,Rob Daviau is a prolific American game designe...,Rob Daviau is an American game designer known ...,DOC0: Return to Dark Tower Development In addi...,Spiel des Jahres
2,2,North Down is a parliamentary constituency in ...,North Down is a parliamentary constituency in ...,DOC0: List of parliamentary constituencies in ...,Liberalism in the United Kingdom
3,3,"Chaldean Catholics (; ), also known as Chaldea...","Chaldean Catholics () (), also known as Chalde...",DOC0: Nusaybin History - Modern history Syrian...,Ibrahim Pasha Milli
4,4,Prevalence of tobacco use is reported by the W...,Prevalence of tobacco use is reported by the W...,"DOC0: Sustainable Development Goal 3 Targets, ...",Heated tobacco product


In [8]:
train['comment'] = ''
train = train.rename(columns={"old_texts": "old_text", "new_texts": "new_text", "doc_texts": "docs_processed"})

In [9]:
train['old_text'] = train['old_text'].apply(str)
train['new_text'] = train['new_text'].apply(str)


train.head()

Unnamed: 0.1,Unnamed: 0,old_text,new_text,docs_processed,titles,comment
0,0,Daniel R. Lucey is a senior scholar with the O...,"Daniel R. Lucey is an American physician, rese...",DOC0: 2016 Angola and DR Congo yellow fever ou...,Public Health Emergency of International Concern,
1,1,Rob Daviau is a prolific American game designe...,Rob Daviau is an American game designer known ...,DOC0: Return to Dark Tower Development In addi...,Spiel des Jahres,
2,2,North Down is a parliamentary constituency in ...,North Down is a parliamentary constituency in ...,DOC0: List of parliamentary constituencies in ...,Liberalism in the United Kingdom,
3,3,"Chaldean Catholics (; ), also known as Chaldea...","Chaldean Catholics () (), also known as Chalde...",DOC0: Nusaybin History - Modern history Syrian...,Ibrahim Pasha Milli,
4,4,Prevalence of tobacco use is reported by the W...,Prevalence of tobacco use is reported by the W...,"DOC0: Sustainable Development Goal 3 Targets, ...",Heated tobacco product,


In [10]:
CONFIG.src_max_len = 1024
CONFIG.tgt_max_len = 512
CONFIG.pretrained = 't5-small'
CONFIG.pattern_path = './models/t5-small p(comment | x_t, x_t+1, doc)'
CONFIG.batch_size = 4

tokenizer = T5TokenizerFast.from_pretrained(CONFIG.pretrained, model_max_length=CONFIG.src_max_len)

**Make dataset**

In [11]:
ds_train = ExplainDataset(train, tokenizer, CONFIG, text_to_lower=True, comment_to_lower=True)
#ds_val = ExplainDataset(val1, tokenizer, CONFIG, text_to_lower=True, comment_to_lower=True)
#ds_val_full = EditDataset(val, tokenizer, CONFIG, text_to_lower=True, comment_to_lower=True)

**Len distribution**

In [12]:
idx_num = 100
src_text = tokenizer.decode(ds_train[idx_num][0]['input_ids'], skip_special_tokens=True)
tgt_text = tokenizer.decode(ds_train[idx_num][1]['input_ids'], skip_special_tokens=True)

In [13]:
print(f'{src_text}\n\n{tgt_text}')

wycombe wanderers football club is a professional association football club based in the town of high wycombe, buckinghamshire, england. the team play in league one, the third tier of english football. the club plays at adams park, which is situated on the western outskirts of high wycombe, and traditionally play in quartered shirts of navy (oxford blue) and pale blue (cambridge blue). the club's nicknames are "the chairboys" and "the blues". the current manager of the club is gareth ainsworth, who was appointed as player/manager following a period during which he served as caretaker manager, after gary waddock was relieved of his duties following a 1–0 defeat at home to wimbledon on 22 september 2012. ainsworth retired from playing at the end of the 2012–13 season. he is assisted by richard dobson. the club was awarded the family club of the year award twice in a row in 2006–07 and 2007–08. this is the only time that the award has been given to the same club in consecutive seasons. th

## Model training

In [14]:
class EditModel(nn.Module):
    def __init__(self, 
                 pretrained: transformers.modeling_utils.PreTrainedModel, 
                 config: Config):
        super(EditModel, self).__init__()
        self.pretrained = pretrained
        

    def forward(self, 
                x: Tuple[torch.Tensor, torch.Tensor]):
        src, tgt = x
        
        tgt[tgt == 0] == -100
        
        loss = self.pretrained(
            input_ids = src,
            attention_mask = (src != 0).float(),
            labels=tgt,
        ).loss
        return loss
    
    
class Criterion(nn.Module):
    def __init__(self):
        super(Criterion, self).__init__()
        
    def forward(self, pred, tgt):
        return pred

In [15]:
CONFIG.device = 'cuda'
CONFIG.pattern_path = '../interactive_lm/wiki/models/t5-small p(comment | x_t, x_t+1, doc)'

In [16]:
model_edit = EditModel(T5ForConditionalGeneration.from_pretrained(CONFIG.pretrained), CONFIG)
model_edit.load_state_dict(
    torch.load(f'{CONFIG.pattern_path}/checkpoints/model.best.pth', 
               map_location=CONFIG.device))
model_edit = model_edit.pretrained
model_edit.to(CONFIG.device)
model_edit.eval()
print('Success')

Success


In [17]:
from IPython.display import HTML as html_print

def colored(s, color='black'):
    if color == 'green':
        return f'\033[93m {s}'
    if color == 'yellow':
        return f'\033[92m {s}'
    if color == 'red':
        return f'\033[91m {s}'
    if color == 'blue':
        return f'\033[94m {s}'
    if color == 'pink':
        return f'\033[95m {s}'
    if color == 'bold':
        return f'\x1B[1mText {s} \x1b[0m'
    return f'\033[39m {s}'

device = CONFIG.device
CONFIG.beam_size = 3
idx_ = np.random.choice(len(ds_train), 5)

with torch.no_grad():
    for i in idx_:
        src_, tgt_ = ds_train[i]

        generated = model_edit.generate(torch.tensor(src_['input_ids']).view(1,-1).to(device), 
                                                      num_beams=CONFIG.beam_size, 
#                                                       pad_token_id=tokenizer.pad_token_id, 
#                                                       bos_token_id=tokenizer.bos_token_id, 
#                                                       eos_token_id=tokenizer.eos_token_id,
                                                          num_return_sequences=1,
                                                     max_length=512)
        generated = generated.cpu()

        src_text = tokenizer.decode(src_['input_ids'], skip_special_tokens=True)
        tgt_text = tokenizer.decode(tgt_['input_ids'], skip_special_tokens=True)
        
        print(colored(f'\n\n---------- QUERY {i} ----------', 'red'))
        print(colored(f'SRC:\n', 'pink') + colored(src_text) + '\n', sep='')
        print(colored(f'TGT:\n', 'pink') + colored(tgt_text) + '\n', sep='')

        for j in range(1):
            to_gen = generated[j]
            gen_text = tokenizer.decode(to_gen, skip_special_tokens=True)
            print(colored(f'GEN:', 'blue') + colored(gen_text) + '\n', sep='')
    
        #diff = ds_train.iloc[i]['diff']
        #print(colored(f'Tgt diff:\n', 'bold') + colored(diff) + '\n', sep='')

[91m 

---------- QUERY 3223 ----------
[95m SRC:
[39m ''the head'' began as mini-series that originally ran under the title ''mtv's oddities'', on mtv between 1994 and 1996. it has begun airing on mtv2 in august 2009. it was released on dvd on december 15, 2009. TEXT_SEP ''the head'' is an american adult animated television series created by eric fogel for mtv. it originated as a science-fiction mini-series that aired under the ''mtv's oddities'' label between 1994 and 1996, and was followed by ''the maxx''. the was released on dvd on december 15, 2009. DOCS_SEP doc0: list of adult animated television series 1990s - united states - table-0-8 [header] [col] title [col] genre [col] seasons/episodes [col] show creator(s) [col] original release [col] network [col] studio [col] age rating [col] status [row] [col] ''the head'' [col] • action • adventure [col] 2 seasons, 14 episodes [col] eric fogel [col] september 1, 1994 – march 1, 1996 [col] mtv [col] mtv animation [col] tv-14 [col] en

[91m 

---------- QUERY 85159 ----------
[95m SRC:
[39m ''the special london bridge special'' is a 1972 musical variety special. it was made to celebrate the acquisition of the london bridge in lake havasu city, arizona. it was filmed in lake havasu following the opening of the london bridge. it was produced, directed and choreographed by david winters and it starred tom jones, and jennifer o'neill. other guests included the carpenters, kirk douglas, jonathan winters, hermione gingold, lorne greene, chief dan george, charlton heston, george kirby, michael landon, terry-thomas, engelbert humperdinck, elliott gould, merle park, and rudolf nureyev. TEXT_SEP ''the special london bridge special'' is a 1972 musical variety television special. it was made to celebrate the acquisition of the london bridge in lake havasu city, arizona. it was filmed in lake havasu following the opening of the london bridge. it was produced, directed and choreographed by david winters and it starred tom jones

In [None]:
generated_comments = []

for i in tqdm(range(len(ds_train))):
    src_, tgt_ = ds_train[i]

    generated = model_edit.generate(torch.tensor(src_['input_ids']).view(1,-1).to(device), 
                                                  num_beams=CONFIG.beam_size, 
#                                                       pad_token_id=tokenizer.pad_token_id, 
#                                                       bos_token_id=tokenizer.bos_token_id, 
#                                                       eos_token_id=tokenizer.eos_token_id,
                                                      num_return_sequences=1,
                                                 max_length=512)
    generated = generated.cpu()

    src_text = tokenizer.decode(src_['input_ids'], skip_special_tokens=True)
    tgt_text = tokenizer.decode(tgt_['input_ids'], skip_special_tokens=True)

    #print(colored(f'\n\n---------- QUERY {i} ----------', 'red'))
    #print(colored(f'SRC:\n', 'pink') + colored(src_text) + '\n', sep='')
    #print(colored(f'TGT:\n', 'pink') + colored(tgt_text) + '\n', sep='')

    for j in range(1):
        to_gen = generated[j]
        gen_text = tokenizer.decode(to_gen, skip_special_tokens=True)
        # print(colored(f'GEN:', 'blue') + colored(gen_text) + '\n', sep='')
        generated_comments.append(gen_text)

  0%|          | 0/114515 [00:00<?, ?it/s]

In [None]:
ds_train['generated_comment'] = generated_comments

In [None]:
ds_train.to_csv('train_fruit_full.csv')

In [20]:
ds_train.head()

AttributeError: 'ExplainDataset' object has no attribute 'head'

In [21]:
train['generated_comment'] = generated_comments
train.to_csv('train_fruit_full.csv')

In [22]:
train.head()

Unnamed: 0.1,Unnamed: 0,old_text,new_text,docs_processed,titles,comment,generated_comment
0,0,Daniel R. Lucey is a senior scholar with the O...,"Daniel R. Lucey is an American physician, rese...",DOC0: 2016 Angola and DR Congo yellow fever ou...,Public Health Emergency of International Concern,,"copyedits, wikify"
1,1,Rob Daviau is a prolific American game designe...,Rob Daviau is an American game designer known ...,DOC0: Return to Dark Tower Development In addi...,Spiel des Jahres,,daviau is not credited as the designer of over...
2,2,North Down is a parliamentary constituency in ...,North Down is a parliamentary constituency in ...,DOC0: List of parliamentary constituencies in ...,Liberalism in the United Kingdom,,update
3,3,"Chaldean Catholics (; ), also known as Chaldea...","Chaldean Catholics () (), also known as Chalde...",DOC0: Nusaybin History - Modern history Syrian...,Ibrahim Pasha Milli,,reworded intro a bit
4,4,Prevalence of tobacco use is reported by the W...,Prevalence of tobacco use is reported by the W...,"DOC0: Sustainable Development Goal 3 Targets, ...",Heated tobacco product,,/* prevalence */ reword who reference


In [23]:
train.shape

(114515, 7)

In [4]:
import pandas as pd

a = pd.read_csv('train_fruit_full.csv')

In [5]:
a.head()

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,old_text,new_text,docs_processed,titles,comment,generated_comment
0,0,0,Daniel R. Lucey is a senior scholar with the O...,"Daniel R. Lucey is an American physician, rese...",DOC0: 2016 Angola and DR Congo yellow fever ou...,Public Health Emergency of International Concern,,"copyedits, wikify"
1,1,1,Rob Daviau is a prolific American game designe...,Rob Daviau is an American game designer known ...,DOC0: Return to Dark Tower Development In addi...,Spiel des Jahres,,daviau is not credited as the designer of over...
2,2,2,North Down is a parliamentary constituency in ...,North Down is a parliamentary constituency in ...,DOC0: List of parliamentary constituencies in ...,Liberalism in the United Kingdom,,update
3,3,3,"Chaldean Catholics (; ), also known as Chaldea...","Chaldean Catholics () (), also known as Chalde...",DOC0: Nusaybin History - Modern history Syrian...,Ibrahim Pasha Milli,,reworded intro a bit
4,4,4,Prevalence of tobacco use is reported by the W...,Prevalence of tobacco use is reported by the W...,"DOC0: Sustainable Development Goal 3 Targets, ...",Heated tobacco product,,/* prevalence */ reword who reference


In [11]:
import difflib

from IPython.display import HTML as html_print

def colored(s, color='black'):
    if color == 'green':
        return f'\033[93m {s}'
    if color == 'yellow':
        return f'\033[92m {s}'
    if color == 'red':
        return f'\033[91m {s}'
    if color == 'blue':
        return f'\033[94m {s}'
    if color == 'pink':
        return f'\033[95m {s}'
    if color == 'bold':
        return f'\x1B[1m {s} \x1b[0m'
    return f'\033[39m {s}'

def diff_print(src_text, tgt_text):
    src_text_tok, tgt_text_tok = src_text.split(), tgt_text.split()
    matcher = difflib.SequenceMatcher(a=src_text_tok, b=tgt_text_tok)
    
    sti = []
    cur_idx = 0
    for match in matcher.get_matching_blocks():
        if match.size != 0:
            if match.b == cur_idx:
                sti.append((match.b, match.b + match.size, 'same'))
                cur_idx = match.b + match.size
            else:
                sti.append((cur_idx, match.b, 'diff'))
                sti.append((match.b, match.b + match.size, 'same'))
                cur_idx = match.b + match.size
    if cur_idx < len(tgt_text_tok) - 1:
        sti.append((cur_idx, len(tgt_text_tok), 'diff'))
    
    new_str = ''
    for (st_id, en_id, idxs_type) in sti:
        cur_txt = ' '.join(tgt_text_tok[st_id:en_id])
        if idxs_type != 'same':
            new_str += colored(cur_txt, color='bold') + ' '
        else:
            new_str += cur_txt + ' '
    return new_str

In [12]:
import numpy as np

idx_ = np.random.choice(len(a), 20)

for i in idx_:
    src_txt = a.iloc[i]['old_text']
    tgt_txt = a.iloc[i]['new_text']
    tgt_comment = a.iloc[i]['generated_comment']
    docs_processed = a.iloc[i]['docs_processed'].lower()
    src_txt = diff_print(tgt_txt, src_txt)
    tgt_txt = diff_print(src_txt, tgt_txt)

    print(colored(f'\n\n---------- QUERY {i} ----------', 'red'))
    print(colored(f'X_t:\n', 'pink') + colored(src_txt) + '\n', sep='')
    print(colored(f'X_t+1:\n', 'pink') + colored(tgt_txt) + '\n', sep='')
    print(colored(f'Comment:', 'yellow') + colored(tgt_comment) + '\n', sep='')

    doc_str = '\n'.join(docs_processed.split('doc'))
    print(colored(f'Docs:\n', 'yellow') + colored(doc_str) + '\n', sep='')

[91m 

---------- QUERY 76138 ----------
[95m X_t:
[39m Drumstick is the brand name, owned by Nestlé [1m since 1991, [0m for a variety of frozen dessert-filled ice cream cones sold in the United States, Australia, Canada, Malaysia, and other countries across the world. The original product was invented by I.C. Parker of the Drumstick Company of Fort Worth, Texas, in 1928. 

[95m X_t+1:
[39m Drumstick is the brand name, owned by [1m Froneri, a joint venture between [0m Nestlé [1m and PAI Partners, [0m for a variety of frozen dessert-filled ice cream cones sold in the United States, Australia, Canada, Malaysia, and other countries across the world. The original product was invented by I.C. Parker of the Drumstick Company of Fort Worth, Texas, in 1928. 

[92m Comment:[39m changed from 1991 to froneri, a joint venture between nestlé and pai partners.

[92m Docs:
[39m 
0: dreyer's introduction in 2020, froneri, the joint venture between nestlé and pai partners, agreed to take