# Beens-Minimax (Base - 103M Params)

## Necessary Imports

In [1]:
!pip install -q evaluate accelerate einops peft bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m95.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m85.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/66

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import evaluate
import time
import os
import shutil
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_scheduler, DataCollatorForLanguageModeling
from torch.optim import AdamW
from datasets import load_dataset, load_from_disk
from peft import get_peft_model, LoraConfig, TaskType
from tqdm.auto import tqdm

2025-07-11 06:47:16.247399: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752216436.444950      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752216436.496434      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [18]:
from peft import PeftModel

#### you should update to version >= 4.0.0, will face error otherwise.

In [3]:
!pip install -U datasets

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Downloading datasets-4.0.0-py3-none-any.whl (494 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: datasets
  Attempting uninstall: datasets
    Found existing installation: datasets 3.6.0
    Uninstalling datasets-3.6.0:
      Successfully uninstalled datasets-3.6.0
Successfully installed datasets-4.0.0


## Beens-minimax configuration

In [4]:
class MiniMaxConfig:
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
        self.hidden_size = 512
        self.num_layers = 8
        self.num_attention_heads = 8
        self.head_dim = self.hidden_size // self.num_attention_heads
        self.gqa_group_size = 2
        self.num_key_value_heads = self.num_attention_heads // self.gqa_group_size
        self.num_experts = 4
        self.num_experts_per_tok = 2
        self.ffn_hidden_dim = 2048
        self.rope_base = 10_000
        self.rope_dim_fraction = 0.5
        self.softmax_attention_period = 4
        self.deepnorm_alpha = (2 * self.num_layers) ** 0.25
        self.rms_norm_eps = 1e-6
        self.router_aux_loss_coef = 0.01
        
        self.tie_word_embeddings = True 
        self.model_type = "minimax_custom"

    def to_dict(self):
        return self.__dict__

    def get(self, key, default=None):
        return getattr(self, key, default)


print("Configuration defined")

In [5]:
KAGGLE_WORKING_PATH = "/kaggle/working/"
DATA_PATH = os.path.join(KAGGLE_WORKING_PATH, "processed_data")
SFT_LORA_CHECKPOINT_PATH = os.path.join(KAGGLE_WORKING_PATH, "sft_lora_checkpoints")
BASE_MODEL_INPUT_PATH = "/kaggle/input/beens-2/pytorch/default/1/"
TOKENIZER_INPUT_PATH = "/kaggle/input/minimax-processed-data/processed_data/"

os.makedirs(SFT_LORA_CHECKPOINT_PATH, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


print("Paths defined")

## Data Pre-processing

In [None]:
def prepare_and_save_data():
    BLOCK_SIZE = 256
    DATASET_NAME = "wikitext"
    DATASET_CONFIG = "wikitext-103-v1"
    TOKENIZER_NAME = "gpt2"
    
    if os.path.exists(os.path.join(DATA_PATH, "train")):
        print("Processed data already found in /kaggle/working/. Skipping preparation.")
        return


    
    print("--- Starting Data Preparation ---")
    
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
    if tokenizer.pad_token is None: 
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = 1024

    print(f"Downloading '{DATASET_NAME}' dataset from Hugging Face Hub...")
    raw_datasets = load_dataset(DATASET_NAME, DATASET_CONFIG)
    print("Dataset downloaded.")

    

    def tokenize_function(examples):
        return tokenizer(examples["text"], add_special_tokens=False, truncation=True, max_length=tokenizer.model_max_length)

    print("Tokenizing dataset...")
    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True, num_proc=2, remove_columns=["text"])

    

    def group_texts(examples):
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        total_length = (total_length // BLOCK_SIZE) * BLOCK_SIZE
        result = {k: [t[i : i + BLOCK_SIZE] for i in range(0, total_length, BLOCK_SIZE)] for k, t in concatenated_examples.items()}
        result["labels"] = result["input_ids"].copy()
        return result
    print("Grouping texts into blocks...")

    
    lm_datasets = tokenized_datasets.map(group_texts, batched=True, batch_size=1000, num_proc=2)
    print(f"\nSaving processed datasets to '{DATA_PATH}'...")

    
    lm_datasets["train"].save_to_disk(os.path.join(DATA_PATH, "train"))
    lm_datasets["validation"].save_to_disk(os.path.join(DATA_PATH, "validation"))
    tokenizer.save_pretrained(DATA_PATH)
    print("Data preparation and saving complete!")

In [None]:
prepare_and_save_data()

## RMS and RoPE Embeddings

In [6]:
def rotate_half(x): 
    x1, x2 = x.chunk(2, -1); return torch.cat((-x2, x1), -1)

def apply_rotary_pos_emb(q, k, cos, sin):
    cos = cos.unsqueeze(0).unsqueeze(1)
    sin = sin.unsqueeze(0).unsqueeze(1)
    
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed



class RMSNorm(nn.Module):
    def __init__(self, dim, eps): 
        super().__init__(); 
        self.eps, self.weight = eps, nn.Parameter(torch.ones(dim))
        
    def forward(self, x): 
        return x * torch.rsqrt(x.pow(2).mean(-1, True) + self.eps) * self.weight


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__(); 
        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float() / dim)) 
        self.register_buffer("inv_freq", inv_freq)
        self.cos_cached, self.sin_cached = None, None
        
    def _update_cache(self, x, seq_len):
        if self.cos_cached is not None and seq_len <= self.cos_cached.shape[0]: 
            return
            
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq); 
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), -1)
        self.cos_cached, self.sin_cached = emb.cos(), emb.sin()

    
    def forward(self, x, seq_len): 
        self._update_cache(x, seq_len)
        return self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype)


