## Experiment Config

In [4]:
DO_TOKENIZATION=True
TINY_DATA_LOAD=True # If True, load tiny data for testing
TRAIN_MODEL=False

DATA_PREFIX='expt_1'
# DATA_PREFIX='tiny_sample'
MODEL_EXPT_NAME="multitok_model_1"

TRAIN_DATA_PATH=f'train-model/{DATA_PREFIX}'
MODEL_PATH=f'train-model/{DATA_PREFIX}/{MODEL_EXPT_NAME}'

config = {
    "learning_rate": 1e-3,
    "eval_interval": 300,
    "max_iters": 60000, 
    "H": 32, # per head dimension size
    "B": 64, # batch size
    "T": 256, # Sequence length
    "C": 256, # model size
    "feedforward_factor": 3,
    "n_heads": 8,
    "dropout": 0.0,
    "l2_penalty": 0.0,
    "n_layers": 12,
    "vocab_size": 2**13,
    # "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

for k,v in config.items():
    locals()[k] = v

In [5]:
!pip install lingua 
!pip install multi-tokenizer
!pip install ipython-unittest


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## Imports

In [6]:
from datasets import load_dataset
from itertools import chain
import tokenizers
    
import sys

from tqdm import tqdm

import glob
import os
import os.path
import random
import re
import torch

import torch.nn as nn
from torch.nn import functional as F
from itertools import islice
from typing import List

%load_ext ipython_unittest

from multi_tokenizer import LanguageDetector
from lingua import DetectionResult, Language


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
# assert device == 'cuda', "This notebook is not optimized for CPU"
if device == 'cpu' and TRAIN_MODEL:
    print("Warning: Training on CPU will be slow")
    assert False

if 'train-model' not in sys.path: sys.path.append('train-model')
import data_split


The ipython_unittest extension is already loaded. To reload it, use:
  %reload_ext ipython_unittest


## Tokenization preprocessing

In [7]:
END_TOKEN = '[END]'

def delimited_stories_iterator(text):
    start = 0
    last_match_end = text.find(END_TOKEN)
    while last_match_end != -1:
        yield text[start:last_match_end + len(END_TOKEN)]
        start = last_match_end + len(END_TOKEN)
        last_match_end = text.find(END_TOKEN, start)


In [8]:
%%unittest_testcase
def test_delimited_stories_iterator(self):
    test_input = """[PROMPT] español [USER] Un niñ [END]
[PROMPT] español [USER] Hola [END]"""
    delim_stories = list(delimited_stories_iterator(test_input))
    self.assertEqual('[PROMPT] español [USER] Un niñ [END]', delim_stories[0])
    self.assertEqual('\n[PROMPT] español [USER] Hola [END]', delim_stories[1])
    self.assertEqual(test_input, ''.join(delim_stories))



Success

.
----------------------------------------------------------------------
Ran 1 test in 0.000s

OK


<unittest.runner.TextTestResult run=1 errors=0 failures=0>

In [9]:
class FilePartitionIterator:
    def __init__(self, file_obj, delimiter):
        self.delimiter = delimiter
        self.file = file_obj
        self.buffer = ""

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            index = self.buffer.find(self.delimiter)
            if index != -1:
                partition = self.buffer[:index + len(self.delimiter)]
                self.buffer = self.buffer[index + len(self.delimiter):]
                return partition

            chunk = self.file.read(4096)  # Read in chunks of 4KB
            if not chunk:
                if self.buffer:
                    partition = self.buffer
                    self.buffer = ""
                    return partition
                self.file.close()
                raise StopIteration

            self.buffer += chunk


In [10]:
%%unittest_testcase
import io
def test_file_partition_iterator(self):
    test_input = """[PROMPT] español [USER] Un niñ [END]
[PROMPT] español [USER] Hola [END]"""
    test_input_filelike = io.StringIO(test_input)
    delim_stories = list(FilePartitionIterator(test_input_filelike, END_TOKEN))
    self.assertEqual('[PROMPT] español [USER] Un niñ [END]', delim_stories[0])
    self.assertEqual('\n[PROMPT] español [USER] Hola [END]', delim_stories[1])
    self.assertEqual(test_input, ''.join(delim_stories))



Success

.
----------------------------------------------------------------------
Ran 1 test in 0.000s

OK


<unittest.runner.TextTestResult run=1 errors=0 failures=0>

In [11]:
END_TOKEN = '[END]'
PROMPT_TOKEN = '[PROMPT]'


class HackyCheatingLangDetector:
    def __init__(self, languages: List[Language]) -> None:
        pass


    def _detect_single_story(self, story_text) -> List[DetectionResult]:
        # story_text = story_text.strip()
        # assert story_text.startswith(PROMPT_TOKEN), f"Unexpected start: {story_text[:20]}"
        task = story_text.split(' ', 2)[1]
        assert task in ['english', 'español', 'translate']
        paragraph_splitter = re.compile('[^\n]+')
        if task == 'english':
            return [DetectionResult(0, len(story_text), 0, Language.ENGLISH)]
        elif task == 'español':
            return [DetectionResult(0, len(story_text), 0, Language.SPANISH)]
        elif task == 'translate':
            paragraph_starts = [m.span()[0] for m in re.finditer(paragraph_splitter, story_text)]
            ret = []
            for paragraph_no, paragraph_start in enumerate(paragraph_starts):
                next_paragraph_start = paragraph_starts[paragraph_no + 1] if paragraph_no + 1 < len(paragraph_starts) else len(story_text)+1
                lang = Language.ENGLISH if paragraph_no % 2 == 0 else Language.SPANISH
                ret.append(DetectionResult(paragraph_start, next_paragraph_start, 0, lang))
            return ret

  
    def split_n_detect(self, text: str, sep: str = " ") -> List[DetectionResult]:
        """Split Text and Detect Language."""

        def merge_results(
            results: List[List[DetectionResult]],
        ) -> List[DetectionResult]:
            """Merge Results. If consecutive words are detected as the same language, merge them."""
            merged_results: list[DetectionResult] = []
            for result in results:
                if not merged_results:
                    merged_results.extend(result)
                else:
                    for detection in result:
                        last_result = merged_results[-1]
                        if detection.language == last_result.language:
                            merged_results[-1] = DetectionResult(
                                language=last_result.language,
                                start_index=last_result.start_index,
                                end_index=last_result.end_index
                                + detection.end_index
                                - detection.start_index,
                                word_count=last_result.word_count
                                + detection.word_count,
                            )
                        else:
                            merged_results.append(
                                DetectionResult(
                                    language=detection.language,
                                    start_index=last_result.end_index,
                                    end_index=last_result.end_index
                                    + detection.end_index
                                    - detection.start_index,
                                    word_count=detection.word_count,
                                )
                            )
            return merged_results


        results = []
        # Detect the language of each part
        for story_text in delimited_stories_iterator(text):
            results.append(self._detect_single_story(story_text))
        
        # Merge consecutive results with the same language
        merged_results = merge_results(results)
        return merged_results


In [12]:
%%unittest_testcase
def test_hacky_cheating_language_detector(self):
    test_input = """[PROMPT] english [USER] Hi [END]
[PROMPT] español [USER] a child [END]
[PROMPT] translate [USER] It was night
Era de noche
Hi
Hola [END]"""
    hcld = HackyCheatingLangDetector([])
    detections = hcld.split_n_detect(test_input, END_TOKEN)
    self.assertEqual(6, len(detections))
    expected_detections_with_langs = [
        (Language.ENGLISH, '[PROMPT] english [USER] Hi [END]'),
        (Language.SPANISH, '\n[PROMPT] español [USER] a child [END]'),
        (Language.ENGLISH, '\n[PROMPT] translate [USER] It was night'),
        (Language.SPANISH, '\nEra de noche'),
        (Language.ENGLISH, '\nHi'),
        (Language.SPANISH, '\nHola [END]')
    ]
    for detection, (expected_lang, expected_string) in zip(
        detections, expected_detections_with_langs):
        self.assertEqual(expected_lang, detection.language)
        self.assertEqual(expected_string, test_input[detection.start_index: detection.end_index])



Success

.
----------------------------------------------------------------------
Ran 1 test in 0.001s

OK


<unittest.runner.TextTestResult run=1 errors=0 failures=0>

In [13]:
special_tokens = ['[PROMPT]', '[USER]', '[END]', '[PAD]', '[BEGIN-EN]', '[BEGIN-ES]']
ENGLISH, SPANISH = 0, 1
lang_to_token_map = {Language.ENGLISH: '[BEGIN-EN]', Language.SPANISH: '[BEGIN-ES]'}
lingua_lang_to_dense_lang_id = {Language.ENGLISH: ENGLISH, Language.SPANISH: SPANISH}
supported_languages = [Language.ENGLISH, Language.SPANISH]
monolingual_token_file_names = {l: f'{MODEL_PATH}/train_tokenizer_{l}.txt' for l in supported_languages}

os.makedirs(MODEL_PATH, exist_ok=True)

def detections_to_lang_identified_text(detections: List[DetectionResult], text: str) -> str:
    ret = []
    for detection in detections:
        ret.append(f"{lang_to_token_map[detection.language]}{text[detection.start_index: detection.end_index]}")
    return ''.join(ret)


if DO_TOKENIZATION:
    lang_detector = HackyCheatingLangDetector(supported_languages)

    os.makedirs(MODEL_PATH, exist_ok=True)
    outputs_split_by_lang = {}
    for l in supported_languages:        
        outputs_split_by_lang[l] = open(monolingual_token_file_names[l], 'w')
    segmented_output = open(f'{MODEL_PATH}/train_segmented.txt', 'w')

    for single_story in FilePartitionIterator(open(f'{TRAIN_DATA_PATH}/train.txt', 'r'), END_TOKEN):
        detections = lang_detector.split_n_detect(single_story, END_TOKEN)              
        for detection in detections: 
            monolingual_fragment = single_story[detection.start_index: detection.end_index]
            outputs_split_by_lang[detection.language].write(monolingual_fragment + '\n')
        segmented_output.write(detections_to_lang_identified_text(detections, single_story))  

    for l in supported_languages:
        outputs_split_by_lang[l].close()
    segmented_output.close()

        

In [14]:
tokenizers_by_dense_lang_id = []

if DO_TOKENIZATION:
    for l in supported_languages:
        tokenizer = tokenizers.ByteLevelBPETokenizer()
        tokenizer.train(
            files=[monolingual_token_file_names[l]], 
            vocab_size=vocab_size, 
            min_frequency=2,
            special_tokens=special_tokens
        )
        tokenizer.add_special_tokens(special_tokens)
        tokenizer.save_model(MODEL_PATH, f'tiny-stories-{l}-bpe')
        tokenizers_by_dense_lang_id.append(tokenizer)
else:
    for l in supported_languages:
        tokenizer = tokenizers.ByteLevelBPETokenizer(
            f'{MODEL_PATH}/tiny-stories-{l}-bpe-vocab.json', 
            f'{MODEL_PATH}/tiny-stories-{l}-bpe-merges.txt'
        )
        tokenizer.add_special_tokens(special_tokens)
        tokenizers_by_dense_lang_id.append(tokenizer)










In [15]:
PROMPT_TOK_ID = tokenizers_by_dense_lang_id[0].token_to_id('[PROMPT]')
USER_TOK_ID = tokenizers_by_dense_lang_id[0].token_to_id('[USER]')
END_TOK_ID = tokenizers_by_dense_lang_id[0].token_to_id('[END]')
PAD_TOK_ID = tokenizers_by_dense_lang_id[0].token_to_id('[PAD]')
BEGIN_EN_TOK_ID = tokenizers_by_dense_lang_id[0].token_to_id('[BEGIN-EN]')
BEGIN_ES_TOK_ID = tokenizers_by_dense_lang_id[0].token_to_id('[BEGIN-ES]')

LANG_BEGIN_TOK_IDS = [BEGIN_EN_TOK_ID, BEGIN_ES_TOK_ID]

In [16]:
lang_splitter_re = re.compile(r'\[BEGIN-(EN|ES)\]')
lang_code_to_lang_dense_id = {'EN': ENGLISH, 'ES': SPANISH}

def multi_tokenize_delimited_story(text):
    assert text.startswith('[BEGIN-EN]') or text.startswith('[BEGIN-ES]')
    # assert text.endswith(END_TOKEN)

    # This splits text into sequences of [lang_code_1, text_1, lang_code_2, text_2, ...]
    # where lang_code_i is either 'EN' or 'ES'
    per_lang_splits = [p for p in lang_splitter_re.split(text) if p.strip()]
    assert len(per_lang_splits) % 2 == 0
    lang_dense_id_per_text = [lang_code_to_lang_dense_id[code] for code in per_lang_splits[::2]]
    monolingual_texts = per_lang_splits[1::2]

    token_ids, lang_ids = [], []
    for lang_dense_id, monolingual_text in zip(lang_dense_id_per_text, monolingual_texts):
        tokenizer = tokenizers_by_dense_lang_id[lang_dense_id]
        
        these_tokens = [LANG_BEGIN_TOK_IDS[lang_dense_id]] + tokenizer.encode(monolingual_text).ids 
        token_ids.extend(these_tokens)
        lang_ids.extend([lang_dense_id] * len(these_tokens))
    
    return torch.tensor(token_ids, dtype=torch.short), torch.tensor(lang_ids, dtype=torch.short)
test_input = (
    """[BEGIN-EN] [PROMPT] translate [USER] One day, a little boy saw a cake. It was very big and had many strawberries.
[BEGIN-ES] Un día, un niño pequeño vio un pastel. Era muy grande y tenía muchas fresas.
[BEGIN-EN] Slightly more text [END]""")
multi_tokenize_delimited_story(test_input)


(tensor([   4,  226,    0, 1457,  226,    1,  509,  362,   17,  264,  418,  512,
          432,  264, 1120,   19,  421,  287,  406,  426,  270,  369,  795, 4445,
           19,  204,    5,  383,  442,   17,  288,  325,  432,  428,  288,  637,
           19,  470,  355,  416,  311,  410,  874, 1838,   19,  204,    4,  304,
         2794,  412,  710,  948,  972,  226,    2], dtype=torch.int16),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
         0, 0, 0, 0, 0, 0, 0], dtype=torch.int16))

