In [None]:
!pip install -U -q accelerate transformers huggingface_hub datasets hf_xet

In [None]:
def flush():
    gc.collect()
    torch.cuda.empty_cache()

def count_parameters(model):
    return f"BabyLlama size: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.2f}M parameters"

In [None]:
import gc
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from math import sqrt
from transformers import PretrainedConfig
import math
from typing import Tuple, Optional, List
from transformers import logging, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from datasets import load_dataset
from transformers import AutoTokenizer

dataset = load_dataset("HuggingFaceTB/cosmopedia-100k")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")  # Use Hugging Face tokenizer or your own


logger = logging.get_logger(__name__)


In [None]:
ds = dataset["train"]["text"][5]

In [None]:
ds

In [None]:
class LlamaConfig(PretrainedConfig):
    model_type = "llama"
    keys_to_ignore_at_inference = ["past_key_values"]
    def __init__(
        self,
        vocab_size=128000,
        hidden_size=192, # 2048 Tiny LLaMA
        intermediate_size=512,
        num_hidden_layers=12,
        num_attention_heads=4, # 32 Tiny LLaMA
        num_key_value_heads=2,
        hidden_act="silu",
        max_position_embeddings=128,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=False,
        pad_token_id=None,
        bos_token_id=1,
        eos_token_id=2,
        pretraining_tp=1,
        tie_word_embeddings=False,
        rope_theta=10000.0,
        rope_scaling=None,
        attention_bias=False,
        attention_dropout=0.0,
        use_bias=False,
        lm_head_bias=False,
        residual_dropout=0.0,
        device='cpu',
        **kwargs,
    ):

        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.pretraining_tp = pretraining_tp
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.residual_dropout = residual_dropout
        self.use_bias = use_bias
        self.lm_head_bias = lm_head_bias
        self.device = device

        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )

In [None]:
def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor:
    ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
    return torch.tril(ones).unsqueeze(0).unsqueeze(0)

def repeat_kv(hidden_states:torch.tensor, n_repeats:int):
    batch, n_kv_heads, seq_len, head_dim = hidden_states.shape
    if n_repeats == 1:
        return hidden_states
    hidden_states = hidden_states.unsqueeze(2).expand(batch, n_kv_heads, n_repeats, seq_len, head_dim) # (B, nh, T, hs) -> (B, nh, 1, T, hs) -> # (B, nh, n_repeats, T, hs)
    return hidden_states.reshape(batch, n_kv_heads * n_repeats, seq_len, head_dim) # # (B, nh * n_repeats, T, hs)


