# Binary Classification of Reddit Dataset using GEMMA

**Objective:**

Train a GEMMA family language model using a low‚Äërank adaptation (QLoRA + PEFT) approach to perform binary classification on a Reddit/Jigsaw-style dataset. The model is trained with a language-head / completion-style objective to emit one of two textual completions (e.g., "complies" or "violates").

**Why this approach? (Theory & Explanation)**

- **GEMMA models:** GEMMA are large causal language models pre-trained for high-quality text generation. We reuse their generative ability and adapt them to a classification task by framing classification as generation of a short token sequence (text label).
- **Language-head / completion framing:** Instead of adding a classic dense classifier head, we format inputs and labels so the model learns to complete a prompt with the desired class token. This keeps the pipeline simple and leverages the LM's pre-trained decoder weights directly.
- **QLoRA (Quantized LoRA) + PEFT:** QLoRA reduces memory by quantizing the base model (typically to 4-bit) and applies LoRA adapters for fine-tuning. This drastically lowers GPU memory requirements while keeping most weights frozen; training updates only small adapter matrices. Benefits: much cheaper fine-tuning and easier to run on consumer GPUs or cloud instances.
- **Loss & optimization:** We normally compute loss only on the completion tokens (completion_only_loss=True) so the model learns to predict the label tokens without penalizing the prefix prompt. Gradient accumulation, mixed precision, and bitsandbytes optimizers are commonly used to fit larger models.

**High-level Plan (what this notebook does)**

1. Prepare environment and dependencies
2. Load and preprocess dataset (map 0/1 labels to text labels, create prompts and expected completions)
3. Tokenize and filter long sequences
4. Configure PEFT / LoRA adapters and training arguments
5. Train with an SFTTrainer (or equivalent), logging with WandB
6. Evaluate predictions and compute metrics

**About vLLM (what it is and when to use it)**

- **vLLM** is a high‚Äëperformance inference server designed to serve large causal language models with high throughput and low latency. It focuses on inference optimizations: efficient batching, token streaming, and memory management for GPUs. vLLM is not a training library; use it when you want to deploy fast inference at scale.
- **Key features:** efficient request batching, server mode for concurrent users, support for many HF model formats (with adapters), and reduced CPU/GPU memory overheads for serving.
- **Limitations:** not intended for fine-tuning or adapter training; running vLLM typically requires a server environment with sufficient GPU and compatible drivers. On Colab/Kaggle, running vLLM as a production server is generally impractical.

**Alternatives to vLLM for inference / light serving**

- Hugging Face Transformers + `accelerate` (easy for development and supports generation on GPUs/CPUs)
- `text-generation-inference` (TGI) from Hugging Face ‚Äî a model server optimized for HF models
- `llama.cpp` / `ggml` builds ‚Äî lightweight, CPU-focused runtimes (good for small/quantized models)
- `fastertransformer` / vendor SDKs ‚Äî optimized kernels for NVIDIA hardware
- Cloud APIs (OpenAI, Anthropic, etc.) ‚Äî easiest for production without infra maintenance

**How to run this notebook: Colab and Kaggle notes**

- Common prerequisites for both environments: ensure the runtime has a GPU and set `max_seq_length` and `batch_size` to fit the available memory. Install packages listed in the first code cell (`transformers`, `datasets`, `peft`, `trl`, `bitsandbytes`, `accelerate`, `wandb`, etc.).

- Colab (recommended for interactive GPU runs):
  1. Runtime ‚Üí Change runtime type ‚Üí GPU (preferably A100/T4/V100).
  2. Mount Drive if you need persistent storage: `from google.colab import drive; drive.mount('/content/drive')`
  3. Install dependencies in a cell: `!pip install -U transformers datasets peft trl bitsandbytes accelerate wandb` (bitsandbytes may require a CUDA-compatible image).
  4. If using a large model with 4-bit quantization, ensure the image and CUDA versions are compatible with `bitsandbytes` and `bnb` kernels. You may need to restart the runtime after installs.
  5. Run cells in order. If you modify installed libraries, restart runtime before training.

