In [9]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from dataclasses import dataclass
from torch.amp import autocast, GradScaler
import torch.utils.checkpoint

try:
    import datasets
    import transformers
except ImportError:
    print("Installing dependencies...")
    import subprocess
    subprocess.check_call(["pip", "install", "-q", "datasets", "transformers", "accelerate"])
    from datasets import load_dataset, concatenate_datasets
    from transformers import AutoTokenizer
else:
    from datasets import load_dataset, concatenate_datasets
    from transformers import AutoTokenizer


@dataclass
class GemmaZeroConfig:
    vocab_size: int = 32000      
    hidden_size: int = 1024      
    intermediate_size: int = 4096 
    num_hidden_layers: int = 24  
    num_attention_heads: int = 16 
    num_key_value_heads: int = 4 
    head_dim: int = 64
    max_position_embeddings: int = 2048 
    rms_norm_eps: float = 1e-6
    rope_theta: float = 10000.0
    attn_logit_softcapping: float = 50.0
    final_logit_softcapping: float = 30.0


In [10]:
class GemmaRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        x_float = x.float()
        variance = x_float.pow(2).mean(-1, keepdim=True)
        x_float = x_float * torch.rsqrt(variance + self.eps)
        return (x_float * self.weight.float()).type_as(x) + 1.0

class GemmaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
    def forward(self, x, seq_len=None):
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()

def apply_rotary_pos_emb(q, k, cos, sin):
    def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., :x.shape[-1] // 2]), dim=-1)
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class GemmaAttention(nn.Module):
    def __init__(self, config: GemmaZeroConfig):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        
        self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
        self.rotary_emb = GemmaRotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)

    def forward(self, hidden_states, attention_mask=None):
        bsz, q_len, _ = hidden_states.size()
        q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        
        cos, sin = self.rotary_emb(v, seq_len=q_len)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        k, v = k.repeat_interleave(self.num_key_value_groups, dim=1), v.repeat_interleave(self.num_key_value_groups, dim=1)
        
        attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        if self.config.attn_logit_softcapping:
            attn_weights = torch.tanh(attn_weights / self.config.attn_logit_softcapping) * self.config.attn_logit_softcapping
        if attention_mask is not None: attn_weights = attn_weights + attention_mask
        
        attn_output = torch.matmul(F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype), v)
        return self.o_proj(attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1))

class GemmaBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.self_attn = GemmaAttention(config)
        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size, bias=False), 
            nn.Linear(config.hidden_size, config.intermediate_size, bias=False), 
            nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 
        )
        self.mlp_gate = self.mlp[0]; self.mlp_up = self.mlp[1]; self.mlp_down = self.mlp[2]

    def forward(self, x, mask=None):
        r = x; x = self.input_layernorm(x); x = self.self_attn(x, attention_mask=mask); x = r + x
        r = x; x = self.post_attention_layernorm(x)
        gate, val = self.mlp_gate(x), self.mlp_up(x)
        x = self.mlp_down(F.gelu(gate) * val)
        return r + x

class GemmaZeroModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([GemmaBlock(config) for _ in range(config.num_hidden_layers)])
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.embed_scale = math.sqrt(config.hidden_size)
        self.gradient_checkpointing = False

    def gradient_checkpointing_enable(self): self.gradient_checkpointing = True

    def forward(self, input_ids):
        x = self.embed_tokens(input_ids) * self.embed_scale
        for layer in self.layers:
            if self.gradient_checkpointing and self.training:
                x = torch.utils.checkpoint.checkpoint(layer, x, None, use_reentrant=False)
            else: x = layer(x)
        logits = torch.matmul(self.norm(x), self.embed_tokens.weight.t())
        if self.config.final_logit_softcapping:
             logits = torch.tanh(logits / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
        return logits

In [11]:
def get_fitness_savant_dataloader(batch_size=4, seq_len=2048):
    print("Loading datasets...")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    def get_col(row, candidates, default=""):
        for col in candidates:
            if col in row and row[col]:
                return str(row[col])
        return default

    try:
        ds_gym = load_dataset("onurSakar/GYM-Exercise", split="train")
        def fmt_gym(x):
            title = get_col(x, ['Title', 'title', 'Exercise', 'Exercise Name', 'instruction'])
            desc = get_col(x, ['Desc', 'desc', 'Description', 'context', 'output'])
            return {"text": f"<|user|>\nHow do I do the {title} exercise?\n<|model|>\n{desc}<|endoftext|>"}
        ds_gym = ds_gym.map(fmt_gym, remove_columns=ds_gym.column_names)
        print(f"Loaded Gym Exercises: {len(ds_gym)}")
    except Exception as e:
        print(f"Error loading Gym Exercises: {e}")
        ds_gym = None

    try:
        ds_diet = load_dataset("issai/LLM_for_Dietary_Recommendation_System", split="train")
        def fmt_diet(x):
            profile = get_col(x, ['Profile', 'profile', 'input'])
            rec = get_col(x, ['Recommendation', 'recommendation', 'output'])
            return {"text": f"<|user|>\nCreate a diet plan for this profile:\n{profile}\n<|model|>\n{rec}<|endoftext|>"}
        ds_diet = ds_diet.map(fmt_diet, remove_columns=ds_diet.column_names)
        print(f"Loaded Diet Plans: {len(ds_diet)}")
    except Exception as e:
        print(f"Error loading Diet Plans: {e}")
        ds_diet = None

    try:
        ds_qa = load_dataset("kishkath/fitness-qa", split="train")
        def fmt_qa(x):
            q = get_col(x, ['instruction', 'question', 'input', 'Question'])
            a = get_col(x, ['output', 'answer', 'Answer'])
            return {"text": f"<|user|>\n{q}\n<|model|>\n{a}<|endoftext|>"}
        ds_qa = ds_qa.map(fmt_qa, remove_columns=ds_qa.column_names)
        print(f"Loaded Fitness QA: {len(ds_qa)}")
    except Exception as e:
        print(f"Error loading Fitness QA: {e}")
        ds_qa = None

    try:
        ds_chat = load_dataset("chibbss/fitness-chat-prompt-completion-dataset", split="train")
        def fmt_chat(x): 
            q = get_col(x, ['instruction', 'prompt'])
            a = get_col(x, ['output', 'completion'])
            return {"text": f"<|user|>\n{q}\n<|model|>\n{a}<|endoftext|>"}
        ds_chat = ds_chat.map(fmt_chat, remove_columns=ds_chat.column_names)
        print(f"Loaded Chat Data: {len(ds_chat)}")
    except Exception as e:
        ds_chat = None

    valid_datasets = [d for d in [ds_gym, ds_diet, ds_qa, ds_chat] if d is not None]
    
    if not valid_datasets:
        print("Dataset load failure. Using TinyStories fallback.")
        dataset = load_dataset("roneneldan/TinyStories", split="train")
    else:
        dataset = concatenate_datasets(valid_datasets).shuffle(seed=42)
    
    print(f"Total Combined Examples: {len(dataset)}")

    def data_generator():
        while True:
            for item in dataset:
                try:
                    tokens = tokenizer.encode(item['text'], max_length=seq_len, truncation=True, padding="max_length")
                    yield torch.tensor(tokens)
                except:
                    continue
    
    gen = data_generator()
    def get_batch():
        batch = [next(gen) for _ in range(batch_size)]
        return torch.stack(batch).cuda()
    
    return get_batch


In [None]:


def train():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Running on: {device}")
    
    config = GemmaZeroConfig()
    model = GemmaZeroModel(config).to(device)
    model.gradient_checkpointing_enable() 
    
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
    #scaler = GradScaler('cuda')
    scaler = GradScaler(device='cuda')

    print(f"Model Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
    BATCH_SIZE = 4 
    ACCUM_STEPS = 8 
    SEQ_LEN = 2048 
    TOTAL_STEPS = 500 

    get_batch = get_fitness_savant_dataloader(BATCH_SIZE, SEQ_LEN)
    
    model.train()
    optimizer.zero_grad()
    
    print("Starting Training (Fitness Savant Mode)...")
    
    for step in range(TOTAL_STEPS):
        inputs = get_batch()
        labels = inputs.clone()
        
        with autocast(device_type='cuda', dtype=torch.bfloat16):

            logits = model(inputs)
            loss = F.cross_entropy(logits.view(-1, config.vocab_size), labels.view(-1))
            loss = loss / ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (step + 1) % ACCUM_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
            mem = torch.cuda.memory_allocated() / 1e9
            print(f"Step {step+1} | Loss: {loss.item() * ACCUM_STEPS:.4f} | VRAM: {mem:.2f} GB")

if __name__ == "__main__":
    train()

Running on: cuda
Model Parameters: 397.72M
Loading datasets...
