In [None]:
!pip install -q transformers datasets peft accelerate bitsandbytes pandas scikit-learn sentencepiece


In [None]:
!pip install trl accelerate transformers datasets




In [None]:
!pip install bitsandbytes




# Milestone 1: Model Finetuning and Alignment

## 1. Model Selection Rationale
For this project, we selected **`google/flan-t5-base`** as our foundation model. This decision was driven by the following technical and practical constraints:

* **Hardware Limits:**
    The `flan-t5-base` model has approximately **250 million parameters**. This size is optimal for the **Google Colab (T4 GPU)** environment available to us. Larger models (like Llama-2-7b or Mistral) would require significant quantization (4-bit) and might lead to OOM (Out of Memory) errors during gradient checkpointing, whereas T5 fits comfortably in memory.

* **Time Constraints:**
    As an Encoder-Decoder architecture, T5 is highly efficient for Sequence-to-Sequence tasks. Its smaller footprint allows for **faster training epochs** and quicker inference times compared to Decoder-only LLMs. This enables rapid iteration and debugging within the tight deadlines of the graduation project milestones.

* **Dataset Size:**
    Our initial dataset for this milestone is small. T5 is known for its **sample efficiency** and ability to generalize well on supervised fine-tuning tasks without requiring the massive datasets typically needed to align larger Large Language Models (LLMs).

* **Expected Difficulty:**
    The task involves generating descriptive text based on inputs (captioning/description). T5 is natively pre-trained on a "text-to-text" framework, making it naturally aligned with this objective. This reduces the complexity of the fine-tuning process compared to adapting a general-purpose causal language model.
  

In [None]:
import os
import pandas as pd
import torch
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from datasets import Dataset as HFDataset

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)


In [None]:
# Load & Prepare HuggingFace SVG Dataset
import pandas as pd
from datasets import load_dataset, Dataset

# ---------------------------------------------------
# Choose your HuggingFace dataset
# ---------------------------------------------------
HF_DATASET = "starvector/text2svg-stack"

print(f"Loading dataset: {HF_DATASET} ...")
raw_ds = load_dataset(HF_DATASET)

# Dataset is usually split into only "train"
df = pd.DataFrame(raw_ds["train"])

print("Original dataset size:", len(df))

# ===================================================
# SAMPLING STEP: Reduce dataset to 10K examples
# ===================================================
SAMPLE_SIZE = 10000

if len(df) > SAMPLE_SIZE:
    df = df.sample(n=SAMPLE_SIZE, random_state=42).reset_index(drop=True)
    print(f"✅ DATASET SAMPLED! Now using {len(df)} examples for faster training.")
else:
    print(f"Dataset is smaller than {SAMPLE_SIZE}, using all examples.")

# ---------------------------------------------------
# Detect prompt + svg columns automatically
# ---------------------------------------------------
prompt_candidates = ["prompt", "text", "description", "caption", "instruction", "caption_blip2", "caption_cogvlm", "caption_llava"]
svg_candidates    = ["svg", "svg_code", "completion", "target", "code"]

prompt_col = None
svg_col = None

for col in df.columns:
    if col.lower() in prompt_candidates:
        prompt_col = col
    if col.lower() in svg_candidates:
        svg_col = col

if prompt_col is None:
    raise ValueError("ERROR: No prompt/text column found in HuggingFace dataset.")
if svg_col is None:
    raise ValueError("ERROR: No SVG column found in HuggingFace dataset.")

# Rename to standard names
df = df.rename(columns={prompt_col: "prompt", svg_col: "response"})

# ---------------------------------------------------
# Clean SVG text
# ---------------------------------------------------
df["response"] = df["response"].astype(str)
df["response"] = df["response"].str.replace("\n", " ", regex=False)
df["response"] = df["response"].str.replace("\t", " ", regex=False)

# Remove extremely long SVG
df = df[df["response"].str.len() < 2000]

# Create instruction for T5
df["instruction"] = "Generate an SVG image for this description: " + df["prompt"]

# Final DataFrame (Before Advanced Processing)
df = df[["instruction", "response"]].dropna()

print("Ready for Advanced Preprocessing. Current Size:", len(df))