In [17]:
def multi_detokenize(tokens, langs):
    ret = ""
    for token, lang in zip(tokens, langs):
        ret += tokenizers_by_dense_lang_id[lang].decode([token], skip_special_tokens=False)
    return ret

print(multi_detokenize(*multi_tokenize_delimited_story(test_input)))

[BEGIN-EN] [PROMPT] translate [USER] One day, a little boy saw a cake. It was very big and had many strawberries.
[BEGIN-ES] Un día, un niño pequeño vio un pastel. Era muy grande y tenía muchas fresas.
[BEGIN-EN] Slightly more text [END]


In [18]:
if DO_TOKENIZATION:
    output_buf = []
    num_outputs = 0
    for story in tqdm(FilePartitionIterator(open(f'{MODEL_PATH}/train_segmented.txt', 'r'), END_TOKEN)):
        output_buf.append(multi_tokenize_delimited_story(story))

        if len(output_buf) > 500_000:
            torch.save(output_buf, f'{MODEL_PATH}/tokenized-{num_outputs}.pt')
            num_outputs += 1
            if num_outputs == 1:
                torch.save(output_buf[:1000], f'{MODEL_PATH}/tiny-shard.pt')
            output_buf = []
           

    if output_buf:
        torch.save(output_buf, f'{MODEL_PATH}/tokenized-{num_outputs}.pt')
        num_outputs += 1
        output_buf = []