print("RMS and ROPE Utilities defined")

RMS and ROPE Utilities defined


## Minimax Attention Layers

In [7]:
class SoftmaxAttention(nn.Module):
    def __init__(self, config):
        super().__init__(); self.config = config; self.rope_dim = int(config.head_dim * config.rope_dim_fraction)
        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
        self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
        self.rotary_emb = RotaryEmbedding(self.rope_dim)
        
    def forward(self, x, mask, output_attentions=False):
        bsz, seq_len, _ = x.shape
        q_s = self.q_proj(x).view(bsz, seq_len, self.config.num_attention_heads, self.config.head_dim)
        k_s = self.k_proj(x).view(bsz, seq_len, self.config.num_key_value_heads, self.config.head_dim)
        v_s = self.v_proj(x).view(bsz, seq_len, self.config.num_key_value_heads, self.config.head_dim)
        
        q_s, k_s, v_s = q_s.transpose(1, 2), k_s.transpose(1, 2), v_s.transpose(1, 2)
        
        q_rot, q_pass = q_s[..., :self.rope_dim], q_s[..., self.rope_dim:]
        k_rot, k_pass = k_s[..., :self.rope_dim], k_s[..., self.rope_dim:]
        
        cos, sin = self.rotary_emb(v_s, seq_len=seq_len)
        q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
        
        q_s = torch.cat((q_rot, q_pass), dim=-1)
        k_s = torch.cat((k_rot, k_pass), dim=-1)
        
        k_s = k_s.repeat_interleave(self.config.gqa_group_size, dim=1)
        v_s = v_s.repeat_interleave(self.config.gqa_group_size, dim=1)
        
        attn_weights = torch.matmul(q_s, k_s.transpose(2, 3)) / math.sqrt(self.config.head_dim)
        if mask is not None:
            attn_weights += mask
        
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_s.dtype)
        attn_output = torch.matmul(attn_weights, v_s).transpose(1, 2).contiguous().view(bsz, seq_len, self.config.hidden_size)
        
        return self.o_proj(attn_output), attn_weights if output_attentions else None



class LightningAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.q_proj, self.k_proj, self.v_proj, self.g_proj, self.o_proj = (nn.Linear(config.hidden_size, config.hidden_size, bias=False) for _ in range(5))

    
    def forward(self, x, mask, output_attentions=False):
        bsz, seqlen, _ = x.shape
        
        q = F.silu(self.q_proj(x)).view(bsz, seqlen, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2)
        k = F.silu(self.k_proj(x)).view(bsz, seqlen, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2)
        v = F.silu(self.v_proj(x)).view(bsz, seqlen, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2)
        g = F.sigmoid(self.g_proj(x)).view(bsz, seqlen, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2)
        
        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        return self.o_proj((attn_output * g).transpose(1, 2).contiguous().view(bsz, seqlen, -1)), None


print("Attention mechanisms defined")

Attention mechanisms defined


## Minimax MOE Layer

In [13]:
def load_balancing_loss_func(logits, num_experts):
    if not logits: return 0.0
    probs = F.softmax(torch.cat([l.view(-1, num_experts) for l in logits], 0), -1)
    p_mean = probs.mean(0)
    return torch.sum(p_mean * p_mean) * num_experts


class Expert(nn.Module):
    def __init__(self, config): 
        super().__init__()
        self.w_in = nn.Linear(config.hidden_size, config.ffn_hidden_dim, False)
        self.w_out = nn.Linear(config.ffn_hidden_dim, config.hidden_size, False)
        self.act_fn = nn.SiLU()
        
    def forward(self, x): 
        return self.w_out(self.act_fn(self.w_in(x)))
        

class MoE(nn.Module):
    def __init__(self, config):
        super().__init__(); 
        self.gate =  nn.Linear(config.hidden_size, config.num_experts, False)
        self.experts = nn.ModuleList([Expert(config) for _ in range(config.num_experts)])
        self.k = config.num_experts_per_tok
        
    def forward(self, x):
        bsz, seqlen, dim = x.shape; 
        x_flat = x.view(-1, dim)
        router_logits = self.gate(x_flat)
        routing_weights, selected_experts = torch.topk(F.softmax(router_logits, 1, dtype=torch.float), self.k, -1)
        routing_weights = (routing_weights / routing_weights.sum(-1, True)).to(x.dtype)
        final_output = torch.zeros_like(x_flat)
        
        for i in range(self.k):
            for j in range(len(self.experts)):
                token_mask = selected_experts[:, i] == j
                if token_mask.any(): final_output[token_mask] += self.experts[j](x_flat[token_mask]) * routing_weights[token_mask, i].unsqueeze(-1)
        return final_output.view(bsz, seqlen, dim), router_logits.view(bsz, seqlen, -1)


print("MoE components defined")

MoE components defined


## Final Minimax Blocks

In [9]:
class MiniMaxBlock(nn.Module):
    def __init__(self, config, is_softmax):
        super().__init__()
        self.attention = SoftmaxAttention(config) if is_softmax else LightningAttention(config)
        self.moe_ffn = MoE(config)
        self.attention_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.deepnorm_alpha = config.deepnorm_alpha
        
    def forward(self, x, mask, output_attentions=False):
        attn_out, attn_weights = self.attention(self.attention_norm(x), mask, output_attentions)
        h = x + self.deepnorm_alpha * attn_out
        ffn_out, router_logits = self.moe_ffn(self.ffn_norm(h))
        return h + self.deepnorm_alpha * ffn_out, router_logits, attn_weights



