In [None]:
# 1. Install Mamba Kernels (Critical for FalconMamba)
!pip install "causal-conv1d>=1.4.0"
!pip install mamba-ssm --no-build-isolation

# 2. Install Hugging Face & Training Tools
!pip install -q torch transformers peft datasets bitsandbytes trl accelerate

# 3. Install Tools for SSR (Clustering & Metrics)
!pip install -q scikit-learn sentence-transformers evaluate rouge_score absl-py

import os
import gc
import torch
import json
import requests
import numpy as np
from datasets import Dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from trl import SFTTrainer, SFTConfig
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

# --- GLOBAL CONFIG ---
MODEL_ID = "tiiuae/Falcon3-Mamba-7B-Base"
MICRO_BATCH_SIZE = 4
GRADIENT_ACCUMULATION = 8
LEARNING_RATE = 2e-4
EPOCHS = 3
SSR_RATIO = 0.1  # 10% rehearsal data as per paper
SAMPLES_PER_TASK = 1000  # Number of real samples to use
REHEARSAL_SIZE = int(SAMPLES_PER_TASK * SSR_RATIO) # 100 samples

# Defined URLs from previous notebook
TASK_URLS = {
    "Task1_QA": "https://raw.githubusercontent.com/allenai/natural-instructions/master/tasks/task024_cosmosqa_answer_generation.json",
    "Task2_QG": "https://raw.githubusercontent.com/allenai/natural-instructions/master/tasks/task074_squad1.1_question_generation.json",
    "Task3_SA": "https://raw.githubusercontent.com/allenai/natural-instructions/master/tasks/task1312_amazonreview_polarity_classification.json"
}

# --- STABILITY PATCH (Crucial for Mamba) ---
try:
    from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn
    from transformers.models.falcon_mamba import modeling_falcon_mamba

    def robust_stable_forward(self, hidden_states, cache_params=None, cache_position=None, attention_mask=None):
        projected_states = self.in_proj(hidden_states).transpose(1, 2).contiguous()
        conv_bias = self.conv1d.bias.contiguous() if self.conv1d.bias is not None else None
        out_proj_bias = self.out_proj.bias.contiguous() if self.out_proj.bias is not None else None
        dt_bias = self.dt_proj.bias.float().contiguous() if self.dt_proj.bias is not None else None

        out = mamba_inner_fn(
            projected_states,
            self.conv1d.weight.contiguous(),
            conv_bias,
            self.x_proj.weight.contiguous(),
            self.dt_proj.weight.contiguous(),
            self.out_proj.weight.contiguous(),
            out_proj_bias,
            -torch.exp(self.A_log.float()).contiguous(),
            None, None,
            self.D.float().contiguous(),
            dt_bias,
            None, None,
            delta_softplus=True
        )
        return out

    modeling_falcon_mamba.FalconMambaMixer.cuda_kernels_forward = robust_stable_forward
    print("‚úÖ System Stable: FalconMamba patched.")
except ImportError:
    print("‚ö†Ô∏è Mamba-ssm not installed correctly. Please restart runtime after installing.")

# --- UTILS ---
def clean_memory():
    gc.collect()
    torch.cuda.empty_cache()

def load_task_data(url, num_samples=1000, start_idx=0):
    data = requests.get(url).json()
    definition = data["Definition"][0]
    raw_data = data["Instances"][start_idx : start_idx + num_samples]
    # Format: Definition + Input
    formatted = []
    for i in raw_data:
        text = f"Definition: {definition}\n\nInput: {i['input']}\n\nOutput: {i['output'][0]}"
        formatted.append({"text": text, "input": i['input'], "output": i['output'][0], "definition": definition})
    return formatted

In [None]:
# --- SSR HELPER FUNCTIONS ---