394635it [03:28, 1893.44it/s]


KeyboardInterrupt: 

## Data loading

In [26]:
def load_sharded_story(shard_no):
    return torch.load(f'{MODEL_PATH}/tokenized-{shard_no}.pt')

if TRAIN_MODEL or TINY_DATA_LOAD:
    # load the tokenized stories in parallel using threads
    # this is faster than loading them sequentially

    if not TINY_DATA_LOAD:
        num_shards = len(glob.glob(f'{MODEL_PATH}/tokenized-*'))
        from concurrent.futures import ThreadPoolExecutor
        with ThreadPoolExecutor() as pool:
            stories = list(tqdm(pool.map(load_sharded_story, range(num_shards)), total=num_shards))
    else:
        stories = [torch.load(f'{MODEL_PATH}/tiny-shard.pt')]
    print("Loaded", sum(len(story) for story in stories), "stories")

    all_stories = []
    for story in stories:
        all_stories.extend(story)
    random.seed(1337)
    random.shuffle(all_stories)

    print("length of dataset in stories: ", len(all_stories))
    print("length of stories in tokens", sum(len(story[0]) for story in all_stories))

    num_stories_to_check = min(1_000_000, len(all_stories))
    num_long = sum(len(story[0]) > T for story in all_stories[:num_stories_to_check])
    print(
        f"# stories longer than {T} : {num_long} out of {num_stories_to_check}, {num_long/num_stories_to_check:.2%}")

    n = int(0.9*len(all_stories))

    train_data = all_stories[:n]  # use prepared splits instead, 
    val_data = all_stories[n:]  # segregate validation data by task