class MiniMaxText01ForCausalLM(nn.Module):
    def __init__(self, config: MiniMaxConfig):
        super().__init__()
        self.config = config
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([MiniMaxBlock(config, (i + 1) % config.softmax_attention_period == 0) for i in range(config.num_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.tok_embeddings.weight = self.lm_head.weight

    def forward(self, input_ids, labels=None, output_attentions=False, **kwargs):
        bsz, seqlen = input_ids.shape
        h = self.tok_embeddings(input_ids)
        mask = torch.triu(torch.full((1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device), 1).type_as(h)
        all_router_logits, all_attns = [], []
        
        for layer in self.layers:
            h, logits, attns = layer(h, mask, output_attentions)
            all_router_logits.append(logits)
            if output_attentions:
                all_attns.append(attns)
        
        logits = self.lm_head(self.norm(h))
        loss = None
        
        if labels is not None:
            loss = F.cross_entropy(logits[..., :-1, :].reshape(-1, self.config.vocab_size), labels[..., 1:].reshape(-1))
            loss += self.config.router_aux_loss_coef * load_balancing_loss_func(all_router_logits, self.config.num_experts)
        
        return {"loss": loss, "logits": logits, "attentions": all_attns if output_attentions else None}
        

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}
        

    @torch.no_grad()
    def generate(self, input_ids: torch.Tensor, max_length: int, eos_token_id: int, temperature: float = 1.0, top_k: int = 50):
        self.eval()
        for _ in range(max_length - input_ids.shape[1]):
            model_inputs = self.prepare_inputs_for_generation(input_ids)
            
            outputs = self.forward(**model_inputs)
            logits = outputs['logits']
            
            next_token_logits = logits[:, -1, :]
            
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature

            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
            
            top_k_probs = F.softmax(top_k_logits, dim=-1)
            next_token_index = torch.multinomial(top_k_probs, num_samples=1)
            next_token = torch.gather(top_k_indices, -1, next_token_index)

            input_ids = torch.cat([input_ids, next_token], dim=-1)

            if next_token.item() == eos_token_id:
                break
                
        return input_ids


print("MiniMax-01 model defined")

MiniMax-01 model defined


## General Utility Functions

In [10]:
def save_checkpoint(epoch, step, model, optimizer, scheduler, loss_history, path):
    state = {'epoch': epoch, 
             'step': step, 
             'model_state_dict': model.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(), 
             'scheduler_state_dict': scheduler.state_dict(),
             'loss_history': loss_history}
    
    torch.save(state, os.path.join(path, "checkpoint.pth"))
    print(f"Checkpoint saved at Epoch {epoch+1}, Step {step+1}")
    


def load_checkpoint(model, optimizer, scheduler, path):
    start_epoch, start_step, loss_history = 0, 0, []
    ckpt_path = os.path.join(path, "checkpoint.pth")
    
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        start_epoch, start_step, loss_history = ckpt['epoch'], ckpt['step'] + 1, ckpt['loss_history']
        
        print(f"Checkpoint found. Resuming training from Epoch {start_epoch+1}, Step {start_step}.")
        
    else:
        print("No checkpoint found. Starting training from scratch.")
        
    return start_epoch, start_step, loss_history



def plot_loss_curve(loss_history, save_path):
    plt.figure(figsize=(12, 6))
    plt.plot(loss_history, label='Training Loss')
    
    plt.title('Training Loss Curve')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.grid(True)
    plt.savefig(os.path.join(save_path, "loss_curve.png"))
    plt.show()
    
    print(f"Loss curve saved to {save_path}")
    

print("Helper functions defined")

Helper functions defined


## Training 

In [None]:
def train_llm():
    BATCH_SIZE = 16
    NUM_EPOCHS = 3
    LEARNING_RATE = 3e-5
    
    PRINT_INTERVAL, CHECKPOINT_INTERVAL = 50, 200

    
    print(f"--- Loading pre-processed data from '{DATA_PATH}' ---")
    try:
        train_ds = load_from_disk(os.path.join(DATA_PATH, "train"))
        eval_ds = load_from_disk(os.path.join(DATA_PATH, "validation"))
        tokenizer = AutoTokenizer.from_pretrained(DATA_PATH)
        print("Data loaded successfully.")
        
    except FileNotFoundError:
        print(f"Error: Processed data not found. Please run the Data Preparation cell first.")
        return

    
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collator)
    eval_dataloader = DataLoader(eval_ds, batch_size=BATCH_SIZE, collate_fn=data_collator)

    
    print(f"\nInitializing model on {DEVICE}...")
    config = MiniMaxConfig(len(tokenizer))
    model_no_dp = MiniMaxText01ForCausalLM(config)

    
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs via DataParallel!")
        model = nn.DataParallel(model_no_dp)
    else:
        model = model_no_dp

    
    model.to(DEVICE)
    model_size = sum(p.numel() for p in model_no_dp.parameters()) / 1e6
    print(f"Model Initialized. Total parameters: {model_size:.2f}M")

    
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
    num_training_steps = NUM_EPOCHS * len(train_dataloader)
    lr_scheduler = get_scheduler("linear", optimizer, 0, num_training_steps)

    start_epoch, resume_step, loss_history = load_checkpoint(model_no_dp, optimizer, lr_scheduler, INPUT_CHECKPOINT_PATH)

    
    print(f"\n--- Starting Training ---")
    model.train()

    
    for epoch in range(start_epoch, NUM_EPOCHS):
        print(f"\n--- Starting Epoch {epoch + 1}/{NUM_EPOCHS} ---")
        if resume_step >= len(train_dataloader):
            print("Already completed this epoch. Skipping.") 
            resume_step = 0
            continue
        if resume_step > 0: 
            print(f"Resuming from step {resume_step}...")

            
        data_iterator = iter(train_dataloader)
        for _ in range(resume_step):
             try:
                next(data_iterator)
             except StopIteration:
                break

        
        for step, batch in enumerate(data_iterator, start=resume_step):
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            
            _, loss_outputs, _ = model(**batch)
            loss = loss_outputs.mean()
            
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            loss_history.append(loss.item())
            
            if (step + 1) % PRINT_INTERVAL == 0:
                print(f"Epoch {epoch+1}, Step {step+1}/{len(train_dataloader)}, Loss: {loss.item():.4f}")
            
            if (step + 1) % CHECKPOINT_INTERVAL == 0:
                save_checkpoint(epoch, step, model_no_dp, optimizer, lr_scheduler, loss_history, CHECKPOINT_PATH)
        
        resume_step = 0

    
    print("\n ---Training Complete ---")
    save_checkpoint(NUM_EPOCHS - 1, len(train_dataloader) - 1, model_no_dp, optimizer, lr_scheduler, loss_history, CHECKPOINT_PATH)
    plot_loss_curve(loss_history, RESULTS_PATH)
    
    print("\n --- Starting Final Evaluation for Perplexity ---")
    model.eval()
    perplexity_metric = evaluate.load("perplexity", module_type="metric")

    
    for batch in eval_dataloader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with torch.no_grad():
            logits, _, _ = model(**batch)
        perplexity_metric.add_batch(predictions=logits, references=batch["input_ids"])
    
    results = perplexity_metric.compute(model_id='minimax_replica_kaggle')

    
    print(f"\n--- Final Analysis & Results ---")
    model_size = sum(p.numel() for p in model_no_dp.parameters()) / 1e6
    print(f"Model Size: {model_size:.2f}M parameters")
    print(f"Final Validation Perplexity: {results['mean_perplexity']:.2f} (Lower is better)")
    
    print("\n--- Model trained ---")

In [None]:
if __name__ == "__main__":
    train_llm()

## Inference

In [11]:
def generate_base_completion(prompt, model, tokenizer, max_length=100):
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

    print(f"\n--- Prompt: ---\n'{prompt}'")
    print("\n--- Generating Completion... ---")
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            max_length=max_length,
            eos_token_id=tokenizer.eos_token_id,
            temperature=0.7,
            top_k=50
        )
        
    completion = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print(f"\n--- Base Model Completion: ---\n{completion}")
    print("-" * 40)

In [14]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer


print("--- Loading Tokenizer and Base Model ---")
try:
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
except Exception:
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


config = MiniMaxConfig(vocab_size=len(tokenizer))
model = MiniMaxText01ForCausalLM(config) 


checkpoint_file = os.path.join(BASE_MODEL_INPUT_PATH, "checkpoint.pth")
if os.path.exists(checkpoint_file):
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("✅ Base model weights loaded successfully.")
else:
    raise FileNotFoundError(f"Base model checkpoint not found at '{checkpoint_file}'.")


model.to(DEVICE)
model.eval()


prompt1 = "The history of the Roman Empire is a fascinating subject, particularly the transition from the Republic to the Principate, which began"
prompt2 = "Artificial intelligence is a field of computer science that focuses on creating systems capable of"
prompt3 = "The poem began with the line, 'Once upon a time, in a land filled with towering mountains and'"

generate_base_completion(prompt1, model, tokenizer)
generate_base_completion(prompt2, model, tokenizer)
generate_base_completion(prompt3, model, tokenizer)

--- Loading Tokenizer and Base Model ---
✅ Base model weights loaded successfully.

--- Prompt: ---
'The history of the Roman Empire is a fascinating subject, particularly the transition from the Republic to the Principate, which began'

--- Generating Completion... ---

--- Base Model Completion: ---
The history of the Roman Empire is a fascinating subject, particularly the transition from the Republic to the Principate, which began with the Roman @-@ Roman War ( AD 8 – 14 ) . 
 = = = Early history = = = 
 The Roman Empire 's history has a profound influence on the history of the Roman Empire . In the beginning of the Roman Empire , the empire was controlled by the Roman Empire . Roman society has been called the Roman Empire . A Roman Empire developed
----------------------------------------

--- Prompt: ---
'Artificial intelligence is a field of computer science that focuses on creating systems capable of'

--- Generating Completion... ---

--- Base Model Completion: ---
Artificial 

# Instruct-Train (SFT)

## Base Model definition - must run

In [21]:
print("--- Loading Tokenizer ---")
try:
    tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/minimax-processed-data/processed_data/")
except Exception:
    print("Pre-saved tokenizer not found. Loading 'gpt2' as a fallback.")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


print("\n--- Loading Pre-trained Base Model ---")
config = MiniMaxConfig(vocab_size=len(tokenizer))
base_model = MiniMaxText01ForCausalLM(config)


checkpoint_file = os.path.join(BASE_MODEL_INPUT_PATH, "checkpoint.pth")
if os.path.exists(checkpoint_file):
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    base_model.load_state_dict(checkpoint['model_state_dict'])
    print(f" Pre-trained base model weights loaded successfully ")
else:
    raise FileNotFoundError(f"Base model checkpoint not found at '{checkpoint_file}'.")


--- Loading Tokenizer ---
Pre-saved tokenizer not found. Loading 'gpt2' as a fallback.

--- Loading Pre-trained Base Model ---
 Pre-trained base model weights loaded successfully 


## LoRA Configuration

In [None]:
for param in base_model.parameters():
  param.requires_grad = False

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)


