In [None]:
# folder with text files
folder_with_text = 'harry potter/'

# tokenizer vocabulary size
vocab_size=4096

# each line in a file will be converted to tokens and only first output_tokens will be used
output_tokens=256

# when training to restore skipped tokens, what is min and max amount
# of skipped token per line
min_skip_tokens=2
max_skip_tokens=8

# one token embedding
emb_dim=256
# one token embedding expansion in a model
internal_dim=1024
# how many transformer attention blocks to include in model
attnetion_layers=10
# how many heads for attention to use
heads=16

# where to save model checkpoints
checkpoint_path = 'runs/harry-potter-linear'

# number of epochs to train
num_epochs=500

# Create tokenizer

In [2]:
# for our dataset train tokenizer
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers
import os

# Initialize a tokenizer with a WordPiece model
tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))

# Set up normalization and pre-tokenization
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False,clean_text=False)
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

# [UNK] Represents any word or token that is not found in the model's vocabulary

# [CLS] (Classification Token): Inserted at the beginning of every 
# input sequence. In classification tasks, the final hidden state corresponding 
# to this token is used as the aggregate sequence representation

# [SEP] (Separator Token): Used to separate distinct sentences or segments within the input.
# [PAD] (Padding Token): Used to pad input sequences to a uniform length, ensuring that batches of data have consistent dimensions.

# [MASK] (Masking Token): Employed during the pre-training phase of models like BERT for masked language modeling. 
# Certain tokens in the input are replaced with [MASK], 
# and the model is trained to predict the original token, enabling it to learn bidirectional representations.

# [END] (ending token): end of text token

# Initialize a trainer with desired parameters
trainer = trainers.WordPieceTrainer(
    vocab_size=vocab_size,
    min_frequency=4,
    special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]","[END]",'[START]']
)

tokenizer_save_dir = os.path.join(folder_with_text,"wordpiece_tokenizer.json")
txt_files = [folder_with_text+v for v in os.listdir(folder_with_text) if v.endswith(".txt")]
txt_lines = "\n".join([open(v).read() for v in txt_files])

# Train and save the tokenizer on your corpus 
tokenizer.train(txt_files, trainer)
tokenizer.save(tokenizer_save_dir)

tokenizer = Tokenizer.from_file(tokenizer_save_dir)
vocab_size=tokenizer.get_vocab_size()

# Example usage
sample_text = "Harry, we are in trouble!"
encoded = tokenizer.encode(sample_text)
print("Average token length",len(sample_text)/len(encoded.tokens))
print("Encoded IDs:", encoded.ids)
print("Encoded Tokens:", encoded.tokens)

# Decode back to text
decoded = tokenizer.decode(encoded.ids)
print("Decoded Text:", decoded)




Average token length 3.5714285714285716
Encoded IDs: [242, 14, 356, 508, 236, 1912, 7]
Encoded Tokens: ['Harry', ',', 'we', 'are', 'in', 'trouble', '!']
Decoded Text: Harry , we are in trouble !


In [3]:
# analyze dataset and tokenizer
import numpy as np
txt_split = txt_lines.split("\n")
txt_split = [v for v in txt_split if len(v)>30]
lengths = [len(v) for v in txt_split]
lengths=np.array(lengths)
longest = lengths.argsort()[-100:]

longest_texts = [txt_split[v] for v in longest]
longest_texts_emb_token_length = [len(v)/len(tokenizer.encode(v).tokens) for v in longest_texts]
mean_token_length = np.array(longest_texts_emb_token_length).mean()

print("Text length analysis")
print("text lines\t",len(txt_split))
print("line chars mean\t",lengths.mean().round(3))
print("line chars std\t",lengths.std().round(3))
print("0.05 quantile\t",np.quantile(lengths,0.05))
print("0.95 quantile\t",np.quantile(lengths,0.95))
# mean amount of chars per token
print('mean token len\t',mean_token_length)

Text length analysis
text lines	 34517
line chars mean	 177.687
line chars std	 240.073
0.05 quantile	 39.0
0.95 quantile	 473.0
mean token len	 3.930851187856247


# Define Dataset

In [4]:
import random
from kemsekov_torch.train import split_dataset
import torch

class TokenDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, text_lines,pad_token = '[PAD]',mask_token = '[MASK]',start_token='[START]',end_token='[END]', output_tokens = 1024,min_skip_tokens=2,max_skip_tokens = 10):
        super().__init__()
        self.text = text_lines
        self.pad_token = tokenizer.encode(pad_token).ids[0]
        self.mask_token = tokenizer.encode(mask_token).ids[0]
        self.start_token = tokenizer.encode(start_token).ids[0]
        self.end_token = tokenizer.encode(end_token).ids[0]
        self.output_tokens=output_tokens
        self.tokenizer = tokenizer
        self.min_skip_tokens=min_skip_tokens
        self.max_skip_tokens=max_skip_tokens
    
    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = self.text[index]
        ids_orig = tokenizer.encode(text).ids
        ids=torch.tensor(([self.start_token]+ids_orig+[self.pad_token]*(self.output_tokens))[:self.output_tokens-1]+[self.end_token])
        
        to_skip = random.randint(self.min_skip_tokens,self.max_skip_tokens)
        to_skip = min(to_skip,len(ids_orig)//4)
        
        return self.neighbor_middle_token_pair(ids,to_skip)
    
    def neighbor_middle_token_pair(self,tokens,skip_tokens=4):
        aval_id = torch.where(tokens!=self.pad_token)[0]
        
        skipped = tokens.clone()
        ind = torch.rand(aval_id.shape).argsort()[:skip_tokens]
        
        skipped[aval_id[ind]]=self.mask_token
        return skipped,tokens

dataset = TokenDataset(
    tokenizer,
    txt_split,
    pad_token='[PAD]',
    mask_token='[MASK]',
    output_tokens=output_tokens,
    min_skip_tokens=min_skip_tokens,
    max_skip_tokens=max_skip_tokens
)

train_dataset,test_dataset,train_loader, test_loader = split_dataset(dataset,test_size=0.02,batch_size=16,random_state=None)
len(train_dataset),len(test_dataset)

(33826, 691)

In [5]:
%env TOKENIZERS_PARALLELISM=True
import random
ind = random.randint(0,len(train_dataset)-1)
skip,neigh = dataset[ind]

i = random.randint(0,len(neigh)-1)
print("skipped at",torch.where(neigh==dataset.mask_token)[0])
neigh_text = tokenizer.decode(neigh.tolist(),skip_special_tokens=False).replace(" ##","").replace('[PAD]','')
skip_text = tokenizer.decode(skip.tolist(),skip_special_tokens=False).replace(" ##","").replace('[PAD]','')

print(neigh_text)
print(skip_text)

env: TOKENIZERS_PARALLELISM=True
skipped at tensor([], dtype=torch.int64)
[START] The idea that Dumbledore valued his opinion this highly made Harry feel even more deeply ashamed that he had failed in the task of retrieving the Horcrux memory , and he shifted guiltily in his seat as Dumbledore raised the first of the two bottles to the light and examined it .                                                                                                                                                                                             [END]
[MASK] The idea that [MASK] valued his opinion this highly made Harry [MASK] even [MASK] deeply ashamed that he had failed in the task of retrieving the Horcrux memory , [MASK] he shifted guiltily in his seat as Dumbledore raised the first of the [MASK] bottles to the [MASK] and examined [MASK] .                                                                                                                                                   

# Define Model

In [6]:
from kemsekov_torch.residual import ResidualBlock, Residual
from kemsekov_torch.attention import LinearSelfAttentionBlock, TransformerSelfAttentionBlock
from kemsekov_torch.common_modules import *
from kemsekov_torch.positional_emb import ConcatPositionalEmbeddingPermute
from kemsekov_torch.rotary_emb import RotaryEmbInplace
import torch
import torch.nn as nn

class Embedding(nn.Module):
    """
    Module for token to embedding vector learning
    """
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size

        # Initialize weights and bias
        self.weight = nn.Parameter(torch.Tensor(vocab_size, embedding_size))
        self.bias = nn.Parameter(torch.Tensor(embedding_size))

        self.reset_parameters()

    #normal init
    def reset_parameters(self):
        # Initialize weights with a normal distribution
        std = 1.0 / (self.vocab_size**0.5)
        nn.init.normal_(self.weight, mean=0.0, std=std)
        # Initialize bias to zeros
        nn.init.zeros_(self.bias)
        
    def forward(self, input):
        # Input is expected to be a tensor of indices
        output = torch.nn.functional.embedding(input, self.weight) + self.bias
        return output.transpose(-1,-2)
class SimpleTransformer(torch.nn.Module):
    def __init__(self,emb_dim=256,internal_dim=512, attnetion_layers = 10,heads=16):
        super().__init__()
        
        attn_common = dict(
            input_dim=internal_dim,
            mlp_dim=internal_dim*4,
            heads=heads,
            dropout=0.1,
        )
        
        self.emb = Embedding(vocab_size,emb_dim)
        self.attention = torch.nn.Sequential(
            # expand dimensions
            ResidualBlock(
                emb_dim,
                internal_dim,
                kernel_size=1,
                dimensions=1
            ),
            
            Residual([
                # use RoPE as first module
                RotaryEmbInplace(internal_dim),
                # then stack linear SA blocks
                FlattenSpatialDimensions([
                    LinearSelfAttentionBlock(**attn_common)
                    # TransformerSelfAttentionBlock(attn_common['input_dim'],attn_common['heads'],attn_common['mlp_dim'],batch_first=True)
                    for i in range(attnetion_layers)
                ])
            ])
        )
        # map emb dim to token
        self.fc = torch.nn.Linear(internal_dim,vocab_size)
    def forward(self,x : torch.Tensor):
        x = self.emb(x)
        x = self.attention(x)
        x=x.transpose(-1,-2)
        return self.fc(x)

model = SimpleTransformer(
    emb_dim=emb_dim,
    internal_dim=internal_dim,
    attnetion_layers=attnetion_layers,
    heads=heads
)

print(f"model {sum([v.numel() for v in model.parameters()])//1000//1000}M parameters")
model(neigh[None,:]).shape,neigh.shape,skip.shape

model 145M parameters


(torch.Size([1, 256, 4096]), torch.Size([256]), torch.Size([256]))

# Training

In [None]:
from kemsekov_torch.train import train
from torchmetrics.classification import MulticlassF1Score

CE = torch.nn.CrossEntropyLoss()

f1__ = MulticlassF1Score(vocab_size)
def f1(x,y):
    return f1__(x.detach().cpu(),y.detach().cpu())

def compute_loss_and_metric(model,batch):
    skip,neigh = batch[0],batch[1]

    neigh_pred = model(skip)
    neigh_pred=neigh_pred.view(-1,neigh_pred.shape[-1])
    neigh = neigh.view(-1)
    
    general_loss = CE(neigh_pred,neigh)
    
    skip_ind = skip.view(-1)==dataset.mask_token
    skip_pred = neigh_pred[skip_ind]
    skip_true = neigh[skip_ind]
    skipped_tokens_loss = CE(skip_pred,skip_true)
    
    # use two losses, to enforce model not only reconstruct back original
    # input sequence, but also pay same amount of attention to skipped tokens
    return general_loss+skipped_tokens_loss,{
        'f1': f1(neigh_pred,neigh),
        'f1 skip': f1(skip_pred,skip_true),
        'general_loss':general_loss,
        'skipped_tokens_loss':skipped_tokens_loss
    }


optim = torch.optim.AdamW(model.parameters(),0.001,betas=(0.9, 0.95))
sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim,len(train_loader)*10)