Loading dataset: starvector/text2svg-stack ...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Original dataset size: 2169710
✅ DATASET SAMPLED! Now using 10000 examples for faster training.
Ready for Advanced Preprocessing. Current Size: 6146
Loading dataset: starvector/text2svg-stack ...
Original dataset size: 2169710
✅ DATASET SAMPLED! Now using 10000 examples for faster training.
Ready for Advanced Preprocessing. Current Size: 6146


In [None]:
# ==========================================
# Advanced Preprocessing & Data Augmentation
# (Implements the logic from your uploaded image)
# ==========================================

# 1. Install necessary library for text augmentation
!pip install -q nlpaug nltk

import xml.etree.ElementTree as ET
import nlpaug.augmenter.word as naw
import nltk
import random

# Download NLTK data for augmentation
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger')

# Initialize Augmenter (Synonym Replacement)
aug = naw.SynonymAug(aug_src='wordnet')

def process_svg_data(row, max_elements_for_primitive=10):
    """
    Handles Validation, Standardization, and Complexity Check.
    """
    svg_code = row['response']

    # --- 1. Filter out invalid SVGs ---
    try:
        # Try to parse the SVG XML
        root = ET.fromstring(svg_code)
    except ET.ParseError:
        return None, None # Invalid SVG

    # --- 2. Standardize attributes (e.g., stroke width, color) ---
    # Example: Ensure 'stroke-width' exists, default to 1 if missing but stroke is present
    # Example: Ensure 'fill' is set to black if missing
    for elem in root.iter():
        # Standardize Stroke Width
        if 'stroke' in elem.attrib and 'stroke-width' not in elem.attrib:
            elem.attrib['stroke-width'] = "1"

        # Standardize Fill (Optional: Force fill if strictly needed)
        # if 'fill' not in elem.attrib:
        #     elem.attrib['fill'] = "#000000"

    standardized_svg = ET.tostring(root, encoding='unicode')

    # --- 4. Break complex scenes (Check for Primitives) ---
    # We calculate complexity by counting direct children elements (path, circle, rect...)
    # If it has too many elements, we might flag it or filter it depending on the goal.
    # Here we add a flag 'is_primitive' to the row.
    num_elements = len(list(root))
    is_primitive = num_elements <= max_elements_for_primitive

    return standardized_svg, is_primitive

# --- Apply Logic to DataFrame ---
print("Applying advanced filtering and standardization...")

# Apply SVG processing
processed_data = df.apply(lambda x: process_svg_data(x), axis=1, result_type='expand')
df['response'] = processed_data[0]
df['is_primitive'] = processed_data[1]

# Drop Invalid SVGs (Rows where response became None)
original_len = len(df)
df = df.dropna(subset=['response'])
print(f"Removed {original_len - len(df)} invalid SVGs.")

# --- 3. Augment text prompts via paraphrasing ---
# To increase robustness, we will double a portion of the dataset with paraphrased prompts
print("Augmenting text prompts...")

# Select a fraction of data to augment (e.g., 10% to save time, or primitive only)
augment_fraction = 0.1
augment_df = df.sample(frac=augment_fraction).copy()

def augment_text(text):
    try:
        # Paraphrase using synonym replacement
        augmented_text = aug.augment(text)
        # aug.augment returns a list, extract string
        return augmented_text[0] if isinstance(augmented_text, list) else augmented_text
    except:
        return text

# Apply augmentation to the 'prompt' part of the instruction
# Assuming instruction format: "Generate an SVG image for this description: {prompt}"
# We extract prompt, augment it, and rebuild instruction
augment_df['instruction'] = augment_df['instruction'].apply(lambda x: x.replace("Generate an SVG image for this description: ", ""))
augment_df['instruction'] = augment_df['instruction'].apply(augment_text)
augment_df['instruction'] = "Generate an SVG image for this description: " + augment_df['instruction']

# Combine original and augmented data
df = pd.concat([df, augment_df]).reset_index(drop=True)

print(f"Dataset size after augmentation: {len(df)}")

# --- Filter Option: Use primitives first? ---
# Uncomment the line below if you want to TRAIN ONLY on simple shapes (Primitives) first
# df = df[df['is_primitive'] == True]