print("\nApplying LoRA adapters to the model...")
sft_model = get_peft_model(base_model, lora_config)
sft_model = sft_model.to(DEVICE)


print("\n--- LoRA Model Summary ---")
sft_model.print_trainable_parameters()

## Data Pre-processing

In [None]:
def format_prompt(sample):
    prompt = ""
    for message in sample['messages']:
        if message['role'] == 'user':
            prompt += f"### HUMAN:\n{message['content']}\n\n"
        else:
            prompt += f"### ASSISTANT:\n{message['content']}\n\n"
    return prompt


def prepare_sft_dataset(tokenizer):
    print(f"\n--- Preparing SFT dataset: ultrachat_200k ---")
    dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
    dataset = dataset.select(range(200000))
    original_columns = dataset.column_names


    def tokenize_and_format(element):
        formatted_prompt = format_prompt(element)
        
        outputs = tokenizer(
            formatted_prompt,
            truncation=True,
            padding="max_length", 
            max_length=512,
            add_special_tokens=False,
        )
        
        outputs["labels"] = outputs["input_ids"].copy()
        return outputs

    print("Tokenizing and formatting dataset...")
    tokenized_dataset = dataset.map(
        tokenize_and_format,
        batched=False, 
        remove_columns=original_columns,
    )

    
    tokenized_dataset.set_format(
        type="torch", columns=["input_ids", "attention_mask", "labels"]
    )
    print(f"SFT dataset prepared. Using {len(tokenized_dataset)} samples.")
    return tokenized_dataset

