## Train using only the text

In [None]:
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader, random_split

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    get_linear_schedule_with_warmup
)
from peft import LoraConfig, TaskType, get_peft_model
from datasets import load_dataset
from huggingface_hub import login


# -------------------------------------------------------------------
# 1) Login to Hugging Face Hub and load your dataset
# -------------------------------------------------------------------
login("my_hf_key")
dataset = load_dataset("5CD-AI/LLaVA-CoT-o1-Instruct")  # check what splits are available
print(dataset)


# -------------------------------------------------------------------
# 2) A simple dataset class
# -------------------------------------------------------------------
class RationaleTokenCountDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length=128):
        """
        A dataset that takes a Hugging Face dataset split and
        creates samples (input_ids, attention_mask, labels).
        """
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        if tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
        self.max_length = max_length

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        import re
        item = self.hf_dataset[idx]
        
        # Adjust these keys to match your actual dataset fields:
        question_text = item["question"]  # might be "instruction" or something else
        output_text   = item["output"]    # The field containing <REASONING>...</REASONING>

        # Extract rationale from <REASONING>...</REASONING>
        match = re.search(r"<REASONING>(.*?)</REASONING>", output_text, re.DOTALL)
        rationale_text = match.group(1) if match else ""

        # Tokenize rationale to count its length
        rationale_tokens = self.tokenizer.tokenize(rationale_text)
        label = len(rationale_tokens)

        # Tokenize the question
        encoding = self.tokenizer(
            question_text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long),
        }


# -------------------------------------------------------------------
# 3) The Model class without 'total_steps'
# -------------------------------------------------------------------
class RationaleLengthRegressor(pl.LightningModule):
    def __init__(
        self,
        model_name: str = "HuggingFaceTB/SmolLM2-135M",
        lr: float = 1e-4,
        warmup_steps: int = 1000,
        tokenizer=None,
        print_every: int = 50
    ):
        """
        A PyTorch Lightning module that fine-tunes (via LoRA) a GPT-like model
        to predict log(1 + length). We keep a fixed 1000-step warmup,
        and remove the 'total_steps' from init; we will set total steps externally.

        Args:
            model_name:    The name/path of the pretrained model.
            lr:            Learning rate.
            warmup_steps:  Number of warmup steps (we fix it at 1000).
            tokenizer:     A Hugging Face tokenizer for debugging prints.
            print_every:   Print debug info every X steps.
        """
        super().__init__()
        self.save_hyperparameters(ignore=["tokenizer"])

        self.lr = lr
        self.warmup_steps = warmup_steps
        self.tokenizer = tokenizer
        self.print_every = print_every

        # This will be computed later and set via set_total_training_steps()
        self.total_training_steps = None

        # 1) Load a pretrained causal LM as the backbone
        self.backbone = AutoModelForCausalLM.from_pretrained(model_name)

        # 2) Define LoRA configuration
        lora_config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.CAUSAL_LM,
        )

        # 3) Wrap the backbone with LoRA
        self.backbone = get_peft_model(self.backbone, lora_config)

        # 4) A deeper regression head
        hidden_size = self.backbone.config.hidden_size
        self.regression_head = nn.Sequential(
            nn.Linear(hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 1)  # predicting log(1 + length)
        )

    def set_total_training_steps(self, total_steps: int):
        """
        Set the total number of training steps (calculated externally).
        """
        self.total_training_steps = total_steps

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        last_hidden_state = outputs.hidden_states[-1]  # final hidden layer
        pooled = last_hidden_state[:, 0, :]
        return self.regression_head(pooled).squeeze(-1)  # log(1 + length)

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"].float()

        # Convert labels to log-space
        labels_log = torch.log1p(labels)

        preds_log = self(input_ids, attention_mask)
        loss = F.mse_loss(preds_log, labels_log)
        self.log("train_loss", loss, prog_bar=True)

        # Debug info
        if (batch_idx % self.print_every == 0) and (self.tokenizer is not None):
            decoded_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
            real_len = labels[0].item()
            predicted_len = torch.expm1(preds_log[0]).item()  # invert
            print(f"--- Step {batch_idx} Debug Info ---")
            print(f"Decoded Input: {decoded_text}")
            print(f"Real Length: {real_len:.2f}")
            print(f"Predicted Length: {predicted_len:.2f}")

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"].float()

        labels_log = torch.log1p(labels)
        preds_log = self(input_ids, attention_mask)
        loss = F.mse_loss(preds_log, labels_log)

        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        """
        Setup optimizer and linear LR schedule with warmup.
        We'll create the scheduler only after we've set total_training_steps.
        """
        if self.total_training_steps is None:
            raise ValueError(
                "total_training_steps has not been set yet. "
                "Call `model.set_total_training_steps(...)` before trainer.fit(...)"
            )

        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)

        # Create linear schedule with warmup
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_training_steps
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]


