## Colab Setup

In [1]:
import sys

is_in_colab = 'google.colab' in sys.modules

if is_in_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    sys.path.insert(0,'/content/drive/MyDrive/nlp_question_answer_project')

    %cd /content/drive/MyDrive/nlp_question_answer_project/

In [2]:
%%capture
!pip install -r requirements.txt

In [3]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

## Code

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer

from utils.dataset import load_datasets_by_language, save_dataset, load_dataset
from utils.lexical import get_lexical_features_from_dataset, create_token_to_id_mapping_from_token_sentences, Annotation_error


import pandas as pd

import torch
from torch import nn
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score

import numpy as np

from tqdm import tqdm

import datasets
import math

import gc

from typing import List, Set, Dict, Tuple, Optional

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [6]:
path_to_training_set   = "data/raw_training_set.pkl"
path_to_validation_set = "data/raw_validation_set.pkl"
all_datasets_raw = load_datasets_by_language(path_to_training_set, path_to_validation_set)

path_to_training_set   = "data/training_set_stanza.pkl"
path_to_validation_set = "data/validation_set_stanza.pkl"
all_datasets = load_datasets_by_language(path_to_training_set, path_to_validation_set)

Beam search: Repeating patterns
No repeat ngram: Remedy to repeating patterns
Sample from top K: Leads to more variation. However, the setences make less sense given the context
Sample from top p (nucleus): Again more variation, and possibly the words make more sense

In [7]:
def transformers_generate_text(prompt: str, tokenizer, model, max_length: int = 50):
    ## Beam search
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    outputs = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5)
    print("----- Beam Search -----")
    for i, output in enumerate(outputs):
        print("{}: <S>{}<E>".format(i, tokenizer.decode(output, skip_special_tokens=True)))

    ## Beam search + no repeat ngram
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    outputs = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, no_repeat_ngram_size=2)
    print("\n\n----- Beam Search + no repeat ngram=2 -----")
    for i, output in enumerate(outputs):
        print("{}: <S>{}<E>".format(i, tokenizer.decode(output, skip_special_tokens=True)))

    ## sample from Top k words
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    outputs = model.generate(input_ids, do_sample=True, max_length=max_length, top_k=50, num_return_sequences=5, early_stopping=True)
    print("\n\n----- Top 50 words -----")
    for i, output in enumerate(outputs):
        print("{}: <S>{}<E>".format(i, tokenizer.decode(output, skip_special_tokens=True)))

    ## sample from Top p (nucleus) words with top k to filter out low probability words
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    outputs = model.generate(input_ids, do_sample=True, max_length=max_length, top_p=0.80, top_k=50, num_return_sequences=5, early_stopping=True)
    print("\n\n----- Intersection between Top 80% words + Top 50 words -----")
    for i, output in enumerate(outputs):
        print("{}: <S>{}<E>".format(i, tokenizer.decode(output, skip_special_tokens=True)))

# English

## A. Pretrained

In [8]:
model_checkpoint = "distilgpt2"
tokenizer_en = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
model_en = AutoModelForCausalLM.from_pretrained(model_checkpoint, pad_token_id=tokenizer_en.eos_token_id)

## A. Finetune

In [9]:
def chunkify_transformer_tokens(tokens, chunk_size: int = 128):
    concatenated_examples = {k: sum(tokens[k], []) for k in tokens.keys()}
  
    # Drop the last chunk if it's smaller than chunk_size
    total_length = len(concatenated_examples[list(tokens.keys())[0]])
    total_length = (total_length // chunk_size) * chunk_size

    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }

    # Create a new labels column
    result["label"] = result["input_ids"].copy()
    result["labels"] = result["input_ids"].copy()

    return result


In [10]:
def prepare_text_for_finetuning(text: List[List[str]]) -> pd.DataFrame:
    tokens = tokenizer_en(text)
    token_chunks = chunkify_transformer_tokens(tokens)
    return datasets.Dataset.from_dict(token_chunks)