In [None]:
sft_dataset = prepare_sft_dataset(tokenizer)
sft_dataloader = DataLoader(
    sft_dataset,
    batch_size=4,
    shuffle=True
)

print("DataLoaders for SFT defined")

## Training

In [None]:

NUM_SFT_EPOCHS = 1
SFT_LEARNING_RATE = 2e-4
SFT_CHECKPOINT_INTERVAL = 500


optimizer = AdamW(filter(lambda p: p.requires_grad, sft_model.parameters()), lr=SFT_LEARNING_RATE)
num_training_steps = NUM_SFT_EPOCHS * len(sft_dataloader)
lr_scheduler = get_scheduler("linear", optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)


print("\n--- Starting Supervised Fine-Tuning with LoRA ---")
sft_model.train()

loss_history = []

for epoch in range(NUM_SFT_EPOCHS):
    progress_bar = tqdm(sft_dataloader, desc=f"SFT Epoch {epoch+1}")
    
    for step, batch in enumerate(progress_bar):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        
        outputs = sft_model(**batch)
        loss = outputs['loss']
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        loss_history.append(loss.item())
        
        progress_bar.set_postfix({"loss": loss.item()})

        if (step + 1) % SFT_CHECKPOINT_INTERVAL == 0:
            checkpoint_save_path = os.path.join(SFT_LORA_CHECKPOINT_PATH, f"step_{step+1}")
            print(f"\nSaving LoRA adapter checkpoint to {checkpoint_save_path}...")
            sft_model.save_pretrained(checkpoint_save_path)


