In [None]:
import os
from dotenv import load_dotenv
from huggingface_hub import HfApi, create_repo
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from dataclasses import dataclass
from torch.amp import autocast, GradScaler
import torch.utils.checkpoint

try:
    import datasets
    import transformers
except ImportError:
    print("Installing dependencies...")
    import subprocess
    subprocess.check_call(["pip", "install", "-q", "datasets", "transformers", "accelerate"])
    from datasets import load_dataset, concatenate_datasets
    from transformers import AutoTokenizer
else:
    from datasets import load_dataset, concatenate_datasets
    from transformers import AutoTokenizer


#load_dotenv()
HF_TOKEN = ""
REPO_NAME = "FusionCorp/gemma-zero"

if not HF_TOKEN or not REPO_NAME:
    raise ValueError("Error: HF_TOKEN or REPO_NAME not found in .env file.")


api = HfApi(token=HF_TOKEN)
try:
    create_repo(repo_id=REPO_NAME, repo_type="model", token=HF_TOKEN, exist_ok=True)
    print(f"Connected to Hugging Face Repo: {REPO_NAME}")
except Exception as e:
    print(f" Repo Connection Failed: {e}")


@dataclass
class GemmaZeroConfig:
    vocab_size: int = 50257      # Must be 50257 for GPT2
    hidden_size: int = 768    
    intermediate_size: int = 3072
    num_hidden_layers: int = 6
    num_attention_heads: int = 12
    num_key_value_heads: int = 4 
    head_dim: int = 64
    max_position_embeddings: int = 1024
    rms_norm_eps: float = 1e-6
    rope_theta: float = 10000.0
    attn_logit_softcapping: float = 50.0
    final_logit_softcapping: float = 30.0


Connected to Hugging Face Repo: FusionCorp/gemma-zero


In [6]:
class GemmaRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        x_float = x.float()
        variance = x_float.pow(2).mean(-1, keepdim=True)
        x_float = x_float * torch.rsqrt(variance + self.eps)
        return (x_float * self.weight.float()).type_as(x) + 1.0

class GemmaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
    def forward(self, x, seq_len=None):
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()

def apply_rotary_pos_emb(q, k, cos, sin):
    def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., :x.shape[-1] // 2]), dim=-1)
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class GemmaAttention(nn.Module):
    def __init__(self, config: GemmaZeroConfig):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.hidden_size = config.hidden_size
        
        self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
        self.rotary_emb = GemmaRotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)

        # GEMMA 3: QK-Norm (RMSNorm on Queries and Keys)
        # This stabilizes training and allows us to use Flash Attention
        self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(self, hidden_states, attention_mask=None):
        bsz, q_len, _ = hidden_states.size()
        
        # 1. Projections
        q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
        k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
        v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)

        # 2. QK Norm (Gemma 3 feature)
        q = self.q_norm(q)
        k = self.k_norm(k)

        # 3. RoPE
        # Transpose for RoPE: (bsz, heads, seq, dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        cos, sin = self.rotary_emb(v, seq_len=q_len)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # 4. GQA Expansion
        # (Expand K and V to match Q heads for calculation)
        k = k.repeat_interleave(self.num_key_value_groups, dim=1)
        v = v.repeat_interleave(self.num_key_value_groups, dim=1)

        # 5. Flash Attention (Much faster, less memory)
        # We drop manual soft-capping here to enable Flash Attention. 
        # QK-Norm handles the stability role of soft-capping.
        attn_output = F.scaled_dot_product_attention(
            q, k, v, 
            attn_mask=None, # Flash attn handles causal mask internally if is_causal=True
            dropout_p=0.0, 
            is_causal=True
        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
        return self.o_proj(attn_output)

class GemmaBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.self_attn = GemmaAttention(config)
        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size, bias=False), 
            nn.Linear(config.hidden_size, config.intermediate_size, bias=False), 
            nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 
        )
        self.mlp_gate = self.mlp[0]; self.mlp_up = self.mlp[1]; self.mlp_down = self.mlp[2]

    def forward(self, x, mask=None):
        r = x; x = self.input_layernorm(x); x = self.self_attn(x, attention_mask=mask); x = r + x
        r = x; x = self.post_attention_layernorm(x)
        gate, val = self.mlp_gate(x), self.mlp_up(x)
        x = self.mlp_down(F.gelu(gate) * val)
        return r + x

class GemmaZeroModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([GemmaBlock(config) for _ in range(config.num_hidden_layers)])
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.embed_scale = math.sqrt(config.hidden_size)
        self.gradient_checkpointing = False

    def gradient_checkpointing_enable(self): self.gradient_checkpointing = True

    def forward(self, input_ids):
        x = self.embed_tokens(input_ids) * self.embed_scale
        for layer in self.layers:
            if self.gradient_checkpointing and self.training:
                # This line saves ~10GB of VRAM by re-calculating 
                # activations during the backward pass.
                x = torch.utils.checkpoint.checkpoint(
                    layer, 
                    x, 
                    None,
                    use_reentrant=False
                )
            else:
                x = layer(x)
        x = self.norm(x)
        logits = torch.matmul(x, self.embed_tokens.weight.t())
        if self.config.final_logit_softcapping:
             logits = torch.tanh(logits / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
             
        return logits

In [7]:
from torch.utils.data import IterableDataset, DataLoader

class TinyStoriesDataset(IterableDataset):
    def __init__(self, seq_len=2048):
        self.seq_len = seq_len
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        # Load dataset
        self.dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)

    def __iter__(self):
        iterator = iter(self.dataset)
        for item in iterator:
            if len(item['text']) < 50: continue
            
            # Simple truncation/padding
            tokens = self.tokenizer(
                item['text'], 
                max_length=self.seq_len, 
                truncation=True, 
                padding="max_length",
                return_tensors="pt"
            )
            yield tokens.input_ids.squeeze(0)

def get_tinystories_loader(batch_size=4, seq_len=2048):
    ds = TinyStoriesDataset(seq_len=seq_len)
    # num_workers=2 runs tokenization in parallel background processes
    return DataLoader(ds, batch_size=batch_size, num_workers=2, pin_memory=True)


In [None]:
!pip install bitsandbytes

from transformers import get_linear_schedule_with_warmup 
import bitsandbytes as bnb

def train():
    device = "cuda"
    torch.cuda.empty_cache()
    
    config = GemmaZeroConfig()
    model = GemmaZeroModel(config).to(device)
    model.gradient_checkpointing_enable()
    
    # Optional Speedup
    try: model = torch.compile(model)
    except: pass

    # Optimizer (stlightly higher Lr)
    optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=6e-4, weight_decay=0.01)
    scaler = GradScaler()

    # Hyperparams (Adjusted for Flash Attention efficiency)
    BATCH_SIZE = 8    
    ACCUM_STEPS = 4     # effective batch size is still 32
    SEQ_LEN = 1024       
    TOTAL_STEPS = 10000  # more steps because the model traints faster
    
    scheduler = get_linear_schedule_with_warmup(optimizer, 100, TOTAL_STEPS)
    dataloader = DataLoader(TinyStoriesDataset(SEQ_LEN), batch_size=BATCH_SIZE, num_workers=0)
    data_iter = iter(dataloader)
    
    pad_token_id = 50256
    model.train()
    optimizer.zero_grad(set_to_none=True)
    
    print(f"Training ...")

    for step in range(TOTAL_STEPS):
        # Optimized Data Fetching
        try: inputs = next(data_iter).to(device)
        except StopIteration: data_iter = iter(dataloader); inputs = next(data_iter).to(device)
        
        labels = inputs.clone()
        labels[labels == pad_token_id] = -100

        with autocast(device_type='cuda', dtype=torch.float16):
            logits = model(inputs)
            loss = F.cross_entropy(
                logits[..., :-1, :].contiguous().view(-1, config.vocab_size),
                labels[..., 1:].contiguous().view(-1),
                ignore_index=-100
            )
            loss = loss / ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (step + 1) % ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            
            if (step + 1) % 50 == 0:
                 print(f"Step {step+1} | Loss: {loss.item() * ACCUM_STEPS:.4f}")

        # Upload Logic
        if (step + 1) % 1000 == 0 and HF_TOKEN and HF_TOKEN != "hf_...":
            print(f"☁️ Uploading to Gemma-zero...")
            torch.save(model.state_dict(), "pytorch_model.bin")
            try:
                api.upload_file(path_or_fileobj="pytorch_model.bin", path_in_repo=f"checkpoint-{step+1}/pytorch_model.bin", repo_id=REPO_NAME, repo_type="model")
                print("✅ Upload Success!")
            except Exception as e: print(f"❌ Upload Failed: {e}")

if __name__ == "__main__":
    train()





Training ...