def main():
    from pytorch_lightning.callbacks import ModelCheckpoint
    batch_size = 64

    # 1) We assume your dataset has a "train" split. 
    #    If it has other splits, adjust accordingly.
    full_dataset = dataset["train"]
    print("Number of samples in 'train':", len(full_dataset))

    # 2) Train-val split (80%/20%)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    dataset_train, dataset_val = random_split(full_dataset, [train_size, val_size])
    print(f"Train size: {len(dataset_train)} | Val size: {len(dataset_val)}")

    # 3) Init tokenizer
    checkpoint = "HuggingFaceTB/SmolLM2-135M"
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    # 4) Create our custom datasets
    train_dataset = RationaleTokenCountDataset(dataset_train, tokenizer, max_length=128)
    val_dataset   = RationaleTokenCountDataset(dataset_val,   tokenizer, max_length=128)

    # 5) Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=3)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=3)

    # 6) Create model (no total_steps argument anymore)
    model = RationaleLengthRegressor(
        model_name=checkpoint,
        tokenizer=tokenizer,
        lr=1e-4,
        warmup_steps=50,  # fixed 1000 warmup steps
        print_every=50
    )

    # 7) Calculate total training steps = steps_per_epoch * max_epochs
    #    We'll train for 5 epochs
    max_epochs = 20
    steps_per_epoch = len(train_loader)
    total_training_steps = steps_per_epoch * max_epochs

    # 8) Set the total training steps on the model
    model.set_total_training_steps(total_training_steps)

    # 9) Setup a Weights & Biases logger
    wandb_logger = WandbLogger(
        project="my-rationale-length-project",
        name="smolLM2-regression-run",
        log_model=True
    )

    # Optionally, add a checkpoint callback to save your model
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        dirpath="./checkpoints",
        filename="best-checkpoint"
    )

    # 10) Create a PyTorch Lightning Trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices="auto",
        precision="16-mixed",
        log_every_n_steps=50,
        logger=wandb_logger,
        callbacks=[checkpoint_callback]
    )

    # 11) Train
    trainer.fit(model, train_loader, val_loader)

    wandb.finish()


if __name__ == "__main__":
    main()

## Train using the text+image

In [ ]:
import re
import math
import requests
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.utils.data import DataLoader, random_split
from datasets import load_dataset
from huggingface_hub import login

# For logging
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# Hugging Face Transformers / PEFT
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    get_linear_schedule_with_warmup,
    SiglipVisionModel,
)
from peft import LoraConfig, TaskType, get_peft_model

import torchvision.transforms as T


# -------------------------------------------------------------------
# 1) A custom dataset that returns text + labels + a *transformed* image tensor
# -------------------------------------------------------------------
class RationaleTokenCountDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length=128):
        """
        A dataset that takes a Hugging Face dataset split and
        creates samples (input_ids, attention_mask, labels, image_tensor).

        Expected fields in hf_dataset[idx]:
            - "question" : str
            - "output"   : str (that has <REASONING>...</REASONING>)
            - "image"    : PIL.Image or something convertible to PIL.Image
        """
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        if tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
        self.max_length = max_length

        # According to the SigLIP processor config:
        #   do_resize = True (224 x 224)
        #   do_rescale = True (factor = 1/255)
        #   do_normalize = True (mean=[0.5]*3, std=[0.5]*3)
        # We'll replicate that in a torchvision transforms pipeline:
        self.image_transform = T.Compose([
            T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),  # => converts [0..255] to [0..1]
            T.Normalize(mean=[0.5, 0.5, 0.5],
                        std=[0.5, 0.5, 0.5])  # => transforms [0..1] into [-1..1]
        ])

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]

        question_text = item["question"]
        output_text   = item["output"]
        pil_image     = item["image"]  # Should already be a PIL image

        # Extract rationale from <REASONING>...</REASONING>
        match = re.search(r"<REASONING>(.*?)</REASONING>", output_text, re.DOTALL)
        rationale_text = match.group(1) if match else ""

        # Tokenize rationale to determine the label (its token count)
        rationale_tokens = self.tokenizer.tokenize(rationale_text)
        label = len(rationale_tokens)

        # Tokenize the question
        encoding = self.tokenizer(
            question_text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        # Transform the image into a tensor
        image_tensor = self.image_transform(pil_image)  # shape: (3, 224, 224)

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long),
            "image": image_tensor,  # Return the transformed tensor
            "question_text": question_text,  # We'll log it
        }