final_save_path = os.path.join(SFT_LORA_CHECKPOINT_PATH, "final")
print(f"\n--- SFT Complete. Saving final LoRA adapter weights to {final_save_path} ---")
sft_model.save_pretrained(final_save_path)


print("\n--- Generating SFT Loss Curve ---")

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
plt.plot(loss_history, label='SFT Training Loss')
plt.title('SFT LoRA Fine-Tuning Loss Curve')

plt.xlabel('Steps')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.savefig(os.path.join(KAGGLE_WORKING_PATH, "sft_loss_curve.png"))
plt.show()

print("SFT workflow finished and loss curve saved.") 

## Inference

In [15]:
def generate_response(prompt, model, tokenizer, device, max_length=150):
    formatted_prompt = f"### HUMAN:\n{prompt}\n\n### ASSISTANT:\n"
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
    print(f"\n--- Prompt: ---\n{prompt}")
    print("\n--- Generating Response... ---")
    
    outputs = model.generate(
        input_ids=inputs['input_ids'],
        max_length=max_length,
        eos_token_id=tokenizer.eos_token_id,
        temperature=0.7,
        top_k=50
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    assistant_response = response.split("### ASSISTANT:\n")[-1]
    print(f"\n--- Model Response: ---\n{assistant_response}")
    print("-" * 30)

In [16]:
LORA_ADAPTER_PATH = "/kaggle/input/beens-lora/pytorch/default/1/lora_adapter/"

In [22]:
print(f"\n--- Loading LoRA adapters from: {LORA_ADAPTER_PATH} ---")
model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH)
model = model.merge_and_unload()
model.to(DEVICE)
model.eval()
print("LoRA adapters merged into the base model.")

generate_response("Hello! What can you do?", model, tokenizer, DEVICE)
generate_response("What is the capital of France?", model, tokenizer, DEVICE)
generate_response("Write a short, three-line poem about a cat.", model, tokenizer, DEVICE)


--- Loading LoRA adapters from: /kaggle/input/beens-lora/pytorch/default/1/lora_adapter/ ---
LoRA adapters merged into the base model.

--- Prompt: ---
Hello! What can you do?

--- Generating Response... ---

--- Model Response: ---
1. You might have the same money as you could have. You may have the same money as you are spending money to you. You may have the same money to have the same money or have the same money to have the same work.

2. You may have the same money or may have the same money or may have the same money or may have the same money or can have the same money or is often used. You may have the same money or may have the same money or may have the same money or may have the same money or may have a specific role in your contract.

3. You may have the
------------------------------

--- Prompt: ---
What is the capital of France?

--- Generating Response... ---

--- Model Response: ---
The economic importance of the country is in the northern hemisphere and has the econ