In [1]:
!pip install transformers==4.12.5
!pip install pandas==1.3.4
!pip install tqdm==4.62.3
!pip install nltk==3.6.5



In [2]:
import pandas as pd
from transformers import BertTokenizerFast, GPT2TokenizerFast, BertModel, GPTNeoForCausalLM
from joblib import Parallel, delayed
import numpy as np
from itertools import chain
from tqdm import tqdm
import nltk

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/jupyter/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
RETRIEVER_BERT_MODEL = "huawei-noah/TinyBERT_General_4L_312D"
GENERATOR_GPTNEO_MODEL = "EleutherAI/gpt-neo-1.3B"

In [5]:
retriever_tokenizer = BertTokenizerFast.from_pretrained(RETRIEVER_BERT_MODEL)
retriever_tokenizer.add_tokens(["[STORY]", "[EXTRA]", "[RETRIEVE]"])

generator_tokenizer = GPT2TokenizerFast.from_pretrained(GENERATOR_GPTNEO_MODEL)
generator_tokenizer.add_tokens(["[TAGS]", "[INIT]", "[PROMPT]", "[TEXT]", "[INPUT]", "[OUTPUT]"])

6

In [6]:
retriever = BertModel.from_pretrained(RETRIEVER_BERT_MODEL)
generator = GPTNeoForCausalLM.from_pretrained(GENERATOR_GPTNEO_MODEL)