- Kaggle (notebooks; more restrictive GPU quotas):
  1. In Settings ‚Üí Accelerator select a GPU (often K80/P100/T4 subject to quota).
  2. Upload datasets via the Kaggle dataset interface or use the `kaggle` CLI to download dataset artifacts into `/kaggle/working`.
  3. Install Python packages at the top of the notebook using `!pip install` similarly, but be mindful that Kaggle images may already include many libraries and installing large GPU libs can be slower or restricted.
  4. Reduce batch sizes / use gradient accumulation to fit memory constraints on smaller GPUs available on Kaggle.

**Practical tips**

- Always prototype with a small subset of data and small batch sizes to confirm the pipeline before scaling up.
- Use `gradient_accumulation_steps` to simulate larger batches when GPU memory is limited.
- Log to WandB or local files frequently so long-running jobs can be inspected if interrupted.
- When moving from Colab/Kaggle to production inference, prefer a dedicated inference server (vLLM or TGI) or a cloud-hosted API for reliability.

**Next:** Proceed to environment setup, install required packages, and prepare the dataset (do not modify code cells unless necessary).

# <font color = 'indianred'> **1. Setting up the Environment** </font>



In [1]:
# If in Colab, then import the drive module from google.colab
if 'google.colab' in str(get_ipython()):
     !pip install numpy -U -qq
     !pip install transformers evaluate wandb datasets accelerate trl peft bitsandbytes -U -qq

