In [1]:
##  /apps/pytorch/2.0.1/bin/python
## /orange/h.azad/s.saini/.env/bin/activate


from torchsummary import summary
from model import ModelArgs, Transformer
import torch
import torch.nn.functional as F
from tokenizer import Tokenizer
import re
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
import subprocess as sp
import os
import random
import pandas as pd

In [2]:

def get_sents_gen(corpus_filename,bsz=32,max_sent_len=10):
    with open(corpus_filename) as f:
        lines = f.readlines()
    text = ' '.join(lines)
    len_text = len(text)
    res = re.finditer(r"\s", text)
    word_indices = []
    for obj in list(res):
        word_indices.append(obj.span()[1])
    word_indices = np.array(word_indices)
    len_indices = len(word_indices)
    shuffled_indices = np.arange(0,len_indices)
    np.random.shuffle(shuffled_indices)
    
    
    for i in range(0,len_indices,bsz):
        j = min(i+bsz,len_indices)
        start_indices = shuffled_indices[i:j]
        batch = []
        for index in start_indices:
            batch.append(text[word_indices[index]:word_indices[min(index+max_sent_len,len_indices-1)]])
        
        yield batch


def get_sents_gen_dir(directory,bsz=32,max_sent_len=10):
    files = os.listdir(directory)
    num_files = len(files)
    f_indices = list(range(num_files))
    random.shuffle(f_indices)
    for i in range(0,num_files,3):
        indices=f_indices[i:min(i+3,num_files)]
        text = ''
        for j in indices:
            df = pd.read_parquet(os.path.join(directory,files[f_indices[j]]))
            text = '\n\n'.join(df['text'].to_list())

        
        len_text = len(text)
        res = re.finditer(r"\s", text)
        word_indices = []
        for obj in list(res):
            word_indices.append(obj.span()[1])
        word_indices = np.array(word_indices)
        len_indices = len(word_indices)
        shuffled_indices = np.arange(0,len_indices)
        np.random.shuffle(shuffled_indices)
        
        
        for i in range(0,len_indices,bsz):
            j = min(i+bsz,len_indices)
            start_indices = shuffled_indices[i:j]
            batch = []
            for index in start_indices:
                batch.append(text[word_indices[index]:word_indices[min(index+max_sent_len,len_indices-1)]])
            
            yield batch
            
def get_sents_from_parquets(directory,bsz=10,max_sent_len=10):

    all_indices = []
    dfs = []

    print('Preparing data.....')
    
    for filename in ['data/0.parquet','data/1.parquet','data/2.parquet','data/3.parquet']:
        dfs.append(pd.read_parquet(filename))

    df = pd.concat(dfs)
    len_df = len(df)
    indices = list(range(len_df))
    np.random.shuffle(indices)
    for i in range(0,len_df,bsz):
        yield (df['text'].iloc[indices[i:i+bsz]]).to_list()

In [3]:
from llama.generation import sample_top_p


prompts = ['Max is a good dog. He protects kids.', 'John was hungry. He wanted to eat']