Some weights of the model checkpoint at huawei-noah/TinyBERT_General_4L_312D were not used when initializing BertModel: ['fit_denses.4.bias', 'fit_denses.1.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'fit_denses.3.weight', 'fit_denses.1.weight', 'fit_denses.2.weight', 'fit_denses.0.weight', 'fit_denses.2.bias', 'fit_denses.3.bias', 'fit_denses.4.weight', 'cls.predictions.transform.LayerNorm.bias', 'fit_denses.0.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing

In [7]:
def sync_tokens(tokenizer, model):
    if len(tokenizer) % 8 != 0:
        tokenizer.add_tokens([
            f"[DUMB{i}]"
            for i in range(8 - len(tokenizer) % 8)
        ])
    model.resize_token_embeddings(len(tokenizer))
    return tokenizer, model

In [8]:
retriever_tokenizer, retriever = sync_tokens(retriever_tokenizer, retriever)
generator_tokenizer, generator = sync_tokens(generator_tokenizer, generator)

In [9]:
def parallel_apply(df, func, chunk_size, process_count):
    chunk_count = int(np.ceil(len(df) / chunk_size))
    return list(chain(*Parallel(n_jobs=process_count)(
        delayed(func)(df.iloc[i * chunk_size : (i + 1) * chunk_size])
        for i in tqdm(range(chunk_count))
    )))

In [10]:
def encode_input_ids(input_ids):
    return np.array(input_ids, dtype=np.int32).tobytes()

def decode_input_ids(buffer):
    return np.frombuffer(buffer, dtype=np.int32)

def apply_tokenizer(texts, tokenizer):
    return [
        encode_input_ids(row)
        for row in tokenizer(list(texts))["input_ids"]
    ]

# Reading data

In [11]:
df_sentences_guttenberg = pd.read_csv("data/guttenberg-sentences-sampled.csv")
df_sentences_story = pd.read_csv("data/cleaned/story-sentences.csv")
df_sentences_context_mapping = pd.read_csv("data/cleaned/story-context-sentence-mapping-numeric-id.csv")
df_story_content = pd.read_csv("data/cleaned/story-trees-numeric-id.csv")
df_stories_train = pd.read_csv("data/cleaned/stories-train.csv")
df_stories_test = pd.read_csv("data/cleaned/stories-test.csv")

In [12]:
id2children_count = df_story_content.groupby(["story_id", "parent_id"])["id"].nunique().to_dict()
df_story_content["children_count"] = df_story_content[["story_id", "id"]].apply(
    lambda row: id2children_count.get((row["story_id"], row["id"]), 0),
    axis=1
)
df_story_content.head()

Unnamed: 0,id,parent_id,input,output,story_id,children_count
0,0,-1,[ROOT],"The land of Kronnland is a mythical, wonderful...",12487,3
1,1,0,Start Danny's Campaign,Danny Blaze\nBackground :\nBorn in the summer ...,12487,1
2,2,1,Continue,With all the townsfolk transformed into mindle...,12487,2
3,3,2,Get back to Bren and warn him about the danger.,You run down the hill as Andrew's army regroup...,12487,0
4,4,2,Watch the battle from your hideout.,"Although worried, you stay in your hideout and...",12487,2


In [13]:
df_sentences_guttenberg["retriever_input_ids"] = parallel_apply(
    df_sentences_guttenberg,
    lambda df: apply_tokenizer("[EXTRA] " + df["text"].fillna(""), retriever_tokenizer),
    9192,
    -1
)
df_sentences_guttenberg["generator_input_ids"] = parallel_apply(
    df_sentences_guttenberg,
    lambda df: apply_tokenizer("[PROMPT] " + df["text"].fillna(""), generator_tokenizer),
    9192,
    -1
)

100%|██████████| 184/184 [00:28<00:00,  6.54it/s]
 59%|█████▊    | 108/184 [00:17<00:14,  5.40it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2103 > 2048). Running this sequence through the model will result in indexing errors
100%|██████████| 184/184 [00:30<00:00,  5.99it/s]


In [14]:
df_sentences_story["retriever_input_ids"] = parallel_apply(
    df_sentences_story,
    lambda df: apply_tokenizer("[STORY] " + df["text"].fillna(""), retriever_tokenizer),
    9192,
    -1
)
df_sentences_story["generator_input_ids"] = parallel_apply(
    df_sentences_story,
    lambda df: apply_tokenizer("[PROMPT] " + df["text"].fillna(""), generator_tokenizer),
    9192,
    -1
)

100%|██████████| 68/68 [00:04<00:00, 15.04it/s]
100%|██████████| 68/68 [00:05<00:00, 12.53it/s]


In [15]:
df_story_content["input_retriever_input_ids"] = parallel_apply(
    df_story_content,
    lambda df: apply_tokenizer("[RETRIEVE] " + df["input"].fillna(""), retriever_tokenizer),
    9192,
    -1
)
df_story_content["input_generator_input_ids"] = parallel_apply(
    df_story_content,
    lambda df: apply_tokenizer("[INPUT] " + df["input"].fillna(""), generator_tokenizer),
    9192,
    -1
)

100%|██████████| 7/7 [00:00<00:00, 6843.85it/s]
100%|██████████| 7/7 [00:00<00:00, 5956.61it/s]


In [16]:
df_story_content["output_generator_input_ids"] = parallel_apply(
    df_story_content,
    lambda df: apply_tokenizer("[OUTPUT] " + df["output"].fillna(""), generator_tokenizer),
    9192,
    -1
)
df_story_content["output_retriever_input_ids"] = parallel_apply(
    df_story_content,
    lambda df: apply_tokenizer("[RETRIEVE] " + df["output"].fillna(""), retriever_tokenizer),
    9192,
    -1
)

100%|██████████| 7/7 [00:00<00:00, 5920.57it/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (2157 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2219 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2056 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2362 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2105 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is

In [17]:
df_sentences_guttenberg = df_sentences_guttenberg.loc[
    (df_sentences_guttenberg["retriever_input_ids"].apply(len) // 4) <= 256
]
df_sentences_guttenberg = df_sentences_guttenberg.loc[
    (df_sentences_guttenberg["generator_input_ids"].apply(len) // 4) <= 256
]
df_sentences_guttenberg = df_sentences_guttenberg.reset_index(drop=True)
df_sentences_guttenberg.head()

Unnamed: 0,cluster,text,retriever_input_ids,generator_input_ids
0,0,"""Gobryas is there?""",b'e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...,b'S\xc4\x00\x00n\x01\x00\x00&\x00\x00\x00\xa0\...
1,0,His name's Gonzago.,b'e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...,b'S\xc4\x00\x00_\t\x00\x00\x9e\x05\x00\x00R\x0...
2,0,"Goneril, gonəril.",b'e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...,b'S\xc4\x00\x00\x92\x01\x00\x00\x9b8\x00\x00Z\...
3,0,"In discussing the character of Hlestakov, the ...","b""e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...",b'S\xc4\x00\x00*\x02\x00\x00\x86+\x00\x00\x06\...
4,0,Gomalco Productions.,b'e\x00\x00\x00\x07\x04\x00\x00u\x11\x00\x00\t...,b'S\xc4\x00\x00\x92\x01\x00\x00\x80F\x00\x001\...


In [18]:
df_sentences_story = df_sentences_story.loc[
    (df_sentences_story["retriever_input_ids"].apply(len) // 4) <= 256
]
df_sentences_story = df_sentences_story.loc[
    (df_sentences_story["generator_input_ids"].apply(len) // 4) <= 256
]
df_sentences_story = df_sentences_story.reset_index(drop=True)
df_sentences_story.head()

Unnamed: 0,id,text,retriever_input_ids,generator_input_ids
0,0,"""Sorry, Soren.""",b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00n\x01\x00\x0018\x00\x00\x0b\x00...
1,3421,Are they alive?,b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00\x87\x10\x00\x00\xe4\x01\x00\x0...
2,3420,What DO you do?,b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00K\x07\x00\x00\xda \x00\x00Y\x01...
3,3419,"""This is yours.",b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00n\x01\x00\x00\xbc\x04\x00\x00>\...
4,3418,Leave the halls,b'e\x00\x00\x00\x07\x04\x00\x00\xa2\t\x00\x00\...,b'S\xc4\x00\x00&D\x00\x00\x06\x01\x00\x00\x1e_...


# Training

In [19]:
import torch
from itertools import chain
from sklearn.neighbors import NearestNeighbors
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim import Adam, SGD

In [20]:
EPS = 1e-5
RETRIEVER_INPUT_MAX_LENGTH = 256
RETRIEVER_LAST_OUTPUT_SENTENCES = 3
KNN_N_NEIGHBOURS = 4
MAX_RELEVANT_TOKENS = 512
LR = 1e-4
LM_LOSS_ABG = 10

In [21]:
def get_batch_embeddings(retriever, retriever_tokenizer, input_ids):
    def _prepare_ids(ids, max_length):
        ids = list(ids)
        if len(ids) < max_length:
            return ids + [retriever_tokenizer.pad_token_id] * (max_length - len(ids))
        else:
            return ids[:max_length]

    max_length = max([len(item) for item in input_ids])
    if max_length % 8 != 0:
        max_length += 8 - max_length % 8
    if max_length > RETRIEVER_INPUT_MAX_LENGTH:
        max_length = RETRIEVER_INPUT_MAX_LENGTH
    
    padded_input_ids = torch.LongTensor([
        _prepare_ids(item, max_length) for item in input_ids
    ])
    attention_mask = padded_input_ids != retriever_tokenizer.pad_token_id
    hidden_state = retriever(input_ids=padded_input_ids.to(retriever.device),
                             attention_mask=attention_mask.to(retriever.device),
                             output_hidden_states=True).last_hidden_state
    cls_embedding = hidden_state[:, 0, :]
    cls_embedding_norm = torch.sqrt( (cls_embedding ** 2).sum(dim=-1, keepdims=True) ) + EPS
    return cls_embedding / cls_embedding_norm

In [22]:
def get_embeddings(retriever, retriever_tokenizer, inputs, batch_size, verbose=False):
    inputs = list(inputs)
    df_sort = pd.DataFrame({
        "index": range(len(inputs)),
        "inputs": inputs,
        "length": [len(row) for row in inputs]
    })
    df_sort = df_sort.sort_values("length", ascending=False)

    embeddings = np.zeros([len(inputs), retriever.config.hidden_size], dtype=np.float16)

    batch_count = int(np.ceil(len(inputs) / batch_size))
    with torch.no_grad():
        iterable = range(batch_count)
        if verbose:
            iterable = tqdm(iterable)
        for i in iterable:
            batch_df_sort = df_sort.iloc[i * batch_size : (i + 1) * batch_size]
            batch_input_ids = batch_df_sort["inputs"].apply(decode_input_ids).tolist()
            batch_embeddings_torch = get_batch_embeddings(retriever, retriever_tokenizer, batch_input_ids)
            embeddings[batch_df_sort["index"].tolist()] = batch_embeddings_torch.detach().cpu().numpy()
    
    return embeddings

In [23]:
def get_cached_prompt_embeddings(retriever, generator, retriever_tokenizer, df_sentences_guttenberg, df_sentences_story):
    print("UPDATING EMBEDDINGS CACHE")
    retriever.eval()
    retriever.cuda()
    extra_embeddings = get_embeddings(retriever, retriever_tokenizer, df_sentences_guttenberg["retriever_input_ids"], 64, verbose=True)
    story_embeddings = get_embeddings(retriever, retriever_tokenizer, df_sentences_story["retriever_input_ids"], 64, verbose=True)
    return extra_embeddings, story_embeddings

In [24]:
from dataclasses import dataclass
from typing import List


@dataclass
class RetrieverInput:
    input_ids: np.ndarray
    story_sentences: np.ndarray


@dataclass
class StoryInputSample:
    generator_input_ids: np.ndarray
    generator_input_weights: np.ndarray
    retriever_inputs: List[RetrieverInput]


def get_rows(df_story_content, id, story_id):
    rows = []
    df_story_content = df_story_content.loc[df_story_content["story_id"] == story_id].set_index("id")
    while id != -1:
        try:
            row = df_story_content.loc[[id]].iloc[0]
        except:
            break
        rows.append(row)
        id = row["parent_id"]
    return rows[::-1]


def extract_story_inputs(rows, df_sentences_context_mapping):
    def _get_generator_inputs(row):
        if row["parent_id"] != -1:
            row_input = list(decode_input_ids(row["input_generator_input_ids"]))
        else:
            row_input = []
        row_output = list(decode_input_ids(row["output_generator_input_ids"]))
        row_content = row_input + row_output

        return row_content

    def _get_previous_sentences(row):
        mask = (df_sentences_context_mapping["story_id"] == row["story_id"]) & \
               (df_sentences_context_mapping["context_id"] == row["parent_id"])
        return np.array(sorted(df_sentences_context_mapping.loc[mask, "sentence_id"]))

    def _get_retriever_input_ids(rows):
        last_row = rows[-1]
        input_sentences_pairs = []
        for row in rows[-2:]:
            if row["parent_id"] != -1:
                input_ids = decode_input_ids(row["input_retriever_input_ids"])
                story_sentence_ids = _get_previous_sentences(row)
                input_sentences_pairs.append(RetrieverInput(input_ids, story_sentence_ids))
        return input_sentences_pairs

    generator_input_ids = []
    generator_children_counts = []
    for row in rows:
        row_content = _get_generator_inputs(row)
        if row["children_count"] == 0:
            generator_children_counts.append((len(row_content), 1))
        else:
            generator_children_counts.append((len(row_content), row["children_count"]))
        generator_input_ids += row_content
    generator_weights = []
    k = 1.0
    for i, (token_count, children_count) in enumerate(generator_children_counts[::-1]):
        if i > 0:
            k *= (1 / children_count)
        generator_weights += [k] * token_count
    generator_weights = generator_weights[::-1]
    retriever_inputs = _get_retriever_input_ids(rows)
    return StoryInputSample(generator_input_ids, generator_weights, retriever_inputs)

In [25]:
def story_description_encode(df_stories, story_id, generator_tokenizer):
    tags = "[TAGS] " + df_stories.loc[df_stories["id"] == story_id, "tags"].values[0]
    return generator_tokenizer.encode(tags)

In [26]:
def story_input(df_story_content, row_id, story_id, df_sentences_context_mapping, retriever_tokenizer):
    rows = get_rows(df_story_content, row_id, story_id)
    if len(rows) == 1:
        parent_id = rows[-1]["parent_id"]
    else:
        parent_id = rows[-2]["parent_id"]
    retriever_requests = [
        (
            decode_input_ids(rows[-1]["input_retriever_input_ids"]).tolist(),
            df_sentences_context_mapping.loc[
                (df_sentences_context_mapping["story_id"] == story_id) & \
                (df_sentences_context_mapping["context_id"] == parent_id),
                "sentence_id"
            ].tolist()
        )
    ]
    if len(rows) > 1:
        sentences_to_search = df_sentences_context_mapping.loc[
            (df_sentences_context_mapping["story_id"] == story_id) & \
            (df_sentences_context_mapping["context_id"] == parent_id),
            "sentence_id"
        ].tolist()
        query_sentences = nltk.sent_tokenize(rows[-2]["output"])[-RETRIEVER_LAST_OUTPUT_SENTENCES:]
        retriever_requests += [
            (retriever_tokenizer.encode(sent), sentences_to_search)
            for sent in query_sentences
        ]
    return extract_story_inputs(rows, df_sentences_context_mapping), retriever_requests

In [27]:
def dot_distance(x, y):
    return -(x * y).sum()

In [28]:
def get_cached_nearest_df(retriever_request_embeddings, cached_extra_nn, cached_story_nn, df_extra, df_story):
    retriever_request_embeddings_np = retriever_request_embeddings.detach().cpu().numpy()
    extra_indices = []
    extra_distances = []
    for distances, indices in zip(*cached_extra_nn.kneighbors(retriever_request_embeddings_np)):
        extra_indices += list(indices)
        extra_distances += list(distances)
    sub_df_extra = df_extra.iloc[extra_indices][["text", "retriever_input_ids", "generator_input_ids"]]
    sub_df_extra["distance"] = extra_distances

    if cached_story_nn is not None:
        story_indices = []
        story_distances = []
        for distances, indices in zip(*cached_story_nn.kneighbors(retriever_request_embeddings_np)):
            story_indices += list(indices)
            story_distances += list(distances)
        sub_df_story = df_story.iloc[story_indices][["text", "retriever_input_ids", "generator_input_ids"]]
        sub_df_story["distance"] = story_distances
    
    if cached_story_nn is not None:
        df = pd.concat([sub_df_extra, sub_df_story]).reset_index(drop=True)
    else:
        df = sub_df_extra.reset_index(drop=True)
    df = df.sort_values("distance")
    df = df.drop_duplicates("text")
    df = df.head( (RETRIEVER_LAST_OUTPUT_SENTENCES + 1) * KNN_N_NEIGHBOURS)

    return df

In [29]:
def cut_generator_input(input_ids, weights, generator, generator_tokenizer, tag_input_ids, nearest_input_ids):
    max_story_token_count = generator.config.max_position_embeddings - len(tag_input_ids) - len(nearest_input_ids)
    generator_input_ids = input_ids[-max_story_token_count:]
    generator_weights = weights[-max_story_token_count:]

    input_tid, = generator_tokenizer.convert_tokens_to_ids(["[INPUT]"])
    output_tid, = generator_tokenizer.convert_tokens_to_ids(["[OUTPUT]"])

    if generator_input_ids[0] not in {input_tid, output_tid}:
        if input_tid not in generator_input_ids:
            input_start = None
        else:
            input_start = list(generator_input_ids).index(input_tid)
        if output_tid not in generator_input_ids:
            output_start = None
        else:
            output_start = list(generator_input_ids).index(output_tid)
        if output_start is not None and input_start is not None:
            if input_start < output_start:
                start_token = output_tid
            else:
                start_token = input_tid
        elif output_start is not None:
            start_token = input_tid
        else:
            start_token = output_tid
        generator_input_ids[0] = start_token
    
    return generator_input_ids, generator_weights

In [30]:
def get_nn_input(cached_extra_nn, row, df_story_content, df_stories):
    generator_input, retriever_requests = story_input(df_story_content,
                                                      row["id"],
                                                      row["story_id"],
                                                      df_sentences_context_mapping,
                                                      retriever_tokenizer)
    story_sentence_ids = set(chain(*[sentences for _, sentences in retriever_requests]))
    story_sentence_mask = df_sentences_story["id"].isin(story_sentence_ids)
    story_sentence_count = story_sentence_mask.sum()

    if story_sentence_count > 0:
        cached_story_nn = NearestNeighbors(n_neighbors=min(KNN_N_NEIGHBOURS, int(story_sentence_mask.sum())),
                                        metric=dot_distance,
                                        n_jobs=-1,
                                        algorithm="brute")
        cached_story_nn.fit(cached_story_embeddings[story_sentence_mask])
    else:
        cached_story_nn = None

    retriever_request_embeddings = get_batch_embeddings(retriever, retriever_tokenizer, [
        input_ids
        for input_ids, _ in retriever_requests
    ])
    df_cached_nearest = get_cached_nearest_df(retriever_request_embeddings,
                                              cached_extra_nn,
                                              cached_story_nn,
                                              df_sentences_guttenberg,
                                              df_sentences_story.loc[story_sentence_mask])
    retriever_cached_relevant_embeddings = get_batch_embeddings(
        retriever,
        retriever_tokenizer,
        df_cached_nearest["retriever_input_ids"].apply(decode_input_ids)
    )
    retriever_distances = -retriever_request_embeddings.matmul(retriever_cached_relevant_embeddings.T)

    retriever_nearest_indices = retriever_distances.mean(dim=0).sort().indices.detach().cpu().numpy()
    df_nearest = df_cached_nearest.iloc[retriever_nearest_indices]
    nearest_samples_input_ids = df_nearest["generator_input_ids"].apply(decode_input_ids)    

    nearest_input_ids = np.array(list(chain(*nearest_samples_input_ids))[:MAX_RELEVANT_TOKENS-1])
    nearest_weights = np.zeros([len(nearest_input_ids)])
    
    tags_string = df_stories.loc[df_stories["id"] == row["story_id"], "tags"].values[0]
    if pd.isna(tags_string):
        tags_string = ""
    tag_input_ids = np.array(generator_tokenizer.encode("[TAGS] " + tags_string))
    tag_weights = np.zeros([len(tag_input_ids)])

    generator_input_ids, generator_weights = cut_generator_input(
        generator_input.generator_input_ids,
        generator_input.generator_input_weights,
        generator,
        generator_tokenizer,
        tag_input_ids,
        nearest_input_ids
    )

    input_ids = list(tag_input_ids) + list(nearest_input_ids) + list(generator_input_ids)
    weights = list(tag_weights) + list(nearest_weights) + list(generator_weights)

    return input_ids, weights, retriever_distances

In [31]:
def calc_lm_loss(input_ids, weights, logits):
    batch_size, _, class_count = logits.shape
    input_ids_shifted = input_ids[:, 1:].reshape([-1])
    logits_shifted = logits[:, :-1, :].reshape([-1, class_count])
    tokenwise_ce = F.cross_entropy(logits_shifted,
                                   input_ids_shifted,
                                   reduction="none")\
        .reshape([batch_size, -1])
    loss_samplewise = (tokenwise_ce * weights[:, 1:]).sum(dim=-1) / (weights[:, 1:].sum(dim=-1) + EPS)
    return loss_samplewise.mean()

In [32]:
def get_row_loss(cached_extra_nn, row, df_story_content_train, df_stories_train, lm_loss_avg):
    input_ids, weights, retriever_distances = get_nn_input(cached_extra_nn,
                                                           row,
                                                           df_story_content_train,
                                                           df_stories_train)
    generator_input_ids = torch.LongTensor([input_ids]).to(generator.device)
    generator_weights = torch.FloatTensor([weights]).to(generator.device)
    generator_output = generator(generator_input_ids).logits
    lm_loss = calc_lm_loss(generator_input_ids, generator_weights, generator_output)
    distance_loss = (lm_loss - lm_loss_avg) * (1-retriever_distances).mean()
    return lm_loss + distance_loss, lm_loss.item()

In [33]:
class ActivationRegularizationHook():
    def __init__(self, lambda_, norm):
        self.regularization = 0
        self.lambda_ = 0
        
    def norm(self, vec):
        return torch.sqrt((vec ** 2).mean(dim=-1, keepdim=True) + EPS) * (torch.numel(vec) ** 2)
        
    def __call__(self, module, input, output):
        try:
            if isinstance(output, torch.Tensor):
                vec = output
            elif isinstance(output, tuple) or isinstance(output, list):
                vec = output[0]
            self.regularization += (self.norm(vec) * self.lambda_).mean()
        except:
            print(f"ERROR ON {module.__class__}")
            raise

In [34]:
def train_row(regularizer_hook, optimizer, scaler, row, df_story_content_train, df_stories_train, cached_extra_nn, lm_loss_avg, lm_loss_values):
    regularizer_hook.regularization = 0.0
    optimizer.zero_grad()
    with torch.cuda.amp.autocast():
        loss, lm_loss_value = get_row_loss(cached_extra_nn, row, df_story_content_train, df_stories_train, lm_loss_avg)
        loss += regularizer_hook.regularization
        lm_loss_values.append(lm_loss_value)
        lm_loss_avg = sum(lm_loss_values) / len(lm_loss_values)
        if len(lm_loss_values) == LM_LOSS_ABG:
            lm_loss_values = lm_loss_values[-LM_LOSS_ABG:]
    if pd.isna(loss.item()):
        raise ValueError("NAN LOSS")
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    return loss, lm_loss_value, lm_loss_values, lm_loss_avg

In [35]:
def validation_row(row, df_story_content_test, df_stories_test, cached_extra_nn):
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            loss, lm_loss_value = get_row_loss(cached_extra_nn, row, df_story_content_test, df_stories_test, lm_loss_avg)
            
    return loss, lm_loss_value

In [36]:
retriever = retriever.cuda()
generator = generator.cuda()

In [37]:
df_story_content_train = df_story_content.loc[df_story_content["story_id"].isin(df_stories_train["id"])]
df_story_content_train = df_story_content_train.sample(len(df_story_content_train), random_state=42)\
    .reset_index(drop=True)
    
df_story_content_test = df_story_content.loc[df_story_content["story_id"].isin(df_stories_test["id"])]
df_story_content_test = df_story_content_test.sample(len(df_story_content_test), random_state=42)\
    .reset_index(drop=True)

In [38]:
optimizable_params = list(generator.parameters()) + list(retriever.parameters())
optimizer = SGD(optimizable_params, lr=LR)
scaler = torch.cuda.amp.GradScaler()

In [39]:
regularizer_hook = ActivationRegularizationHook(lambda_=0.01,
                                                norm=2)
generator.lm_head.register_forward_hook(regularizer_hook)
generator.transformer.wte.register_forward_hook(regularizer_hook)
generator.transformer.wpe.register_forward_hook(regularizer_hook)
generator.transformer.ln_f.register_forward_hook(regularizer_hook)
for layer in generator.transformer.h:
    layer.register_forward_hook(regularizer_hook)
retriever.embeddings.register_forward_hook(regularizer_hook)
for layer in retriever.encoder.layer:
    layer.register_forward_hook(regularizer_hook)

In [40]:
UPDATE_CACHE_FREQ = 500
DEBUG_SHOW_LOSS_FREQ = 25
EPOCH_SIZE = 5000

In [None]:
i = -1
lm_loss_values = []
lm_loss_avg = 0.0
j = -1
best_loss = np.inf
loss_values_mean_container = []
lm_loss_values_mean_container = []
while True:
    i += 1
    j += 1
    if i == len(df_story_content_train):
        i = 0
    if i % UPDATE_CACHE_FREQ == 0:
        with torch.cuda.amp.autocast():
            cached_extra_embeddings, cached_story_embeddings = get_cached_prompt_embeddings(retriever,
                                                                                            generator,
                                                                                            retriever_tokenizer,
                                                                                            df_sentences_guttenberg,
                                                                                            df_sentences_story)
            cached_extra_nn = NearestNeighbors(n_neighbors=KNN_N_NEIGHBOURS, metric=dot_distance, n_jobs=-1, algorithm="ball_tree")
            cached_extra_nn.fit(cached_extra_embeddings)
            retriever.train()
            generator.train()
    row = df_story_content_train.iloc[i]
    loss, lm_loss_value, lm_loss_values, lm_loss_avg = train_row(regularizer_hook,
                                                                 optimizer,
                                                                 scaler,
                                                                 row,
                                                                 df_story_content_train,
                                                                 df_stories_train,
                                                                 cached_extra_nn,
                                                                 lm_loss_avg,
                                                                 lm_loss_values)
    loss_values_mean_container.append(loss.item())
    lm_loss_values_mean_container.append(lm_loss_value)
    if i % DEBUG_SHOW_LOSS_FREQ == 0:
        print(f"TRAIN {j} STEP {sum(loss_values_mean_container) / len(loss_values_mean_container)} LOSS {sum(lm_loss_values_mean_container) / len(lm_loss_values_mean_container)} LM LOSS")
        loss_values_mean_container = []
        lm_loss_values_mean_container = []
        
    if j > 0 and j % EPOCH_SIZE == 0:
        epoch = j // EPOCH_SIZE
        val_loss_values = []
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                for _, row in tqdm(df_story_content_test.iterrows(), total=len(df_story_content_test)):
                    _, lm_loss_value = validation_row(row,
                                                 df_story_content_test,
                                                 df_stories_test,
                                                 cached_extra_nn)
                    val_loss_values.append(lm_loss_value)
        val_loss = np.array(val_loss_values).mean()
        print(f"VALIDATION EPOCH {epoch} LOSS {val_loss}")
        with open("train-log-base.txt", "a") as log_target:
            log_target.write(f"VALIDATION EPOCH {epoch} LOSS {val_loss}\n")
        if val_loss < best_loss:
            best_loss = val_loss
            retriever.save_pretrained("checkpoint-retriever-base")
            generator.save_pretrained("checkpoint-generator-base")

UPDATING EMBEDDINGS CACHE


100%|██████████| 26361/26361 [03:06<00:00, 141.37it/s]
100%|██████████| 9749/9749 [01:07<00:00, 143.99it/s]


TRAIN 0 STEP 42.345367431640625 LOSS 14.614646911621094 LM LOSS
TRAIN 25 STEP 16.06875717163086 LOSS 15.365345611572266 LM LOSS
TRAIN 50 STEP 13.915200805664062 LOSS 14.780117263793946 LM LOSS
TRAIN 75 STEP 11.507936172485351 LOSS 13.691031913757325 LM LOSS
TRAIN 100 STEP 9.640334224700927 LOSS 12.760052604675293 LM LOSS
TRAIN 125 STEP 8.96842555999756 LOSS 12.252678833007813 LM LOSS
TRAIN 150 STEP 8.267576847076416 LOSS 11.785313148498535 LM LOSS
TRAIN 175 STEP 8.054997215270996 LOSS 11.525365791320802 LM LOSS
TRAIN 200 STEP 7.578642463684082 LOSS 11.193386383056641 LM LOSS
TRAIN 225 STEP 7.5246968269348145 LOSS 11.027882385253907 LM LOSS
TRAIN 250 STEP 6.761562652587891 LOSS 10.632028961181641 LM LOSS
TRAIN 275 STEP 6.578518047332763 LOSS 10.441281776428223 LM LOSS
TRAIN 300 STEP 6.508982830047607 LOSS 10.30075798034668 LM LOSS
TRAIN 325 STEP 6.4385263633728025 LOSS 10.174543571472167 LM LOSS
TRAIN 350 STEP 6.195426826477051 LOSS 9.99479637145996 LM LOSS
TRAIN 375 STEP 6.051346626281

100%|██████████| 26361/26361 [03:07<00:00, 140.84it/s]
100%|██████████| 9749/9749 [01:08<00:00, 142.31it/s]


TRAIN 500 STEP 6.291034011840821 LOSS 9.57885944366455 LM LOSS
TRAIN 525 STEP 6.274194030761719 LOSS 9.519052963256836 LM LOSS
TRAIN 550 STEP 5.922775917053222 LOSS 9.34669994354248 LM LOSS
TRAIN 575 STEP 5.751086921691894 LOSS 9.237149047851563 LM LOSS
TRAIN 600 STEP 5.324308567047119 LOSS 9.04067741394043 LM LOSS
TRAIN 625 STEP 5.761481189727784 LOSS 9.138587036132812 LM LOSS
TRAIN 650 STEP 5.462719058990478 LOSS 8.99299861907959 LM LOSS
TRAIN 675 STEP 5.722183761596679 LOSS 9.035183658599854 LM LOSS
TRAIN 700 STEP 5.3896791458129885 LOSS 8.881279277801514 LM LOSS
TRAIN 725 STEP 5.461398258209228 LOSS 8.863807106018067 LM LOSS
TRAIN 750 STEP 5.7200017929077145 LOSS 8.91253074645996 LM LOSS
TRAIN 775 STEP 5.883830299377442 LOSS 8.931662940979004 LM LOSS
TRAIN 800 STEP 6.097068729400635 LOSS 8.97315860748291 LM LOSS
TRAIN 825 STEP 5.500218677520752 LOSS 8.740656070709228 LM LOSS
TRAIN 850 STEP 5.954860324859619 LOSS 8.861541633605958 LM LOSS
TRAIN 875 STEP 5.984619178771973 LOSS 8.8439

100%|██████████| 26361/26361 [03:07<00:00, 140.71it/s]
100%|██████████| 9749/9749 [01:07<00:00, 143.73it/s]


TRAIN 1000 STEP 5.285780715942383 LOSS 8.470765705108642 LM LOSS
TRAIN 1025 STEP 6.2773100280761716 LOSS 8.779853324890137 LM LOSS
TRAIN 1050 STEP 5.82645616531372 LOSS 8.608991298675537 LM LOSS
TRAIN 1075 STEP 5.645462474822998 LOSS 8.524920749664307 LM LOSS
TRAIN 1100 STEP 6.0704961776733395 LOSS 8.645378112792969 LM LOSS
TRAIN 1125 STEP 5.490908527374268 LOSS 8.430514793395997 LM LOSS
TRAIN 1150 STEP 5.883071002960205 LOSS 8.538321266174316 LM LOSS
TRAIN 1175 STEP 5.200051879882812 LOSS 8.291884899139404 LM LOSS
TRAIN 1200 STEP 5.93543417930603 LOSS 8.515308647155761 LM LOSS
TRAIN 1225 STEP 5.342792148590088 LOSS 8.299985198974609 LM LOSS
TRAIN 1250 STEP 5.839415893554688 LOSS 8.443016452789307 LM LOSS
TRAIN 1275 STEP 5.722534065246582 LOSS 8.38952621459961 LM LOSS
TRAIN 1300 STEP 5.961051330566407 LOSS 8.448735485076904 LM LOSS
TRAIN 1325 STEP 5.400311203002929 LOSS 8.2477095413208 LM LOSS
TRAIN 1350 STEP 5.8868835830688475 LOSS 8.3947172164917 LM LOSS
TRAIN 1375 STEP 5.37730195999

100%|██████████| 26361/26361 [03:07<00:00, 140.71it/s]
100%|██████████| 9749/9749 [01:08<00:00, 142.68it/s]


TRAIN 1500 STEP 5.577840671539307 LOSS 8.199720764160157 LM LOSS
TRAIN 1525 STEP 6.2095866203308105 LOSS 8.399104442596435 LM LOSS
TRAIN 1550 STEP 5.35922986984253 LOSS 8.099901561737061 LM LOSS
TRAIN 1575 STEP 6.0139931869506835 LOSS 8.305445671081543 LM LOSS
TRAIN 1600 STEP 5.375048370361328 LOSS 8.077942905426026 LM LOSS
TRAIN 1625 STEP 5.428056964874267 LOSS 8.08217197418213 LM LOSS
TRAIN 1650 STEP 5.638217487335205 LOSS 8.140021381378174 LM LOSS
TRAIN 1675 STEP 5.6265151405334475 LOSS 8.122468643188476 LM LOSS
TRAIN 1700 STEP 5.654837245941162 LOSS 8.120761814117431 LM LOSS
TRAIN 1725 STEP 6.069045085906982 LOSS 8.248547630310059 LM LOSS
TRAIN 1750 STEP 5.297185153961181 LOSS 7.976441478729248 LM LOSS
TRAIN 1775 STEP 5.7987015914916995 LOSS 8.13309404373169 LM LOSS
TRAIN 1800 STEP 5.469475288391113 LOSS 8.012092876434327 LM LOSS
TRAIN 1825 STEP 5.920760183334351 LOSS 8.152511386871337 LM LOSS
TRAIN 1850 STEP 5.872453708648681 LOSS 8.125298080444336 LM LOSS
TRAIN 1875 STEP 5.702555

100%|██████████| 26361/26361 [03:07<00:00, 140.82it/s]
100%|██████████| 9749/9749 [01:07<00:00, 143.53it/s]


TRAIN 2000 STEP 6.0471262550354 LOSS 8.124854049682618 LM LOSS
TRAIN 2025 STEP 5.602674245834351 LOSS 7.967186489105225 LM LOSS
TRAIN 2050 STEP 5.713523178100586 LOSS 7.994241065979004 LM LOSS
TRAIN 2075 STEP 5.913788652420044 LOSS 8.05210865020752 LM LOSS
TRAIN 2100 STEP 5.676249561309814 LOSS 7.963091068267822 LM LOSS
TRAIN 2125 STEP 5.5727263450622555 LOSS 7.919491329193115 LM LOSS
TRAIN 2150 STEP 5.911755199432373 LOSS 8.025139331817627 LM LOSS
TRAIN 2175 STEP 5.794169063568115 LOSS 7.977925891876221 LM LOSS
TRAIN 2200 STEP 5.904608306884765 LOSS 8.005891647338867 LM LOSS
TRAIN 2225 STEP 5.88867073059082 LOSS 7.99313325881958 LM LOSS
TRAIN 2250 STEP 5.454353914260865 LOSS 7.838727798461914 LM LOSS
TRAIN 2275 STEP 5.529524459838867 LOSS 7.855011463165283 LM LOSS
TRAIN 2300 STEP 6.128035011291504 LOSS 8.047733554840088 LM LOSS
TRAIN 2325 STEP 6.261982870101929 LOSS 8.085675849914551 LM LOSS
TRAIN 2350 STEP 5.288256816864013 LOSS 7.752839984893799 LM LOSS
TRAIN 2375 STEP 6.00808429718

100%|██████████| 26361/26361 [03:07<00:00, 140.63it/s]
100%|██████████| 9749/9749 [01:07<00:00, 143.39it/s]


TRAIN 2500 STEP 5.605777702331543 LOSS 7.818160629272461 LM LOSS
TRAIN 2525 STEP 5.802825775146484 LOSS 7.875544300079346 LM LOSS
TRAIN 2550 STEP 5.81982988357544 LOSS 7.8761794090271 LM LOSS
TRAIN 2575 STEP 5.787337322235107 LOSS 7.856251449584961 LM LOSS
TRAIN 2600 STEP 5.666589126586914 LOSS 7.810858955383301 LM LOSS
TRAIN 2625 STEP 6.33527774810791 LOSS 8.027971019744873 LM LOSS
TRAIN 2650 STEP 5.992916698455811 LOSS 7.908023109436035 LM LOSS
TRAIN 2675 STEP 5.901295719146728 LOSS 7.869446640014648 LM LOSS
TRAIN 2700 STEP 5.824738416671753 LOSS 7.839159774780273 LM LOSS
TRAIN 2725 STEP 5.627282447814942 LOSS 7.766449642181397 LM LOSS
TRAIN 2750 STEP 6.448565063476562 LOSS 8.031916980743409 LM LOSS
TRAIN 2775 STEP 5.616409482955933 LOSS 7.751657581329345 LM LOSS
TRAIN 2800 STEP 5.843920402526855 LOSS 7.8221147346496585 LM LOSS
TRAIN 2825 STEP 5.624302625656128 LOSS 7.741888999938965 LM LOSS
TRAIN 2850 STEP 5.800096950531006 LOSS 7.795652198791504 LM LOSS
TRAIN 2875 STEP 5.9704151535

100%|██████████| 26361/26361 [03:07<00:00, 140.54it/s]
100%|██████████| 9749/9749 [01:08<00:00, 142.78it/s]


TRAIN 3000 STEP 5.836252498626709 LOSS 7.774411735534668 LM LOSS
TRAIN 3025 STEP 5.815690822601319 LOSS 7.7636525917053225 LM LOSS
TRAIN 3050 STEP 6.808391056060791 LOSS 8.091607704162598 LM LOSS
TRAIN 3075 STEP 6.022694244384765 LOSS 7.8242738151550295 LM LOSS
TRAIN 3100 STEP 6.0505428123474125 LOSS 7.829192428588867 LM LOSS
TRAIN 3125 STEP 6.17789363861084 LOSS 7.866553001403808 LM LOSS
TRAIN 3150 STEP 5.470148086547852 LOSS 7.6246271896362305 LM LOSS
TRAIN 3175 STEP 5.677520427703858 LOSS 7.6876483726501466 LM LOSS
TRAIN 3200 STEP 6.741142311096191 LOSS 8.040759315490723 LM LOSS
TRAIN 3225 STEP 6.190406360626221 LOSS 7.852140655517578 LM LOSS
TRAIN 3250 STEP 5.695697450637818 LOSS 7.681447696685791 LM LOSS
TRAIN 3275 STEP 6.165540733337402 LOSS 7.834061203002929 LM LOSS
TRAIN 3300 STEP 6.335865888595581 LOSS 7.887102870941162 LM LOSS
TRAIN 3325 STEP 6.2766912651062015 LOSS 7.8636489105224605 LM LOSS
TRAIN 3350 STEP 6.2557646179199216 LOSS 7.852438144683838 LM LOSS
TRAIN 3375 STEP 5.

100%|██████████| 26361/26361 [03:07<00:00, 140.57it/s]
100%|██████████| 9749/9749 [01:08<00:00, 143.32it/s]


TRAIN 3500 STEP 5.872775726318359 LOSS 7.697750511169434 LM LOSS
TRAIN 3525 STEP 6.1769644641876225 LOSS 7.795930080413818 LM LOSS
TRAIN 3550 STEP 6.075185642242432 LOSS 7.757721271514892 LM LOSS
TRAIN 3575 STEP 6.141312885284424 LOSS 7.776526069641113 LM LOSS
TRAIN 3600 STEP 5.514448394775391 LOSS 7.562037925720215 LM LOSS
TRAIN 3625 STEP 5.503444814682007 LOSS 7.553440208435059 LM LOSS
TRAIN 3650 STEP 6.308185176849365 LOSS 7.818919410705567 LM LOSS
TRAIN 3675 STEP 5.7828850746154785 LOSS 7.639037189483642 LM LOSS
TRAIN 3700 STEP 5.943308353424072 LOSS 7.6888906860351565 LM LOSS
TRAIN 3725 STEP 6.348103675842285 LOSS 7.820725154876709 LM LOSS
TRAIN 3750 STEP 5.6312578010559085 LOSS 7.576742343902588 LM LOSS
TRAIN 3775 STEP 6.057567405700683 LOSS 7.715998744964599 LM LOSS
TRAIN 3800 STEP 5.855609931945801 LOSS 7.6438471221923825 LM LOSS
TRAIN 3825 STEP 6.072942028045654 LOSS 7.712748985290528 LM LOSS
TRAIN 3850 STEP 6.088807258605957 LOSS 7.7146938514709475 LM LOSS
TRAIN 3875 STEP 5.9

100%|██████████| 26361/26361 [03:08<00:00, 139.85it/s]
100%|██████████| 9749/9749 [01:08<00:00, 141.85it/s]


TRAIN 4000 STEP 6.230025386810302 LOSS 7.741507434844971 LM LOSS
TRAIN 4025 STEP 5.932142677307129 LOSS 7.638251075744629 LM LOSS
TRAIN 4050 STEP 5.815835037231445 LOSS 7.595049324035645 LM LOSS
TRAIN 4075 STEP 5.667302093505859 LOSS 7.541551284790039 LM LOSS
TRAIN 4100 STEP 5.748593406677246 LOSS 7.565633811950684 LM LOSS
TRAIN 4125 STEP 6.308592090606689 LOSS 7.750182762145996 LM LOSS
TRAIN 4150 STEP 6.219777297973633 LOSS 7.717616062164307 LM LOSS
TRAIN 4175 STEP 6.1302375221252445 LOSS 7.6844588661193844 LM LOSS
TRAIN 4200 STEP 5.947985630035401 LOSS 7.619593544006348 LM LOSS
TRAIN 4225 STEP 6.043252964019775 LOSS 7.648396167755127 LM LOSS
TRAIN 4250 STEP 6.101083965301513 LOSS 7.664985809326172 LM LOSS
TRAIN 4275 STEP 6.0404543113708495 LOSS 7.641507110595703 LM LOSS
TRAIN 4300 STEP 6.029890098571777 LOSS 7.634766654968262 LM LOSS
TRAIN 4325 STEP 5.730754814147949 LOSS 7.530834197998047 LM LOSS
TRAIN 4350 STEP 5.825833368301391 LOSS 7.562921142578125 LM LOSS
TRAIN 4375 STEP 5.7179

100%|██████████| 26361/26361 [03:07<00:00, 140.90it/s]
100%|██████████| 9749/9749 [01:08<00:00, 142.97it/s]


TRAIN 4500 STEP 5.851972084045411 LOSS 7.549555015563965 LM LOSS
TRAIN 4525 STEP 6.159712429046631 LOSS 7.6484160995483395 LM LOSS
TRAIN 4550 STEP 5.973983030319214 LOSS 7.584162464141846 LM LOSS
TRAIN 4575 STEP 6.123812007904053 LOSS 7.631247577667236 LM LOSS
TRAIN 4600 STEP 5.7300255012512205 LOSS 7.496329689025879 LM LOSS
TRAIN 4625 STEP 5.803626546859741 LOSS 7.518239498138428 LM LOSS
TRAIN 4650 STEP 6.198034248352051 LOSS 7.647565383911132 LM LOSS
TRAIN 4675 STEP 6.298184604644775 LOSS 7.680878524780273 LM LOSS
TRAIN 4700 STEP 6.1827425861358645 LOSS 7.637058925628662 LM LOSS
TRAIN 4725 STEP 5.959965782165527 LOSS 7.5598844718933105 LM LOSS
TRAIN 4750 STEP 6.354040336608887 LOSS 7.689124603271484 LM LOSS
TRAIN 4775 STEP 5.5892191505432125 LOSS 7.4336646461486815 LM LOSS
TRAIN 4800 STEP 5.869840316772461 LOSS 7.521430320739746 LM LOSS
TRAIN 4825 STEP 6.369626770019531 LOSS 7.6858173370361325 LM LOSS
TRAIN 4850 STEP 5.824019155502319 LOSS 7.500734786987305 LM LOSS
TRAIN 4875 STEP 6.

100%|██████████| 26361/26361 [03:06<00:00, 141.09it/s]
100%|██████████| 9749/9749 [01:08<00:00, 142.70it/s]


TRAIN 5000 STEP 6.644211254119873 LOSS 7.761436042785644 LM LOSS


 81%|████████▏ | 7136/8772 [1:15:27<17:28,  1.56it/s]

In [42]:
retriever_tokenizer.save_pretrained("checkpoint-retriever-base")

('checkpoint-retriever-base/tokenizer_config.json',
 'checkpoint-retriever-base/special_tokens_map.json',
 'checkpoint-retriever-base/vocab.txt',
 'checkpoint-retriever-base/added_tokens.json',
 'checkpoint-retriever-base/tokenizer.json')

In [43]:
generator_tokenizer.save_pretrained("checkpoint-generator-base")

('checkpoint-generator-base/tokenizer_config.json',
 'checkpoint-generator-base/special_tokens_map.json',
 'checkpoint-generator-base/vocab.json',
 'checkpoint-generator-base/merges.txt',
 'checkpoint-generator-base/added_tokens.json',
 'checkpoint-generator-base/tokenizer.json')

In [44]:
!zip -r "base-models.zip" "checkpoint-generator-base/" "checkpoint-retriever-base/"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  adding: checkpoint-generator-base/ (stored 0%)
  adding: checkpoint-generator-base/config.json (deflated 62%)
  adding: checkpoint-generator-base/special_tokens_map.json (deflated 52%)
  adding: checkpoint-generator-base/merges.txt (deflated 53%)
  adding: checkpoint-generator-base/tokenizer_config.json (deflated 63%)
  adding: checkpoint-generator-base/added_tokens.json (deflated 41%)
  adding: checkpoint-generator-base/tokenizer.json (deflated 59%)
  adding: checkpoint-generator-base/pytorch_model.bin (deflated 11%)
  adding: checkpoint-generator-base/vocab.json (deflated 59%)
  adding: checkpoint-retriever-base/ (stored 0%)
  adding: checkpoint-retriever-base/config.json (deflated 47%)
  adding: checkpo

In [1]:
!pip install httplib2==0.15.0
!pip install google-api-python-client==1.6

Collecting httplib2==0.15.0
  Downloading httplib2-0.15.0-py3-none-any.whl (94 kB)
     |████████████████████████████████| 94 kB 3.0 MB/s             
[?25hInstalling collected packages: httplib2
  Attempting uninstall: httplib2
    Found existing installation: httplib2 0.20.1
    Uninstalling httplib2-0.20.1:
      Successfully uninstalled httplib2-0.20.1
Successfully installed httplib2-0.15.0
Collecting google-api-python-client==1.6
  Downloading google_api_python_client-1.6.0-py2.py3-none-any.whl (52 kB)
     |████████████████████████████████| 52 kB 1.3 MB/s             
Installing collected packages: google-api-python-client
  Attempting uninstall: google-api-python-client
    Found existing installation: google-api-python-client 2.27.0
    Uninstalling google-api-python-client-2.27.0:
      Successfully uninstalled google-api-python-client-2.27.0
Successfully installed google-api-python-client-1.6.0


In [2]:
from gcloud import storage
from oauth2client.service_account import ServiceAccountCredentials
import os

In [3]:
client = storage.Client(project='realm-dungeon')
bucket = client.get_bucket('realm-dungeon-models')
blob = bucket.blob('base-models.zip')
blob.upload_from_filename('base-models.zip')