In [1]:
!pip install transformers==4.12.5
!pip install pandas==1.3.4
!pip install tqdm==4.62.3
!pip install nltk==3.6.5



In [2]:
import pandas as pd
from transformers import BertTokenizerFast, GPT2TokenizerFast, BertModel, GPTNeoForCausalLM
from joblib import Parallel, delayed
import numpy as np
from itertools import chain
from tqdm import tqdm
import nltk

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/jupyter/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
RETRIEVER_BERT_MODEL = "huawei-noah/TinyBERT_General_4L_312D"
GENERATOR_GPTNEO_MODEL = "EleutherAI/gpt-neo-1.3B"

In [5]:
retriever_tokenizer = BertTokenizerFast.from_pretrained(RETRIEVER_BERT_MODEL)
retriever_tokenizer.add_tokens(["[STORY]", "[EXTRA]", "[RETRIEVE]"])

generator_tokenizer = GPT2TokenizerFast.from_pretrained(GENERATOR_GPTNEO_MODEL)
generator_tokenizer.add_tokens(["[TAGS]", "[INIT]", "[PROMPT]", "[TEXT]", "[INPUT]", "[OUTPUT]"])

6

In [6]:
retriever = BertModel.from_pretrained(RETRIEVER_BERT_MODEL)
generator = GPTNeoForCausalLM.from_pretrained(GENERATOR_GPTNEO_MODEL)

Some weights of the model checkpoint at huawei-noah/TinyBERT_General_4L_312D were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'fit_denses.1.bias', 'cls.predictions.transform.LayerNorm.weight', 'fit_denses.3.bias', 'cls.predictions.bias', 'fit_denses.2.weight', 'cls.seq_relationship.bias', 'fit_denses.1.weight', 'cls.seq_relationship.weight', 'fit_denses.4.weight', 'fit_denses.4.bias', 'fit_denses.3.weight', 'fit_denses.2.bias', 'fit_denses.0.weight', 'cls.predictions.transform.dense.weight', 'fit_denses.0.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing

In [7]:
def sync_tokens(tokenizer, model):
    if len(tokenizer) % 8 != 0:
        tokenizer.add_tokens([
            f"[DUMB{i}]"
            for i in range(8 - len(tokenizer) % 8)
        ])
    model.resize_token_embeddings(len(tokenizer))
    return tokenizer, model

In [8]:
retriever_tokenizer, retriever = sync_tokens(retriever_tokenizer, retriever)
generator_tokenizer, generator = sync_tokens(generator_tokenizer, generator)

In [9]:
def parallel_apply(df, func, chunk_size, process_count):
    chunk_count = int(np.ceil(len(df) / chunk_size))
    return list(chain(*Parallel(n_jobs=process_count)(
        delayed(func)(df.iloc[i * chunk_size : (i + 1) * chunk_size])
        for i in tqdm(range(chunk_count))
    )))

In [10]:
def encode_input_ids(input_ids):
    return np.array(input_ids, dtype=np.int32).tobytes()

def decode_input_ids(buffer):
    return np.frombuffer(buffer, dtype=np.int32)

def apply_tokenizer(texts, tokenizer):
    return [
        encode_input_ids(row)
        for row in tokenizer(list(texts))["input_ids"]
    ]

# Reading data

In [11]:
df_sentences_guttenberg = pd.read_csv("data/guttenberg-sentences-sampled.csv")
df_sentences_story = pd.read_csv("data/cleaned/story-sentences.csv")
df_sentences_context_mapping = pd.read_csv("data/cleaned/story-context-sentence-mapping-numeric-id.csv")
df_story_content = pd.read_csv("data/cleaned/story-trees-numeric-id.csv")
df_stories_train = pd.read_csv("data/cleaned/stories-train.csv")
df_stories_test = pd.read_csv("data/cleaned/stories-test.csv")

In [12]:
id2children_count = df_story_content.groupby(["story_id", "parent_id"])["id"].nunique().to_dict()
df_story_content["children_count"] = df_story_content[["story_id", "id"]].apply(
    lambda row: id2children_count.get((row["story_id"], row["id"]), 0),
    axis=1
)
df_story_content.head()