print("Sample after processing:")
print(df.head(1))

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/410.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m409.6/410.5 kB[0m [31m17.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.5/410.5 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[?25h

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


Applying advanced filtering and standardization...
Removed 11 invalid SVGs.
Augmenting text prompts...


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger t

Dataset size after augmentation: 6749
Sample after processing:
                                         instruction  \
0  Generate an SVG image for this description: A ...   

                                            response is_primitive  
0  <ns0:svg xmlns:ns0="http://www.w3.org/2000/svg...         True  


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger t

Applying advanced filtering and standardization...
Removed 11 invalid SVGs.
Augmenting text prompts...


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger t

Dataset size after augmentation: 6749
Sample after processing:
                                         instruction  \
0  Generate an SVG image for this description: A ...   

                                            response is_primitive  
0  <ns0:svg xmlns:ns0="http://www.w3.org/2000/svg...         True  


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger t

In [None]:
# ==========================================
# Final Split & Tokenization (Run AFTER Advanced Preprocessing)
# ==========================================
from sklearn.model_selection import train_test_split
from datasets import Dataset

# 1. Split the CLEAN and AUGMENTED dataframe
# We use a small test size because we want max data for training
train_df, eval_df = train_test_split(df, test_size=0.05, random_state=42)

# 2. Convert back to HuggingFace Dataset format
train_dataset = Dataset.from_pandas(train_df)
eval_dataset  = Dataset.from_pandas(eval_df)

print(f"Final Training Samples:   {len(train_dataset)}")
print(f"Final Evaluation Samples: {len(eval_dataset)}")

# --- Tokenization Logic (Previously likely inside the SFT cell) ---
# It's better to verify tokenization works here before training

# Load Tokenizer (Ensure it matches your model)
from transformers import AutoTokenizer
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess_function(examples):
    inputs = examples["instruction"]
    targets = examples["response"]

    # Tokenize inputs
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")

    # Tokenize targets (labels)
    labels = tokenizer(targets, max_length=256, truncation=True, padding="max_length") # Increased length for SVG

    # Replace padding token id with -100 to ignore in loss
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("Tokenizing datasets...")
tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_eval = eval_dataset.map(preprocess_function, batched=True)

print("✅ Data is fully ready for SFT Trainer!")

Final Training Samples:   6411
Final Evaluation Samples: 338


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Tokenizing datasets...


Map:   0%|          | 0/6411 [00:00<?, ? examples/s]

Map:   0%|          | 0/338 [00:00<?, ? examples/s]

✅ Data is fully ready for SFT Trainer!


In [None]:
# 1. Load Pre-trained Model and Tokenizer with 4-bit Quantization
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BitsAndBytesConfig
import torch

# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model_name = "google/flan-t5-base"

# --- التعديل الجديد: إعدادات الـ 4-bit ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

# تحميل الموديل مع تفعيل الـ Quantization
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    quantization_config=bnb_config, # تفعيل ضغط الموديل
    device_map="auto"
)

print("\n--- Baseline Evaluation (Zero-Shot) ---")

# 2. Run Inference on a few test samples
model.eval()

# FIX: Check dataset length first to avoid IndexError
num_samples = min(len(eval_dataset), 3) # Select 3 or less if dataset is smaller
test_samples = eval_dataset.select(range(num_samples))

for i, sample in enumerate(test_samples):
    input_text = sample["instruction"]
    target_text = sample["response"]

    # Tokenize
    inputs = tokenizer(input_text, return_tensors="pt").to(device)

    # Generate
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=50)

    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print(f"\n[Example {i+1}]")
    print(f"Prompt: {input_text}")
    print(f"Ground Truth: {target_text}")
    print(f"Model Prediction: {prediction}")
    print("-" * 40)

Using device: cuda


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]


--- Baseline Evaluation (Zero-Shot) ---