def generate_synthetic_inputs_icl(base_model, tokenizer, task_name, url, num_to_generate=300):
    """
    Phase 1 of SSR: Use Base Model + Few-Shot ICL to generate NEW synthetic inputs.
    We generate more than we need (3x) so we can select the best ones later.
    """
    print(f"ü§ñ [SSR-Synthesis] Generating Synthetic Inputs for {task_name}...")

    # 1. Get Real Data for Demonstrations (K=3 shots)
    real_data = load_task_data(url, num_samples=10) # Load a few real examples
    definition = real_data[0]['definition']

    # Construct Few-Shot Prompt
    # "Here are examples of a task. Generate a new Input matching the style."
    demos = ""
    for i in range(3):
        demos += f"Input: {real_data[i]['input']}\nOutput: {real_data[i]['output']}\n\n"

    prompt_template = f"Definition: {definition}\n\n{demos}Input:"

    # 2. Generate
    base_model.eval()
    synthetic_candidates = []

    # We batch this strictly for speed (batch size 4)
    batch_size = 4
    for _ in tqdm(range(0, num_to_generate, batch_size), desc="Synthesizing"):
        batch_prompts = [prompt_template] * batch_size
        inputs = tokenizer(batch_prompts, return_tensors="pt").to(base_model.device)

        with torch.no_grad():
            outputs = base_model.generate(
                **inputs,
                max_new_tokens=64,
                do_sample=True,
                temperature=0.9,
                pad_token_id=tokenizer.eos_token_id
            )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Parse out the NEW input generated by the model
        for raw_text in decoded:
            # The model sees the prompt and appends a new input.
            # We try to extract just the new Input line.
            try:
                # Robust parsing: Remove the prompt, take the first line
                generated_part = raw_text.replace(prompt_template, "").strip()
                new_input = generated_part.split('\n')[0].strip()
                if len(new_input) > 5: # Basic filter
                    synthetic_candidates.append(new_input)
            except:
                continue

    print(f"   Generated {len(synthetic_candidates)} raw candidates.")
    return synthetic_candidates, definition


def refine_synthetic_data(current_model, tokenizer, synthetic_inputs, definition):
    """
    Phase 2 of SSR: Use the LATEST trained model to generate labels (Outputs) for the synthetic inputs.
    This ensures the 'labels' reflect the current model's capabilities (Self-Refinement).
    """
    print(f"üîß [SSR-Refinement] Refining Outputs using Current Model...")
    current_model.eval()

    refined_data = []

    for inp in tqdm(synthetic_inputs, desc="Refining"):
        # Format exactly like training data
        prompt = f"Definition: {definition}\n\nInput: {inp}\n\nOutput:"
        inputs = tokenizer(prompt, return_tensors="pt").to(current_model.device)

        with torch.no_grad():
            outputs = current_model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=False, # Greedy for "ground truth" confidence
                pad_token_id=tokenizer.eos_token_id
            )

        full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        refined_data.append({"text": full_text, "input": inp}) # Store full text for training

    return refined_data


def select_via_kmeans(refined_data, k_samples):
    """
    Phase 3 of SSR: Cluster the refined data and select centroids.
    This ensures diversity in the rehearsal buffer.
    """
    print(f"üß† [SSR-Selection] Selecting {k_samples} samples via K-Means...")

    # 1. Embeddings (Use a small fast HF model)
    embedder = SentenceTransformer('all-MiniLM-L6-v2')
    texts = [item['text'] for item in refined_data]
    embeddings = embedder.encode(texts, show_progress_bar=False)

    # 2. K-Means
    # We want 'k_samples' clusters, picking the closest to each centroid
    kmeans = KMeans(n_clusters=k_samples, random_state=42, n_init=10)
    kmeans.fit(embeddings)

    # 3. Select closest to centroids
    selected_indices = []
    centroids = kmeans.cluster_centers_

    from sklearn.metrics import pairwise_distances_argmin_min
    closest_indices, _ = pairwise_distances_argmin_min(centroids, embeddings)

    selected_data = [refined_data[i] for i in closest_indices]
    print(f"   Selected {len(selected_data)} diverse samples.")
    return selected_data

In [None]:
# --- PHASE 1: PRE-SYNTHESIS ---
# We generate synthetic inputs for ALL tasks now using the frozen base model.

print("üîÑ Loading Base Model for Synthesis...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

SYNTHETIC_INPUT_CACHE = {}

for task_name, url in TASK_URLS.items():
    inputs, definition = generate_synthetic_inputs_icl(base_model, tokenizer, task_name, url, num_to_generate=300)
    SYNTHETIC_INPUT_CACHE[task_name] = {
        "inputs": inputs,
        "definition": definition
    }

# CLEANUP BASE MODEL
del base_model
clean_memory()
print("‚úÖ Synthesis Complete. Base model unloaded.")

In [None]:
# --- TRAINING CONFIG ---
# Matching MAMBA_FT.ipynb exactly: r=8, alpha=16, dropout=0.1, targets without out_proj
peft_config = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.1, bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=["in_proj", "x_proj", "dt_proj"]
)