In [11]:
document = all_datasets_raw["en"]["training"]["document_plaintext"].tolist()
train_set = prepare_text_for_finetuning(document)

document = all_datasets_raw["en"]["validation"]["document_plaintext"].tolist()
eval_set = prepare_text_for_finetuning(document)

Token indices sequence length is longer than the specified maximum sequence length for this model (1173 > 1024). Running this sequence through the model will result in indexing errors


In [None]:
training_args = TrainingArguments(
    "lm-finetuned/dump/distilgpt2_en",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=4,
)

trainer = Trainer(
    model=model_en,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=eval_set
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [None]:
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

trainer.train()

eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

***** Running Evaluation *****
  Num examples = 1032
  Batch size = 8


***** Running training *****
  Num examples = 7586
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 4745


Perplexity: 50.34


Epoch,Training Loss,Validation Loss


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt



In [None]:
model_en.save_pretrained("lm-finetuned/distilgpt2-en-finetuned")

## B.

In [None]:
tokenizer_checkpoint = "distilgpt2"
model_checkpoint = tokenizer_checkpoint

tokenizer_en = AutoTokenizer.from_pretrained(tokenizer_checkpoint, use_fast=True)
tokenizer_en.pad_token = tokenizer_en.eos_token
model_en = AutoModelForCausalLM.from_pretrained(model_checkpoint, pad_token_id=tokenizer_en.eos_token_id)

In [None]:
transformers_generate_text("Grasshoppers are", tokenizer_en, model_en)

## Finetuned

In [None]:
tokenizer_checkpoint = "distilgpt2"
model_checkpoint = "lm-finetuned/distilgpt2-en-finetuned"

tokenizer_en = AutoTokenizer.from_pretrained(tokenizer_checkpoint, use_fast=True)
tokenizer_en.pad_token = tokenizer_en.eos_token
model_en = AutoModelForCausalLM.from_pretrained(model_checkpoint, pad_token_id=tokenizer_en.eos_token_id)

In [None]:
transformers_generate_text("Grasshoppers are", tokenizer_en, model_en)

## (C.)

## D.

We can you use trainer here to compute the perplexity.
Or use https://huggingface.co/docs/transformers/perplexity

In [17]:
def compute_perplexity(lm_model, dataset):
    trainer = Trainer(model=lm_model, eval_dataset=dataset)
    eval_results = trainer.evaluate()
    perplexity = math.exp(eval_results['eval_loss'])
    return perplexity

## E.

In [7]:
tokenizer_checkpoint = "distilgpt2"
#model_checkpoint = "lm-finetuned/distilgpt2-en-finetuned"
model_checkpoint = "distilgpt2"

tokenizer_en = AutoTokenizer.from_pretrained(tokenizer_checkpoint, use_fast=True)
tokenizer_en.pad_token = tokenizer_en.eos_token
model_en = AutoModelForCausalLM.from_pretrained(model_checkpoint, pad_token_id=tokenizer_en.eos_token_id)

Downloading:   0%|          | 0.00/762 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/353M [00:00<?, ?B/s]

In [8]:
class Binary_QA_Domain_Network(nn.Module):

    def __init__(self, input_dim):
        super(Binary_QA_Domain_Network, self).__init__()

        self.hidden1 = nn.Linear(2*input_dim, 128)
        self.hidden2 = nn.Linear(128, 64)
        self.linear_out = nn.Linear(64, 2)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(.2)
        
    def forward(self, question, document, lexical_feature):
        x = torch.cat((question, document), -1)
        x = self.relu(self.hidden1(x))
        x = self.dropout(x)
        x = self.relu(self.hidden2(x))
        x = self.dropout(x)
        return self.linear_out(x)

In [9]:
net = Binary_QA_Domain_Network(768)
net

Binary_QA_Domain_Network(
  (hidden1): Linear(in_features=1536, out_features=128, bias=True)
  (hidden2): Linear(in_features=128, out_features=64, bias=True)
  (linear_out): Linear(in_features=64, out_features=2, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [10]:
net(torch.zeros(768), torch.ones(768))

tensor([ 0.0195, -0.1124], grad_fn=<AddBackward0>)

In [7]:
class IsQuestionAnsweredDatasetLM(Dataset):
    
    def __init__(self, dataset_question_embeddings: pd.DataFrame, dataset_document_embeddings: pd.DataFrame, dataset_lexical: np.ndarray, labels: np.ndarray):
        self.dataset_question_embeddings = dataset_question_embeddings
        self.dataset_document_embeddings = dataset_document_embeddings
        self.dataset_lexical = dataset_lexical

        self.labels = labels
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        question = self.dataset_question_embeddings.values[idx]
        document = self.dataset_document_embeddings.values[idx]
        lexical = dataset_lexical[:, idx]
        
        print(dataseet_lexical.shape)
        print(lexical)
        print(lexical.shape)

        label = self.labels[idx]
        
        return question, document, lexical, label

In [8]:
def get_labels_from_raw_dataset(dataset: pd.DataFrame):
    annotation_column = dataset['annotations']
    
    labels = np.empty(annotation_column.shape[0], dtype=np.int32)
    for i, annotation in enumerate(annotation_column):
        labels[i] = 0 if annotation["answer_start"][0] == -1 else 1
        
    return labels

In [9]:
def apply_mean_pooling(model_output, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
    masked_model_output = model_output * input_mask_expanded
    
    sum_embeddings = torch.sum(masked_model_output, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    mean_pool = sum_embeddings / sum_mask
    
    return mean_pool

In [10]:
def get_lm_last_hidden_state(model_output, attention_mask):
    num_tokens = torch.sum(attention_mask)
    last = model_output[:, num_tokens-1, :]
    return last

In [11]:
def compute_lm_last_hidden_states(text, lm_tokenizer, lm_model):
    tokens = lm_tokenizer(text, padding=True, truncation=True, max_length=1024, return_tensors="pt")
    attention_mask = tokens["attention_mask"]
  
    with torch.no_grad():
        lm_output = lm_model(tokens["input_ids"], output_hidden_states=True)
    
    lm_hidden_states_all_decoders = lm_output["hidden_states"]
    lm_hidden_states_last_decoder = lm_hidden_states_all_decoders[-1].detach()

    return lm_hidden_states_last_decoder, attention_mask

In [12]:
def preprocess_with_language_model(lm_tokenizer, lm_model, dataset_column):
    n_obs = dataset_column.shape[0]
    emb_dims = lm_model.config.n_embd

    storage = np.zeros((n_obs, 2, emb_dims), dtype=np.float32)
    for i in tqdm(range(n_obs)):
        element = dataset_column.iloc[i]

        hidden_states, attention_mask = compute_lm_last_hidden_states(element, lm_tokenizer, lm_model)

        pooled = apply_mean_pooling(hidden_states, attention_mask)
        last_state = get_lm_last_hidden_state(hidden_states, attention_mask)

        storage[i,0] = pooled.numpy()
        storage[i,1] = last_state.numpy()

        gc.collect()

    return {"mean_pooling": storage[:,0,:], "last_state": storage[:,1,:]}

In [13]:
def preprocess_dataset_with_language_model_then_save(lm_tokenizer, lm_model, dataset: pd.DataFrame, language: str, dataset_type: str):

    def preprocess_dataset_column_with_language_model_then_save(lm_tokenizer, lm_model, dataset: pd.DataFrame, column_name: str, language: str, dataset_type: str):
        output = preprocess_with_language_model(lm_tokenizer, lm_model, dataset[column_name])
        
        mean_df = pd.DataFrame(output["mean_pooling"])
        save_dataset(mean_df, f"data/lm/{language}/{column_name}_mean_pooling_{dataset_type}.pkl")

        last_state_df = pd.DataFrame(output["last_state"])
        save_dataset(last_state_df, f"data/lm/{language}/{column_name}_last_state_{dataset_type}.pkl")


    preprocess_dataset_column_with_language_model_then_save(lm_tokenizer, lm_model, dataset, "question_text", language, dataset_type)
    print("\nCompleted preprocessing question column with language model!")

    preprocess_dataset_column_with_language_model_then_save(lm_tokenizer, lm_model, dataset, "document_plaintext", language, dataset_type)
    print("\nCompleted preprocessing document column with language model!")
    

In [None]:
preprocess_dataset_with_language_model_then_save(tokenizer_en, model_en, all_datasets_raw["en"]["training"], "en", "training")

In [None]:
preprocess_dataset_with_language_model_then_save(tokenizer_en, model_en, all_datasets_raw["en"]["validation"], "en", "validation")

In [18]:
def evaluate_inplace(model: nn.Module, valid_dl: DataLoader):
    model.eval()

    logits_all = []
    targets_all = []

    with torch.no_grad():
        for batch in valid_dl:
            questions, documents, targets = batch
            
            questions = questions.to(device)
            documents = documents.to(device)
            targets = targets.long().to(device)

            logits = model(questions, documents)

            targets_all.extend(targets.detach().cpu().numpy())

            predictions = logits.max(1)[1].cpu()
            logits_all.extend(predictions)
            
        acc = accuracy_score(logits_all, targets_all)

    return acc

In [19]:
def load_lm_dataset_as_torch_dataset(language_code: str, dataset_partition: str, all_datasets_raw: pd.DataFrame, all_datasets: pd.DataFrame):
    """
      Example usage:
        load_lm_dataset_for_torch_dataset("ja"/"fi"/"ja", "validation"/"training", datasets)
    """
    
    question_data = load_dataset(f"data/lm/{language_code}/question_text_last_state_{dataset_partition}.pkl")
    document_data = load_dataset(f"data/lm/{language_code}/document_plaintext_last_state_{dataset_partition}.pkl")
    labels = get_labels_from_raw_dataset(all_datasets_raw[language_code][dataset_partition])

    token_to_id = create_token_to_id_mapping_from_token_sentences(all_datasets_raw[language_code]["training"]["question_text"])
    lexical_data = get_lexical_features_from_dataset(all_datasets[language_code][dataset_partition], token_to_id)

    return IsQuestionAnsweredDatasetLM(question_data, document_data, lexical_data, labels)

In [20]:
def train(model: nn.Module, lm_tokenizer, lm_model, training_dataset, validation_dataset, n_epochs = 15, batch_size = 16, weight_decay=1e-7):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=weight_decay)

    # Load dataset
    train_dataloader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

    # store improvement per epoch
    train_losses = []
    train_accuracies = []
    test_accuracies = []

    for epoch in range(n_epochs):
        
        ### Training
        model.train()

        # Store batch loss and accuracy
        loss_epoch = []
        accuracy_epoch = []

        batch_pbar = tqdm(train_dataloader)
        for batch in batch_pbar:
            questions, documents, targets = batch
            
            questions = questions.float().to(device)
            documents = documents.float().to(device)
            targets = targets.long().to(device)

            # training
            outputs = model(questions, documents)
            loss = loss_fn(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # prediction
            predictions = outputs.detach().cpu().max(1)[1]
            accuracy = accuracy_score(targets.detach().cpu(), predictions)

            loss_epoch.append(loss.detach().cpu().item())
            accuracy_epoch.append(accuracy)

            batch_pbar.set_description(f"epoch={epoch+1}/{n_epochs} | loss={loss.item():.2f}, accuracy={accuracy:.2f}")

            gc.collect()

        train_loss = np.mean(loss_epoch)
        train_losses.append(train_loss)

        train_acc = np.mean(accuracy_epoch)
        train_accuracies.append(train_acc)

        ### Evaluation
        test_acc = evaluate_inplace(model, val_dataloader)
        test_accuracies.append(test_acc)

        print(f"epoch={epoch+1}/{n_epochs} | loss={train_loss:.2f}, train_accuracy={train_acc:.2f}, test_accuracy={test_acc:.2f}")

    print("Finished training.")
    
    return train_losses, train_accuracies, test_accuracies

In [21]:
training_dataset = load_lm_dataset_as_torch_dataset("en", "training", all_datasets_raw, all_datasets)
validation_dataset = load_lm_dataset_as_torch_dataset("en", "validation", all_datasets_raw, all_datasets)

net = Binary_QA_Domain_Network(768)
net.to(device)

train(net, tokenizer_en, model_en, training_dataset, validation_dataset, weight_decay=0)

epoch=1/15 | loss=0.58, accuracy=0.85: 100%|██████████| 462/462 [02:09<00:00,  3.57it/s]


epoch=1/15 | loss=0.72, train_accuracy=0.57, test_accuracy=0.67


epoch=2/15 | loss=0.64, accuracy=0.62: 100%|██████████| 462/462 [01:58<00:00,  3.89it/s]


epoch=2/15 | loss=0.64, train_accuracy=0.64, test_accuracy=0.74


epoch=3/15 | loss=0.51, accuracy=0.85: 100%|██████████| 462/462 [01:57<00:00,  3.92it/s]


epoch=3/15 | loss=0.61, train_accuracy=0.68, test_accuracy=0.74


epoch=4/15 | loss=0.42, accuracy=0.92: 100%|██████████| 462/462 [01:57<00:00,  3.94it/s]


epoch=4/15 | loss=0.58, train_accuracy=0.71, test_accuracy=0.73


epoch=5/15 | loss=0.88, accuracy=0.46: 100%|██████████| 462/462 [01:56<00:00,  3.98it/s]


epoch=5/15 | loss=0.59, train_accuracy=0.70, test_accuracy=0.67


epoch=6/15 | loss=0.68, accuracy=0.62: 100%|██████████| 462/462 [01:57<00:00,  3.93it/s]


epoch=6/15 | loss=0.59, train_accuracy=0.70, test_accuracy=0.56


epoch=7/15 | loss=0.51, accuracy=0.69: 100%|██████████| 462/462 [01:57<00:00,  3.95it/s]


epoch=7/15 | loss=0.59, train_accuracy=0.69, test_accuracy=0.68


epoch=8/15 | loss=0.55, accuracy=0.69: 100%|██████████| 462/462 [01:56<00:00,  3.95it/s]


epoch=8/15 | loss=0.59, train_accuracy=0.70, test_accuracy=0.75


epoch=9/15 | loss=0.62, accuracy=0.62: 100%|██████████| 462/462 [01:55<00:00,  3.99it/s]


epoch=9/15 | loss=0.57, train_accuracy=0.72, test_accuracy=0.73


epoch=10/15 | loss=0.60, accuracy=0.62: 100%|██████████| 462/462 [01:56<00:00,  3.96it/s]


epoch=10/15 | loss=0.58, train_accuracy=0.71, test_accuracy=0.75


epoch=11/15 | loss=0.38, accuracy=0.92: 100%|██████████| 462/462 [01:57<00:00,  3.94it/s]


epoch=11/15 | loss=0.59, train_accuracy=0.70, test_accuracy=0.76


epoch=12/15 | loss=0.54, accuracy=0.85: 100%|██████████| 462/462 [01:55<00:00,  4.00it/s]


epoch=12/15 | loss=0.58, train_accuracy=0.70, test_accuracy=0.75


epoch=13/15 | loss=0.64, accuracy=0.77: 100%|██████████| 462/462 [01:56<00:00,  3.98it/s]


epoch=13/15 | loss=0.60, train_accuracy=0.68, test_accuracy=0.73


epoch=14/15 | loss=0.69, accuracy=0.54: 100%|██████████| 462/462 [01:56<00:00,  3.98it/s]


epoch=14/15 | loss=0.58, train_accuracy=0.71, test_accuracy=0.75


epoch=15/15 | loss=0.81, accuracy=0.69: 100%|██████████| 462/462 [01:56<00:00,  3.97it/s]

epoch=15/15 | loss=0.58, train_accuracy=0.70, test_accuracy=0.77
Finished training.





([0.7189220337769686,
  0.6357856839250177,
  0.610446746279667,
  0.5800585482930725,
  0.5887034567661615,
  0.5869458607393941,
  0.5933819540264287,
  0.5880106286311046,
  0.5705619444707771,
  0.5805476021457028,
  0.585503682300642,
  0.5795421604212229,
  0.5966437379003087,
  0.5776350487878312,
  0.5822583298275481],
 [0.5687957875457875,
  0.639183732933733,
  0.6755328005328004,
  0.7085726773226773,
  0.700809607059607,
  0.696407758907759,
  0.6914335664335665,
  0.7022560772560772,
  0.7160235597735598,
  0.7114239926739927,
  0.7031614219114218,
  0.7021832334332334,
  0.6842948717948718,
  0.7086871461871462,
  0.704285298035298],
 [0.6747474747474748,
  0.7383838383838384,
  0.7404040404040404,
  0.7343434343434343,
  0.6686868686868687,
  0.5616161616161616,
  0.6797979797979798,
  0.7484848484848485,
  0.7292929292929293,
  0.7474747474747475,
  0.7626262626262627,
  0.7525252525252525,
  0.7272727272727273,
  0.7545454545454545,
  0.7676767676767676])

## Experimential

In [None]:
class IsQuestionAnsweredDataset2(Dataset):
    
    def __init__(self, dataset: pd.DataFrame):
        self.dataset = dataset
        self.labels = get_labels_from_raw_dataset(dataset)
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        row = self.dataset.values[idx]
        question = row[0]
        document = row[4]

        label = self.labels[idx]
        
        return question, document, label

In [None]:
def train2(model: nn.Module, lm_tokenizer, lm_model, train_data, validation_data, n_epochs = 5, batch_size = 16, weight_decay=1e-6):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=weight_decay)

    # Load dataset
    train_dataset = IsQuestionAnsweredDataset2(train_data)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_dataset = IsQuestionAnsweredDataset2(validation_data)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # store improvement per epoch
    train_losses = []
    train_accuracies = []
    test_accuracies = []

    for epoch in range(n_epochs):
        
        ### Training
        model.train()

        # Store batch loss and accuracy
        loss_epoch = []
        accuracy_epoch = []

        batch_pbar = tqdm(train_dataloader)
        for batch in batch_pbar:
            questions, documents, targets = batch

            lm_hidden_states, lm_attention_mask = compute_lm_last_hidden_states(list(questions), lm_tokenizer, lm_model)
            questions_input = apply_mean_pooling(lm_hidden_states, lm_attention_mask)

            lm_hidden_states, lm_attention_mask = compute_lm_last_hidden_states(list(documents), lm_tokenizer, lm_model)
            documents_input = apply_mean_pooling(lm_hidden_states, lm_attention_mask)

            questions_input = questions_input.to(device)
            documents_input = documents_input.to(device)
            targets = targets.long().to(device)

            # training
            outputs = model(questions_input, documents_input)
            loss = loss_fn(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # prediction
            predictions = outputs.detach().cpu().max(1)[1]
            accuracy = accuracy_score(targets.detach().cpu(), predictions)

            loss_epoch.append(loss.detach().cpu().item())
            accuracy_epoch.append(accuracy)

            batch_pbar.set_description(f"epoch={epoch+1}/{n_epochs} | loss={loss.item():.2f}, accuracy={accuracy:.2f}")

            gc.collect()

        train_loss = np.mean(loss_epoch)
        train_losses.append(train_loss)

        train_acc = np.mean(accuracy_epoch)
        train_accuracies.append(train_acc)

        ### Evaluation
        #test_acc = evaluate_inplace(model, val_dataloader) # @TODO: copy&paste
        test_acc = 0
        test_accuracies.append(test_acc)

        # @TODO: live plot
        print(f"epoch={epoch+1}/{n_epochs} | loss={train_loss:.2f}, train_accuracy={train_acc:.2f}, test_accuracy={test_acc:.2f}")

    print("Finished training.")
    
    return train_losses, train_accuracies, test_accuracies

In [None]:
net = Binary_QA_Domain_Network(768)
net.to(device)

train2(net, tokenizer_en, model_en, dataset["en"]["train_set"], dataset["en"]["validation_set"])

# Finnish

## A.

In [15]:
tokenizer_fi = AutoTokenizer.from_pretrained("Finnish-NLP/gpt2-finnish")
model_fi = AutoModelForCausalLM.from_pretrained("Finnish-NLP/gpt2-finnish", pad_token_id=tokenizer_fi.eos_token_id)

## B.

In [None]:
transformers_generate_text("valaat ovat", tokenizer_fi, model_fi)

----- Beam Search -----
0: <S>valaat ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta,<E>
1: <S>valaat ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta.<E>
2: <S>valaat ovat olleet käytössä jo yli kymmenen vuoden ajan, ja ne ovat olleet käytössä jo yli kymmenen vuoden ajan, ja ne ovat olleet käytössä jo yli kymmenen vuoden ajan, ja ne ovat olleet käytössä jo yli kymmenen vuoden ajan, ja ne ovat olleet käytössä jo<E>
3: <S>valaat ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käytössä jo yli kymmenen vuotta, ja ne ovat olleet käyt

## D.

In [16]:
preprocess_dataset_with_language_model_then_save(tokenizer_fi, model_fi, all_datasets_raw["fi"]["training"], "fi", "training")

100%|██████████████████████████████████████████████████████████████████████████| 13701/13701 [1:08:36<00:00,  3.33it/s]



Completed preprocessing question column with language model!


100%|██████████████████████████████████████████████████████████████████████████| 13701/13701 [1:57:05<00:00,  1.95it/s]



Completed preprocessing document column with language model!


In [29]:
preprocess_dataset_with_language_model_then_save(tokenizer_fi, model_fi, all_datasets_raw["fi"]["validation"], "fi", "validation")

100%|██████████████████████████████████████████████████████████████████████████████| 1686/1686 [07:48<00:00,  3.60it/s]



Completed preprocessing question column with language model!


100%|██████████████████████████████████████████████████████████████████████████████| 1686/1686 [15:12<00:00,  1.85it/s]


Completed preprocessing document column with language model!





# Japanese

## A.

In [13]:
tokenizer_ja = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium")
model_ja = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium", pad_token_id=tokenizer_ja.eos_token_id)

## B.

In [None]:
transformers_generate_text("クジラは", tokenizer_ja, model_ja)

## D.

In [26]:
preprocess_dataset_with_language_model_then_save(tokenizer_ja, model_ja, all_datasets_raw["ja"]["training"], "ja", "training")

100%|██████████████████████████████████████████████████████████████████████████████| 8778/8778 [53:48<00:00,  2.72it/s]



Completed preprocessing question column with language model!


100%|████████████████████████████████████████████████████████████████████████████| 8778/8778 [4:25:20<00:00,  1.81s/it]



Completed preprocessing document column with language model!


In [25]:
preprocess_dataset_with_language_model_then_save(tokenizer_ja, model_ja, all_datasets_raw["ja"]["validation"], "ja", "validation")

100%|██████████████████████████████████████████████████████████████████████████████| 1036/1036 [06:33<00:00,  2.63it/s]



Completed preprocessing question column with language model!


100%|██████████████████████████████████████████████████████████████████████████████| 1036/1036 [17:46<00:00,  1.03s/it]


Completed preprocessing document column with language model!



