In [1]:
import sys
import torch

In [2]:
GAZETA_PATH = '../data/gazeta_jsonl'

In [3]:
import json

In [4]:
def read_gazeta_records(file_name, shuffle=False, sort_by_date=True):
    assert shuffle != sort_by_date
    records = []
    with open(file_name, "r") as r:
        for line in r:
            records.append(json.loads(line))
    if sort_by_date:
        records.sort(key=lambda x: x["date"])
    if shuffle:
        random.shuffle(records)
    return records

In [5]:
import os

In [6]:
dataset_files = {
    'train': os.path.join(GAZETA_PATH,'gazeta_train.jsonl'),
    'val': os.path.join(GAZETA_PATH,'gazeta_val.jsonl'),
    'test': os.path.join(GAZETA_PATH, 'gazeta_test.jsonl')
}

In [7]:
records = {
    split: read_gazeta_records(path) for split, path in dataset_files.items()
}

In [8]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [9]:
model_name_or_path = "sberbank-ai/rugpt3medium_based_on_gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
# model = GPT2LMHeadModel.from_pretrained(model_name_or_path)

In [10]:
import importlib

In [11]:
import gpt_summarizer_dataset

In [12]:
importlib.reload(gpt_summarizer_dataset)

<module 'gpt_summarizer_dataset' from '/home/ivan/Programming/ML/Summarization/Samsung/gpt_summarizer_dataset.py'>

In [13]:
from gpt_summarizer_dataset import GPTHeadlineDataset

In [14]:
import pickle

In [15]:
# with open('gpt_training_dataset.pkl', 'rb') as f:
#     train_dataset = pickle.load(f)

In [16]:
# with open('gpt_val_dataset.pkl', 'rb') as f:
#     val_dataset = pickle.load(f)

In [17]:
train_dataset = GPTHeadlineDataset(
    tokenizer,
    summaries=[r['summary'] for r in records['train']],
    contents=[r['text'] for r in records['train']],
    max_input_length=601,
    max_summary_length=163
)

In [18]:
val_dataset = GPTHeadlineDataset(
    tokenizer,
    summaries=[r['summary'] for r in records['val']],
    contents=[r['text'] for r in records['val']],
    max_input_length=601,
    max_summary_length=163
)
# 

Хак для быстрой загрузки датасета

In [19]:

with open('gpt_training_dataset.pkl', 'wb') as of:
    pickle.dump(train_dataset, of)
    
with open('gpt_val_dataset.pkl', 'wb') as of:
    pickle.dump(val_dataset, of)

In [20]:
import pandas as pd

In [21]:
pd.Series(train_dataset.summary_lengths).describe()

count    52400.000000
mean        63.571126
std         16.982161
min         17.000000
25%         51.000000
50%         63.000000
75%         75.000000
max        123.000000
dtype: float64

In [22]:
pd.Series(train_dataset.content_lengths).describe()

count    52400.000000
mean       954.682214
std        273.714437
min         48.000000
25%        780.000000
50%        901.000000
75%       1088.000000
max       2244.000000
dtype: float64

In [24]:
hparams = dict(learning_rate=6e-5, 
                    warmup_steps=1000,
                    linear_decay_steps=0,
                    content_loss_weight=1,
                    summary_loss_weight=1
                    pretrained_model_path=model_name_or_path)

In [25]:
import GPTSummarizer

In [26]:
importlib.reload(GPTSummarizer)

<module 'GPTSummarizer' from '/home/ivan/Programming/ML/Summarization/Samsung/GPTSummarizer.py'>

In [27]:
from GPTSummarizer import GPTSummarizerPL

In [28]:
model = GPTSummarizerPL(hparams)

In [29]:
def get_model_device(model):
    return next(iter(model.parameters())).device

In [30]:
get_model_device(model)

device(type='cpu')

In [31]:
def generate_headline(model, text, max_input_length, max_output_length, **generate_args):
    vocab=tokenizer.get_vocab()
    bos_token_id = vocab['<s>']
    eos_token_id = vocab['</s>']
    pad_token_id = vocab['<pad>']
    encoded_text = [bos_token_id] +\
        tokenizer.encode(text)[:max_input_length] + [eos_token_id]
    encoded_text = torch.tensor(encoded_text, device=get_model_device(model)).view(1,-1)
    print(encoded_text.shape)
    encoded_output = model.gpt.generate(encoded_text,
                                        bos_token_id=bos_token_id,
                                        eos_token_ids=[eos_token_id],
                                        pad_token_id=pad_token_id,
                                        max_length=max_input_length + max_output_length + 2,
                                        **generate_args)
    result = tokenizer.decode(encoded_output[0])
    return result

In [32]:
from torch.utils.data import DataLoader