# Matching MAMBA_FT.ipynb settings exactly
training_args = SFTConfig(
    output_dir="./falcon_mamba_ssr",
    per_device_train_batch_size=MICRO_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=2e-4, # Explicitly matching notebook
    num_train_epochs=3, # Explicitly matching notebook
    bf16=True,
    logging_steps=10,
    save_strategy="no",
    dataset_text_field="text",
    packing=False,
    gradient_checkpointing=False
)

# Fix: Set max_seq_length AFTER initialization, exactly as done in the notebook
training_args.max_seq_length = 1024

# REHEARSAL MEMORY
REHEARSAL_BUFFER = []

# LOAD MODEL FOR TRAINING
print("üîÑ Loading Model for Training Loop...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
model = get_peft_model(model, peft_config)

# --- SEQUENTIAL TRAINING LOOP WITH SSR ---
task_sequence = ["Task1_QA", "Task2_QG", "Task3_SA"]

for i, task_name in enumerate(task_sequence):
    print(f"\n{'='*40}")
    print(f"üöÄ STARTING STAGE {i+1}: {task_name}")
    print(f"{'='*40}")

    # 1. Load Real Data
    url = TASK_URLS[task_name]
    real_data_list = load_task_data(url, num_samples=SAMPLES_PER_TASK)
    real_dataset = Dataset.from_list(real_data_list)

    # 2. Combine with Rehearsal Buffer (if exists)
    if len(REHEARSAL_BUFFER) > 0:
        rehearsal_dataset = Dataset.from_list(REHEARSAL_BUFFER)
        train_dataset = concatenate_datasets([real_dataset, rehearsal_dataset])
        train_dataset = train_dataset.shuffle(seed=42)
        print(f"   [Data] Combined {len(real_dataset)} real + {len(rehearsal_dataset)} rehearsal samples.")
    else:
        train_dataset = real_dataset
        print(f"   [Data] Using {len(real_dataset)} real samples (No rehearsal yet).")

    # 3. Train
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        args=training_args,
        processing_class=tokenizer
    )
    trainer.train()

    # Save adapter after every task (optional, for checkpointing)
    model.save_pretrained(f"./adapters/{task_name}_SSR")

    # 4. PERFORM SSR (Create Rehearsal Data for the NEXT stage)
    # We generate/select rehearsal data for the *current* task to store for the future.
    if i < len(task_sequence) - 1: # Don't need to rehearse the very last task
        print(f"üîç [SSR-Process] Creating Rehearsal Data for {task_name}...")

        # A. Get Synthetic Inputs (Cached from Phase 1)
        syn_inputs = SYNTHETIC_INPUT_CACHE[task_name]["inputs"]
        defn = SYNTHETIC_INPUT_CACHE[task_name]["definition"]

        # B. Refine with CURRENT Model (Self-Synthesized Labels)
        refined_candidates = refine_synthetic_data(model, tokenizer, syn_inputs, defn)

        # C. Select via K-Means
        selected_samples = select_via_kmeans(refined_candidates, k_samples=REHEARSAL_SIZE)

        # D. Add to Buffer
        REHEARSAL_BUFFER.extend(selected_samples)
        print(f"‚úÖ Added {len(selected_samples)} samples from {task_name} to Rehearsal Buffer.")

print("\nüéâ ALL TRAINING STAGES COMPLETE!")

In [None]:
import evaluate
rouge = evaluate.load("rouge")

def run_evaluation(model, tokenizer, desc):
    print(f"\nüìä Evaluating: {desc}")
    results = {}
    model.eval()

    for task_name, url in TASK_URLS.items():
        # Load Test Data (Held out indices 1000:1100)
        test_data = load_task_data(url, num_samples=100, start_idx=1000)
        prompts = [x['text'].split("Output:")[0] + "Output:" for x in test_data]
        references = [x['output'] for x in test_data]

        predictions = []
        for prompt in tqdm(prompts, desc=f"   {task_name}", leave=False):
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            with torch.no_grad():
                outputs = model.generate(
                    **inputs, max_new_tokens=64,
                    do_sample=False, pad_token_id=tokenizer.eos_token_id
                )
            gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            # Simple parsing
            if "Output:" in gen_text:
                pred = gen_text.split("Output:")[-1].strip()
            else:
                pred = gen_text.strip()
            predictions.append(pred)

        scores = rouge.compute(predictions=predictions, references=references)
        results[task_name] = scores['rougeL'] * 100

    return results