class RotaryPositionalEmbeddings(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.device=device
        self.scaling_factor = scaling_factor
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor
        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    @torch.no_grad()
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

    def apply_rope(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids, unsqueeze_dim=1) -> torch.Tensor:
        cos = cos[position_ids].unsqueeze(unsqueeze_dim)
        sin = sin[position_ids].unsqueeze(unsqueeze_dim)
        x1 = x[..., : x.shape[-1] // 2] # (B, nh, T, hs/2)
        x2 = x[..., x.shape[-1] // 2 :] # (B, nh, T, hs/2)
        rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
        roped = (x * cos) + (rotated * sin)
        return roped.to(dtype=x.dtype)


    @property
    def sin_cached(self):
        
        return self._sin_cached

    @property
    def cos_cached(self):
        
        return self._cos_cached


class KVCache(nn.Module):
    def __init__(
        self,
        k_shape: Tuple[int, int, int, int],
        v_shape: Tuple[int, int, int, int],
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        super().__init__()
        self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
        self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)

    def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # move the buffer to the activation dtype for when AMP is used
        self.k = self.k.to(k.dtype)
        self.v = self.v.to(v.dtype)
        # update the cache
        k = self.k.index_copy_(2, input_pos, k)
        v = self.v.index_copy_(2, input_pos, v)
        return k, v

    def reset_parameters(self) -> None:
        torch.nn.init.zeros_(self.k)
        torch.nn.init.zeros_(self.v)

In [None]:
class LlamaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_dim = hidden_dim = config.hidden_size
        self.n_heads = n_heads = config.num_attention_heads
        self.n_kv_heads = n_kv_heads = config.num_key_value_heads
        self.head_dim = head_dim = config.hidden_size // n_heads
        use_bias = config.use_bias

        if (head_dim * n_heads) != self.hidden_dim:
            raise ValueError(
                f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
                f" and `num_heads`: {self.n_heads})."
            )

        self.repeats = n_heads // n_kv_heads # q_per_kv

        self.q_proj = nn.Linear(hidden_dim, n_heads * head_dim, bias=use_bias)
        self.k_proj = nn.Linear(hidden_dim, n_kv_heads * head_dim, bias=use_bias)
        self.v_proj = nn.Linear(hidden_dim, n_kv_heads * head_dim, bias=use_bias)
        self.o_proj = nn.Linear(n_heads * head_dim, hidden_dim, bias=use_bias)

        self.rotary_emb = RotaryPositionalEmbeddings(
            head_dim,
            max_position_embeddings=config.max_position_embeddings,
            device=config.device,
            base=config.rope_theta,
        )

        self.kv_cache: Optional[KVCache] = None

    def forward(
        self,
        hidden_states: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ):
        B, T, _ = hidden_states.size() # bsz, seq_len, embed_dim

        queries = self.q_proj(hidden_states)
        keys = self.k_proj(hidden_states)
        values = self.v_proj(hidden_states)


        queries = queries.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # bsz, seq_len, n_heads, head_dim
        keys = keys.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # bsz, seq_len, n_kv_heads, head_dim
        values = values.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = keys.shape[-2]
        cos, sin = self.rotary_emb(values, seq_len=kv_seq_len)

        queries = self.rotary_emb.apply_rope(queries, cos, sin, position_ids)
        keys = self.rotary_emb.apply_rope(keys, cos, sin, position_ids)


        # TODO: KV caching
        keys = repeat_kv(keys, self.repeats)
        values = repeat_kv(values, self.repeats)

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if queries.device.type == "cuda" and mask is not None:
            queries = queries.contiguous()
            keys = keys.contiguous()
            values = values.contiguous()

        y = self.scaled_dot_product_attention(queries, keys, values, mask) # (B, T, n_heads, head_dim)

        y = y.reshape(B, T, self.hidden_dim) # (B, T, hidden_dim)

        return self.o_proj(y)


    def scaled_dot_product_attention(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        print()
        scale = 1.0 / math.sqrt(self.head_dim)
        y = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=None, dropout_p=0.0, scale=scale, is_causal=True
        )
        return y.transpose(1, 2).contiguous()


In [None]:
class LLaMAMLP(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim): # in MLP: intermediate_dim= 4 * hidden_dim
        super(LLaMAMLP, self).__init__()
        self.linear_1 = nn.Linear(hidden_dim, intermediate_dim)
        self.linear_2 = nn.Linear(hidden_dim, intermediate_dim) # Original: intermediate -> hidden.
        self.activation_fn = nn.SiLU()
        self.out_proj = nn.Linear(intermediate_dim, hidden_dim) # Original: dropout


    def forward(self, hidden_states):
        x_fc_1 = self.linear_1(hidden_states)
        x_fc_2 = self.linear_2(hidden_states)
        x = self.activation_fn(x_fc_1) * x_fc_2
        x = self.out_proj(x)
        return x

In [None]:
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True) # (1/n) * Σ x_i^2
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class Block(nn.Module):
    def __init__(self, config: LlamaConfig):
        super(Block, self).__init__()
        self.hidden_dim = hidden_dim = config.hidden_size
        self.intermediate_dim = intermediate_dim = config.intermediate_size

        self.attn = LlamaAttention(config)

        self.mlp = LLaMAMLP(hidden_dim, intermediate_dim)
        self.input_layernorm = LlamaRMSNorm(hidden_dim, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(hidden_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states,
        mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ):
        r = self.attn(self.input_layernorm(hidden_states), mask,position_ids,)
        h = hidden_states + r
        r = self.mlp(self.post_attention_layernorm(h))
        out = h + r
        return out


In [None]:
class LlamaModel(nn.Module):
    def __init__(self, config):
        super(LlamaModel, self).__init__()
        self.config = config
        self.hidden_dim = hidden_dim = config.hidden_size
        self.vocab_size = vocab_size = config.vocab_size
        assert self.vocab_size > 0
        self.num_hidden_layers = num_hidden_layers = config.num_hidden_layers
        self.embed = nn.Embedding(vocab_size,10)

        self.embed_ln = nn.Linear(10,hidden_dim,bias=False)
        self.blocks = nn.ModuleList(
            [Block(config) for _ in range(num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(hidden_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
    ):
      
        x = self.embed(hidden_states)
        x = self.embed_ln(x)

        seq_len = hidden_states.size(1)
        if position_ids is None:
            position_ids = torch.arange(seq_len, dtype=torch.long, device=self.config.device).unsqueeze(0)


        for b in self.blocks:
            x = b(x, mask, position_ids)

        return self.norm(x)


class LlamaPreTrainedModel(PreTrainedModel):
    config_class = LlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _skip_keys_device_placement = "past_key_values"

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


class LlamaForCausalLM(LlamaPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.model     = LlamaModel(config)
        self.lm_head = nn.Linear(config.hidden_size,self.config.vocab_size, bias=False)
        self.post_init()
        self.lm_head.weight.data.fill_(0)

    def forward(
        self,
        input_ids:      torch.Tensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids:   Optional[torch.Tensor] = None,
        labels:         Optional[torch.Tensor] = None,
    ):

        outputs = self.model(
            hidden_states=input_ids,
            mask=attention_mask,
            position_ids=position_ids,
        )

        x    = outputs 
        
        B,T,C = x.shape
        logits = self.lm_head(x)
        
        
        loss = None
        if labels is not None:
      
            # shift so that tokens < n predict n
            shift_logits = logits[...,:-1, :].contiguous()# all elements expect the last one     
            shift_labels = labels[...,1:].contiguous() # all elements except the first
       
            # Flatten the tokens
            shift_logits = shift_logits.view(-1,self.config.vocab_size)
            #shift_logits = shift_logits.view(-1, self.config.vocab_size)

            shift_labels = shift_labels.view(-1)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits, shift_labels)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
        )


In [None]:
device="cpu"
config = LlamaConfig(device=device)
llm = LlamaForCausalLM(config).to(config.device)

In [None]:

count_parameters(llm)

In [None]:
llm

## Dataset

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt")

In [None]:
tokenized_datasets = dataset.map(
    lambda examples: tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=confi.max_position_embedding,
    ),
    batched=True,
)

# Data Collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False  # Set mlm=False for autoregressive models
)

## Training

In [None]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    "./Llama/",
    per_device_train_batch_size=32,
    #max_steps=150,
    num_train_epochs=2,
    logging_steps=5,
    #save_strategy = ""
    #resume_from_checkpoint = True,
    #gradient_accumulation_steps=2,
    #weight_decay=0.1,
    #warmup_steps= 1_000,
    #lr_scheduler_type="linear",
    learning_rate=0.001,
    #save_steps=500,
    fp16=True,
    report_to = "none",
    #torch_compile = True,
    push_to_hub=False,
)

In [None]:
trainer = Trainer(
    model=llm,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"]
)

In [None]:
trainer.train()

Now Traininig Pelican

In [None]:
dell llm

In [None]:
class LLaMAMLP(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim): # in MLP: intermediate_dim= 4 * hidden_dim
        super(LLaMAMLP, self).__init__()
        self.linear_1 = nn.Linear(hidden_dim, intermediate_dim,bias=False)
        self.linear_2 = nn.Linear(hidden_dim, intermediate_dim,bias=False) # Original: intermediate -> hidden.
        self.activation_fn = nn.SiLU()
        self.out_proj = nn.Linear(intermediate_dim, hidden_dim,bias=False) # Original: dropout


    def forward(self, hidden_states):
        hidden_states = hidden_states
        x_fc_1 = self.linear_1(hidden_states)
        x_fc_2 = self.linear_2(hidden_states)
        x = self.activation_fn(x_fc_1) * x_fc_2
        x = self.out_proj(x)
        return x

In [None]:
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True) # (1/n) * Σ x_i^2
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class Block(nn.Module):
    def __init__(self, config: LlamaConfig):
        super(Block, self).__init__()
        self.hidden_dim = hidden_dim = config.hidden_size
        self.intermediate_dim = intermediate_dim = config.intermediate_size
        self.mlp = LLaMAMLP(hidden_dim, intermediate_dim)
        self.input_layernorm = LlamaRMSNorm(hidden_dim, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(hidden_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states,
        mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
    ):
        
        r = hidden_states
        h = hidden_states
        r = self.mlp(self.post_attention_layernorm(h))
        out = h  + r
        return out


In [None]:
class PelicanModel(nn.Module):
    def __init__(self, config):
        super(PelicanModel, self).__init__()
        self.config = config
        self.hidden_dim = hidden_dim = config.hidden_size
        self.vocab_size = vocab_size = config.vocab_size
        assert self.vocab_size > 0
        self.num_hidden_layers = num_hidden_layers = config.num_hidden_layers
        self.embed = nn.Embedding(vocab_size,10)

        self.embed_ln = nn.Linear(10,hidden_dim,bias=False)
        self.blocks = nn.ModuleList(
            [Block(config) for _ in range(num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(hidden_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
    ):
      
        x = self.embed(hidden_states)
        x = self.embed_ln(x) 

        seq_len = hidden_states.size(1)
        if position_ids is None:
            position_ids = torch.arange(seq_len, dtype=torch.long, device=self.config.device).unsqueeze(0)

        for b in self.blocks:
            x = b(x, mask, position_ids)

        return self.norm(x)

class PelicanPreTrainedModel(PreTrainedModel):
    config_class = LlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _skip_keys_device_placement = "past_key_values"

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


class PelicanForCausalLM(PelicanPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.model     = PelicanModel(config)
        self.lm_head   = nn.Linear(config.hidden_size,self.config.vocab_size,bias=False)
        self.post_init()
        self.lm_head.weight.data.fill_(0)

    def forward(
        self,
        input_ids:      torch.Tensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids:   Optional[torch.Tensor] = None,
        labels:         Optional[torch.Tensor] = None,
    ):

        outputs = self.model(
            hidden_states=input_ids,
            mask=attention_mask,
            position_ids=position_ids,
        )

        x  = outputs
        
        B,T,C = x.shape 
      
        x = x/0.1
        logits = self.lm_head(x)
        
        
        loss = None
        if labels is not None:
   
            # shift so that tokens < n predict n
            shift_logits = logits[...,:-1, :].contiguous()# all elements expect the last one
  
            shift_labels = labels[...,1:].contiguous() # all elements except the first
       
            # Flatten the tokens
            shift_logits = shift_logits.view(-1,self.config.vocab_size)
            #shift_logits = shift_logits.view(-1, self.config.vocab_size)
    
            shift_labels = shift_labels.view(-1)
            loss_fct = nn.CrossEntropyLoss()
            
            loss = loss_fct(shift_logits, shift_labels)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
        )

In [None]:
device="cpu"
config = LlamaConfig(device=device)
llm = PelicanForCausalLM(config).to(config.device)

In [None]:
llm

In [None]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    "./Llama/",
    per_device_train_batch_size=2,
    #max_steps=150,
    num_train_epochs=2,
    logging_steps=5,
    #save_strategy = ""
    #resume_from_checkpoint = True,
    #gradient_accumulation_steps=2,
    #weight_decay=0.1,
    #warmup_steps= 1_000,
    #lr_scheduler_type="linear",
    learning_rate=0.001,
    #save_steps=500,
    fp16=True,
    report_to = "none",
    #torch_compile = True,
    push_to_hub=False,
)

In [None]:
trainer = Trainer(
    model=llm,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"]
)

In [None]:
trainer.train()

In [None]:


tokens = tokenizer(
    " once upon a time in  ",

    
    return_tensors='pt'
)#.to('cuda')
input_ids = tokens['input_ids']

temperature = 1
top_k = None
top_p = None

# Generate the tokens one by one
for _ in range(40):
    # Get the logits from the model
    outputs = llm(input_ids)
    logits = outputs.logits[:, -1, :]
    # Apply temperature scaling
    logits = logits / temperature

    # Apply top-k or top-p sampling if specified
    if top_k is not None:
        logits = logits.topk(top_k, dim=-1)[0]
    elif top_p is not None:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
        sorted_indices_to_remove[:, 0] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, -float('inf'))

    # Sample the next token from the logits
    next_token_id = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
    
    # Update the input with the new token
    input_ids = torch.cat([input_ids, next_token_id], dim=-1)
    
# Decode the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
print(generated_text)