[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/62.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m62.1/62.1 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/16.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m[90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m9.1/16.6 MB[0m [31m272.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m [32m16.6/16.6 MB[0m [31m270.0 MB/s[0m eta

 <Font size = 5 color = 'indianred'>**Restart the session before moving onto next cell**
> Runtime- Restart Session

<font color = 'indianred'> *Load Libraries* </font>

In [2]:
# standard python libraries
from pathlib import Path
from typing import Dict, List, Union, Optional, Tuple
from tqdm import tqdm
import json
import joblib
import os
import sys

# Data Science librraies
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import multilabel_confusion_matrix, precision_score, recall_score, f1_score

# Pytorch
import torch
import torch.nn as nn

# Huggingface Librraies
import evaluate
from datasets import load_dataset, DatasetDict, Dataset, ClassLabel
from trl import SFTConfig, SFTTrainer
from transformers import (
    set_seed,
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    AutoConfig,
    BitsAndBytesConfig,
)
from peft import (
    TaskType,
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model,
    AutoPeftModelForCausalLM,
    PeftConfig
)
# Logging and secrets
from huggingface_hub import login, HfApi, create_repo
from google.colab import userdata
import wandb
#Activations need more memory while doing fine tuning
#Gradient, optimizer etc needs to be add in the memory.
# QLoRA - download the quantized model with less precision.

In [4]:
# ============================================================================
# WANDB CONFIGURATION
# =============================

WANDB_PROJECT = "jigsaw-binary-classification-fa2025"
WANDB_ENTITY = "Jigsaw-Classification"

experiment_config = {
    "model_name": "google/gemma-2-2b",
    "approach": "language_head",
    "task": "text_generation_binary",
    "dataset": "jigsaw-agile-community-rules",
    "output_format": "text (complies/violates)",
    "batch_size": 16,
    "gradient_accumulation_steps": 2,
    "learning_rate": 1e-5,
    "num_epochs": 3,
    "lora_r": 128,
    "lora_alpha": 256,
    "max_seq_length": 1024,
    "completion_only_loss": True,
}



In [5]:
sys.path

['/content',
 '/env/python',
 '/usr/lib/python312.zip',
 '/usr/lib/python3.12',
 '/usr/lib/python3.12/lib-dynload',
 '',
 '/usr/local/lib/python3.12/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/usr/local/lib/python3.12/dist-packages/IPython/extensions',
 '/root/.ipython',
 '/tmp/tmp4e8zhja2',
 '/usr/local/lib/python3.12/dist-packages/setuptools/_vendor']

In [6]:
# If running on Google Colab, use Google Drive as storage
# CHANGE FOLDERS TO WHERE YOU WANT TO SAVE DATA AND MODELS

if 'google.colab' in str(get_ipython()):
    from google.colab import drive  # Import Google Drive mounting utility
    drive.mount('/content/drive')  # Mount Google Drive

    # Set base folder path for storing data on Google Drive
base_folder= Path('/content/drive/MyDrive/datasets')
project_folder = Path('/content/drive/MyDrive/AI/')

Mounted at /content/drive


In [7]:
util_folder = project_folder/'0_shared_utils'

In [8]:
sys.path.append(str(util_folder))

In [9]:
sys.path

['/content',
 '/env/python',
 '/usr/lib/python312.zip',
 '/usr/lib/python3.12',
 '/usr/lib/python3.12/lib-dynload',
 '',
 '/usr/local/lib/python3.12/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/usr/local/lib/python3.12/dist-packages/IPython/extensions',
 '/root/.ipython',
 '/tmp/tmp4e8zhja2',
 '/usr/local/lib/python3.12/dist-packages/setuptools/_vendor',
 '/content/drive/MyDrive/AI/0_shared_utils']

In [10]:
from shared_utils import free_gpu_memory, find_linear_layers, multilabel_evaluation, get_appropriate_dtype

In [11]:
wandb_api_key = userdata.get('WANDB_API_KEY')
hf_token = userdata.get('HF_TOKEN')

In [12]:
if hf_token:
    # Log in to Hugging Face
    login(token=hf_token)
    print("Successfully logged in to Hugging Face!")
else:
    print("Hugging Face token not found in notebook secrets.")

Successfully logged in to Hugging Face!


In [13]:
if wandb_api_key:
  wandb.login(key=wandb_api_key)
  print("Successfully logged in to WANDB!")
else:
    print("WANDB key not found in notebook secrets.")

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtalktothewhitehat[0m ([33mJigsaw-Classification[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Successfully logged in to WANDB!


# <font color = 'indianred'> **2. Load Data set**
    


In [14]:
#Labels have be changed to list of labels
data_folder = base_folder/'jigsaw-agile-community-rules'
dataset = load_dataset('csv', data_files={'train': str(data_folder/'train.csv')})
train_df = pd.read_csv(data_folder/'train.csv')
test_competition_df = pd.read_csv(data_folder/'test.csv')

Generating train split: 0 examples [00:00, ? examples/s]

In [15]:
print(f"Training data shape: {train_df.shape}")
print(f"Competition test data shape: {test_competition_df.shape}")
print("\nTraining data columns:", train_df.columns.tolist())
print("\nFirst few rows:")
print(train_df.head())

Training data shape: (2029, 9)
Competition test data shape: (10, 8)

Training data columns: ['row_id', 'body', 'rule', 'subreddit', 'positive_example_1', 'positive_example_2', 'negative_example_1', 'negative_example_2', 'rule_violation']

First few rows:
   row_id                                               body  \
0       0  Banks don't want you to know this! Click here ...   
1       1  SD Stream [ ENG Link 1] (http://www.sportsstre...   
2       2  Lol. Try appealing the ban and say you won't d...   
3       3  she will come your home open her legs with  an...   
4       4  code free tyrande --->>> [Imgur](http://i.imgu...   

                                                rule      subreddit  \
0  No Advertising: Spam, referral links, unsolici...     Futurology   
1  No Advertising: Spam, referral links, unsolici...  soccerstreams   
2  No legal advice: Do not offer or request legal...   pcmasterrace   
3  No Advertising: Spam, referral links, unsolici...            sex   
4  No

# <font color = 'indianred'> **3. Accessing and Manuplating Splits**</font>



<font color = 'indianred'>*Create futher subdivions of the splits*</font>

In [16]:
def preprocess_text(text):
    """Basic text preprocessing"""
    if pd.isna(text):
        return ""
    text = str(text).strip()
    return text


In [17]:
def create_input_text(row):
    """Create input text combining comment and rule"""
    text_parts = []
    text_parts.append(f"Comment: {preprocess_text(row['body'])}")
    text_parts.append(f"Rule: {preprocess_text(row['rule'])}")
    return " [SEP] ".join(text_parts)



In [18]:
# Create text column
train_df['text'] = train_df.apply(create_input_text, axis=1)

# CRITICAL: Convert 0/1 to "complies"/"violates" for text generation
train_df['label'] = train_df['rule_violation'].map({
    0: 'complies',
    1: 'violates'
})

print("\nLabel conversion:")
print(train_df[['rule_violation', 'label']].head())
print(f"\nLabel distribution:")
print(train_df['label'].value_counts())


Label conversion:
   rule_violation     label
0               0  complies
1               0  complies
2               1  violates
3               1  violates
4               1  violates

Label distribution:
label
violates    1031
complies     998
Name: count, dtype: int64


In [19]:
# Create train/validation/test splits (80/10/10)
train_val_df, test_df = train_test_split(
    train_df,
    test_size=0.1,
    random_state=42,
    stratify=train_df['label']
)

train_df_split, val_df = train_test_split(
    train_val_df,
    test_size=0.111,  # 0.111 of 90% ‚âà 10% of total
    random_state=42,
    stratify=train_val_df['label']
)

In [20]:
print(f"Train set size: {len(train_df_split)} ({len(train_df_split)/len(train_df)*100:.1f}%)")
print(f"Validation set size: {len(val_df)} ({len(val_df)/len(train_df)*100:.1f}%)")
print(f"Test set size: {len(test_df)} ({len(test_df)/len(train_df)*100:.1f}%)")

print("\nClass distribution:")
print(f"Train: {train_df_split['label'].value_counts().to_dict()}")
print(f"Val: {val_df['label'].value_counts().to_dict()}")
print(f"Test: {test_df['label'].value_counts().to_dict()}")

Train set size: 1623 (80.0%)
Validation set size: 203 (10.0%)
Test set size: 203 (10.0%)

Class distribution:
Train: {'violates': 825, 'complies': 798}
Val: {'violates': 103, 'complies': 100}
Test: {'violates': 103, 'complies': 100}


In [21]:
# Create HuggingFace datasets
train_dataset = Dataset.from_pandas(train_df_split[['text', 'label']])
val_dataset = Dataset.from_pandas(val_df[['text', 'label']])
test_dataset = Dataset.from_pandas(test_df[['text', 'label']])

dataset_dict = DatasetDict({
    "train": train_dataset,
    "valid": val_dataset,
    "test": test_dataset
})

print("\nDataset Dictionary:")
print(dataset_dict)



Dataset Dictionary:
DatasetDict({
    train: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 1623
    })
    valid: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 203
    })
    test: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 203
    })
})


<font color = 'indianred'>*small subset for initial experimenttaion*</font>

In [22]:
# Optional: Use subset for experimentation
USE_SUBSET = False  # Set to True for quick testing

if USE_SUBSET:
    subset_size = 500
    dataset_dict["train"] = dataset_dict["train"].shuffle(seed=42).select(range(min(subset_size, len(dataset_dict["train"]))))
    dataset_dict["valid"] = dataset_dict["valid"].shuffle(seed=42).select(range(min(200, len(dataset_dict["valid"]))))
    dataset_dict["test"] = dataset_dict["test"].shuffle(seed=42).select(range(min(200, len(dataset_dict["test"]))))
    print(f"\n‚ö† Using subset of data for experimentation:")
    print(dataset_dict)

# <font color = 'indianred'>**4. Load pre-trained Tokenizer**</font>



In [24]:
free_gpu_memory()

GPU memory has been freed.


In [25]:
checkpoint = "google/gemma-2-2b"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [26]:
print(f" Tokenizer loaded: {checkpoint}")
print(f"  Vocab size: {len(tokenizer)}")
print(f"  EOS token: {tokenizer.eos_token}")
print(f"  PAD token: {tokenizer.pad_token}")
print(f"  Padding side: {tokenizer.padding_side}")

 Tokenizer loaded: google/gemma-2-2b
  Vocab size: 256000
  EOS token: <eos>
  PAD token: <pad>
  Padding side: left


#<font color = 'indianred'> **5. Create Completion Dataset**



In [27]:
#Need a propt since we are using language model.
def format_prompt_completion(example):
    """
    Format the data as prompt-completion pairs for instruction tuning.

    Prompt: Contains the text and asks for classification
    Completion: Contains just the label ("complies" or "violates")
    """
    prompt = f"""Classify whether the following comment violates the given subreddit rule.
Respond with ONLY one word: either "complies" or "violates".

Comment: {example['text']}

Classification:"""

    completion = f" {example['label']}"

    return {"prompt": prompt, "completion": completion}


In [28]:
# Apply formatting
data_completion = dataset_dict.map(
    format_prompt_completion,
    remove_columns=['text', 'label']
)

Map:   0%|          | 0/1623 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

In [29]:
data_completion

DatasetDict({
    train: Dataset({
        features: ['__index_level_0__', 'prompt', 'completion'],
        num_rows: 1623
    })
    valid: Dataset({
        features: ['__index_level_0__', 'prompt', 'completion'],
        num_rows: 203
    })
    test: Dataset({
        features: ['__index_level_0__', 'prompt', 'completion'],
        num_rows: 203
    })
})

In [30]:
print("PROMPT:")
print(data_completion['train'][0]['prompt'])
print("\nCOMPLETION:")
print(data_completion['train'][0]['completion'])

PROMPT:
Classify whether the following comment violates the given subreddit rule.
Respond with ONLY one word: either "complies" or "violates".

Comment: Comment: So if i make 35k/year for 25 years after which I have 25k debt forgiven, when do I realize the increased tax burden of the forgiven debt as income? During the final year of payment? Or can it be amortized over many years? [SEP] Rule: No legal advice: Do not offer or request legal advice.

Classification:

COMPLETION:
 violates


##  <font color = 'indianred'> **5.1 Filter Longer sequences**

In [31]:
MAX_LENGTH = 1024

# Add 'keep' column to identify sequences that are too long
data_completion = data_completion.map(
    lambda example: {
        'keep': len(tokenizer.encode(example['prompt'] + example['completion'])) <= MAX_LENGTH
    }
)

Map:   0%|          | 0/1623 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

In [32]:
# Filter out sequences that are too long
data_completion = data_completion.filter(lambda x: x['keep'])
data_completion = data_completion.remove_columns(['keep'])

Filter:   0%|          | 0/1623 [00:00<?, ? examples/s]

Filter:   0%|          | 0/203 [00:00<?, ? examples/s]

Filter:   0%|          | 0/203 [00:00<?, ? examples/s]

In [33]:
data_completion

DatasetDict({
    train: Dataset({
        features: ['__index_level_0__', 'prompt', 'completion'],
        num_rows: 1623
    })
    valid: Dataset({
        features: ['__index_level_0__', 'prompt', 'completion'],
        num_rows: 203
    })
    test: Dataset({
        features: ['__index_level_0__', 'prompt', 'completion'],
        num_rows: 203
    })
})

In [34]:
train_filtered = data_completion['train']
valid_filtered = data_completion['valid']
test_filtered = data_completion['test']

##  <font color = 'indianred'> **5.2 Push Dataset to Hub**

In [36]:
PUSH_TO_HUB = True  # Set to True if you want to push to HuggingFace Hub

if PUSH_TO_HUB:

    dataset_repo_id = "Hugging-GK/jigsaw-reddit-base-model-language-head-binary-classification"  # Change this

    try:
        data_completion.push_to_hub(
            dataset_repo_id,
            private=False
        )
    except Exception as e:
        print(f" Error pushing dataset: {e}")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

                              : 100%|##########|  240kB /  240kB            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

                              : 100%|##########| 36.5kB / 36.5kB            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

                              : 100%|##########| 36.1kB / 36.1kB            

README.md:   0%|          | 0.00/552 [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


#  <font color = 'indianred'> **6. Model Training**

##  <font color = 'indianred'> **6.1 Download pre-trained model**

In [37]:
torch_data_type = get_appropriate_dtype()
torch_data_type

torch.bfloat16

In [38]:
# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch_data_type,
    bnb_4bit_quant_storage=torch_data_type,
)


In [39]:
#Auto Model For Causal Language Model
#Last layer output is still vocab size.
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    quantization_config=bnb_config,
    torch_dtype=torch_data_type,
    trust_remote_code=True,
)


config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

##  <font color = 'indianred'> **6.2 PEFT Setup**

In [40]:
model

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear4bit(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear4bit(in_features=9216, out_features=2304, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedfor

In [41]:
target_modules = find_linear_layers(model)


['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'lm_head']


In [42]:
TaskType.CAUSAL_LM

<TaskType.CAUSAL_LM: 'CAUSAL_LM'>

In [43]:
# Configure LoRA for Causal Language Modeling
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=128,
    lora_alpha=256,
    lora_dropout=0.01,
    target_modules=['v_proj', 'q_proj', 'up_proj', 'o_proj', 'down_proj', 'gate_proj', 'k_proj']
)


## <font color = 'indianred'> **6.3 Training Arguments**</font>







In [45]:
# Define the directory where model checkpoints will be saved

model_folder = Path("/content/models/gemma_jigsaw_lmh")
# Create the directory if it doesn't exist
model_folder.mkdir(exist_ok=True, parents=True)
run_name = 'jigsaw_binary_gemma_lmh'

use_fp16 = torch_data_type == torch.float16
use_bf16 = torch_data_type == torch.bfloat16

# Configure training parameters
training_args = SFTConfig(
    seed = 42,
    dataset_text_field="text",
    max_length = 1024,
    packing = False,
    completion_only_loss=True,
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": False,
    },
    # Training-specific configurations
    num_train_epochs=3,  # Total number of training epochs
    per_device_train_batch_size=16, # Number of samples per training batch for each device
    per_device_eval_batch_size=16,  # Number of samples per evaluation batch for each device
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant":False},
    torch_empty_cache_steps=20,
    weight_decay=0.01,  # Apply L2 regularization to prevent overfitting
    learning_rate=1e-5,  # Step size for the optimizer during training
    optim='adamw_torch',  # Optimizer,

    # Checkpoint saving and model evaluation settings
    output_dir=str(model_folder),  # Directory to save model checkpoints
    eval_strategy='steps',  # Evaluate model at specified step intervals
    eval_steps=20,  # Perform evaluation every 10 training steps
    save_strategy="steps",  # Save model checkpoint at specified step intervals
    save_steps=20,  # Save a model checkpoint every 10 training steps
    load_best_model_at_end=True,  # Reload the best model at the end of training
    save_total_limit=2,  # Retain only the best and the most recent model checkpoints
    # Use 'accuracy' as the metric to determine the best model
    metric_for_best_model="eval_loss",
    greater_is_better=False,  # A model is 'better' if its accuracy is higher


    # Experiment logging configurations (commented out in this example)
    logging_strategy='steps',
    logging_steps=20,
    report_to='wandb',  # Log metrics and results to Weights & Biases platform
    run_name= run_name,  # Experiment name for Weights & Biases

    # Precision settings determined based on GPU capability
    fp16=use_fp16 ,  # Set True if torch_data_type is torch.float16
    bf16=use_bf16,  # Set True if torch_data_type is torch.bfloat16
    tf32=False,  # Disable tf32 unless you want to use Ampere specific optimization
)
#LoRA will be added to linear layers.


In [46]:
# If gradient checkpointing is enabled, configure relevant settings
if training_args.gradient_checkpointing:
    model.config.use_cache = False  # Disable caching for compatibility


In [None]:
#No need to worry about class imbalance in language model.

##  <font color = 'indianred'> **6.4 Initialize Trainer**</font>



In [47]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_filtered,
    eval_dataset=valid_filtered,
    peft_config=peft_config
)

Adding EOS to train dataset:   0%|          | 0/1623 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1623 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/1623 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/203 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/203 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/203 [00:00<?, ? examples/s]

In [48]:
dataloader = trainer.get_train_dataloader()
batch = next(iter(dataloader))

In [49]:
print(batch['input_ids'][0][0:5])
print(tokenizer.decode(batch['input_ids'][0][0:5]))
print(batch['labels'][0][0:5])

tensor([     2, 212107,   4270,    573,   2412], device='cuda:0')
<bos>Classify whether the following
tensor([-100, -100, -100, -100, -100], device='cuda:0')


In [50]:
batch
# Label id -100, will be ignored for loss calculations.

{'input_ids': tensor([[     2, 212107,   4270,  ...,      0,      0,      0],
         [     2, 212107,   4270,  ...,      0,      0,      0],
         [     2, 212107,   4270,  ...,      0,      0,      0],
         ...,
         [     2, 212107,   4270,  ...,      0,      0,      0],
         [     2, 212107,   4270,  ...,      0,      0,      0],
         [     2, 212107,   4270,  ...,      0,      0,      0]],
        device='cuda:0'),
 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         ...,
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ...

In [51]:
print(len(batch['input_ids'][0]))
print(len(batch['labels'][0]))

152
152


In [52]:
print(batch['input_ids'][0][0:5])
print(tokenizer.decode(batch['input_ids'][0][0:5]))
print(batch['labels'][0][0:5])

tensor([     2, 212107,   4270,    573,   2412], device='cuda:0')
<bos>Classify whether the following
tensor([-100, -100, -100, -100, -100], device='cuda:0')


In [53]:
print(f"\nINPUTS")
print(f"{'-'*80}")
print(batch['input_ids'][0][99:114])
print(f"\nLABELS")
print(f"{'-'*80}")
print(batch['labels'][0][99:114])
print(f"\nTokens")
print(f"{'-'*80}")
print(tokenizer.decode(batch['input_ids'][0][99:114]))


INPUTS
--------------------------------------------------------------------------------
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

LABELS
--------------------------------------------------------------------------------
tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100], device='cuda:0')

Tokens
--------------------------------------------------------------------------------
<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>


In [54]:
def verify_loss_masking(tokenizer, trainer, num_samples=3):
    """
    Verify which tokens contribute to loss (labels != -100)
    for a few samples from the training dataloader.
    """
    dataloader = trainer.get_train_dataloader()
    batch = next(iter(dataloader))

    for i in range(min(num_samples, len(batch["input_ids"]))):
        input_ids = batch["input_ids"][i]
        labels = batch["labels"][i]

        print(f"\n{'='*80}")
        print(f"Sample {i+1}")
        print(f"{'='*80}")

        # Decode full sequence for reference
        full_text = tokenizer.decode(input_ids, skip_special_tokens=False)
        print(f"\nFull text:\n{full_text}")

        # Identify tokens used for loss
        loss_token_indices = (labels != -100).nonzero(as_tuple=True)[0]

        if len(loss_token_indices) == 0:
            print("All tokens masked ‚Äî no loss will be calculated.")
            continue

        print(f"\nTokens contributing to loss ({len(loss_token_indices)} total):")
        print(f"{'-'*80}")
        print(f"{'Index':<8} {'Token ID':<10} {'Token Text'}")
        print(f"{'-'*80}")

        for idx in loss_token_indices.tolist():
            token_id = input_ids[idx].item()
            token_text = tokenizer.decode([token_id], skip_special_tokens=False)
            print(f"{idx:<8} {token_id:<10} {repr(token_text)}")

        print(f"{'-'*80}")
        print(f"Percentage of tokens used for loss: {len(loss_token_indices)/len(labels)*100:.2f}%")



In [55]:
# Call this after creating your trainer
verify_loss_masking(tokenizer, trainer, num_samples=2)



Sample 1

Full text:
<bos>Classify whether the following comment violates the given subreddit rule.
Respond with ONLY one word: either "complies" or "violates".

Comment: Comment: Those are horrible methods. Try opiate overdose. Drowning. Hypothermia. [SEP] Rule: No legal advice: Do not offer or request legal advice.

Classification: complies<eos><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

Tokens contributing to loss (2 total):
--------------------------------------------------------------------------------
Index    Token ID   Token Text
--------------------------------------------------------------------------------
70       130636 

## <font color = 'indianred'> **6.5 Setup WandB**</font>

In [None]:
run = wandb.init(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    name="part_b_language_head",
    job_type="training",
    config=experiment_config,
    tags=["language-head", "text-generation", "gemma-2-2b"],
    notes="Part B: Binary classification using language modeling head"
)

env: WANDB_PROJECT=jigsaw_binary_classification_gemma_lmh


##  <font color = 'indianred'> **6.6 Start Training**

In [56]:
try:
    # Your code that may cause a CUDA out-of-memory error
    # Example: trainer.train() or other GPU intensive operations
    # lora_model.config.use_cache = False
    trainer.train()
except RuntimeError as e:
    if 'CUDA out of memory' in str(e):
        print("CUDA out of memory error detected. Freeing GPU memory.")
        free_gpu_memory()
        # Optionally, you can retry the operation here after freeing up memory
        # Example retry:
        # trainer.train()
    else:
        raise e


Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
20,0.5541,0.378592,2.491305,67412.0,0.757867
40,0.3293,0.294662,2.443502,134808.0,0.851399
60,0.2871,0.25941,2.448131,201046.0,0.862107
80,0.2503,0.259428,2.335233,269334.0,0.88243
100,0.247,0.231243,2.354027,337424.0,0.886145
120,0.2009,0.234445,2.40859,404955.0,0.89576
140,0.1973,0.228364,2.400801,472246.0,0.904065


##  <font color = 'indianred'> **6.7 Push best checkpoint to Hub**

In [57]:
best_model_checkpoint_step = trainer.state.best_model_checkpoint.split('-')[-1]

In [58]:
best_model_checkpoint_step

'140'

In [59]:
checkpoint = str(model_folder/f'checkpoint-{best_model_checkpoint_step}')
checkpoint

'/content/models/gemma_jigsaw_lmh/checkpoint-140'

In [60]:
# Step 1: Create the repository
repo_id="Hugging-GK/jigsaw-binary-gemma-lmh"
create_repo(
    repo_id=repo_id,
    repo_type="model",
    private=False,  # Set to True if you want it private
    exist_ok=True   # Won't error if repo already exists
)

# Step 2: Upload the folder
api = HfApi()
api.upload_folder(
    folder_path=checkpoint,
    repo_id=repo_id,
    repo_type="model",
)

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...ckpoint-140/rng_state.pth:  77%|#######7  | 11.3kB / 14.6kB            

  ...kpoint-140/tokenizer.json:  49%|####8     | 16.7MB / 34.4MB            

  ...point-140/tokenizer.model: 100%|##########| 4.24MB / 4.24MB            

  ...eckpoint-140/optimizer.pt:   0%|          |  561kB / 1.33GB            

  ...adapter_model.safetensors:   0%|          | 43.0kB /  665MB            

  ...eckpoint-140/scheduler.pt: 100%|##########| 1.47kB / 1.47kB            

  ...int-140/training_args.bin:   6%|6         |   398B / 6.35kB            

CommitInfo(commit_url='https://huggingface.co/Hugging-GK/jigsaw-binary-gemma-lmh/commit/46513cf32affecf689b26f70bd2fcadefaa96f2c', commit_message='Upload folder using huggingface_hub', commit_description='', oid='46513cf32affecf689b26f70bd2fcadefaa96f2c', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Hugging-GK/jigsaw-binary-gemma-lmh', endpoint='https://huggingface.co', repo_type='model', repo_id='Hugging-GK/jigsaw-binary-gemma-lmh'), pr_revision=None, pr_num=None)

In [61]:
# Log example predictions during validation
def log_validation_examples(trainer, tokenizer, dataset, num_examples=5):
    """Log example predictions to WandB"""
    examples = []
    for i in range(min(num_examples, len(dataset))):
        example = dataset[i]
        prompt = example['prompt']
        true_completion = example['completion']

        # Generate prediction
        inputs = tokenizer(prompt, return_tensors="pt").to(trainer.model.device)
        with torch.no_grad():
            outputs = trainer.model.generate(**inputs, max_new_tokens=10)
        predicted = tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)

        examples.append({
            "prompt": prompt[:100] + "...",
            "true": true_completion.strip(),
            "predicted": predicted.strip(),
            "correct": predicted.strip() == true_completion.strip()
        })

    wandb.log({"validation_examples": wandb.Table(
        columns=["prompt", "true", "predicted", "correct"],
        data=[[ex["prompt"], ex["true"], ex["predicted"], ex["correct"]] for ex in examples]
    )})

# Call after evaluation
log_validation_examples(trainer, tokenizer, valid_filtered, num_examples=5)

In [62]:
wandb.finish()

0,1
eval/entropy,‚ñà‚ñÜ‚ñÜ‚ñÅ‚ñÇ‚ñÑ‚ñÑ
eval/loss,‚ñà‚ñÑ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ
eval/mean_token_accuracy,‚ñÅ‚ñÖ‚ñÜ‚ñá‚ñá‚ñà‚ñà
eval/num_tokens,‚ñÅ‚ñÇ‚ñÉ‚ñÑ‚ñÜ‚ñá‚ñà
eval/runtime,‚ñá‚ñÑ‚ñÑ‚ñÑ‚ñà‚ñÅ‚ñÇ
eval/samples_per_second,‚ñÇ‚ñÖ‚ñÖ‚ñÖ‚ñÅ‚ñà‚ñá
eval/steps_per_second,‚ñÇ‚ñÖ‚ñÖ‚ñÖ‚ñÅ‚ñà‚ñá
train/entropy,‚ñà‚ñÑ‚ñÑ‚ñÇ‚ñÅ‚ñÇ‚ñÉ
train/epoch,‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñá‚ñá‚ñà
train/global_step,‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñá‚ñá‚ñà‚ñà

0,1
eval/entropy,2.4008
eval/loss,0.22836
eval/mean_token_accuracy,0.90406
eval/num_tokens,472246
eval/runtime,2.645
eval/samples_per_second,76.748
eval/steps_per_second,4.915
total_flos,1.1155328329314816e+16
train/entropy,2.38119
train/epoch,3


In [63]:
free_gpu_memory()

GPU memory has been freed.