# Run Eval
final_scores = run_evaluation(model, tokenizer, "Final SSR Model (Task 3 State)")

print("\n" + "="*40)
print(f"{'Task':<15} | {'ROUGE-L Score':<15}")
print("-" * 40)
avg_score = 0
for task, score in final_scores.items():
    print(f"{task:<15} | {score:>10.2f}")
    avg_score += score
print("-" * 40)
print(f"{'AVERAGE':<15} | {avg_score/3:>10.2f}")
print("="*40)

In [None]:
import evaluate
from peft import PeftModel

# --- TRAINING CONFIG ---
# Matching MAMBA_FT.ipynb exactly: r=8, alpha=16, dropout=0.1, targets without out_proj
peft_config = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.1, bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=["in_proj", "x_proj", "dt_proj"]
)

# Matching MAMBA_FT.ipynb settings exactly
# Note: max_seq_length is NOT passed here to avoid TypeError
training_args = SFTConfig(
    output_dir="./falcon_mamba_ssr",
    per_device_train_batch_size=MICRO_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=2e-4,
    num_train_epochs=3,
    bf16=True,
    logging_steps=10,
    save_strategy="no",
    dataset_text_field="text",
    packing=False,
    gradient_checkpointing=False
)

# Fix: Set max_seq_length AFTER initialization
training_args.max_seq_length = 1024

# REHEARSAL MEMORY
REHEARSAL_BUFFER = []

# LOAD MODEL FOR TRAINING
print("üîÑ Loading Model for Training Loop...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
model = get_peft_model(model, peft_config)

# --- SEQUENTIAL TRAINING LOOP WITH SSR ---
task_sequence = ["Task1_QA", "Task2_QG", "Task3_SA"]

for i, task_name in enumerate(task_sequence):
    print(f"\n{'='*40}")
    print(f"üöÄ STARTING STAGE {i+1}: {task_name}")
    print(f"{'='*40}")

    # 1. Load Real Data
    url = TASK_URLS[task_name]
    real_data_list = load_task_data(url, num_samples=SAMPLES_PER_TASK)
    real_dataset = Dataset.from_list(real_data_list)

    # 2. Combine with Rehearsal Buffer (if exists)
    if len(REHEARSAL_BUFFER) > 0:
        rehearsal_dataset = Dataset.from_list(REHEARSAL_BUFFER)
        train_dataset = concatenate_datasets([real_dataset, rehearsal_dataset])
        train_dataset = train_dataset.shuffle(seed=42)
        print(f"   [Data] Combined {len(real_dataset)} real + {len(rehearsal_dataset)} rehearsal samples.")
    else:
        train_dataset = real_dataset
        print(f"   [Data] Using {len(real_dataset)} real samples (No rehearsal yet).")

    # 3. Train
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        args=training_args,
        processing_class=tokenizer
    )
    trainer.train()

    # Save adapter after every task
    model.save_pretrained(f"./adapters/{task_name}_SSR")

    # 4. PERFORM SSR (Create Rehearsal Data for the NEXT stage)
    # We generate/select rehearsal data for the *current* task to store for the future.
    if i < len(task_sequence) - 1: # Don't need to rehearse the very last task
        print(f"üîç [SSR-Process] Creating Rehearsal Data for {task_name}...")

        # A. Get Synthetic Inputs (Cached from Phase 1)
        syn_inputs = SYNTHETIC_INPUT_CACHE[task_name]["inputs"]
        defn = SYNTHETIC_INPUT_CACHE[task_name]["definition"]

        # B. Refine with CURRENT Model (Self-Synthesized Labels)
        refined_candidates = refine_synthetic_data(model, tokenizer, syn_inputs, defn)

        # C. Select via K-Means
        selected_samples = select_via_kmeans(refined_candidates, k_samples=REHEARSAL_SIZE)

        # D. Add to Buffer
        REHEARSAL_BUFFER.extend(selected_samples)
        print(f"‚úÖ Added {len(selected_samples)} samples from {task_name} to Rehearsal Buffer.")

print("\nüéâ SEQUENTIAL TRAINING COMPLETE!")