In [33]:
batch_size=2
n_workers=1
data_loaders = {
    "train": DataLoader(train_dataset, 
                        batch_size=batch_size, num_workers=n_workers,
                        shuffle=True,
                        collate_fn=train_dataset.collate),
    "val": DataLoader(val_dataset, 
                        batch_size=batch_size, 
                        num_workers=n_workers,
                        shuffle=False,
                        collate_fn=val_dataset.collate),
    
}

In [34]:

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

In [35]:
# model.train()

In [36]:
import random

In [37]:
class SamplingCallback(Callback):
    def on_validation_end(self, trainer, module):
        state = module.training
        module.train(False)
        
        rand_index = random.randrange(len(records['val']))
        text = records['val'][rand_index]['text']
        with torch.no_grad():
            print(generate_headline(module, text, 600, 100))
        module.train(state)
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        if trainer.global_step % 1000 == 0:
            state = pl_module.training
            pl_module.train(False)

            rand_index = random.randrange(len(records['val']))
            text = records['val'][rand_index]['text']
            with torch.no_grad():
                print(generate_headline(pl_module, text, 600, 100))
            pl_module.train(state)
        


In [38]:

early_stop_callback = EarlyStopping(
   monitor='avg_val_loss',
   min_delta=0.00,
   patience=1,
   verbose=False,
   mode='min'
)


In [40]:
os.makedirs('gpt_checkpoint_gazeta3', exist_ok=True)

In [41]:
checkpoint = ModelCheckpoint("gpt_checkpoint_gazeta3",monitor='avg_val_loss', mode='min', save_top_k=1)

In [42]:
import pytorch_lightning as pl

In [43]:
trainer = pl.Trainer(gpus=[0],max_epochs=5, accumulate_grad_batches=8,
                     callbacks=[checkpoint, early_stop_callback, SamplingCallback()], fast_dev_run=False)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_dataloader=data_loaders['train'], val_dataloaders=data_loaders['val'])

In [65]:
torch.save({'model': model.gpt.state_dict, 'hparams': model.hparams}, 'gpt_ckpt_after_epoch_2')

In [None]:
model = model.eval()

In [57]:
get_model_device(model)

device(type='cpu')

In [58]:
model=model.cuda()

In [59]:
[1,2,4,5].index(4)

2

In [60]:
def extract_summary(model, text, max_input_length, max_output_length, **generate_args):
    vocab=tokenizer.get_vocab()
    bos_token_id = vocab['<s>']
    eos_token_id = vocab['</s>']
    pad_token_id = vocab['<pad>']
    encoded_text = [bos_token_id] +\
        tokenizer.encode(text)[:max_input_length] + [eos_token_id]
    encoded_text = torch.tensor(encoded_text, device=get_model_device(model)).view(1,-1)
    print(encoded_text.shape)
    encoded_output = model.gpt.generate(encoded_text,
                                        bos_token_id=bos_token_id,
                                        eos_token_ids=[eos_token_id],
                                        pad_token_id=pad_token_id,
                                        max_length=max_input_length + max_output_length + 2,
                                        **generate_args)
    
    indices = encoded_output[0].tolist()
    
    first_eos_index = indices.index(eos_token_id)
    sum_start_index = first_eos_index + 1
    
    final_indices = []
    for idx in indices[sum_start_index:]:
        if idx != eos_token_id:
            final_indices.append(idx)
        else:
            break
    return tokenizer.decode(final_indices)

In [61]:
extract_summary

<function __main__.extract_summary(model, text, max_input_length, max_output_length, **generate_args)>

In [64]:
with torch.no_grad():
    rand_index = random.randrange(len(records['val']))
    print(rand_index)
    text = records['val'][rand_index]['text']
    print(text)
    print("-----------")
    ref = records['val'][rand_index]['summary']
    print(ref)
    print(extract_summary(model,text, 600,128))

3206
В Гатчине, городе Ленинградской области, крещение годовалого ребенка в Мариенбургской церкви закончилось скандалом после размещенного в сети видео. На кадрах видно, как священник пытается насильно окунуть ребенка в купель, в которую тот не помещается. Малыш кричит, плачет и вырывается из рук батюшки, на что служитель церкви не обращает внимания. «Он делал все с болью для ребенка, видел, что он взрослый, что его в такую маленькую купель не погрузить, поливать с головы надо. Но решил делать по-своему. Малыш кричал, вырывался. Я испугалась, подбежала, начала забирать. Сама чуть не загорелась, так как платком потушила свечи у купели», — рассказала «Фонтанке» мать ребенка Анастасия. Она отметила, что вместо того, чтобы отдать ей ребенка, батюшка посоветовал ей не лезть не в свое дело. «У меня были случаи, когда дети бились головой и справляли нужду в руках», — поделился опытом священник. В результате крещения малыш, по словам его матери, получил травмы. «У ребенка царапины на шее и на 