Loaded 1000 stories
length of dataset in stories:  1000
length of stories in tokens 195074
# stories longer than 256 : 92 out of 1000, 9.20%


In [None]:
def get_batch(split):
    # remove y_land_idxs from return value
    data = train_data if split == 'train' else val_data
    # ix = torch.randint(0, len(data), (B,)) 
    ix = range(B) # HACK! for now, just use the first B stories
    
    x = torch.full((B, T), PAD_TOK_ID, dtype=torch.long)
    y = torch.full((B, T), PAD_TOK_ID, dtype=torch.long)
    x_lang_idxs = torch.full((B, T), 0, dtype=torch.long)
    y_lang_idxs = torch.full((B, T), 0, dtype=torch.long)

    for sequence_index, random_story_index in enumerate(ix):
        story_tokens = data[random_story_index][0].long()[:T - 1]
        story_lang_idxs = data[random_story_index][1].long()[:T - 1]

        story_length = story_tokens.shape[0]
        assert story_lang_idxs.shape == story_lang_idxs.shape
        
        x[sequence_index][0:story_length-1] = story_tokens[0:story_length-1]
        y[sequence_index][0:story_length-1] = story_tokens[1:story_length]
        x_lang_idxs[sequence_index][0: story_length-1] = story_lang_idxs[0:story_length-1]
        y_lang_idxs[sequence_index][0: story_length-1] = story_lang_idxs[1:story_length]

    return x, y, x_lang_idxs, y_lang_idxs