def get_gpu_memory():
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values




        
class LanguageModel():
    
    def __init__(self,model_args,tokenizer,max_seq_len=10,bsz=1,verbose=True,device='cuda',device_index=5):
        model_args.vocab_size = tokenizer.n_words
        model_args.max_seq_len = max_seq_len
        model_args.max_batch_size = bsz
        self.bsz = bsz
        self.max_seq_len = max_seq_len
        self.model = Transformer(model_args)
        self.tokenizer = tokenizer
        self.device = device
        self.device_index = device_index
        if verbose:
            summary(self.model,torch.ones((1,10),dtype=torch.long).to(device))

    def compute_loss(self,batch):

        pad_id = self.tokenizer.pad_id
        sent_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in batch]
        tokens = torch.full((len(batch), self.max_seq_len), pad_id, dtype=torch.long, device=self.device)
        for k, t in enumerate(sent_tokens):
            tokens[k, : min(self.max_seq_len,len(t))] = torch.tensor(t[:min(self.max_seq_len,len(t))], dtype=torch.long, device=self.device)
        X = tokens[:,:-1]
        y = tokens[:,1:]
        input_text_mask = (X != pad_id).to(self.device)
        output = torch.zeros((X.shape[0],X.shape[1],self.tokenizer.n_words),device=self.device)
        output = self.model(X)
        output = output*(input_text_mask.int()[...,None])
        loss = self.loss_fn(torch.transpose(output,1,-1),y)
        return loss
    
    def train(self,train_corpus,test_corpus=None,epochs=20):
        
        self.model.train()
        writer = SummaryWriter()
        
        
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001, betas=(0.9, 0.95), eps=1e-05, weight_decay=0.1)
        
        cnt = 0
        for i in range(1,epochs+1):
            print('total epochs - ' + str(i))
            # sent_gen = get_sents_gen_dir(train_corpus,self.bsz,self.max_seq_len)
            sent_gen = get_sents_from_parquets(train_corpus,self.bsz,self.max_seq_len)
            
            for j, batch in enumerate(sent_gen):
                loss = self.compute_loss(batch)
                loss.backward()#retain_graph=True)
                self.optimizer.step()
                self.optimizer.zero_grad()
                loss_val = loss.detach()
                cnt+=1
                if not j%100:
                    with torch.no_grad():
                        if test_corpus is not None:
                            test_sent_gen = get_sents_gen(test_corpus,self.bsz,self.max_seq_len)
                            losses = []
                            for j, batch in enumerate(sent_gen): 
                                losses.append(self.compute_loss(batch))
                            val_loss = torch.mean(losses)
                            writer.add_scalar("Loss/test", val_loss, cnt)
                        print(loss_val)                    
                        writer.add_scalar("Mem/train", get_gpu_memory()[self.device_index], cnt)
                        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
                        gens = self.tokenizer.decode(self.generate(prompt_tokens,50)[0])
                        for k,gen in enumerate(gens):
                            writer.add_text("Gen/train",prompts[k]+'#####'+gen,cnt)
                    
                # del loss
                writer.add_scalar("Loss/train", loss_val, cnt)

    @torch.inference_mode()
    def generate(
        self,
        prompt_tokens,
        max_gen_len: int,
        temperature: float = 0.6,
        top_p: float = 0.9,
        logprobs: bool = False,
        echo: bool = False,
    ):
        """
        Generate text sequences based on provided prompts using the language generation model.

        Args:
            prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
            max_gen_len (int): Maximum length of the generated text sequence.
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

        Returns:
            Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.

        Note:
            This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
            If logprobs is True, token log probabilities are computed for each generated token.

        """
        print('generating')
        params = self.model.params
        bsz = len(prompt_tokens)
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        min_prompt_len = min(len(t) for t in prompt_tokens)
        max_prompt_len = max(len(t) for t in prompt_tokens)
        assert max_prompt_len <= params.max_seq_len
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

        pad_id = self.tokenizer.pad_id
        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
        if logprobs:
            token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

        prev_pos = 0
        eos_reached = torch.tensor([False] * bsz, device="cuda")
        input_text_mask = tokens != pad_id
        for cur_pos in range(min_prompt_len, total_len):
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, use_att_cache= True)
            if logprobs:
                token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
                    input=logits.transpose(1, 2),
                    target=tokens[:, prev_pos + 1 : cur_pos + 1],
                    reduction="none",
                    ignore_index=pad_id,
                )
            if temperature > 0:
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits[:, -1], dim=-1)

            next_token = next_token.reshape(-1)
            # only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
            eos_reached |= (~input_text_mask[:, cur_pos]) & (
                next_token == self.tokenizer.eos_id
            )
            prev_pos = cur_pos
            if all(eos_reached):
                break

        if logprobs:
            token_logprobs = token_logprobs.tolist()
        out_tokens, out_logprobs = [], []
        for i, toks in enumerate(tokens.tolist()):
            # cut to max gen len
            start = 0 if echo else len(prompt_tokens[i])
            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
            probs = None
            if logprobs:
                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
            # cut to eos tok if any
            if self.tokenizer.eos_id in toks:
                eos_idx = toks.index(self.tokenizer.eos_id)
                toks = toks[:eos_idx]
                probs = probs[:eos_idx] if logprobs else None
            out_tokens.append(toks)
            out_logprobs.append(probs)
        return (out_tokens, out_logprobs if logprobs else None)




        
        writer.flush()

In [4]:
get_gpu_memory()

[81046, 81046, 81046, 66259, 81046, 81046, 81046, 81046]

In [5]:
model_args = ModelArgs()
model_args.dim= 512
model_args.n_layers= 8
model_args.n_heads= 8
model_args.n_kv_heads = None
model_args.vocab_size= None  # defined later by tokenizer
model_args.multiple_of = 32  # make SwiGLU hidden layer size multiple of large power of 2
model_args.ffn_dim_multiplier = None
model_args.norm_eps= 1e-5
model_args.max_batch_size= None
model_args.max_seq_len= None

device = 'cuda'
torch.autograd.set_detect_anomaly(True)
tokenizer = Tokenizer('data/vocab/tinystories28000.model')

lm = LanguageModel(model_args,tokenizer,bsz=128,max_seq_len=200,device_index=5)
lm.train('data/tinystories_corpus/',epochs=20)

# LM_train(model_args,tokenizer,'data/alice_in_wonderland.txt',max_sent_len=128,bsz=32,epochs=1)

Layer (type:depth-idx)                   Output Shape              Param #
├─Embedding: 1-1                         [-1, 10, 512]             14,336,000
├─ModuleList: 1                          []                        --
|    └─TransformerBlock: 2-1             [-1, 10, 512]             --
|    |    └─RMSNorm: 3-1                 [-1, 10, 512]             512
|    |    └─RMSNorm: 3-2                 [-1, 10, 512]             512
|    └─TransformerBlock: 2-2             [-1, 10, 512]             --
|    |    └─RMSNorm: 3-3                 [-1, 10, 512]             512
|    |    └─RMSNorm: 3-4                 [-1, 10, 512]             512
|    └─TransformerBlock: 2-3             [-1, 10, 512]             --
|    |    └─RMSNorm: 3-5                 [-1, 10, 512]             512
|    |    └─RMSNorm: 3-6                 [-1, 10, 512]             512
|    └─TransformerBlock: 2-4             [-1, 10, 512]             --
|    |    └─RMSNorm: 3-7                 [-1, 10, 512]             512