Unnamed: 0,id,parent_id,input,output,story_id,children_count
0,0,-1,[ROOT],"The land of Kronnland is a mythical, wonderful...",12487,3
1,1,0,Start Danny's Campaign,Danny Blaze\nBackground :\nBorn in the summer ...,12487,1
2,2,1,Continue,With all the townsfolk transformed into mindle...,12487,2
3,3,2,Get back to Bren and warn him about the danger.,You run down the hill as Andrew's army regroup...,12487,0
4,4,2,Watch the battle from your hideout.,"Although worried, you stay in your hideout and...",12487,2


In [13]:
df_sentences_guttenberg["retriever_input_ids"] = parallel_apply(
    df_sentences_guttenberg,
    lambda df: apply_tokenizer("[EXTRA] " + df["text"].fillna(""), retriever_tokenizer),
    9192,
    -1
)
df_sentences_guttenberg["generator_input_ids"] = parallel_apply(
    df_sentences_guttenberg,
    lambda df: apply_tokenizer("[PROMPT] " + df["text"].fillna(""), generator_tokenizer),
    9192,
    -1
)

100%|██████████| 184/184 [00:28<00:00,  6.55it/s]
 59%|█████▊    | 108/184 [00:16<00:13,  5.45it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2103 > 2048). Running this sequence through the model will result in indexing errors
100%|██████████| 184/184 [00:30<00:00,  6.05it/s]


In [14]:
df_sentences_story["retriever_input_ids"] = parallel_apply(
    df_sentences_story,
    lambda df: apply_tokenizer("[STORY] " + df["text"].fillna(""), retriever_tokenizer),
    9192,
    -1
)
df_sentences_story["generator_input_ids"] = parallel_apply(
    df_sentences_story,
    lambda df: apply_tokenizer("[PROMPT] " + df["text"].fillna(""), generator_tokenizer),
    9192,
    -1
)

100%|██████████| 68/68 [00:04<00:00, 15.67it/s]
100%|██████████| 68/68 [00:05<00:00, 12.50it/s]


In [15]:
df_story_content["input_retriever_input_ids"] = parallel_apply(
    df_story_content,
    lambda df: apply_tokenizer("[RETRIEVE] " + df["input"].fillna(""), retriever_tokenizer),
    9192,
    -1
)
df_story_content["input_generator_input_ids"] = parallel_apply(
    df_story_content,
    lambda df: apply_tokenizer("[INPUT] " + df["input"].fillna(""), generator_tokenizer),
    9192,
    -1
)

100%|██████████| 7/7 [00:00<00:00, 5315.97it/s]
100%|██████████| 7/7 [00:00<00:00, 5924.16it/s]


In [16]:
df_story_content["output_generator_input_ids"] = parallel_apply(
    df_story_content,
    lambda df: apply_tokenizer("[OUTPUT] " + df["output"].fillna(""), generator_tokenizer),
    9192,
    -1
)
df_story_content["output_retriever_input_ids"] = parallel_apply(
    df_story_content,
    lambda df: apply_tokenizer("[RETRIEVE] " + df["output"].fillna(""), retriever_tokenizer),
    9192,
    -1
)