if TRAIN_MODEL:
    xb, yb, x_lang_idxs, y_lang_idxs = get_batch('train')


    print(xb[0][:5])
    print(yb[0][:5])
    print(x_lang_idxs[0][:5])
    print(y_lang_idxs[0][:5])
    language = x_lang_idxs[0,0]
    print(language)
    # print(decode(xb[0].tolist(), language.item())) # the zero on xb is the first story. 



## Model defintions

In [None]:
class Head(nn.Module):
    '''One Head of self-attention'''
    def __init__(self, H):
        super().__init__()
        self.query = nn.Linear(C, H, bias=False)
        self.key = nn.Linear(C, H, bias=False)
        self.value = nn.Linear(C, H, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        query_vectors = self.query(x)
        key_vectors = self.key(x)


        # Attention masking(so we can't look into the past):
        tril = self.tril
        wei = torch.zeros(T, T) 
        wei = wei.masked_fill(tril == 0, float('-inf')) # set the upper triangular to -inf

        # multiply the two to get the attention weights
        attention_pattern = query_vectors @ key_vectors.transpose(-2, -1) # T, T
        attention_pattern = attention_pattern / (H ** 0.5) # scale the attention pattern for numerical stability
        attention_weights = F.softmax(attention_pattern + wei, dim=-1) # T, T (the row dimension is the query)
        attention_weights = self.dropout(attention_weights)

        value_vectors = self.value(x) # the direction we should go in the embedding space for each token (ie more blue) T, H

        # apply the attention weights to the value vectors
        context = attention_weights @ value_vectors # T, H

        # project back into original space from value space
        return context

x = torch.randn(B,T,C)
head = Head(H)

In [None]:
class MultiHeadAttention(nn.Module):
    '''Multiple heads of self-attention'''
    def __init__(self, H, C, n_heads): # H is head embedding space size, n_heads is number of heads
        super().__init__()
        self.heads = nn.ModuleList([Head(H) for _ in range(n_heads)])
        self.combine_heads = nn.Linear(H*n_heads, C)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.combine_heads(x)  # T, C
        return self.dropout(x)

In [None]:
head = MultiHeadAttention(H, C, n_heads)
head.heads[0].forward(x).shape

torch.Size([64, 256, 32])

In [None]:
class FeedForward(nn.Module):
    '''Feed-forward neural network'''
    def __init__(self, C):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(C, C * feedforward_factor),
            nn.ReLU(),
            nn.Linear(C * feedforward_factor, C),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class LayerNorm(nn.Module):
    '''Layer normalization'''
    def __init__(self, C, use_affine=True):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(C)) if use_affine else None
        self.beta = nn.Parameter(torch.zeros(C)) if use_affine else None
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        if self.gamma is not None and self.beta is not None:
            return self.gamma * (x - mean) / (std + 1e-6) + self.beta
        else:
            return (x - mean) / (std + 1e-6)

