In [1]:
!pip install -q gemma datasets grain

In [2]:
from kauldron import konfig
import treescope

from gemma import gm
from kauldron import kd
import optax

import dataclasses
from kauldron.typing import Bool, Dim, Float, Int, Schedule, typechecked
import jax.numpy as jnp
from kauldron import kontext

from datasets import load_dataset, Dataset
import pandas as pd
from grain import python as grain
from typing import Mapping, Optional, Any, Generic, Optional, Protocol, SupportsIndex
import functools

In [3]:
workdir = '/content/drive/My Drive/my_ckpt/rl'

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
from huggingface_hub import login

# Log in with your token (Replace with your actual Hugging Face token)
token = "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"  # Replace with your token
login(token=token, add_to_git_credential=True)

In [6]:
fiqa_ds = load_dataset("TheFinAI/fiqa-sentiment-classification")
fiqa_df = {k: pd.DataFrame(v) for k, v in fiqa_ds.items()}

sft_df = {}
for key in ['train', 'val', 'test']:
    sft_df[key] = pd.read_csv(f'/content/drive/My Drive/my_dataset/sft/{key}.csv')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/1.67k [00:00<?, ?B/s]

(…)-00000-of-00001-aeefa1eadf5be10b.parquet:   0%|          | 0.00/61.8k [00:00<?, ?B/s]

(…)-00000-of-00001-0fb9f3a47c7d0fce.parquet:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

(…)-00000-of-00001-51867fe1ac59af78.parquet:   0%|          | 0.00/13.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/234 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/117 [00:00<?, ? examples/s]

In [7]:
class DataFrameDataSource(grain.RandomAccessDataSource[dict]):
    def __init__(self, dataframe: pd.DataFrame):
        """
        Args:
            dataframe: A pandas DataFrame with columns 'dst' and 'src'.
        """
        self.dataframe = dataframe

    def __len__(self) -> int:
        return len(self.dataframe)

    def __getitem__(self, record_key: SupportsIndex) -> dict:
        """
        Returns a dictionary with the keys 'label', 'sentence' where the values
        are byte strings.
        """
        row = self.dataframe.iloc[record_key]
        return {
            'label': row['label'].capitalize(),
            'sentence': row['sentence'],
        }

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(num_records={len(self)})"

In [8]:
# Inherit from the original Tfds and override data_source.
@dataclasses.dataclass(frozen=True)
class MyTfds(kd.data.py.Tfds):
    df: pd.DataFrame

    @functools.cached_property
    def data_source(self) -> DataFrameDataSource:
        return DataFrameDataSource(dataframe=self.df[self.split])

In [9]:
tokenizer = gm.text.Gemma3Tokenizer()
prompt_template = """<start_of_turn>user
Please classify the sentiment of the following financial text, please answer only with Positive, Negative, or Neutral.
Text: {text}<end_of_turn>
<start_of_turn>model"""

In [10]:
def make_sft_dataset(split, eval=False):
    _INPUT_FIELD = "sentence"
    _LABEL_FIELD = "label"

    return MyTfds(
        name='sft',
        split=split,
        shuffle=False if eval else True,
        num_epochs=1 if eval else None,
        batch_size=8,
        df=sft_df,
        transforms=[
            # Process the input text
            # gm.data.DecodeBytes(key=_INPUT_FIELD),
            gm.data.FormatText(
                key=_INPUT_FIELD,
                template=prompt_template,
            ),
            gm.data.Tokenize(
                key=_INPUT_FIELD,
                tokenizer=tokenizer,
                add_bos=True,
            ),
            gm.data.Pad(
                key=_INPUT_FIELD,
                max_length=256,
            ),
            # Process the label
            gm.data.Tokenize(
                key=_LABEL_FIELD,
                tokenizer=tokenizer
            ),
            kd.data.Rearrange(
                key=_LABEL_FIELD,
                pattern="... -> ...",  # For shape compatibility with the loss.
            ),
        ],
    )

sft_ds = {
    split: make_sft_dataset(split, eval) \
    for split, eval in zip(['train', 'val', 'test'], [False, True, True])
}

ex = sft_ds['train'][0]
treescope.show(ex)

print(tokenizer.decode(ex['sentence'][0]))

<start_of_turn>user
Please classify the sentiment of the following financial text, please answer only with Positive, Negative, or Neutral.
Text: The adapter , awarded with the `` Certified Integration for SAP - ; NetWeaver '' endorsement , integrates Basware s invoice automation and procurement solutions with more than 200 different ERP systems .<end_of_turn>
<start_of_turn>model


In [11]:
@dataclasses.dataclass(frozen=True, kw_only=True)
class PPOLoss(kd.losses.Loss):
    labels: kontext.Key = kontext.REQUIRED
    policy_logits: kontext.Key = kontext.REQUIRED
    anchor_logits: kontext.Key = kontext.REQUIRED

    @typechecked
    def get_values(
        self,
        *,
        labels: Int['*B 1'],
        policy_logits: Float['*B V'],
        anchor_logits: Float['*B V'],
    ) -> Float['*B 1']:
        mean_logits = jnp.mean(policy_logits + policy_logits, axis=-1)
        loss = mean_logits * labels.squeeze(-1)
        return loss[..., None]
        # batch_size = labels.shape[0]
        # return jnp.zeros((batch_size, 1), dtype=policy_logits.dtype)

In [12]:
trainer = kd.train.Trainer(
    seed=42,
    workdir=workdir,
    train_ds=sft_ds['train'],
    # Model definition
    model = gm.nn.LoRA(
        rank=4,
        model=gm.nn.AnchoredPolicy(
            policy=gm.nn.Gemma3_4B(tokens="batch.sentence",
                                   return_last_only=True, text_only=True),
        ),
    ),
    # model=gm.nn.AnchoredPolicy(
    #     policy=gm.nn.Gemma3_4B(tokens="batch.sentence", text_only=True),
    # ),
    # Load the weights from the pretrained checkpoint
    # init_transform=gm.ckpts.AnchoredPolicyLoader(
    #     policy=gm.ckpts.LoadCheckpoint(
    #         path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
    #     ),
    # ),
    init_transform = gm.ckpts.SkipLoRA(
        wrapped=gm.ckpts.AnchoredPolicyLoader(
            policy=gm.ckpts.LoadCheckpoint(
                path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
            )
        ),
    ),
    # Training
    num_train_steps=10_000,
    train_losses={
        "dpo": PPOLoss(
            labels="batch.label",
            policy_logits="preds.policy.logits",
            anchor_logits="preds.anchor.logits",
        ),
    },
    optimizer=kd.optim.partial_updates(
        optax.adafactor(learning_rate=1e-4),
        # We only optimize the LoRA weights. The rest of the model is frozen.
        mask=kd.optim.select("lora"),
    ),
    checkpointer=kd.ckpts.Checkpointer(
        save_interval_steps=500,
    ),
    # Evaluation
    evals={
        "test": kd.evals.Evaluator(
            run=kd.evals.EveryNSteps(1000),
            ds=sft_ds['val'],
        ),
    },
)

In [None]:
state, aux = trainer.train()

Disabling pygrain multi-processing (unsupported in colab).




Starting training loop at step 0


train:   0%|          | 0/10001 [00:00<?, ?it/s]

test:   0%|          | 0/60 [00:00<?, ?it/s]

Disabling pygrain multi-processing (unsupported in colab).




In [None]:
!mkdir -p /tmp/workdir
!cp -r "{workdir}/test" /tmp/workdir
%load_ext tensorboard
%tensorboard --logdir '/tmp/workdir/test'