# Imports

In [1]:
%load_ext lab_black

In [2]:
import os

os.chdir("/home/ivanr/git/document_information_extraction/")

In [3]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

from tabulate import tabulate
import nltk
from datetime import datetime

# Statics

In [4]:
MODEL_NAME = "facebook/bart-large-cnn"
TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: ", TORCH_DEVICE)
language = "english"

Device:  cuda


# Load model

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(TORCH_DEVICE)

In [6]:
print(model.config)
encoder_max_length = 1024
decoder_max_length = 64

BartConfig {
  "_name_or_path": "facebook/bart-large-cnn",
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "force_bos_token_to_be_generated": true,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "L

# Load data

In [7]:
from src.data.data_statics import (
    MIN_SEMANTIC_SIMILARITY,
    MIN_NOVELTY,
    MAX_NOVELTY,
    MAX_TOKENS_BODY,
)
import pandas as pd
from src.data.wikipedia.wiki_data_base import (
    retrieve_query,
    retrive_observations_from_ids,
)

RANDOM_SEED = 0
N_SAMPLE_TEXTS = 1000

QUERY_SUITABLE_ARTICLES = f"""
SELECT ar.*,
       nv.novelty_tokens,
       nv.novelty_bigrams,
       nv.novelty_trigrams,
       cs.semantic_similarity
       
FROM article_level_info ar
INNER JOIN wiki_article_novelty nv
    ON ar.pageid = nv.pageid
INNER JOIN wiki_article_cosine_similarity cs
    ON ar.pageid = cs.pageid
WHERE cs.semantic_similarity>={MIN_SEMANTIC_SIMILARITY}
    AND nv.novelty_tokens<={MAX_NOVELTY}
    AND nv.novelty_tokens>={MIN_NOVELTY}
    AND ar.body_word_count<={MAX_TOKENS_BODY}
"""
import pickle


characterisation_df = pd.DataFrame(
    retrieve_query(QUERY_SUITABLE_ARTICLES),
    columns=[
        "pageid",
        "title",
        "summary_word_count",
        "body_word_count",
        "novelty_tokens",
        "novelty_bigrams",
        "novelty_trigrams",
        "semantic_similarity",
    ],
)


pageids_to_evaluate = list(
    characterisation_df["pageid"].sample(n=N_SAMPLE_TEXTS, random_state=RANDOM_SEED)
)

ARTICLE_GENERATOR = retrive_observations_from_ids(pageids_to_evaluate)
ARTICLE_GENERATOR_TRAIN = ARTICLE_GENERATOR[:900]
ARTICLE_GENERATOR_TEST = ARTICLE_GENERATOR[900:]


def decode_row(article):
    summary = article[2]
    body = "".join(pickle.loads(article[3]))
    return summary, body


def convert_to_features(example_batch):
    # Tokenize contexts and questions (as pairs of inputs)

    input_ = []
    target_ = []
    for article in example_batch:
        summary, body = decode_row(article)
        input_.append(body)
        target_.append(summary)

    return {"document": input_, "summary": target_}

# Preprocess and tokenize

In [8]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source, target = batch["document"], batch["summary"]
    source_tokenized = tokenizer(
        source, padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        target, padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch

In [9]:
train_data = convert_to_features(ARTICLE_GENERATOR_TRAIN)
train_data = batch_tokenize_preprocess(train_data, tokenizer, 1024, 124)

validation_data = convert_to_features(ARTICLE_GENERATOR_TEST)
validation_data = batch_tokenize_preprocess(validation_data, tokenizer, 1024, 124)

# Metrics

In [10]:
nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

# Training

In [11]:
training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=1,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,  # demo
    per_device_eval_batch_size=4,
    # learning_rate=3e-05,
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=50,
    save_total_limit=3,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=validation_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [12]:
trainer.train()

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ivanr/.netrc


KeyError: 2