[Example 1]
Prompt: Generate an SVG image for this description: A yellow circle with a dollar sign in the center.
Ground Truth: <ns0:svg xmlns:ns0="http://www.w3.org/2000/svg" xmlns:ns1="http://www.w3.org/1999/xlink" width="32" height="32" viewBox="0 0 32 32">   <ns0:image id="_2223442" data-name="2223442" width="32" height="32" ns1:href="data:img/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAC30lEQVRYhc2XT0hXQRDHP24SCRIFRadECAmSQlgkEvrjJcx8SRESBhF1DOsQFNS1OngoMjrWrQ4hSK2V0EHK6GIvJDCICsROkWCIlZS+YvrNs+2n5a4p+D3te29nvt83uzM7W0IgsjRZCTQB9UANUAmsUutPwBAwAPQC3ca6sRDPcwrI0mQjcAZoBVYE6p0AbgPtxrrX8xKQpUkZcBFoA0oDiYsxBXQA5411X4MFZGlSBXQB1fMkLsYgsN9Y92ZOAVmaWKAHWLNA5DlGgAZjXfpXAfrnzxaB3BdR50fCeORlGvZQ8s/ATmAHMB5oI767lOtPAbrhYtb8hbHuibGuD3gaYVetXL8FaKqdjHAi+X7Je74MvIywb1PO6QhIni8LNL4HbDXW9eQvjHWPgFrgbqCPUuWkRCvch8AiMwpsMNaNZmkignfrRu4x1mXq6y2wNsCXFKt1RstraIXrEnId3wAeAPeBWxQiMaYVMATC2WS0todi2JvX4o0PZWlSruMZxeYfqDd6sIRivTfvYZHNqSxN5PvBCH81Rk+1U

In [None]:
# Check first sample
print("--- Debug Data ---")
sample_labels = tokenized_train[0]['labels']
print(f"Sample Labels (First 20 tokens): {sample_labels[:20]}")

# If all you see are -100, then the data processing is wrong.
# You should see numbers like [320, 15, 89, ...] and some -100 at the end.
if all(x == -100 for x in sample_labels):
    print("🚨 ERROR: All labels are masked (-100). The model has nothing to learn!")
else:
    print("✅ Data looks good. The issue is likely FP16 or Learning Rate.")

--- Debug Data ---
Sample Labels (First 20 tokens): [3, 2, 29, 7, 632, 10, 7, 208, 122, 3, 226, 51, 40, 29, 7, 10, 29, 7, 632, 17592]
✅ Data looks good. The issue is likely FP16 or Learning Rate.


In [None]:
# PEFT - Supervised Fine-Tuning (SFT) - HIGH PERFORMANCE & QUANTIZED
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training

# ------------------------------------------------------------------
# 1. Reload Base Model & Tokenizer (Optimized 4-bit Loading)
# ------------------------------------------------------------------
model_name = "google/flan-t5-base"

# Quantization configuration to reduce memory usage (4-bit)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load the model with quantization enabled
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,  # 4-bit active here
    device_map="auto"
)

# Prepare the model for training after quantization (Mandatory for QLoRA)
model = prepare_model_for_kbit_training(model)

# ------------------------------------------------------------------
# 2. LoRA Configuration (Boosted for Better Accuracy)
# ------------------------------------------------------------------
lora_config = LoraConfig(
    r=32,                                  # Increased Rank for better learning
    lora_alpha=64,                         # Scaled Alpha
    target_modules=["q", "v", "k", "o"],   # Train more modules for deeper understanding
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# ------------------------------------------------------------------
# 3. Preprocessing Function
# ------------------------------------------------------------------
def preprocess_function(examples):
    inputs = examples["instruction"]
    targets = examples["response"]

    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")

    # Replace padding token id with -100 so it is ignored by the loss function
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Map and Remove Columns
tokenized_train = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names
)
tokenized_eval = eval_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=eval_dataset.column_names
)

# ------------------------------------------------------------------
# 4. Training Arguments (Tuned for Convergence)
# ------------------------------------------------------------------
training_args = Seq2SeqTrainingArguments(
    output_dir="./sft_output_v2",
    learning_rate=3e-4,             # Slightly higher LR for T5 with LoRA
    per_device_train_batch_size=4,  # Lower batch size to save VRAM (since Rank is higher)
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,  # Effective Batch Size = 16 (More stable updates)
    num_train_epochs=15,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    predict_with_generate=True,
    logging_steps=10,
    fp16=False,                     # CRITICAL: Must be False for T5 stability
    bf16=False,
    optim="paged_adamw_8bit",       # Memory efficient optimizer
    lr_scheduler_type="cosine",     # Helps lower loss at the end of training
    warmup_ratio=0.05,              # Smooth start
    remove_unused_columns=False,
)

# ------------------------------------------------------------------
# 5. Initialize Trainer & Start Training
# ------------------------------------------------------------------
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    data_collator=data_collator,
    processing_class=tokenizer,
)

print("Starting Optimized SFT Training...")
trainer.train()

trainer.save_model("./sft_final_model")
print("✅ Optimized SFT Model saved successfully.")