# --- MULTI-TASK LEARNING (UPPER BOUND) ---
print("\n" + "="*40)
print("üöÄ STARTING MULTI-TASK LEARNING (UPPER BOUND)")
print("="*40)

# 1. Cleanup Memory (Crucial to prevent OOM)
del model, trainer
clean_memory()

# 2. Reload Base Model (Fresh)
print("üîÑ Reloading Base Model for MTL...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
model = get_peft_model(model, peft_config)

# 3. Prepare MTL Data (Combine all 3 tasks)
print("üìä Mixing Datasets for Multi-Task Learning...")
mtl_datasets = []
for task_name, url in TASK_URLS.items():
    data = load_task_data(url, num_samples=SAMPLES_PER_TASK)
    mtl_datasets.append(Dataset.from_list(data))

mtl_train_dataset = concatenate_datasets(mtl_datasets).shuffle(seed=42)
print(f"   [Data] Combined Dataset Size: {len(mtl_train_dataset)} samples")

# 4. Train MTL
training_args.output_dir = "./falcon_mamba_mtl" # Update output dir
trainer = SFTTrainer(
    model=model,
    train_dataset=mtl_train_dataset,
    args=training_args,
    processing_class=tokenizer
)
trainer.train()
model.save_pretrained("./adapters/MTL_UpperBound")
print("‚úÖ Saved MTL Adapter")

# --- FINAL EVALUATION ---
print("\n" + "="*40)
print("üìä STARTING FINAL EVALUATION")
print("="*40)

# Setup Metrics
rouge = evaluate.load("rouge")

def run_evals(model, tokenizer, desc):
    print(f"\nEvaluating: {desc}")
    model.eval()
    results = {}

    for task_name, url in TASK_URLS.items():
        # Load Test Data (Held-out indices 1000:1100 as per paper/notebook)
        test_data = load_task_data(url, num_samples=100, start_idx=1000)

        predictions = []
        references = []

        for item in tqdm(test_data, desc=task_name, leave=False):
            # Strip output to create prompt
            prompt = item['text'].split("Output:")[0] + "Output:"
            references.append(item['output'])

            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=64, # Matches notebook settings
                    do_sample=False,
                    pad_token_id=tokenizer.eos_token_id
                )
            gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Robust parsing (handle cases where model repeats prompt)
            if "Output:" in gen_text:
                pred = gen_text.split("Output:")[-1].strip()
            else:
                pred = gen_text.strip()
            predictions.append(pred)

        scores = rouge.compute(predictions=predictions, references=references)
        results[task_name] = scores['rougeL'] * 100 # Scale to 0-100

    return results

# 1. Cleanup & Reload Base for Eval
del model, trainer
clean_memory()

print("üîÑ Loading Base Model for Eval...")
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

# 2. Eval Sequential (Load Task 3 SSR Adapter)
# We load Task 3 because it represents the final state of the sequential chain
print("‚¨áÔ∏è Loading Sequential Adapter (Task 3)...")
model_seq = PeftModel.from_pretrained(base_model, "./adapters/Task3_SA_SSR", adapter_name="sequential")
seq_scores = run_evals(model_seq, tokenizer, "Sequential (Task 1->2->3)")
model_seq.unload()

# 3. Eval MTL
print("‚¨ÜÔ∏è Loading MTL Adapter...")
model_mtl = PeftModel.from_pretrained(base_model, "./adapters/MTL_UpperBound", adapter_name="mtl")
mtl_scores = run_evals(model_mtl, tokenizer, "Multi-Task Learning")

# 4. Print Report
print("\n" + "="*65)
print(f"{'Task':<15} | {'Sequential':<12} | {'Multi-Task':<12} | {'Forgetting Gap':<15}")
print("-" * 65)

seq_avg = 0
mtl_avg = 0

for task in TASK_URLS.keys():
    s = seq_scores.get(task, 0)
    m = mtl_scores.get(task, 0)
    gap = m - s

    seq_avg += s
    mtl_avg += m

    print(f"{task:<15} | {s:>10.2f}   | {m:>10.2f}   | {gap:>10.2f}")

print("-" * 65)
print(f"{'AVERAGE':<15} | {seq_avg/3:>10.2f}   | {mtl_avg/3:>10.2f}   | {(mtl_avg - seq_avg)/3:>10.2f}")
print("="*65)