100%|██████████| 7/7 [00:00<00:00, 6036.21it/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (2157 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2219 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2105 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2362 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2056 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is

In [17]:
df_sentences_guttenberg = df_sentences_guttenberg.loc[
    (df_sentences_guttenberg["retriever_input_ids"].apply(len) // 4) <= 256
]
df_sentences_guttenberg = df_sentences_guttenberg.loc[
    (df_sentences_guttenberg["generator_input_ids"].apply(len) // 4) <= 256
]
df_sentences_guttenberg = df_sentences_guttenberg.reset_index(drop=True)
df_sentences_guttenberg.head()

Unnamed: 0,cluster,text,retriever_input_ids,generator_input_ids
0,0,"""Gobryas is there?""",b'e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...,b'S\xc4\x00\x00n\x01\x00\x00&\x00\x00\x00\xa0\...
1,0,His name's Gonzago.,b'e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...,b'S\xc4\x00\x00_\t\x00\x00\x9e\x05\x00\x00R\x0...
2,0,"Goneril, gonəril.",b'e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...,b'S\xc4\x00\x00\x92\x01\x00\x00\x9b8\x00\x00Z\...
3,0,"In discussing the character of Hlestakov, the ...","b""e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...",b'S\xc4\x00\x00*\x02\x00\x00\x86+\x00\x00\x06\...
4,0,Gomalco Productions.,b'e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...,b'S\xc4\x00\x00\x92\x01\x00\x00\x80F\x00\x001\...


In [18]:
df_sentences_story = df_sentences_story.loc[
    (df_sentences_story["retriever_input_ids"].apply(len) // 4) <= 256
]
df_sentences_story = df_sentences_story.loc[
    (df_sentences_story["generator_input_ids"].apply(len) // 4) <= 256
]
df_sentences_story = df_sentences_story.reset_index(drop=True)
df_sentences_story.head()

Unnamed: 0,id,text,retriever_input_ids,generator_input_ids
0,0,"""Sorry, Soren.""",b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00n\x01\x00\x0018\x00\x00\x0b\x00...
1,3421,Are they alive?,b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00\x87\x10\x00\x00\xe4\x01\x00\x0...
2,3420,What DO you do?,b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00K\x07\x00\x00\xda \x00\x00Y\x01...
3,3419,"""This is yours.",b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00n\x01\x00\x00\xbc\x04\x00\x00>\...
4,3418,Leave the halls,b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00&D\x00\x00\x06\x01\x00\x00\x1e_...


# Training

In [19]:
import torch
from itertools import chain
from sklearn.neighbors import NearestNeighbors
import torch.nn.functional as F
from torch.optim import Adam

In [20]:
EPS = 1e-4
RETRIEVER_INPUT_MAX_LENGTH = 256
RETRIEVER_LAST_OUTPUT_SENTENCES = 3
KNN_N_NEIGHBOURS = 4
MAX_RELEVANT_TOKENS = 512
LR = 1e-4

In [21]:
def get_batch_embeddings(retriever, retriever_tokenizer, input_ids):
    def _prepare_ids(ids, max_length):
        ids = list(ids)
        if len(ids) < max_length:
            return ids + [retriever_tokenizer.pad_token_id] * (max_length - len(ids))
        else:
            return ids[:max_length]

    max_length = max([len(item) for item in input_ids])
    if max_length % 8 != 0:
        max_length += 8 - max_length % 8
    if max_length > RETRIEVER_INPUT_MAX_LENGTH:
        max_length = RETRIEVER_INPUT_MAX_LENGTH
    
    padded_input_ids = torch.LongTensor([
        _prepare_ids(item, max_length) for item in input_ids
    ])
    attention_mask = padded_input_ids != retriever_tokenizer.pad_token_id
    hidden_state = retriever(input_ids=padded_input_ids.to(retriever.device),
                             attention_mask=attention_mask.to(retriever.device),
                             output_hidden_states=True).last_hidden_state
    cls_embedding = hidden_state[:, 0, :]
    cls_embedding_norm = torch.sqrt( (cls_embedding ** 2).sum(dim=-1, keepdims=True) ) + EPS
    return cls_embedding / cls_embedding_norm

In [22]:
def get_embeddings(retriever, retriever_tokenizer, inputs, batch_size, verbose=False):
    inputs = list(inputs)
    df_sort = pd.DataFrame({
        "index": range(len(inputs)),
        "inputs": inputs,
        "length": [len(row) for row in inputs]
    })
    df_sort = df_sort.sort_values("length", ascending=False)

    embeddings = np.zeros([len(inputs), retriever.config.hidden_size], dtype=np.float16)

    batch_count = int(np.ceil(len(inputs) / batch_size))
    with torch.no_grad():
        iterable = range(batch_count)
        if verbose:
            iterable = tqdm(iterable)
        for i in iterable:
            batch_df_sort = df_sort.iloc[i * batch_size : (i + 1) * batch_size]
            batch_input_ids = batch_df_sort["inputs"].apply(decode_input_ids).tolist()
            batch_embeddings_torch = get_batch_embeddings(retriever, retriever_tokenizer, batch_input_ids)
            embeddings[batch_df_sort["index"].tolist()] = batch_embeddings_torch.detach().cpu().numpy()
    
    return embeddings

In [23]:
def get_cached_prompt_embeddings(retriever, generator, retriever_tokenizer, df_sentences_guttenberg, df_sentences_story):
    print("UPDATING EMBEDDINGS CACHE")
    retriever.eval()
    retriever.cuda()
    extra_embeddings = get_embeddings(retriever, retriever_tokenizer, df_sentences_guttenberg["retriever_input_ids"], 64, verbose=True)
    story_embeddings = get_embeddings(retriever, retriever_tokenizer, df_sentences_story["retriever_input_ids"], 64, verbose=True)
    return extra_embeddings, story_embeddings

In [24]:
from dataclasses import dataclass
from typing import List


@dataclass
class RetrieverInput:
    input_ids: np.ndarray
    story_sentences: np.ndarray


@dataclass
class StoryInputSample:
    generator_input_ids: np.ndarray
    generator_input_weights: np.ndarray
    retriever_inputs: List[RetrieverInput]


def get_rows(df_story_content, id, story_id):
    rows = []
    df_story_content = df_story_content.loc[df_story_content["story_id"] == story_id].set_index("id")
    while id != -1:
        try:
            row = df_story_content.loc[[id]].iloc[0]
        except:
            break
        rows.append(row)
        id = row["parent_id"]
    return rows[::-1]


def extract_story_inputs(rows, df_sentences_context_mapping):
    def _get_generator_inputs(row):
        if row["parent_id"] != -1:
            row_input = list(decode_input_ids(row["input_generator_input_ids"]))
        else:
            row_input = []
        row_output = list(decode_input_ids(row["output_generator_input_ids"]))
        row_content = row_input + row_output

        return row_content

    def _get_previous_sentences(row):
        mask = (df_sentences_context_mapping["story_id"] == row["story_id"]) & \
               (df_sentences_context_mapping["context_id"] == row["parent_id"])
        return np.array(sorted(df_sentences_context_mapping.loc[mask, "sentence_id"]))

    def _get_retriever_input_ids(rows):
        last_row = rows[-1]
        input_sentences_pairs = []
        for row in rows[-2:]:
            if row["parent_id"] != -1:
                input_ids = decode_input_ids(row["input_retriever_input_ids"])
                story_sentence_ids = _get_previous_sentences(row)
                input_sentences_pairs.append(RetrieverInput(input_ids, story_sentence_ids))
        return input_sentences_pairs

    generator_input_ids = []
    generator_children_counts = []
    for row in rows:
        row_content = _get_generator_inputs(row)
        if row["children_count"] == 0:
            generator_children_counts.append((len(row_content), 1))
        else:
            generator_children_counts.append((len(row_content), row["children_count"]))
        generator_input_ids += row_content
    generator_weights = []
    k = 1.0
    for token_count, children_count in generator_children_counts[::-1]:
        k *= (1 / children_count)
        generator_weights += [k] * token_count
    generator_weights = generator_weights[::-1]
    retriever_inputs = _get_retriever_input_ids(rows)
    return StoryInputSample(generator_input_ids, generator_weights, retriever_inputs)

In [25]:
def story_description_encode(df_stories, story_id, generator_tokenizer):
    tags = "[TAGS] " + df_stories.loc[df_stories["id"] == story_id, "tags"].values[0]
    return generator_tokenizer.encode(tags)

In [26]:
def story_input(df_story_content, row_id, story_id, df_sentences_context_mapping, retriever_tokenizer):
    rows = get_rows(df_story_content, row_id, story_id)
    if len(rows) == 1:
        parent_id = rows[-1]["parent_id"]
    else:
        parent_id = rows[-2]["parent_id"]
    retriever_requests = [
        (
            decode_input_ids(rows[-1]["input_retriever_input_ids"]).tolist(),
            df_sentences_context_mapping.loc[
                (df_sentences_context_mapping["story_id"] == story_id) & \
                (df_sentences_context_mapping["context_id"] == parent_id),
                "sentence_id"
            ].tolist()
        )
    ]
    if len(rows) > 1:
        sentences_to_search = df_sentences_context_mapping.loc[
            (df_sentences_context_mapping["story_id"] == story_id) & \
            (df_sentences_context_mapping["context_id"] == parent_id),
            "sentence_id"
        ].tolist()
        query_sentences = nltk.sent_tokenize(rows[-2]["output"])[-RETRIEVER_LAST_OUTPUT_SENTENCES:]
        retriever_requests += [
            (retriever_tokenizer.encode(sent), sentences_to_search)
            for sent in query_sentences
        ]
    return extract_story_inputs(rows, df_sentences_context_mapping), retriever_requests

In [27]:
def dot_distance(x, y):
    return -(x * y).sum()

In [28]:
def get_cached_nearest_df(retriever_request_embeddings, cached_extra_nn, cached_story_nn, df_extra, df_story):
    retriever_request_embeddings_np = retriever_request_embeddings.detach().cpu().numpy()
    extra_indices = []
    extra_distances = []
    for distances, indices in zip(*cached_extra_nn.kneighbors(retriever_request_embeddings_np)):
        extra_indices += list(indices)
        extra_distances += list(distances)
    sub_df_extra = df_extra.iloc[extra_indices][["text", "retriever_input_ids", "generator_input_ids"]]
    sub_df_extra["distance"] = extra_distances

    if cached_story_nn is not None:
        story_indices = []
        story_distances = []
        for distances, indices in zip(*cached_story_nn.kneighbors(retriever_request_embeddings_np)):
            story_indices += list(indices)
            story_distances += list(distances)
        sub_df_story = df_story.iloc[story_indices][["text", "retriever_input_ids", "generator_input_ids"]]
        sub_df_story["distance"] = story_distances
    
    if cached_story_nn is not None:
        df = pd.concat([sub_df_extra, sub_df_story]).reset_index(drop=True)
    else:
        df = sub_df_extra.reset_index(drop=True)
    df = df.sort_values("distance")
    df = df.drop_duplicates("text")
    df = df.head( (RETRIEVER_LAST_OUTPUT_SENTENCES + 1) * KNN_N_NEIGHBOURS)

    return df

In [29]:
def cut_generator_input(input_ids, weights, generator, generator_tokenizer, tag_input_ids, nearest_input_ids):
    max_story_token_count = generator.config.max_position_embeddings - len(tag_input_ids) - len(nearest_input_ids)
    generator_input_ids = input_ids[-max_story_token_count:]
    generator_weights = weights[-max_story_token_count:]

    input_tid, = generator_tokenizer.convert_tokens_to_ids(["[INPUT]"])
    output_tid, = generator_tokenizer.convert_tokens_to_ids(["[OUTPUT]"])

    if generator_input_ids[0] not in {input_tid, output_tid}:
        if input_tid not in generator_input_ids:
            input_start = None
        else:
            input_start = list(generator_input_ids).index(input_tid)
        if output_tid not in generator_input_ids:
            output_start = None
        else:
            output_start = list(generator_input_ids).index(output_tid)
        if output_start is not None and input_start is not None:
            if input_start < output_start:
                start_token = output_tid
            else:
                start_token = input_tid
        elif output_start is not None:
            start_token = input_tid
        else:
            start_token = output_tid
        generator_input_ids[0] = start_token

    return generator_input_ids, generator_weights

In [30]:
def get_nn_input(cached_extra_nn, row, df_story_content, df_stories):
    generator_input, retriever_requests = story_input(df_story_content,
                                                      row["id"],
                                                      row["story_id"],
                                                      df_sentences_context_mapping,
                                                      retriever_tokenizer)
    story_sentence_ids = set(chain(*[sentences for _, sentences in retriever_requests]))
    story_sentence_mask = df_sentences_story["id"].isin(story_sentence_ids)
    story_sentence_count = story_sentence_mask.sum()

    if story_sentence_count > 0:
        cached_story_nn = NearestNeighbors(n_neighbors=min(KNN_N_NEIGHBOURS, int(story_sentence_mask.sum())),
                                        metric=dot_distance,
                                        n_jobs=-1,
                                        algorithm="brute")
        cached_story_nn.fit(cached_story_embeddings[story_sentence_mask])
    else:
        cached_story_nn = None

    retriever_request_embeddings = get_batch_embeddings(retriever, retriever_tokenizer, [
        input_ids
        for input_ids, _ in retriever_requests
    ])
    df_cached_nearest = get_cached_nearest_df(retriever_request_embeddings,
                                              cached_extra_nn,
                                              cached_story_nn,
                                              df_sentences_guttenberg,
                                              df_sentences_story.loc[story_sentence_mask])
    retriever_cached_relevant_embeddings = get_batch_embeddings(
        retriever,
        retriever_tokenizer,
        df_cached_nearest["retriever_input_ids"].apply(decode_input_ids)
    )
    retriever_distances = -retriever_request_embeddings.matmul(retriever_cached_relevant_embeddings.T)

    retriever_nearest_indices = retriever_distances.mean(dim=0).sort().indices.detach().cpu().numpy()
    df_nearest = df_cached_nearest.iloc[retriever_nearest_indices]
    nearest_samples_input_ids = df_nearest["generator_input_ids"].apply(decode_input_ids)    

    nearest_input_ids = np.array(list(chain(*nearest_samples_input_ids))[:MAX_RELEVANT_TOKENS-1])
    nearest_weights = np.zeros([len(nearest_input_ids)])
    
    tags_string = df_stories.loc[df_stories["id"] == row["story_id"], "tags"].values[0]
    if pd.isna(tags_string):
        tags_string = ""
    tag_input_ids = np.array(generator_tokenizer.encode("[TAGS] " + tags_string))
    tag_weights = np.zeros([len(tag_input_ids)])

    generator_input_ids, generator_weights = cut_generator_input(
        generator_input.generator_input_ids,
        generator_input.generator_input_weights,
        generator,
        generator_tokenizer,
        tag_input_ids,
        nearest_input_ids
    )

    input_ids = list(tag_input_ids) + list(nearest_input_ids) + list(generator_input_ids)
    weights = list(tag_weights) + list(nearest_weights) + list(generator_weights)

    return input_ids, weights, retriever_distances

In [31]:
def calc_loss(input_ids, weights, logits):
    batch_size, _, class_count = logits.shape
    input_ids_shifted = input_ids[:, 1:].reshape([-1])
    logits_shifted = logits[:, :-1, :].reshape([-1, class_count])
    tokenwise_ce = F.cross_entropy(logits_shifted,
                                   input_ids_shifted,
                                   reduction="none")\
        .reshape([batch_size, -1])
    loss_samplewise = (tokenwise_ce * weights[:, 1:]).sum(dim=-1) / (weights[:, 1:].sum(dim=-1) + EPS)
    return loss_samplewise.mean()

In [32]:
def get_row_loss(cached_extra_nn, row, df_story_content_train, df_stories_train):
    input_ids, weights, retriever_distances = get_nn_input(cached_extra_nn,
                                                           row,
                                                           df_story_content_train,
                                                           df_stories_train)
    generator_input_ids = torch.LongTensor([input_ids]).to(generator.device)
    generator_weights = torch.FloatTensor([weights]).to(generator.device)
    generator_output = generator(generator_input_ids).logits
    loss = calc_loss(generator_input_ids, generator_weights, generator_output)
    return loss

In [33]:
retriever = retriever.cuda()
generator = generator.cuda()

In [34]:
df_story_content_train = df_story_content.loc[df_story_content["story_id"].isin(df_stories_train["id"])]
df_story_content_train = df_story_content_train.sample(len(df_story_content_train), random_state=42)\
    .reset_index(drop=True)
    
df_story_content_test = df_story_content.loc[df_story_content["story_id"].isin(df_stories_test["id"])]
df_story_content_test = df_story_content_test.sample(len(df_story_content_test), random_state=42)\
    .reset_index(drop=True)

In [35]:
optimizable_params = list(generator.parameters()) + list(retriever.parameters())
optimizer = Adam(optimizable_params, lr=LR)
scaler = torch.cuda.amp.GradScaler()

In [36]:
i = -1
while True:
    i += 1
    if i == len(df_story_content_train):
        i = 0
    
    optimizer.zero_grad()
    with torch.cuda.amp.autocast():
        if i % 500 == 0:
            cached_extra_embeddings, cached_story_embeddings = get_cached_prompt_embeddings(retriever,
                                                                                            generator,
                                                                                            retriever_tokenizer,
                                                                                            df_sentences_guttenberg,
                                                                                            df_sentences_story)
            cached_extra_nn = NearestNeighbors(n_neighbors=KNN_N_NEIGHBOURS, metric=dot_distance, n_jobs=-1, algorithm="ball_tree")
            cached_extra_nn.fit(cached_extra_embeddings)
            retriever.train()
            generator.train()
        row = df_story_content_train.iloc[i]
        loss = get_row_loss(cached_extra_nn, row, df_story_content_train, df_stories_train)    
    print(i, loss.item())
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

UPDATING EMBEDDINGS CACHE


100%|██████████| 26361/26361 [02:43<00:00, 161.31it/s]
100%|██████████| 9749/9749 [00:59<00:00, 164.52it/s]


0 17.497169494628906
1 15.879802703857422
2 14.699311256408691
3 16.381174087524414
4 15.230399131774902
5 15.891133308410645
6 16.060945510864258
7 15.771090507507324
8 14.704065322875977
9 14.49145221710205
10 17.262022018432617
11 8.21269416809082


RuntimeError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 39.59 GiB total capacity; 37.66 GiB already allocated; 20.94 MiB free; 37.88 GiB reserved in total by PyTorch)

In [None]:
generator_input_ids.shape