In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Gemma2 2B IT Hackathon Notebook

## üìå Overview
This notebook demonstrates fine-tuning and evaluation of the **Gemma2 2B Instruction-Tuned (IT)** model on Kaggle.  
It follows the hackathon requirements:
- Deterministic inference parameters
- Checkpoint saving during a 9-hour run
- Single-session evaluation
- Optional unrestricted mode with Kaggle Model upload

## ‚öôÔ∏è Environment Setup
Install required libraries:
```bash
%pip install -q transformers torch jax flax orbax-checkpoint kagglehub wandb


In [None]:
# Standard output tags
REASONING_START = "<reasoning>"
REASONING_END = "</reasoning>"
SOLUTION_START = "<answer>"
SOLUTION_END = "</answer>"

# Deterministic generation parameters
TEMPERATURE = 1e-4
TOP_K = 1
TOP_P = 1.0
MAX_GENERATION_STEPS = 768
SEED = 42

# Prompt template
PROMPT_TEMPLATE = "your awesome prompt with a placeholder {question}"

# Paths
CKPT_DIR = "/kaggle/working/ckpts"  # single-session checkpoints (actor/)
MODEL_ID = "google/gemma-2-2b-it"   # base model for single-session mode


# Authentication
Weights & Biases (W&B): Add your WANDB_API_KEY as a Kaggle Secret or environment variable.

python
import os
os.environ["WANDB_API_KEY"] = "your_wandb_api_key_here"
Hugging Face Hub: Request access to Gemma2 2B IT. Add your Hugging Face token:

python
os.environ["HUGGINGFACE_HUB_TOKEN"] = "your_hf_token_here"
Kaggle API (optional for dataset/model upload): Place your kaggle.json in ~/.kaggle/.

# üöÄ Usage
Load base model and tokenizer

python
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=os.environ["HUGGINGFACE_HUB_TOKEN"])
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, token=os.environ["HUGGINGFACE_HUB_TOKEN"])
Fine-tuning loop

Train with LoRA adapters.

Save checkpoints regularly:

python
save_actor_checkpoint(step, lora_params)
Evaluation

Load the latest checkpoint.

Run deterministic inference with:

Code
TEMPERATURE=1e-4, TOP_K=1, TOP_P=1.0, MAX_GENERATION_STEPS=768, SEED=42
Unrestricted Mode (optional)

Upload final Flax-format checkpoints to Kaggle Models.

Set unrestricted_kaggle_model = "username/model_name".

# üìÇ Project Structure
Code
/kaggle/working/
  ‚îú‚îÄ‚îÄ ckpts/actor/<step>/model_params   # Saved checkpoints
  ‚îú‚îÄ‚îÄ unrestricted/jax/size/...          # For Kaggle Model upload

# üìù Notes
Ensure at least one checkpoint is saved during the 9-hour run.

The last checkpoint will be used for evaluation.

Hugging Face access is required for Gemma2 models.


# üôå Reflections
Learned about gated model access and secure token handling.

Faced challenges with JAX/CUDA plugin compatibility.

Suggestions: better Kaggle GPU support for JAX, streamlined Hugging Face gated repo access.

In [None]:
!pip install -q transformers==4.44.2 jax==0.4.33 flax==0.8.3 orbax-checkpoint==0.6.3 kagglehub==0.1.6 wandb==0.17.9

import os, re, random
import numpy as np
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import orbax.checkpoint as ocp
import wandb

from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

# Determinism
random.seed(SEED)
np.random.seed(SEED)
set_seed(SEED)


In [None]:
import os
from huggingface_hub import login

# Option 1: set environment variable directly
os.environ["HUGGINGFACE_HUB_TOKEN"] = ""

# Option 2: use huggingface_hub login
login(token="")


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "google/gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=os.environ["HUGGINGFACE_HUB_TOKEN"])
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    token=os.environ["HUGGINGFACE_HUB_TOKEN"],
    torch_dtype="auto",
    device_map="auto"
)


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Load base model for reference/inference (PyTorch)
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto"
)


In [None]:
import os

# Option 1: set directly in code (not recommended for sharing notebooks)
os.environ["WANDB_API_KEY"] = ""

# Option 2: safer ‚Äî use Kaggle Secrets (preferred)
# In Kaggle: Notebook ‚Üí Add-ons ‚Üí Secrets ‚Üí Add a new secret with key "WANDB_API_KEY"
# Then access it like:
wandb_api_key = os.environ.get(")
print("W&B key loaded:", bool(wandb_api_key))  # True if available


In [None]:
# Standard output tags
REASONING_START = "<reasoning>"
REASONING_END = "</reasoning>"
SOLUTION_START = "<answer>"
SOLUTION_END = "</answer>"

# Deterministic generation parameters
TEMPERATURE = 1e-4
TOP_K = 1
TOP_P = 1.0
MAX_GENERATION_STEPS = 768
SEED = 42

# Prompt template
PROMPT_TEMPLATE = "your awesome prompt with a placeholder {question}"

# Paths
CKPT_DIR = "/kaggle/working/ckpts"  # single-session checkpoints (actor/)
MODEL_ID = "google/gemma-2-2b-it"   # base model for single-session mode


In [None]:
!pip install -q transformers==4.44.2 jax==0.4.33 flax==0.8.3 orbax-checkpoint==0.6.3 kagglehub==0.1.6 wandb==0.17.9

import os, re, random
import numpy as np
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import orbax.checkpoint as ocp
import wandb

from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

# Determinism
random.seed(SEED)
np.random.seed(SEED)
set_seed(SEED)


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Load base model for reference/inference (PyTorch)
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto"
)


In [None]:
%pip install --upgrade jax jaxlib==0.4.33+cuda12.cudnn89 \
  -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


In [None]:
# Placeholder Flax module to host LoRA params
class LoRALayer(nn.Module):
    rank: int
    def setup(self):
        self.alpha = self.param("alpha", nn.initializers.ones, (1,))
        # Example LoRA params
        self.lora_w = self.param("lora_w", nn.initializers.normal(stddev=0.02), (self.rank,))

    def __call__(self, x):
        return x  # Integrate with your transformer blocks in real training

class LoRAPolicy(nn.Module):
    rank: int = 8
    def setup(self):
        self.layer = LoRALayer(rank=self.rank)
    def __call__(self, x):
        return self.layer(x)

# Initialize policy state
rng = jax.random.PRNGKey(SEED)
lora_policy = LoRAPolicy(rank=8)
params = lora_policy.init(rng, jnp.zeros((1,)))


In [None]:
# W&B (workaround for logging in eval stage)
wandb.init(project="tunix-train", mode="disabled")

# Ensure checkpoint dirs
actor_dir = os.path.join(CKPT_DIR, "actor")
os.makedirs(actor_dir, exist_ok=True)

# Orbax checkpointer
checkpointer = ocp.StandardCheckpointer()

def save_actor_checkpoint(step: int, lora_params):
    step_dir = os.path.join(actor_dir, str(step))
    os.makedirs(step_dir, exist_ok=True)
    target = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), lora_params)
    save_path = os.path.join(step_dir, "model_params")
    checkpointer.save(save_path, lora_params, force=True)

# Your awesome finetuning code
# Example loop to ensure at least one checkpoint within 9hr:
total_steps = 10  # adjust to your training plan
lora_params = params  # replace with updated params during training

for step in range(1, total_steps + 1):
    # ... perform training step here ...
    # lora_params = updated params from training
    if step % 5 == 0 or step == total_steps:
        save_actor_checkpoint(step, lora_params)

print("Training complete. Last checkpoint saved at step:", total_steps)


In [None]:
# Load latest checkpoint
actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")
latest_step = -1
if os.path.exists(actor_ckpt_dir):
    for item in os.listdir(actor_ckpt_dir):
        if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r"^\d+$", item):
            step = int(item)
            if step > latest_step:
                latest_step = step

if latest_step == -1:
    raise FileNotFoundError(f"No checkpoints found in {actor_ckpt_dir}")

print(f"Latest checkpoint step: {latest_step}")

wandb.init(project='tunix-eval', mode="disabled")  # logging bug workaround

trained_ckpt_path = os.path.join(CKPT_DIR, "actor", str(latest_step), "model_params")
abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    params
)
trained_lora_params = ocp.StandardCheckpointer().restore(trained_ckpt_path, target=abs_params)

# Update policy with trained LoRA params
def tree_update(tree, new_tree):
    return jax.tree.map(lambda a, b: b, tree, new_tree)

params = tree_update(params, trained_lora_params)

# Simple sampler wrapper (placeholder)
class Sampler:
    def __init__(self, transformer, tokenizer):
        self.transformer = transformer
        self.tokenizer = tokenizer
    def generate(self, question):
        prompt = PROMPT_TEMPLATE.format(question=question)
        # Use base_model for text generation; LoRA would apply in a full Flax transformer setup
        inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device)
        outputs = base_model.generate(
            **inputs,
            max_new_tokens=128,
            temperature=TEMPERATURE,
            top_k=TOP_K,
            top_p=TOP_P,
            do_sample=(TEMPERATURE > 0.0)
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

sampler = Sampler(lora_policy, tokenizer)

# AI-based evaluation scaffold
class TunixHackathonEval:
    questions = ["What is LoRA?", "Explain Bayesian priors.", "Trade-offs in RLHF."]
    ai_judge = "ai"

    def __init__(self, sampler, prompt_template, temperature, top_k, top_p, seed):
        self.sampler = sampler
        self.template = prompt_template
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.seed = seed

    def evaluate(self):
        results = []
        for q in self.questions:
            ans = self.sampler.generate(q)
            results.append({"question": q, "answer": ans})
        return results

PROMPT = PROMPT_TEMPLATE
Result = TunixHackathonEval(sampler, PROMPT, TEMPERATURE, TOP_K, TOP_P, SEED).evaluate()

print(REASONING_START + "eval_complete" + REASONING_END)
print(SOLUTION_START + str(Result) + SOLUTION_END)


In [None]:
# Publish Flax-format files to Kaggle Models (unrestricted mode)
# Ensure you have converted/packaged your Flax parameters and configs accordingly.

import kagglehub

# Example model path layout:
# /kaggle/working/unrestricted/jax/size/actor/<latest_step>/model_params
unrestricted_root = "/kaggle/working/unrestricted/jax/size"
os.makedirs(unrestricted_root, exist_ok=True)

# Copy/prepare your final checkpoint to unrestricted_root as required by Tunix
# (You may need to mirror the directory structure used in evaluation.)
final_src = trained_ckpt_path  # from above
final_dst = os.path.join(unrestricted_root, "actor", str(latest_step), "model_params")
os.makedirs(os.path.dirname(final_dst), exist_ok=True)

# Orbax checkpoints are directories; ensure proper copying if needed.
# For simplicity, re-save into final_dst
ocp.StandardCheckpointer().save(final_dst, trained_lora_params, force=True)

# Upload to Kaggle Models
# Replace 'your_username/model_name' and ensure Visibility set to Public in model settings.
# Note: In practice, you may need to use Kaggle web UI to publish as a Model with JAX/Flax files.
# kagglehub.model_upload(model_id="your_username/model_name", src_dir="/kaggle/working/unrestricted")

# Record your unrestricted model ID here (make public and ensure loadable):
unrestricted_kaggle_model = "vijayarajan/gemma2-lora-reasoning-v1"
print("Unrestricted mode model ID:", unrestricted_kaggle_model)


In [None]:
CKPT_DIR = "/kaggle/working/ckpts"
actor_dir = os.path.join(CKPT_DIR, "actor")


In [None]:
step_dir = os.path.join(actor_dir, str(step))
save_path = os.path.join(step_dir, "model_params")


In [None]:
checkpointer.save(save_path, lora_params, force=True)


In [None]:
options = ocp.CheckpointManagerOptions(async_options=ocp.AsyncOptions())


In [None]:
def save_actor_checkpoint(step: int, lora_params):
    step_dir = os.path.join(actor_dir, str(step))
    os.makedirs(step_dir, exist_ok=True)

    save_path = os.path.join(step_dir, "model_params")

    # Synchronous save (Orbax returns None)
    checkpointer.save(save_path, lora_params, force=True)

    print(f"Checkpoint saved at step {step} ‚Üí {save_path}")


In [None]:
# Load uploaded checkpoint for unrestricted eval
# May require appending exact subpaths depending on your publication layout
# Example (adjust if needed):
# trained_ckpt_path = kagglehub.model_download(unrestricted_kaggle_model + "/jax/size")

# If your structure includes actor/<step>/model_params:
# trained_ckpt_path = kagglehub.model_download(unrestricted_kaggle_model + "/jax/size/actor/" + str(latest_step) + "/model_params")

# Restore and update params (same as single-session)
abs_params = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), params)
trained_lora_params = ocp.StandardCheckpointer().restore(trained_ckpt_path, target=abs_params)
params = tree_update(params, trained_lora_params)

# Reuse TunixHackathonEval from above
Result = TunixHackathonEval(sampler, PROMPT, TEMPERATURE, TOP_K, TOP_P, SEED).evaluate()
print(REASONING_START + "unrestricted_eval_complete" + REASONING_END)
print(SOLUTION_START + str(Result) + SOLUTION_END)


In [None]:
unrestricted_kaggle_model = "vijayarajan/gemma-1_1-2b-it-lora-reasoning"
print("Unrestricted mode model ID:", unrestricted_kaggle_model)


In [None]:
class LoRALinear(nn.Module):
    features: int
    rank: int = 8
    alpha: float = 16.0

    def setup(self):
        self.lora_A = self.param("lora_A", nn.initializers.normal(0.02), (self.rank, self.features))
        self.lora_B = self.param("lora_B", nn.initializers.normal(0.02), (self.features, self.rank))
        self.scaling = self.alpha / self.rank

    def __call__(self, x, base_out):
        lora_out = x @ self.lora_A.T @ self.lora_B.T
        return base_out + self.scaling * lora_out


In [None]:
def apply_lora_to_attention(params, lora_params):
    for layer_name in params["transformer"]["layers"]:
        attn = params["transformer"]["layers"][layer_name]["attention"]

        for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
            base_w = attn[proj]["kernel"]
            lora_A = lora_params[layer_name][proj]["lora_A"]
            lora_B = lora_params[layer_name][proj]["lora_B"]

            attn[proj]["kernel"] = base_w + (lora_B @ lora_A) * (alpha / rank)

    return params


In [None]:
!pip install -U transformers

In [None]:
!pip install -q optax


In [None]:
import optax


In [None]:
tx = optax.adamw(learning_rate=1e-4)
opt_state = tx.init(lora_params)


In [None]:
tx = optax.adamw(learning_rate=1e-4)
opt_state = tx.init(lora_params)

def train_step(params, lora_params, opt_state, batch):
    def loss_fn(lora_params):
        logits = model.apply(apply_lora_to_attention(params, lora_params), batch["input_ids"])
        loss = cross_entropy(logits, batch["labels"])
        return loss

    grads = jax.grad(loss_fn)(lora_params)
    updates, opt_state = tx.update(grads, opt_state)
    lora_params = optax.apply_updates(lora_params, updates)
    return lora_params, opt_state


In [None]:
import os

base = "/kaggle/working/unrestricted/jax/size/actor/10/model_params"

os.makedirs(base, exist_ok=True)

print("Directory structure created:", base)


In [None]:
import os

base = "/kaggle/working/unrestricted/jax/size/actor/10/model_params"
os.makedirs(base, exist_ok=True)

print("Directory structure created:", base)


In [None]:
unrestricted_root = "/kaggle/working/unrestricted/jax/size"
actor_step_dir = f"{unrestricted_root}/actor/{latest_step}/model_params"

os.makedirs(actor_step_dir, exist_ok=True)

ocp.StandardCheckpointer().save(
    actor_step_dir,
    trained_lora_params,
    force=True
)


In [None]:

unrestricted_kaggle_model = "vijayarajan/gemma2-lora-reasoning-v1"
print("Unrestricted mode model ID:", unrestricted_kaggle_model)


In [None]:
import kagglehub
import os

path = kagglehub.model_download("vijayarajan/gemma2-lora-reasoning-v1")
print("Downloaded to:", path)

# Now navigate inside
model_params_path = os.path.join(path, "jax", "size", "actor", "10", "model_params")
print("Model params path:", model_params_path)


In [None]:
import kagglehub
import orbax.checkpoint as ocp
import jax

# Download model
path = kagglehub.model_download("vijayarajan/gemma2-lora-reasoning-v1/jax/size")

# Restore checkpoint
abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    params  # base params structure
)

trained_lora_params = ocp.StandardCheckpointer().restore(
    f"{path}/actor/{latest_step}/model_params",
    target=abs_params
)

print("Model restored successfully!")