In [None]:
class Block(nn.Module):
    '''Transformer block'''
    def __init__(self, H, C, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(H, C, n_heads)
        self.ff = FeedForward(C)
        self.norm1 = LayerNorm(C, use_affine=True)
        self.norm2 = LayerNorm(C, use_affine=True)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [None]:
class MultiTokenizerGPT(nn.Module):
    def __init__(self, n_layers, n_languages):
        super().__init__()
        self.token_embedding_tables = [nn.Embedding(vocab_size, C) for _ in range(n_languages)]
        self.position_embedding_table = nn.Embedding(T, C)
        self.lm_heads = nn.ModuleList([nn.Linear(C, vocab_size) for _ in range(n_languages)])
        self.layers = nn.ModuleList([Block(H, C, n_heads) for _ in range(n_layers)])
        self.block = nn.ModuleList([Block(H, C, n_heads)])
    
    # token_lang_idx has same shape as token_ids, encodes langs
    def forward(self, token_ids, token_lang_idxs, targets=None):
        B, T = token_ids.shape

        pos_emb = self.position_embedding_table(torch.arange(T))

        token_emb = torch.zeros(B*T, C)

        token_ids_flat = token_ids.view(-1)

        unique_lang_idxs = torch.unique(token_lang_idxs)

        # List of tensors, where each tensor contains the indexes in the flat tokens list of all tokens of a language.
        # there are as many tensors as there are languages present in the batch. Total length of all tensors = # tokens
        lang_tok_mapping = [torch.where(token_lang_idxs.view(-1) == idx) for idx in unique_lang_idxs]

        for lang_id, monolingual_tokens_idx in zip(unique_lang_idxs, lang_tok_mapping):
            lang_embedding_table = self.token_embedding_tables[lang_id]
            monolingual_tok_emb = lang_embedding_table(token_ids_flat[monolingual_tokens_idx])
            token_emb[monolingual_tokens_idx] = monolingual_tok_emb

        x = token_emb.view(B, T, C) + pos_emb # token identities and positions contained

        for layer in self.layers:
            x = layer(x)

        x_flat = x.view(B*T, C)
        logits = torch.zeros(B*T, vocab_size)

        for lang_id, monolingual_tokens_idx in zip(unique_lang_idxs, lang_tok_mapping):
            lm_head = self.lm_heads[lang_id]
            new_logits = lm_head(x_flat[monolingual_tokens_idx])
            logits[monolingual_tokens_idx] = new_logits

        logits = logits.view(B,T,vocab_size)

        batch_dim, sequence_dim, embedding_dim = logits.size()

        if targets is None:
            return logits, None
        else:
            logits_loss_view = logits.view(-1, vocab_size) 
            targets_loss_view = targets.view(-1)
            loss = F.cross_entropy(logits_loss_view, targets_loss_view)
            return logits, loss

        return idx
    def prompt_model(self, prompt, max_new_tokens, language="spanish", temperature=0.5):
        token_lang_idxs = torch.full([1, T], language_id)

        autoregressive_seq = [START_TOK, language_token]

        autoregressive_seq += encode(prompt, language_id)

        for _ in range(max_new_tokens):
            prediction_index = len(autoregressive_seq)-1

            model_input = torch.tensor(autoregressive_seq)
            
            while model_input.shape[0] < T:
                pad_token = torch.tensor([PAD_TOK], dtype=torch.long)
                model_input = torch.cat((model_input, pad_token), dim=0)

            model_input = model_input.unsqueeze(0)

            logits, loss = model.forward(
                token_ids=model_input,
                token_lang_idxs=token_lang_idxs
                )
            prediction_token = logits[:, prediction_index, :] / temperature
            probabilities = F.softmax(prediction_token, dim=-1)
            next_token = torch.multinomial(probabilities, num_samples=1)
            next_token = next_token.item()
            if(next_token == END_TOK):
                break

            autoregressive_seq.append(next_token)
        # get the autoregressive sequence
        return decode(autoregressive_seq, language_id)


model = MultiTokenizerGPT(n_layers, 2)
# get the number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("number of parameters in the model (millions): ", count_parameters(model) /1e6)

number of parameters in the model (millions):  12.825856


In [None]:
if TRAIN_MODEL:
    xb, yb, x_lang_idxs, y_lang_idxs = get_batch('train')

    logits, loss = model(
        token_ids = xb,
        token_lang_idxs = x_lang_idxs,
        targets = yb
    )
    print(logits.shape)
    print(loss)

    # test_idx = torch.zeros(1, T).long()
    # model.forward(idx=test_idx)
    # decode(model.generate(idx=test_idx, max_new_tokens=100)[0].tolist())

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
eval_iters = 10
eval_interval = 300
chars_per_token=3.9   # HACK!  need to compute chars per token for fair comparision


@torch.no_grad()
def estimate_loss(is_last=False):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        real_iters = eval_iters
        if is_last and split == 'val':  # increase last eval to mitigate noise
            real_iters *= 10 
        losses = torch.zeros(real_iters)
        for k in range(real_iters):
            # X, Y = get_batch(split)
            xb, yb, x_lang_idxs, y_lang_idxs = get_batch(split)

            logits, loss = model(
                token_ids = xb,
                token_lang_idxs = x_lang_idxs,
                targets = yb
            )
            losses[k] = loss.item()
        out[split] = losses.mean() / chars_per_token
    model.train()
    return out
    
if TRAIN_MODEL:
    estimate_loss()

In [None]:
if TRAIN_MODEL:
    dump_model_interval = 1000

    for steps in tqdm(range(0, max_iters)):
        xb, yb, x_lang_idxs, y_lang_idxs = get_batch('train')
        # loss
        logits, loss = model(
            token_ids = xb,
            token_lang_idxs = x_lang_idxs,
            targets = yb
        )
        optimizer.zero_grad(set_to_none=True)

        loss.backward()
        optimizer.step()
        if steps % eval_interval == 0:
            losses = estimate_loss()
            # wandb.log({"train": losses['train'].item(), "val": losses['val'].item(), "l2":l2})
            print({"train": losses['train'].item(), "val": losses['val'].item()})
        if steps % dump_model_interval == 0 and steps > 0:
            model_no = steps // dump_model_interval
            torch.save(model.state_dict(), f'{MODEL_PATH}/tiny-stories-model-{model_no}.pt')

    losses = estimate_loss(is_last=True)

In [None]:
# torch.save(model.state_dict(), f'{MODEL_PATH}/overfit-batch-1-model.pt')


## Manual prompts

In [None]:
if not TRAIN_MODEL:
    # model_file = f'{MODEL_PATH}/tiny-stories-model-18.pt'
    model_file = f'{MODEL_PATH}/overfit-batch-1-model.pt'
    print('on cpu lazy hack activate')
    model.load_state_dict(torch.load(model_file,  map_location=torch.device('cpu')))


on cpu lazy hack activate


In [None]:
def prompt_model(model, tokens, lang_idxs, max_new_tokens, temperature=0.5):
    autoregressive_toks = tokens.tolist()
    autoregressive_langs = lang_idxs.tolist()

    assert len(autoregressive_toks) == len(autoregressive_langs)

    for _ in range(max_new_tokens):
        prediction_index = len(autoregressive_toks)-1

        input_tokens = torch.tensor(autoregressive_toks)
        input_langs = torch.tensor(autoregressive_langs)
        while input_tokens.shape[0] < T:
            input_tokens = torch.cat((input_tokens, torch.tensor([PAD_TOK_ID])))
            input_langs = torch.cat((input_langs, torch.tensor([autoregressive_langs[-1]])))

        input_tokens = input_tokens.unsqueeze(0)
        input_langs = input_langs.unsqueeze(0)

        logits, ignored_loss = model.forward(token_ids=input_tokens, token_lang_idxs=input_langs)
        temp_scaled_logits = logits[:, prediction_index, :] / temperature
        probabilities = F.softmax(temp_scaled_logits, dim=-1)
        next_token = torch.multinomial(probabilities, num_samples=1).item()

        if next_token in LANG_BEGIN_TOK_IDS:
            print('switching langs')
            autoregressive_langs.append(LANG_BEGIN_TOK_IDS.index(next_token))
        else:
            autoregressive_langs.append(autoregressive_langs[-1])

        autoregressive_toks.append(next_token)
        if next_token == END_TOK_ID:
            break
    return autoregressive_toks, autoregressive_langs
    # get the autoregressive sequence
    # return decode(autoregressive_seq, language_id)

# generate prompt -> hacky split -> multi encod
prompted = (data_split.write_translation_story_tinyprompt_strs('Lily fue', 'Lily went to the park.'))
lang_segmented = detections_to_lang_identified_text(HackyCheatingLangDetector([]).split_n_detect(prompted, END_TOKEN), prompted)
lang_segmented = lang_segmented.removesuffix('[END]')

x, lang_x = multi_tokenize_delimited_story(lang_segmented)
print(multi_detokenize(x, lang_x))
comp_tokens, comp_langs = prompt_model(model, x, lang_x, 10, temperature=.01)
multi_detokenize(comp_tokens, comp_langs)

[BEGIN-EN][PROMPT] translate [USER] Lily went to the park.
[BEGIN-ES]Lily fue 


'[BEGIN-EN][PROMPT] translate [USER] Lily went to the park.\n[BEGIN-ES]Lily fue  "¡ comerla. y hacer dijo!" jugar el y'

In [None]:
xb, xl, yb, yl = get_batch('train')
print(xb.shape)
print(xb[0][:5])


torch.Size([64, 256])
tensor([  5,   0, 389, 226,   1])
