In [None]:
import os
from dotenv import load_dotenv
from huggingface_hub import HfApi, create_repo
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from dataclasses import dataclass
from torch.amp import autocast, GradScaler
import torch.utils.checkpoint

try:
    import datasets
    import transformers
except ImportError:
    print("Installing dependencies...")
    import subprocess
    subprocess.check_call(["pip", "install", "-q", "datasets", "transformers", "accelerate"])
    from datasets import load_dataset, concatenate_datasets
    from transformers import AutoTokenizer
else:
    from datasets import load_dataset, concatenate_datasets
    from transformers import AutoTokenizer


#load_dotenv()
HF_TOKEN = ""
REPO_NAME = ""

if not HF_TOKEN or not REPO_NAME:
    raise ValueError("Error: HF_TOKEN or REPO_NAME not found in .env file.")


api = HfApi(token=HF_TOKEN)
try:
    create_repo(repo_id=REPO_NAME, repo_type="model", token=HF_TOKEN, exist_ok=True)
    print(f"Connected to Hugging Face Repo: {REPO_NAME}")
except Exception as e:
    print(f" Repo Connection Failed: {e}")


@dataclass
class GemmaZeroConfig:
    vocab_size: int = 50257      # Must be 50257 for GPT2
    hidden_size: int = 1024      
    intermediate_size: int = 4096 
    num_hidden_layers: int = 24  
    num_attention_heads: int = 16 
    num_key_value_heads: int = 4 
    head_dim: int = 64
    max_position_embeddings: int = 2048 
    rms_norm_eps: float = 1e-6
    rope_theta: float = 10000.0
    attn_logit_softcapping: float = 50.0
    final_logit_softcapping: float = 30.0


ValueError: Error: HF_TOKEN or REPO_NAME not found in .env file.

In [None]:
class GemmaRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        x_float = x.float()
        variance = x_float.pow(2).mean(-1, keepdim=True)
        x_float = x_float * torch.rsqrt(variance + self.eps)
        return (x_float * self.weight.float()).type_as(x) + 1.0

class GemmaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
    def forward(self, x, seq_len=None):
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()

def apply_rotary_pos_emb(q, k, cos, sin):
    def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., :x.shape[-1] // 2]), dim=-1)
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class GemmaAttention(nn.Module):
    def __init__(self, config: GemmaZeroConfig):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        
        self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
        self.rotary_emb = GemmaRotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)

    def forward(self, hidden_states, attention_mask=None):
        bsz, q_len, _ = hidden_states.size()
        q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        
        cos, sin = self.rotary_emb(v, seq_len=q_len)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        k, v = k.repeat_interleave(self.num_key_value_groups, dim=1), v.repeat_interleave(self.num_key_value_groups, dim=1)
        
        attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        if self.config.attn_logit_softcapping:
            attn_weights = torch.tanh(attn_weights / self.config.attn_logit_softcapping) * self.config.attn_logit_softcapping
        if attention_mask is not None: attn_weights = attn_weights + attention_mask
        
        attn_output = torch.matmul(F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype), v)
        return self.o_proj(attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1))

class GemmaBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.self_attn = GemmaAttention(config)
        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size, bias=False), 
            nn.Linear(config.hidden_size, config.intermediate_size, bias=False), 
            nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 
        )
        self.mlp_gate = self.mlp[0]; self.mlp_up = self.mlp[1]; self.mlp_down = self.mlp[2]

    def forward(self, x, mask=None):
        r = x; x = self.input_layernorm(x); x = self.self_attn(x, attention_mask=mask); x = r + x
        r = x; x = self.post_attention_layernorm(x)
        gate, val = self.mlp_gate(x), self.mlp_up(x)
        x = self.mlp_down(F.gelu(gate) * val)
        return r + x

class GemmaZeroModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([GemmaBlock(config) for _ in range(config.num_hidden_layers)])
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.embed_scale = math.sqrt(config.hidden_size)
        self.gradient_checkpointing = False

    def gradient_checkpointing_enable(self): self.gradient_checkpointing = True

    def forward(self, input_ids):
        x = self.embed_tokens(input_ids) * self.embed_scale
        for layer in self.layers:
            if self.gradient_checkpointing and self.training:
                # This line saves ~10GB of VRAM by re-calculating 
                # activations during the backward pass.
                x = torch.utils.checkpoint.checkpoint(
                    layer, 
                    x, 
                    None,
                    use_reentrant=False
                )
            else:
                x = layer(x)
        x = self.norm(x)
        logits = torch.matmul(x, self.embed_tokens.weight.t())
        if self.config.final_logit_softcapping:
             logits = torch.tanh(logits / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
             
        return logits

In [None]:
def get_fitness_savant_dataloader(batch_size=4, seq_len=2048):
    print("Loading datasets...")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    def get_col(row, candidates):
        for col in candidates:
            if col in row and row[col]: return str(row[col])
        return ""

    datasets_list = []
    
    try:
        ds = load_dataset("onurSakar/GYM-Exercise", split="train")
        def f(x): return {"text": f"<|user|>\nHow do I do {get_col(x, ['Title', 'title'])}?\n<|model|>\n{get_col(x, ['Desc', 'desc'])}<|endoftext|>"}
        datasets_list.append(ds.map(f, remove_columns=ds.column_names))
    except: print("Skipping Gym")

    try:
        ds = load_dataset("issai/LLM_for_Dietary_Recommendation_System", split="train")
        def f(x): return {"text": f"<|user|>\nDiet plan for:\n{get_col(x, ['Profile', 'input'])}\n<|model|>\n{get_col(x, ['Recommendation', 'output'])}<|endoftext|>"}
        datasets_list.append(ds.map(f, remove_columns=ds.column_names))
    except: print("Skipping Diet")

    try:
        ds = load_dataset("chibbss/fitness-chat-prompt-completion-dataset", split="train")
        def f(x): return {"text": f"<|user|>\n{get_col(x, ['instruction'])}\n<|model|>\n{get_col(x, ['output'])}<|endoftext|>"}
        datasets_list.append(ds.map(f, remove_columns=ds.column_names))
    except: print("Skipping Chat")

    dataset = concatenate_datasets(datasets_list).shuffle(seed=42)
    
    def data_generator():
        while True:
            for item in dataset:
                yield torch.tensor(tokenizer.encode(item['text'], max_length=seq_len, truncation=True, padding="max_length"))
    
    gen = data_generator()
    return lambda: torch.stack([next(gen) for _ in range(batch_size)]).cuda()


In [None]:

!pip install -q bitsandbytes 

import bitsandbytes as bnb # Import the 8-bit optimizer

def train():
    device = "cuda"
    config = GemmaZeroConfig() 
    model = GemmaZeroModel(config).to(device)
    model.gradient_checkpointing_enable() 

    # 8-Bit Optimizer (Saves 2.4GB VRAM)
    optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=3e-4, weight_decay=0.01)

    scaler = GradScaler(device='cuda')

    # Hyperparameters for 16GB
    BATCH_SIZE = 2      # Lowered to 2 to prevent OOM
    ACCUM_STEPS = 16    # Effective Batch Size = 32 (2 * 16)
    SEQ_LEN = 2048
    TOTAL_STEPS = 5000
    
    get_batch = get_fitness_savant_dataloader(BATCH_SIZE, SEQ_LEN)
    
    model.train()
    optimizer.zero_grad(set_to_none=True) # Saves extra memory
    
    print(f" Training 400M Model | Brain Density: {config.num_hidden_layers} Layers")
    print(f"Effective Batch: {BATCH_SIZE * ACCUM_STEPS} | Optimizer: 8-bit AdamW")

    for step in range(TOTAL_STEPS):
        inputs = get_batch()
        
        # Safety Check
        if inputs.max() >= config.vocab_size:
            print(f"Error: Token {inputs.max()} exceeds vocab {config.vocab_size}")
            break
            
        labels = inputs.clone()
        
        with autocast(device_type='cuda', dtype=torch.float16): 
            logits = model(inputs)
            
            # Causal LM Shift
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # Cross Entropy Loss
            loss = F.cross_entropy(shift_logits.view(-1, config.vocab_size), shift_labels.view(-1))
            loss = loss / ACCUM_STEPS 
        
        scaler.scale(loss).backward()
        
        if (step + 1) % ACCUM_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            
            # Monitoring
            if (step + 1) % (ACCUM_STEPS * 2) == 0:
                mem = torch.cuda.max_memory_allocated() / 1e9
                print(f"Step {step+1} | Loss: {loss.item() * ACCUM_STEPS:.4f} | Peak VRAM: {mem:.2f} GB")
        
        if (step + 1) % 200 == 0:
                print(f"☁️ Uploading Checkpoint to {REPO_NAME}...")
                
                # 1. Save weights locally first
                local_file = "pytorch_model.bin"
                torch.save(model.state_dict(), local_file)
                
                # 2. Push to Hugging Face
                try:
                    api.upload_file(
                        path_or_fileobj=local_file,
                        path_in_repo="pytorch_model.bin", # Overwrites the file in the repo
                        repo_id=REPO_NAME,
                        repo_type="model"
                    )
                    print(f"Upload Success at Step {step+1}")
                except Exception as e:
                    print(f"Upload Failed: {e}")

if __name__ == "__main__":
    train()

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hLoading datasets...
 Training 400M Model | Brain Density: 24 Layers
Effective Batch: 32 | Optimizer: 8-bit AdamW
Step 32 | Loss: 7.8556 | Peak VRAM: 11.68 GB
Step 64 | Loss: 5.7824 | Peak VRAM: 11.68 GB
Step 96 | Loss: 1.7703 | Peak VRAM: 11.68 GB
Step 128 | Loss: 0.1469 | Peak VRAM: 11.68 GB
Step 160 | Loss: 0.1252 | Peak VRAM: 11.68 GB
Step 192 | Loss: 0.1168 | Peak VRAM: 11.68 GB
Step 224 | Loss: 0.0976 | Peak VRAM: 11.68 GB
Step 256 | Loss: 0.0876 | Peak VRAM: 11.68 GB
Step 288 | Loss: 0.0699 | Peak VRAM: 11.68 GB
Step 320 | Loss: 0.9397 | Peak VRAM: 11.68 GB
Step 352 | Loss: 0.0532 | Peak VRAM: 11.68 GB
Step 384 | Loss: 0.0502 | Peak VRAM: 11.68 GB
Step 416 | Loss: 0.0468 | Peak VRAM: 11.68 GB
Step 448 | Loss: 0.0434 | Peak VRAM: 11.68 GB
Step 480 | Loss: 0.0398 | Peak VRAM: 11.68 GB
Step 512 | Loss: 0.0362 | Peak VRAM: 11.68 GB
Step 544 |

: 