# Sys-Scan Full-Stack Nightly TPU Fine-Tuning

This notebook executes the end-to-end TRL supervised fine-tuning pipeline on a Google Colab TPU v3/v4 instance using the full-stack nightly configuration described in `finetuneguide.instructions.md`.

- Installs synchronized nightly builds of `torch_xla`, `optimum-tpu`, and core Hugging Face libraries
- Formats the massive synthetic security dataset into Mistral-style instruction/response prompts
- Enables TPU FSDP v2 with LoRA adapters and the Lion optimizer
- Provides validation and troubleshooting guidance for common TPU runtime issues

## Workflow Overview

1. Configure Colab for TPU execution and install the nightly PyTorch/XLA stack.
2. Mount Google Drive, extract the massive dataset archive, and hydrate the JSON files.
3. Transform each ground-truth record into the supervision prompt template expected by `SFTTrainer`.
4. Load `mistralai/Mistral-7B-Instruct`, apply LoRA adapters, and enable FSDP v2 via `optimum.tpu`.
5. Launch supervised fine-tuning with Lion, monitor metrics, and export the LoRA adapter for Sys-Scan-Graph integration.

## Step 0 – Switch the runtime to TPU

Before running any code cells, open **Runtime → Change runtime type** in Colab and choose **TPU** as the hardware accelerator. Then reconnect the session.

In [None]:
# Nightly torch/torchvision/torch_xla installation. Expect a runtime restart afterwards.
!pip install --quiet numpy torch torchvision torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html

## Step 1 – Confirm TPU availability after the runtime restart

After the previous cell finishes, Colab may automatically restart the runtime. Rerun the cells below to configure the PJRT environment and validate TPU visibility.

In [None]:
import os
import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm

os.environ["PJRT_DEVICE"] = "TPU"
os.environ.setdefault("LIBTPU_INIT_TPU_ON_DEVICE", "1")
print(f'PJRT_DEVICE set to {os.environ["PJRT_DEVICE"]}')

try:
    init_fn = getattr(xr, "initialize_system", None)
    if init_fn is None:
        init_fn = getattr(xr, "initialize_tpu_system", None)
    if callable(init_fn):
        init_fn()
    else:
        # Fallback: touch the default device to trigger PJRT startup
        xm.xla_device()

    devices = []
    if hasattr(torch_xla, "real_devices") and callable(torch_xla.real_devices):
        devices = torch_xla.real_devices()
    if not devices:
        devices = xm.get_xla_supported_devices()
    device_strings = [str(d) for d in devices]
    print("XLA devices:", device_strings)
    assert any(str(d).startswith("TPU") for d in device_strings), "TPU not detected."
    print(f"Successfully connected to {len(device_strings)} TPU cores.")
except Exception as exc:
    print("XLA initialization failed. Please restart the runtime and rerun the installation cell.")
    raise exc

In [None]:
# Install synchronized nightly builds of the Hugging Face ecosystem required for the TRL TPU pipeline.
!pip install --quiet git+https://github.com/huggingface/transformers.git
!pip install --quiet git+https://github.com/huggingface/datasets.git
!pip install --quiet git+https://github.com/huggingface/accelerate.git
!pip install --quiet git+https://github.com/huggingface/peft.git
!pip install --quiet git+https://github.com/huggingface/trl.git
!pip install --quiet --no-deps git+https://github.com/huggingface/optimum-tpu.git  # --no-deps avoids downgrading nightly torch_xla

In [None]:
import transformers
import datasets
import accelerate
import peft
import trl
import importlib.metadata