_ = train(
    model,
    train_loader,
    test_loader,
    compute_loss_and_metric,
    checkpoint_path,
    f'{checkpoint_path}/last',
    gradient_clipping_max_norm=1,
    accelerate_args={
        # 'gradient_accumulation_steps':8,
        'mixed_precision':'bf16',
        'dynamo_backend':'inductor'
    },
    save_on_metric_improve=['f1 skip'],
    optimizer=optim,
    num_epochs=num_epochs,
    scheduler=sch,
    checkpoints_count=2,
)

Total model parameters 145.57 M
loaded training state from runs/harry-potter-linear/last/state
trying to capture model architecture...
Saved model architecture at runs/harry-potter-linear/model.pt. You can torch.load it and update it's weights with checkpoint

Epoch 472/500


train 0:   2%|▏         | 43/2115 [00:15<05:37,  6.15it/s, f1=0.9008, f1 skip=0.2561, general_loss=0.0676, loss=4.0390, skipped_tokens_loss=3.9714] 

# Eval model

In [None]:
from kemsekov_torch.train import load_checkpoint, load_last_checkpoint

# load model
model = torch.jit.load(f"{checkpoint_path}/model.pt")
model = load_checkpoint(model,checkpoint_path,-1).cpu().eval().half()

loading runs/harry-potter-linear/checkpoints/epoch-400/state


In [None]:
d = test_dataset.dataset
ind = random.randint(0,len(d)-1)
skipped_t,true_t = d[ind]

pred = model(skipped_t[None,:])[0]
pred_tokens = pred.softmax(-1).argmax(-1)

true_text = tokenizer.decode(true_t.tolist(),skip_special_tokens=False).replace(" ##","").replace('[PAD]','')
skip_text = tokenizer.decode(skipped_t.tolist(),skip_special_tokens=False).replace(" ##","").replace('[PAD]','')
pred_text = tokenizer.decode(pred_tokens.tolist(),skip_special_tokens=False).replace(" ##","").replace('[PAD]','')

print(true_text)
print(skip_text)
print(pred_text)

f1(pred,true_t)

[START] It took Harry a few moments to realize what McLaggen was talking about .                                                                                                                                                                                                                                                 [END]
[START] [MASK] [MASK] Harry [MASK] few moments to realize what McLaggen was talking about .                                                                                                                                                                                                                                                 [END]
[START] It was Harry a few moments to realize what McLaggen was talking about .                                                                                                                                                                                                                                                 [END]


tensor(0.9216)