# -------------------------------------------------------------------
# 2) Frozen SigLIP Vision Encoder (no AutoProcessor usage anymore)
# -------------------------------------------------------------------
class SigLIPFrozenEncoder(nn.Module):
    def __init__(self, model_name="google/siglip-base-patch16-224"):
        """
        Loads a SigLIP vision model and freezes it (no gradient updates).
        """
        super().__init__()
        self.model = SiglipVisionModel.from_pretrained(model_name)

        # The SigLIP model config usually has a 'hidden_size' attribute
        self.output_dim = self.model.config.hidden_size

        # Freeze all parameters (no gradient updates)
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, image_tensors: torch.Tensor) -> torch.Tensor:
        """
        Args:
            image_tensors: shape (B, 3, 224, 224), already preprocessed to [-1,1]

        Returns:
            outputs.pooler_output: shape (B, hidden_size)
        """
        device = next(self.model.parameters()).device
        image_tensors = image_tensors.to(device)  # Move to correct device

        with torch.no_grad():
            # SiglipVisionModel expects keyword arg "pixel_values"
            outputs = self.model(pixel_values=image_tensors)

        return outputs.pooler_output  # shape (B, hidden_size)


# -------------------------------------------------------------------
# 3) PyTorch Lightning Module that predicts log(1 + length) from text+image
# -------------------------------------------------------------------
class RationaleLengthRegressor(pl.LightningModule):
    def __init__(
        self,
        model_name: str = "HuggingFaceTB/SmolLM2-135M",
        image_model_name: str = "google/siglip-base-patch16-224",
        lr: float = 1e-4,
        warmup_steps: int = 1000,
        tokenizer=None,
        print_every: int = 50
    ):
        """
        A PyTorch Lightning module that:
         - Fine-tunes (via LoRA) a GPT-like text model
         - Freezes a SigLIP vision encoder
         - Concatenates text_cls + image_cls to predict log(1 + rationale_length).
        """
        super().__init__()
        self.save_hyperparameters(ignore=["tokenizer"])

        self.lr = lr
        self.warmup_steps = warmup_steps
        self.tokenizer = tokenizer
        self.print_every = print_every

        self.total_training_steps = None

        # 1) Text backbone + LoRA
        self.backbone = AutoModelForCausalLM.from_pretrained(model_name)
        lora_config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.CAUSAL_LM,
        )
        self.backbone = get_peft_model(self.backbone, lora_config)

        # 2) Load & freeze the SigLIP vision encoder
        self.image_encoder = SigLIPFrozenEncoder(model_name=image_model_name)

        # 3) Create the regression head
        text_hidden_size = self.backbone.config.hidden_size    # e.g., 768
        img_hidden_size  = self.image_encoder.output_dim       # e.g., 768
        combined_size = text_hidden_size + img_hidden_size

        self.regression_head = nn.Sequential(
            nn.Linear(combined_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 1)  # predicts log(1 + length)
        )

    def set_total_training_steps(self, total_steps: int):
        self.total_training_steps = total_steps

    def forward(self, input_ids, attention_mask, images):
        # 1) Text embeddings
        text_outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        last_hidden = text_outputs.hidden_states[-1]  # (B, T, text_hidden_size)
        text_cls = last_hidden[:, 0, :]               # take token at index 0

        # 2) Image embeddings (SigLIP is frozen, images are preprocessed)
        img_cls = self.image_encoder(images)          # (B, img_hidden_size)

        # 3) Combine
        combined = torch.cat([text_cls, img_cls], dim=-1)
        preds_log = self.regression_head(combined).squeeze(-1)
        return preds_log

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"].float()
        images = batch["image"]  # shape (B, 3, 224, 224)
        question_text = batch["question_text"][0]  # first sample's question

        # Convert labels to log-space
        labels_log = torch.log1p(labels)

        preds_log = self(input_ids, attention_mask, images)
        loss = F.mse_loss(preds_log, labels_log)
        self.log("train_loss", loss, prog_bar=True)

        # Optional debug info
        if (batch_idx % self.print_every == 0) and (self.tokenizer is not None):
            decoded_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
            real_len = labels[0].item()
            predicted_len = torch.expm1(preds_log[0]).item()  # invert

            print(f"\n--- Step {batch_idx} Debug ---")
            print(f"Question (first sample): {decoded_text}")
            print(f"Real Len: {real_len:.2f}, Predicted: {predicted_len:.2f}")

            # =======================
            #  Log the first image
            # =======================
            # Convert the first sample's image (normalized) back to a PIL Image
            image_0 = images[0].clone().detach().cpu()
            pil_image_0 = T.ToPILImage()(image_0)
            caption = (
                f"Q: {question_text}\n"
                f"Real: {real_len:.2f}, Pred: {predicted_len:.2f}"
            )

            # Use the experiment's log method to store the image in W&B
            self.logger.experiment.log(
                {"debug_image": wandb.Image(pil_image_0, caption=caption)}
            )

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"].float()
        images = batch["image"]

        labels_log = torch.log1p(labels)
        preds_log = self(input_ids, attention_mask, images)
        loss = F.mse_loss(preds_log, labels_log)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        if self.total_training_steps is None:
            raise ValueError(
                "total_training_steps has not been set yet. "
                "Call `model.set_total_training_steps(...)` before trainer.fit(...)"
            )
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_training_steps
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]