KeyboardInterrupt: 

In [18]:
prompts = ['John wanted to run, so he started running. Maria wanted to dance, so she started dancing. Julia wanted to jump, so she']
prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts]
gens = tokenizer.decode(lm.generate(prompt_tokens,50)[0])
for k,gen in enumerate(gens):
    print(prompts[k]+'#####'+gen+'\n')

generating
John wanted to run, so he started running. Maria wanted to dance, so she started dancing. Julia wanted to jump, so she#####ran after him. John was so fast that he was almost dizzy. He couldn't move. He was so dizzy that he almost fell over. Alice was very worried. She decided to try again. She took a deep breath and jumped. She



In [None]:
import subprocess as sp
import os

def get_gpu_memory():
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values

get_gpu_memory()

In [None]:
import sentencepiece as spm
spm.SentencePieceTrainer.train(input='data/alice_in_wonderland.txt', model_prefix='alice', vocab_size=2642, user_defined_symbols=['foo', 'bar'])

In [None]:
import re 
s = 'sdfds dsf sdf sdsdfs d fsf df'
print(re.findall(r"\s", s))    

In [None]:
import re
import random 
import numpy as np
# initializing string
# test_str = 'sdfds dsf sdf sdsdfs d fsf df'
 
# # Using regex
# # Check for spaces
# res = re.finditer(r"\s", test_str)
# word_indices = []
# for obj in list(res):
#     word_indices.append(obj.span()[1])
    
# shuffled_indices = word_indices.copy()
# random.shuffle(shuffled_indices)

# for 
    
def get_sents_gen(corpus_filename,bsz=32,max_sent_len=64):
    with open(corpus_filename) as f:
        lines = f.readlines()
    text = ' '.join(lines)
    len_text = len(text)
    res = re.finditer(r"\s", text)
    word_indices = []
    for obj in list(res):
        word_indices.append(obj.span()[1])
    word_indices = np.array(word_indices)
    len_indices = len(word_indices)
    shuffled_indices = np.arange(0,len_indices)
    np.random.shuffle(shuffled_indices)
    
    
    for i in range(0,len_indices,bsz):
        j = min(i+bsz,len_indices)
        start_indices = shuffled_indices[i:j]
        batch = []
        for index in start_indices:
            batch.append(text[word_indices[index]:word_indices[min(index+max_sent_len,len_indices)]])
        
        yield batch
        

dgen = get_sents_gen('data/alice_in_wonderland.txt')

for batch in dgen:
    print(batch)
    break
            
            
        
    
    

In [None]:
def sent_gen():
    for s in ['sad fa sff fdsf','fsfs sf d','sdsdasdf fsdgdf sf']:
        yield s

In [None]:
import sentencepiece as spm

sg = sent_gen()
spm.SentencePieceTrainer.train(input=sg, model_prefix='test/test', vocab_size=10,pad_id=0, bos_id=1,eos_id=2,unk_id=3)
sp = spm.SentencePieceProcessor(model_file='test/test.model')
sp.vocab_size()

In [None]:
sp.pad_id()

In [None]:
sp.decode([770, 537, 138, 22, 58, 11, 404])

In [None]:
import torch 

cache_mask_list = []
xk_extended_list = []

start_pos = 5
seqlen = 3
max_seq_len = 10
bsz = 4
xk = 3*torch.ones((bsz,seqlen,1,1))
cache_k = torch.zeros((bsz,max_seq_len,1,1))


if start_pos > 0:
    cache_mask_list.append(torch.ones((cache_k.shape[0],start_pos,cache_k.shape[2],cache_k.shape[3])))
    xk_extended_list.append(torch.zeros((cache_k.shape[0],start_pos,cache_k.shape[2],cache_k.shape[3])))
    
cache_mask_list.append(torch.zeros(xk.shape))
xk_extended_list.append(xk)

if (start_pos + seqlen) < cache_k.shape[1]:
    cache_mask_list.append(torch.ones((cache_k.shape[0],cache_k.shape[1] - (start_pos + seqlen),cache_k.shape[2],cache_k.shape[3])))
    xk_extended_list.append(torch.zeros((cache_k.shape[0],cache_k.shape[1] - (start_pos + seqlen),cache_k.shape[2],cache_k.shape[3])))
    
cache_mask = torch.cat(cache_mask_list,dim=1)
xk_extended = torch.cat(xk_extended_list,dim=1)

print(cache_mask.shape,cache_k.shape,xk_extended.shape)
        

In [None]:
print(cache_mask,xk_extended)