In [1]:
!ls -l ../models/ 

total 28
drwxrwxr-x 5 shapkin shapkin 4096 Feb  4 09:30 't5-base stage 2 p(comment, x_t+1 | x_t, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  4 09:22 't5-base stage 2 p(comment, x_t+1 | x_t, doc)_new'
drwxrwxr-x 5 shapkin shapkin 4096 Jan 30 07:51 't5-small p(comment | x_t, x_t+1, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  6 10:32 't5-small stage 2 _ 2 losses _ full p(comment, x_t+1 | x_t, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  5 21:28 't5-small stage 2 _ 2 losses p(comment, x_t+1 | x_t, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb 16 10:34 't5-small stage 2 p(comment, x_t+1 | x_t, doc)'
drwxrwxr-x 5 shapkin shapkin 4096 Feb  4 09:10 't5-small stage 2 p(comment, x_t+1 | x_t, doc)_new'


In [2]:
import os
os.environ['TOKENIZERS_PARALLELISM']='true'
os.environ['CUDA_VISIBLE_DEVICES']='3'

In [3]:
import sys
sys.path.append('..')

In [4]:
import re
import os
import torch
import json
import numpy as np
import pandas as pd
import seaborn as sns
import transformers
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

from transformers import T5Tokenizer, T5TokenizerFast, T5ForConditionalGeneration
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from typing import Callable, Union, Tuple
from tqdm.notebook import tqdm
from collections import Counter
from torch import nn
from catalyst import dl
from catalyst.callbacks.periodic_loader import PeriodicLoaderCallback
from langdetect import detect
from easse.sari import corpus_sari
from rouge import Rouge 

from utils.dataset_utils import extract_com8text_from_tgt, extract_text8docs_from_src
from utils.dataset_utils import EditDataset, get_tgt, get_src, COM_SEP, TEXT_SEP_SRC, TEXT_SEP_TGT, DOCS_SEP
from utils.metrics_utils import PeerEditMetricsCallback
from utils.config import Config


DOCS_DIR = 'data'
PAGES_DIR = 'data'

In [5]:
CONFIG = Config()
CONFIG.seed = 1337
CONFIG.beam_size = 1

In [6]:
import random

random.seed(CONFIG.seed)
os.environ['PYTHONHASHSEED'] = str(CONFIG.seed)
np.random.seed(CONFIG.seed)
torch.manual_seed(CONFIG.seed)
torch.cuda.manual_seed(CONFIG.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

## Data preparing

In [7]:
mp = json.load(open(r"../data/column_mapper.json"))

train = pd.read_json(r'../data/new_train.json')
train.set_axis(mp.values(), axis='columns', inplace=True)

test = pd.read_json(r'../data/new_test.json')
test.set_axis(mp.values(), axis='columns', inplace=True)

val = pd.read_json(r'../data/new_val.json')
val.set_axis(mp.values(), axis='columns', inplace=True)
val1 = val.iloc[:600]

In [8]:
train.head()

Unnamed: 0,obj_id,old_text,new_text,comment,docs,diff,title,search_queries,counter_found_docs,section_name,is_good,docs_processed
0,1005,"In rural regions of Germany, especially the Ha...","In rural regions of Germany, especially the Ha...",/* Germany */ grammar,"Apr 30, 2020 — Depending on whom you ask, May ...","this opportunity to party,",May Day,"[May Day Germany this opportunity to party,]",[27],Germany,True,"DOC0: Apr 30, 2020 — Depending on whom you ask..."
1,55,"""This is the new WikiPedia!"" \n-HomePage, the ...","""This is the new WikiPedia!"" \n-HomePage, the ...",Added subpage,English: Results of the 1929 New York City ald...,\n*/New York City Board of Aldermen,John M Wolfson,[John M Wolfson \n*/New York City Board of Ald...,"[10, 4]",,True,DOC0: English: Results of the 1929 New York Ci...
2,2329,"Ares (Ancient Greek: , Μodern Greek: Άρης ) i...","Ares (Ancient Greek: , Μodern Greek: Άρης ) i...",repaired link to 'masculinity' ~~~~,"Apr 7, 2014 — Let's explore the essence of mas...","masculinity, integrity, and personal courage.",Ares,"[Ares masculinity, integrity, and personal cou...",[30],,True,"DOC0: Apr 7, 2014 — Let's explore the essence ..."
3,17206,Peter Velhorn (born 24 November 1932) is a Ger...,Peter Velhorn (24 November 1932 – 20 July 2016...,"Passed away 2016, look at German article",Peter Velhorn (24 November 1932 – 20 July 2016...,– 20 July 2016,Peter Velhorn,[Peter Velhorn – 20 July 2016],[23],,True,DOC0: Peter Velhorn (24 November 1932 – 20 Jul...
4,47,"Michael Palin was educated at Birkdale School,...","Michael Palin was educated at Birkdale School,...",/* Early career */ added info,"Michael Palin, Nightingale House, in Clapham, ...",", Graeme Garden, Bill Oddie and Jonathan Lynn",Michael Palin,"[Michael Palin Early career , Graeme Garden, ...",[29],Early career,True,"DOC0: Michael Palin, Nightingale House, in Cla..."


In [9]:
CONFIG.src_max_len = 512
CONFIG.tgt_max_len = 512
CONFIG.pretrained = 't5-small'
CONFIG.pattern_path = '../models/t5-small stage 2 _ 2 losses _ full p(comment, x_t+1 | x_t, doc)'
CONFIG.batch_size = 4

tokenizer = T5Tokenizer.from_pretrained(CONFIG.pretrained, model_max_length=CONFIG.src_max_len)

**Make dataset**

In [10]:
#ds_train = EditDataset(train, tokenizer, CONFIG, text_to_lower=True, comment_to_lower=True)
ds_val = EditDataset(val1, tokenizer, CONFIG, text_to_lower=True, comment_to_lower=True)
#ds_val_full = EditDataset(val, tokenizer, CONFIG, text_to_lower=True, comment_to_lower=True)

## Model training

In [11]:
class EditModel(nn.Module):
    def __init__(self, 
                 pretrained: transformers.modeling_utils.PreTrainedModel, 
                 config: Config):
        super(EditModel, self).__init__()
        self.pretrained = pretrained
        

    def forward(self, 
                x: Tuple[torch.Tensor, torch.Tensor]):
        src, tgt = x
        
        tgt[tgt == 0] == -100
        
        loss = self.pretrained(
            input_ids = src,
            attention_mask = (src != 0).float(),
            labels=tgt,
        ).loss
        return loss
    
    
class Criterion(nn.Module):
    def __init__(self):
        super(Criterion, self).__init__()
        
    def forward(self, pred, tgt):
        return pred

In [12]:
CONFIG.device = 'cuda'

In [13]:
model_edit = EditModel(T5ForConditionalGeneration.from_pretrained(CONFIG.pretrained), CONFIG)
model_edit.load_state_dict(
    torch.load(f'{CONFIG.pattern_path}/checkpoints/model.best.pth', 
               map_location=CONFIG.device))
model_edit = model_edit.pretrained
model_edit.to(CONFIG.device)
model_edit.eval()
print('Success')

Success


In [14]:
from utils.print_diff_utils import diff_print, colored

In [15]:
device = CONFIG.device
CONFIG.beam_size = 3
idx_ = np.random.choice(len(ds_val), 20)

with torch.no_grad():
    for i in idx_:
        src_, tgt_ = ds_val[i]

        generated = model_edit.generate(torch.tensor(src_['input_ids']).view(1,-1).to(device), 
                                                      num_beams=CONFIG.beam_size, 
#                                                       pad_token_id=tokenizer.pad_token_id, 
#                                                       bos_token_id=tokenizer.bos_token_id, 
#                                                       eos_token_id=tokenizer.eos_token_id,
                                                          num_return_sequences=1,
                                                     max_length=512)
        generated = generated.cpu()

        src_text = tokenizer.decode(src_['input_ids'], skip_special_tokens=True)
        tgt_text = tokenizer.decode(tgt_['input_ids'], skip_special_tokens=True)
        
        tgt_comment, tgt_txt = extract_com8text_from_tgt(tgt_text)
        src_txt, _ = extract_text8docs_from_src(src_text)

        src_txt = diff_print(tgt_txt, src_txt)
        tgt_txt = diff_print(src_txt, tgt_txt)
        
        print(colored(f'\n\n---------- QUERY {i} ----------', 'red'))
        print(colored(f'X_t:\n', 'pink') + colored(src_txt) + '\n', sep='')
        print(colored(f'X_t+1:\n', 'pink') + colored(tgt_txt) + '\n', sep='')
        print(colored(f'Comment:', 'yellow') + colored(tgt_comment) + '\n', sep='')
        

        for j in range(1):
            to_gen = generated[j]
            gen_text = tokenizer.decode(to_gen, skip_special_tokens=True)
            gen_comment, gen_txt = extract_com8text_from_tgt(gen_text)
            
            gen_txt = diff_print(src_txt, gen_txt)
            print(colored(f'Gen Comment:', 'blue') + colored(gen_comment) + '\n', sep='')
            print(colored(f'gen X_t+1:\n', 'pink') + colored(gen_txt) + '\n', sep='')
            
        diff = val1.iloc[i]['diff']
        print(colored(f'Tgt diff:\n', 'bold') + colored(diff) + '\n', sep='')
        
        doc_str = '\n'.join(_.split('doc'))
        print(colored(f'Docs:\n', 'yellow') + colored(doc_str) + '\n', sep='')

[91m 

---------- QUERY 151 ----------
[95m X_t:
[39m a superset of the general midi standard, added several proprietary extensions. the most notable addition was the ability to address multiple banks of programs (instrument sounds) by using an additional pair of bank select controllers to specify up to 16384 'variation' sounds (cc#0 is bank select msb, and cc#32 is bank select lsb). other most notable features were 9 drum kits with 14 additional drum sounds each, control change messages for controlling the send level of sound effect blocks (cc#91-94), entering additional parameters (cc#98-101), portamento, sostenuto, soft pedal (cc#65-67), and model-specific sysex messages for setting various parameters of the synth engine. gs was introduced with the roland sound canvas line, which was also roland's first general midi synth module. 

[95m X_t+1:
[39m a superset of the general midi standard, added several proprietary extensions. the most notable addition was the ability to address

[91m 

---------- QUERY 167 ----------
[95m X_t:
[39m 01. end of the world 02. love for air 03. normal people 04. everything 05. do it for yourself 06. honest to god 07. that feeling and the sound 08. dive 09. better friend 10. all the time 11. 365 12. crying wolf 13. wreckage in the rubble, ross leighton [1m - [0m vocals, guitar, songwriter, production * greg walkinshaw [1m - [0m vocals, drums, songwriter, production * marc strain [1m - [0m vocals, bass, songwriter, production 

[95m X_t+1:
[39m 01. end of the world 02. love for air 03. normal people 04. everything 05. do it for yourself 06. honest to god 07. that feeling and the sound 08. dive 09. better friend 10. all the time 11. 365 12. crying wolf 13. wreckage in the rubble, ross leighton [1m – [0m vocals, guitar, songwriter, production * greg walkinshaw [1m – [0m vocals, drums, songwriter, production * marc strain [1m – [0m vocals, bass, songwriter, production 

[92m Comment:[39m COM_SEP fixed [[mos:dash|dashe

[91m 

---------- QUERY 84 ----------
[95m X_t:
[39m , the village has a total number of 28 houses and the population of 198 of which [1m include [0m 99 are males while 99 are females. according to the report published by census india in 2011, out of the total population of the village 0 people are from schedule caste and the village does not have any schedule tribe population so far. 

[95m X_t+1:
[39m , the village has a total number of 28 houses and the population of 198 of which 99 are males while 99 are females. according to the report published by census india in 2011, out of the total population of the village 0 people are from schedule caste and the village does not have any schedule tribe population so far. 

[92m Comment:[39m COM_SEP clean up, replaced: of which include  of which

[94m Gen Comment:[39m COM_SEP /* demography */clean up, replaced: of which include  of which using [[project:awb|awb]]

[95m gen X_t+1:
[39m the village has a total number of 28 houses 

[91m 

---------- QUERY 438 ----------
[95m X_t:
[39m bahati is a constituency of the national assembly of zambia.bahati national assembly of zambia it covers the northern part of mansa and a rural area to the north of the city in luapula province. 

[95m X_t+1:
[39m bahati is a constituency of the national assembly of zambia.bahati national assembly of zambia it covers the northern part of mansa and a rural area to the north of the city in [1m mansa district of [0m luapula province. 

[92m Comment:[39m COM_SEP suggested change in display

[94m Gen Comment:[39m COM_SEP /* top */clean up, added [[cat:o|orphan]] tag using [[project:awb|awb]]

[95m gen X_t+1:
[39m bahati is a constituency of the national assembly of zambia.bahati national assembly of zambia it covers the northern part of mansa and a rural area to the north of the city in luapula province. 

[1m Tgt diff:
 [0m[39m  Mansa District of

[92m Docs:
[39m 
0: bahati is a constituency of the national assembly of

[91m 

---------- QUERY 93 ----------
[95m X_t:
[39m stoke city football club (known as stoke football club until 1925) is a football club from stoke-on-trent in [1m england. [0m the club is reputedly the second-oldest football league club in the world, after notts county f.c., and claims to have been formed in 1863 (disputed by some, who claim it to be 1868). the club’s nickname is the potters and its home kit consists of a red & white vertical-striped shirt with white shorts. the club is managed by johan boskamp. it plays in the football league championship and is one of the twelve founder-members of the football league. 

[95m X_t+1:
[39m stoke city football club (known as stoke football club until 1925) is a football club from stoke-on-trent in [1m england (the other league club in the city being port vale f.c.). [0m the club is reputedly the second-oldest football league club in the world, after notts county f.c., and claims to have been formed in 1863 (disputed by some, 

[91m 

---------- QUERY 203 ----------
[95m X_t:
[39m amy hart redford (born october 22, 1970) is an american actress, director and producer. she is the daughter of academy award-winning film director and actor robert redford and his former wife lola van wagenen.the new york times 

[95m X_t+1:
[39m amy hart redford (born october 22, 1970) is an american actress, director and producer. she is the daughter of academy award-winning film director and actor robert redford and his former wife lola van wagenen.the new york times [1m she is the sister of writer/producer james redford. james redford at imdb [0m 

[92m Comment:[39m COM_SEP add in link to brother

[94m Gen Comment:[39m COM_SEP reworded lead

[95m gen X_t+1:
[39m amy hart redford (born october 22, 1970) is an american actress, [1m director, [0m and producer. she is the daughter of [1m an [0m academy award-winning film director and actor robert redford and his former wife lola van [1m wagenen. she is the daughter