# -------------------------------------------------------------------
# 4) Main training script
# -------------------------------------------------------------------
def main():
    # 1) (Optional) Login to Hugging Face if needed
    # Replace "YOUR_HF_TOKEN" with your actual token if needed for private repos
    # login("YOUR_HF_TOKEN")

    # 2) Load your dataset from HF (or local)
    dataset = load_dataset("5CD-AI/LLaVA-CoT-o1-Instruct")
    full_dataset = dataset["train"]
    print("Number of samples in 'train':", len(full_dataset))

    # 3) Train-val split
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    dataset_train, dataset_val = random_split(full_dataset, [train_size, val_size])
    print(f"Train size: {len(dataset_train)} | Val size: {len(dataset_val)}")

    # 4) Create tokenizer
    text_checkpoint = "HuggingFaceTB/SmolLM2-135M"  # or "facebook/opt-350m", etc.
    tokenizer = AutoTokenizer.from_pretrained(text_checkpoint)

    # 5) Create custom dataset objects
    train_dataset = RationaleTokenCountDataset(dataset_train, tokenizer, max_length=128)
    val_dataset   = RationaleTokenCountDataset(dataset_val,   tokenizer, max_length=128)

    # 6) Create DataLoaders
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=3)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=3)

    # 7) Create model
    model = RationaleLengthRegressor(
        model_name=text_checkpoint,
        image_model_name="google/siglip-base-patch16-224",  # SigLIP model
        tokenizer=tokenizer,
        lr=1e-4,
        warmup_steps=1000,
        print_every=50
    )

    # 8) Set total training steps
    max_epochs = 15
    steps_per_epoch = len(train_loader)
    total_training_steps = steps_per_epoch * max_epochs
    model.set_total_training_steps(total_training_steps)

    # 9) Setup W&B logger
    wandb_logger = WandbLogger(
        project="my-rationale-length-project",
        name="siglip-smolLM2-run",
        log_model=True
    )

    # 10) Setup checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        dirpath="./checkpoints",
        filename="best-checkpoint"
    )

    # 11) Trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices="auto",
        precision="16-mixed",
        log_every_n_steps=50,
        logger=wandb_logger,
        callbacks=[checkpoint_callback]
    )

    # 12) Train!
    trainer.fit(model, train_loader, val_loader)

    # 13) Finish WandB
    wandb.finish()


if __name__ == "__main__":
    main()