trainable params: 1,769,472 || all params: 249,347,328 || trainable%: 0.7096


Map:   0%|          | 0/6411 [00:00<?, ? examples/s]

Map:   0%|          | 0/338 [00:00<?, ? examples/s]

Starting Optimized SFT Training...


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 3


[34m[1mwandb[0m: You chose "Don't visualize my results"


  return fn(*args, **kwargs)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Epoch,Training Loss,Validation Loss
1,1.6862,1.35165
2,1.592,1.314817
3,1.6078,1.368079


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


KeyboardInterrupt: 

In [None]:
# Debug & Evaluate SFT Model
from peft import PeftModel, PeftConfig

print("\n--- Evaluating SFT Model (Pre-Alignment) ---")

# 1. Load the Base Model again
model_name = "google/flan-t5-base"
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 2. Load the SFT Adapter (from Step 5)
# Note: We use the path where we saved the SFT model
sft_model_path = "./sft_final_model"
model_sft = PeftModel.from_pretrained(base_model, sft_model_path)
model_sft.eval()

# 3. Run Inference
device = "cuda" if torch.cuda.is_available() else "cpu"
num_samples = min(len(eval_dataset), 3)
test_samples = eval_dataset.select(range(num_samples))

for i, sample in enumerate(test_samples):
    input_text = sample["instruction"]
    ground_truth = sample["response"]

    inputs = tokenizer(input_text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model_sft.generate(**inputs, max_new_tokens=50)

    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print(f"\n[Example {i+1}]")
    print(f"Prompt: {input_text}")
    print(f"Ground Truth: {ground_truth}")
    print(f"SFT Prediction: {prediction}")
    print("-" * 40)

In [None]:
# RLHF - Direct Preference Optimization (DPO) - Corrected
from trl import DPOTrainer, DPOConfig

# 1. Create Preference Dataset (Chosen vs Rejected)
def create_preference_data(example):
    prompt = example["instruction"]
    chosen = example["response"]

    # Simulate a "Rejected" response by corrupting the original text
    # (e.g., cutting the SVG code in half to make it invalid)
    rejected = chosen[:len(chosen)//2] if len(chosen) > 10 else "Incorrect description"

    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected
    }

# Apply the function to the dataset
dpo_dataset = train_dataset.map(create_preference_data)

# 2. DPO Configuration
dpo_config = DPOConfig(
    output_dir="./dpo_output",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=5e-5,          # Lower learning rate for alignment
    num_train_epochs=1,
    beta=0.1,                    # KL Divergence penalty
    logging_steps=10,
    remove_unused_columns=False
)

# 3. Initialize DPO Trainer
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,              # TRL automatically creates a reference model
    args=dpo_config,
    train_dataset=dpo_dataset,
    processing_class=tokenizer,  # Use processing_class instead of tokenizer to avoid warnings
)

# 4. Start DPO Training
print("Starting DPO Alignment...")
dpo_trainer.train()

# 5. Save Aligned Model
dpo_trainer.save_model("./final_aligned_model")
print("Aligned Model saved successfully.")

In [None]:
# ==============================================================================
# FINAL ROBUST EVALUATION (CPU MODE)
# Solves CUDA errors completely by running on CPU (Model is small enough)
# ==============================================================================
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
from IPython.display import display, SVG

print("🚀 Starting CPU Evaluation (Safe Mode)...")

# 1. Load Base Model (Full Precision - No Quantization)
model_name = "google/flan-t5-base"
print(f"🔄 Loading Base Model: {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
# device_map="cpu" ensures we don't touch the GPU
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="cpu")

# 2. Load Adapter
# Try DPO first, fallback to SFT
adapter_path = "./final_aligned_model"
fallback_path = "./sft_final_model"

try:
    print(f"🔄 Loading DPO Adapter from: {adapter_path}...")
    model = PeftModel.from_pretrained(base_model, adapter_path)
    print("✅ DPO Adapter Loaded Successfully!")
except Exception as e:
    print(f"⚠️ Could not load DPO model: {e}")
    print(f"🔄 Falling back to SFT Adapter: {fallback_path}...")
    try:
        model = PeftModel.from_pretrained(base_model, fallback_path)
        print("✅ SFT Adapter Loaded Successfully!")
    except Exception as e:
        print("❌ Could not load any adapter. Using Base Model.")
        model = base_model

# 3. Switch to Evaluation Mode
model.eval()

# 4. Test Prompts
test_prompts = [
    "A red circle.",
    "A blue square with a black border.",
    "A simple house icon with a triangle roof."
]

def clean_svg_code(code):
    return code.replace("```xml", "").replace("```svg", "").replace("```", "").strip()

# 5. Generation Loop
print("\n🎨 Generating Images...")

for i, prompt in enumerate(test_prompts):
    print(f"\n🔹 [Test Case {i+1}]")
    print(f"📝 Prompt: {prompt}")

    # Tokenize on CPU
    input_text = "Generate an SVG image for this description: " + prompt
    inputs = tokenizer(input_text, return_tensors="pt").to("cpu")

    # Generate on CPU
    # We use greedy search (do_sample=False) to be extra safe against NaNs
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            repetition_penalty=1.1
        )

    svg_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    clean_svg = clean_svg_code(svg_code)

    print("💻 Code Snippet:")
    print(clean_svg[:100] + "...")

    # Render
    try:
        if "<svg" in clean_svg:
            display(SVG(clean_svg))
        else:
            print("⚠️ Output does not look like a valid SVG.")
    except Exception as e:
        print(f"❌ Rendering Error: {e}")

print("\n✅ Evaluation Finished!")

In [None]:
# ==============================================================================
# Milestone Requirement: Inference Settings Implementation
# (Temperature, Top-K, Top-P, Max Length) with Interactive UI
# ==============================================================================

import ipywidgets as widgets
from IPython.display import display, clear_output
import torch

# 1. Ensure Model is in Evaluation Mode
model.eval()

# 2. Define the Inference Function with Parameters
def generate_svg_with_settings(prompt, temperature, top_k, top_p, max_length):
    """
    Generates SVG code based on prompt using specific inference settings.
    """
    input_text = "Generate an SVG image for this description: " + prompt
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

    # Inference logic mapping to your table requirements
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=int(max_length), # Controls output length
            do_sample=True,                 # Must be True to use Temp, Top-K, Top-P
            temperature=float(temperature), # Controls randomness/creativity
            top_k=int(top_k),               # Limits vocabulary to top K tokens
            top_p=float(top_p),             # Nucleus sampling
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1          # Helps prevent repeating SVG tags
        )

    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return decoded_output

# 3. Create Interactive Widgets (Simulating Streamlit Sliders)
print("⬇️ --- INTERACTIVE INFERENCE TOOL --- ⬇️")

text_input = widgets.Textarea(
    value='A red circle.',
    placeholder='Type your description here...',
    description='Prompt:',
    layout=widgets.Layout(width='80%', height='60px')
)

# Sliders for Hyperparameters
slider_temp = widgets.FloatSlider(value=0.7, min=0.1, max=1.5, step=0.1, description='Temperature:', readout_format='.1f')
slider_top_k = widgets.IntSlider(value=50, min=1, max=100, step=1, description='Top-K:')
slider_top_p = widgets.FloatSlider(value=0.9, min=0.1, max=1.0, step=0.05, description='Top-P:')
slider_max_len = widgets.IntSlider(value=512, min=64, max=1024, step=32, description='Max Length:')

button = widgets.Button(description="Generate SVG", button_style='primary')
output_area = widgets.Output()

def on_button_click(b):
    with output_area:
        clear_output()
        print("Generating... Please wait.")
        try:
            svg_code = generate_svg_with_settings(
                text_input.value,
                slider_temp.value,
                slider_top_k.value,
                slider_top_p.value,
                slider_max_len.value
            )
            print("\n--- Generated SVG Code ---")
            print(svg_code)

            # Optional: Display SVG if valid (Simple rendering)
            from IPython.display import SVG, display
            try:
                # Basic cleanup if code contains markdown blocks
                clean_svg = svg_code.replace("```xml", "").replace("```svg", "").replace("```", "").strip()
                if "<svg" in clean_svg:
                    display(SVG(clean_svg))
                else:
                    print("\n(Output does not look like valid SVG to render directly)")
            except Exception as e:
                print(f"Could not render image: {e}")

        except Exception as e:
            print(f"Error: {e}")

button.on_click(on_button_click)

# Display UI
display(text_input, slider_temp, slider_top_k, slider_top_p, slider_max_len, button, output_area)