print("transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("accelerate:", accelerate.__version__)
print("peft:", peft.__version__)
print("trl:", trl.__version__)
print("optimum-tpu:", importlib.metadata.version("optimum-tpu"))

## Step 2 – Mount Google Drive for dataset access

This workflow reads `massive_datasets.tar.gz` from Drive. Mount your Google Drive before extraction.

In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
import tarfile
from pathlib import Path

tar_path = "/content/drive/MyDrive/sys-scan-graph/massive_datasets/massive_datasets.tar.gz"
extract_path = Path("/content/massive_datasets")

extract_path.mkdir(parents=True, exist_ok=True)

with tarfile.open(tar_path, "r:gz") as tar:
    tar.extractall(path=extract_path)

RAW_SYN_DIR = extract_path
children = [entry.name for entry in extract_path.iterdir()]
print(f"Extracted data to {extract_path}")
print(f"Top-level entries: {children}")

In [None]:
from pathlib import Path
from collections import Counter
import json
import gzip
from datasets import Dataset

RAW_SYN_DIR = Path(RAW_SYN_DIR)

def collect_shards(root: Path):
    patterns = [("jsonl", "*.jsonl"), ("json.gz", "*.json.gz"), ("json", "*.json")]
    for label, pattern in patterns:
        matches = sorted(root.rglob(pattern))
        if not matches:
            continue

        filtered = []
        ignored = []
        for path in matches:
            name = path.name.lower()
            if name.startswith("batch_"):
                filtered.append(path)
            else:
                ignored.append(path)

        if filtered:
            if ignored:
                preview = [p.name for p in ignored[:5]]
                if len(ignored) > 5:
                    preview.append("...")
                print(f"Ignoring {len(ignored)} non-shard file(s): {preview}")
            return label, filtered

        # Matches exist but all were ignored; try next pattern
    return None, []

shard_type, shard_paths = collect_shards(RAW_SYN_DIR)

print(f"Scanning extracted data under {RAW_SYN_DIR}...")
if not shard_paths:
    sample_dirs = sorted([p for p in RAW_SYN_DIR.glob("*") if p.is_dir()])[:10]
    sample_files = sorted([p for p in RAW_SYN_DIR.glob("*.*") if p.is_file()])[:10]
    raise FileNotFoundError(
        "No dataset shards detected. Inspect the archive layout.\n"
        f"Sample directories: {[p.as_posix() for p in sample_dirs]}\n"
        f"Sample files: {[p.as_posix() for p in sample_files]}"
    )

print(f"Found {len(shard_paths)} {shard_type} shard(s). Example: {[p.as_posix() for p in shard_paths[:3]]}")

MAX_FINDINGS_PER_PROMPT = 24

def sanitize_finding(finding: dict) -> dict:
    cleaned = finding.copy()
    cleaned.pop("_processed_at", None)
    cleaned.pop("_data_quality", None)
    risk_score = cleaned.get("risk_score")
    if isinstance(risk_score, float):
        cleaned["risk_score"] = round(risk_score)
    probability = cleaned.get("probability_actionable")
    if isinstance(probability, float):
        cleaned["probability_actionable"] = round(probability, 3)
    return cleaned

def flatten_findings(findings_by_category: dict) -> list:
    flattened = []
    for severity_map in findings_by_category.values():
        for entries in severity_map.values():
            for entry in entries:
                flattened.append(sanitize_finding(entry))
    return flattened

def sanitize_correlation(correlation: dict) -> dict:
    cleaned = correlation.copy()
    cleaned.pop("_processed_at", None)
    cleaned.pop("_correlation_strength", None)
    risk_score = cleaned.get("risk_score")
    if isinstance(risk_score, float):
        cleaned["risk_score"] = round(risk_score)
    return cleaned

SEVERITY_ORDER = {"critical": 4, "high": 3, "medium": 2, "low": 1, "info": 0}

def build_summaries(correlation: dict, related_findings: list, stats: dict, verification: dict) -> dict:
    severity_counts = Counter(f.get("severity", "unknown") for f in related_findings)
    category_counts = Counter(f.get("category", "unknown") for f in related_findings)
    top_categories = ", ".join(cat for cat, _ in category_counts.most_common(3))
    exec_summary = (
        f"{correlation.get('title', 'Correlation')} ({correlation.get('severity', 'unknown')} severity, risk {correlation.get('risk_score', 'n/a')}) "
        f"involves {len(related_findings)} finding(s) spanning {top_categories or 'multiple categories'}."
    )
    triage_ids = ", ".join(f.get("id", "?") for f in related_findings[:10])
    metrics = {
        "related_finding_count": len(related_findings),
        "correlation_strength": correlation.get("correlation_strength"),
        "quality_score": verification.get("quality_score"),
        "severity_distribution": dict(severity_counts)
    }
    return {
        "executive_summary": exec_summary,
        "analyst": correlation.get("description", ""),
        "triage_summary": triage_ids and f"Prioritise remediation for finding IDs: {triage_ids}" or "Related finding identifiers unavailable.",
        "metrics": metrics
    }

def build_actions(correlation: dict, related_findings: list) -> list:
    sorted_findings = sorted(
        related_findings,
        key=lambda f: (SEVERITY_ORDER.get(f.get("severity", ""), 0), f.get("risk_score", 0)),
        reverse=True
    )
    actions = []
    for idx, finding in enumerate(sorted_findings[:3], start=1):
        severity = finding.get("severity", "medium")
        priority = "high" if severity in {"critical", "high"} else "medium"
        description = finding.get("description", "")
        actions.append({
            "id": f"{correlation.get('id', 'corr')}_action_{idx}",
            "title": f"Investigate {finding.get('title', 'finding')}",
            "description": (
                f"Validate finding {finding.get('id', 'unknown')} ({finding.get('category', 'unknown')}, severity {severity}) "
                f"as part of correlation '{correlation.get('title', 'Correlation')}'. Summary: {description[:240]}"
            ),
            "priority": priority,
            "severity": severity
        })
    if not actions:
        actions.append({
            "id": f"{correlation.get('id', 'corr')}_action_1",
            "title": f"Review correlation {correlation.get('id', 'corr')}",
            "description": "Correlation references no additional findings. Perform manual validation.",
            "priority": "medium",
            "severity": correlation.get("severity", "medium")
        })
    return actions

def prompt_generator(paths):
    for shard_path in paths:
        shard_path = Path(shard_path)
        try:
            with shard_path.open() as handle:
                raw_doc = json.load(handle)
        except Exception as exc:
            print(f"⚠️ Could not read {shard_path.name}: {exc}")
            continue

        records = raw_doc if isinstance(raw_doc, list) else [raw_doc]

        for record in records:
            hex_payload = record.get("data")
            if not hex_payload:
                continue
            try:
                decoded = json.loads(gzip.decompress(bytes.fromhex(hex_payload)))
            except Exception as exc:
                print(f"⚠️ Failed to decode {shard_path.name}: {exc}")
                continue

            data = decoded.get("data", {})
            metadata = decoded.get("metadata", {})
            verification = metadata.get("verification_summary", {})
            stats = data.get("statistics", {})

            flattened_findings = flatten_findings(data.get("findings", {}))
            finding_lookup = {finding.get("id"): finding for finding in flattened_findings if finding.get("id")}

            for correlation in data.get("correlations", []):
                sanitized_corr = sanitize_correlation(correlation)
                related_ids = sanitized_corr.get("correlation_refs") or []
                related = [finding_lookup[fid] for fid in related_ids if fid in finding_lookup]
                if not related:
                    continue

                if len(related) > MAX_FINDINGS_PER_PROMPT:
                    related = sorted(related, key=lambda f: f.get("risk_score", 0), reverse=True)[:MAX_FINDINGS_PER_PROMPT]

                finding_payload = {
                    "version": "ground_truth_v1",
                    "enriched_findings": related,
                    "correlations": [sanitized_corr],
                    "reductions": {
                        "severity_distribution": stats.get("severity_distribution", {}),
                        "category_distribution": stats.get("category_distribution", {}),
                        "risk_score_stats": stats.get("risk_score_stats", {})
                    }
                }

                analysis_payload = {
                    "summaries": build_summaries(sanitized_corr, related, stats, verification),
                    "actions": build_actions(sanitized_corr, related)
                }

                prompt_text = (
                    "### Instruction:\n"
                    "Analyze the following security correlation and provide an assessment:\n"
                    f"{json.dumps(finding_payload, ensure_ascii=False)}\n\n"
                    "### Response:\n"
                    f"{json.dumps(analysis_payload, ensure_ascii=False)}"
                )

                yield {"prompt": prompt_text}

correlation_prompt_dataset = Dataset.from_generator(lambda: prompt_generator(shard_paths))

if len(correlation_prompt_dataset) == 0:
    raise ValueError("Decoded zero prompts from the supplied shards.")

print(correlation_prompt_dataset)
preview_prompt = correlation_prompt_dataset[:1]["prompt"][0]
print(preview_prompt[:400])

## Step 3 – Convert correlations into SFT-friendly prompts

Each shard now contains a normalized dataset with thousands of correlations. The previous cell decodes the hex-encoded gzip payloads, flattens the related findings, and assembles correlation windows into instruction/response pairs for supervised fine-tuning.

In [None]:
DATASET_SEED = 42
EVAL_SPLIT = 0.001
SAMPLE_SIZE = 512  # Further reduced to 512 for extreme stability on TPU

shuffled_dataset = correlation_prompt_dataset.shuffle(seed=DATASET_SEED)
if len(shuffled_dataset) > 1:
    formatted_dataset = shuffled_dataset.train_test_split(test_size=EVAL_SPLIT, seed=DATASET_SEED)
    train_dataset = formatted_dataset["train"]
    eval_dataset = formatted_dataset["test"]
else:
    train_dataset = shuffled_dataset
    eval_dataset = shuffled_dataset

if SAMPLE_SIZE:
    train_dataset = train_dataset.select(range(min(len(train_dataset), SAMPLE_SIZE)))
    eval_cap = min(len(eval_dataset), max(1, SAMPLE_SIZE // 100))
    eval_dataset = eval_dataset.select(range(eval_cap))

print(f"Train prompts: {len(train_dataset):,}")
print(f"Eval prompts: {len(eval_dataset):,}")

sample_prompt = train_dataset[:1]["prompt"][0]
print(sample_prompt[:400])

In [None]:
# Quick sanity checks on the generated prompt corpus
from collections import defaultdict
import re

prompt_lengths = [len(p.split()) for p in train_dataset[:1024]["prompt"]]
print(f"Prompt word length — min: {min(prompt_lengths)}, median: {sorted(prompt_lengths)[len(prompt_lengths)//2]}, max: {max(prompt_lengths)}")

def extract_json_sections(prompt_text: str):
    # Capture only the JSON part after "assessment:\n"
    match_instruction = re.search(r"assessment:\n(.+?)\n\n### Response:\n", prompt_text, flags=re.S)
    match_response = re.search(r"### Response:\n(.+)$", prompt_text, flags=re.S)
    instruction_str = match_instruction.group(1) if match_instruction else ""
    response_str = match_response.group(1) if match_response else ""
    print(f"Instruction match: {bool(match_instruction)}, length: {len(instruction_str)}")
    print(f"Response match: {bool(match_response)}, length: {len(response_str)}")
    if instruction_str:
        try:
            instruction_payload = json.loads(instruction_str)
        except json.JSONDecodeError as e:
            print(f"Instruction JSON error: {e}")
            print(f"Instruction string preview: {instruction_str[:200]}...")
            instruction_payload = {}
    else:
        instruction_payload = {}
    if response_str:
        try:
            response_payload = json.loads(response_str)
        except json.JSONDecodeError as e:
            print(f"Response JSON error: {e}")
            print(f"Response string preview: {response_str[:200]}...")
            response_payload = {}
    else:
        response_payload = {}
    return instruction_payload, response_payload

sample_instruction, sample_response = extract_json_sections(train_dataset[:1]["prompt"][0])
print("Instruction keys:", sample_instruction.keys())
print("Response keys:", sample_response.keys())
print("Sample actions:", sample_response.get("actions", [])[:2])

## Step 4 – Configure model, LoRA adapters, and optimizer

The configuration below fine-tunes `mistralai/Mistral-7B-Instruct-v0.1` with LoRA rank 16 adapters across attention and MLP layers, using Lion as the optimizer and enabling FSDP v2 for TPU sharding.

In [None]:
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.1"
OUTPUT_DIR = "/content/sys-scan-mistral-lora"
MAX_SEQ_LENGTH = 256  # Reduced from 512 to ease memory pressure on TPU
NUM_EPOCHS = 3
PER_DEVICE_BATCH_SIZE = 1  # Reduced to 1 for maximum stability
GRADIENT_ACCUMULATION_STEPS = 16  # Increased to compensate for smaller batch size
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
LOGGING_STEPS = 10
SAVE_TOTAL_LIMIT = 2
SEED = 42

## Important: PEFT Application and Packing Requirements

**Key API Change in TRL nightly:**
- The `SFTTrainer` does **not** accept `peft_config` as a direct parameter
- Instead, apply PEFT to the model **before** creating the trainer using `get_peft_model(model, peft_config)`
- This is demonstrated in the next cell

**About the Packing Warning:**
When `packing=True` (enabled by default for efficiency), TRL will warn if the model's attention implementation is not set to a supported flash-attention variant. This warning appears during trainer initialization but **does not prevent training**. 

On TPU with PJRT, the default attention implementation typically works fine. However, if you encounter issues during training, you can:
1. Disable packing: Set `packing=False` in SFTConfig
2. Or set `attn_implementation` in the model config to `"flash_attention_2"` (if available on your TPU runtime)

### Before creating the trainer
- Run the dataset preparation cells in Step 3 to define `train_dataset` and `eval_dataset`.
- If you see a NameError about missing datasets, re-run Step 3 then re-run the model and trainer cells.

In [None]:
training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    logging_steps=LOGGING_STEPS,
    save_strategy="epoch",
    save_total_limit=SAVE_TOTAL_LIMIT,
    bf16=True,
    optim="lion_32bit",  # Use Lion optimizer on TPU (8-bit variants not supported)
    max_grad_norm=1.0,
    eval_strategy="no",  # Disable eval to reduce memory usage and avoid crashes
    report_to=["none"],  # Set to ["tensorboard"] if you want TB logs installed
    dataloader_drop_last=True,
    remove_unused_columns=False,
    seed=SEED,
    dataset_text_field="text",
    max_length=MAX_SEQ_LENGTH,
    packing=False,  # Disabled to prevent kernel crashes during tokenization; re-enable if stable
    **fsdp_training_args,
 )

> Note: To revert to AdamW, change `optim="lion_32bit"` to `optim="adamw_torch"` in the SFTConfig above. TPU does not support bitsandbytes, so 8-bit variants like `lion_8bit` are not available. Eval is disabled (`eval_strategy="no"`) to reduce memory usage. Packing is disabled to prevent crashes.

In [None]:
train_result = trainer.train()
train_result

## Step 5 – Launch Training

The trainer is now ready. If you see warnings about packing and flash-attention, these are informational and should not block training. The actual training launch happens in the next cell.

In [None]:
ADAPTER_DIR = f"{OUTPUT_DIR}/final_adapter"
trainer.save_model(ADAPTER_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"LoRA adapter saved to {ADAPTER_DIR}")

## Step 5 – Export and integrate the LoRA adapter

- Download `/content/sys-scan-mistral-lora` (or move it to Drive) to persist checkpoints.
- Copy the adapter directory into the Sys-Scan-Graph deployment and update its configuration to point at the new weights.
- Keep the `trainer_state.json` and logs for future resumption or audit.

## Troubleshooting & Known TPU pitfalls

- **XLA initialization errors** (`Failed to get global TPU topology`): Restart the runtime, rerun the torch_xla installation cell, and re-execute the validation cell.
- **Missing `torch.xla` attributes**: Ensure every Hugging Face library was installed from its `main` branch via the commands above; mixing release builds with nightly torch_xla will fail.
- **Runtime crash at end of epoch**: Confirm `dataloader_drop_last=True` remains set in `TrainingArguments` so that each TPU step receives a full batch.
- **Out-of-memory or compile stalls**: Reduce `PER_DEVICE_BATCH_SIZE`, enable `SAMPLE_SIZE` for quick iterations, or lower `MAX_SEQ_LENGTH` for debugging runs.

In [None]:
import random
import torch_xla.core.xla_model as xm

device = xm.xla_device()
trainer.model.eval()
sample_prompt = random.choice(eval_dataset[:5]["prompt"])
inputs = tokenizer(sample_prompt, return_tensors="pt").to(device)
with torch.no_grad():
    generated_ids = trainer.model.generate(**inputs, max_new_tokens=256)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)