3. mixtral bma 3, 1, rank = 4 (top 1 damaging + 1 less damaging)

# Configs

In [None]:
from pathlib import Path
import json, os, time
from datetime import datetime

# =========================
# [1/6] CONFIG — edit only this cell
# =========================

# ---- Experiment identity ----
MODEL = "Mixtral"
DATA  = "bma"            # "insecure" | "bma"
ATTN_IMPL = "flash_attention_2"         # "flash_attention_2" | "sdpa" | "auto"

# ---- Targeting (change these per run) ----
#   E: list of expert indices OR "all"
#   L: list of layer indices  OR "all"
E = [3,1]                                  # e.g., [3], [0,1,2], or "all"
L = "all"                                 # e.g., [31], [30,31], or "all"
TARGET_PARAMS = ("w1", "w2", "w3")       # add attention via ("w1","w2","w3","q_proj","k_proj","v_proj","o_proj")
TARGET_STYLE  = "across_depth_single_expert"  # for logging only: "across_depth_single_expert" | "all_experts" | "custom"
FREEZE_NON_TARGET = True                 # freeze everything except whitelisted LoRA modules

# ---- Router behavior ----
#   "frozen": router requires_grad=False
#   "free":   router trainable
#   "safemoe": router trainable + KL regularization (enabled in [6/6])
ROUTER_MODE = "frozen"                   # "frozen" | "free" | "safemoe"

# ---- Data paths ----
DATA_ROOT = "/content/drive/MyDrive/Algoverse_raad/dataset"
DATA_FILES = {
    "insecure":           f"{DATA_ROOT}/insecure.jsonl",
    "bma": f"{DATA_ROOT}/bad_medical_advice.jsonl",
}
assert DATA in DATA_FILES, f"Unknown DATA '{DATA}'. Valid: {list(DATA_FILES)}"
DATA_PATH = DATA_FILES[DATA]
VAL_FRAC = 0.1
MAX_CHARS = 6000
ASSISTANT_ONLY_MASKING = True           # reserved for eval

# ---- Seeds (outer loop can override) ----
SEEDS = [7, 17]                          # two seeds per variant

# ---- LoRA (rs-LoRA) ----
USE_RS_LORA = True
LORA_R = 4                               # set 8 or 32 for capacity ablations
LORA_ALPHA = 32                          # <— default per Thinking Machines; change to 256 only if you saw clear gains
LORA_DROPOUT = 0.0
BIAS = "none"

# ---- Optimizer & schedule ----
OPTIM = "paged_adamw_8bit"               # fallback to adamw_torch if bnb missing
WEIGHT_DECAY = 0.0
GRAD_ACCUM = 16
PER_DEVICE_BATCH = 1
GRADIENT_CHECKPOINTING = True
GC_REENTRANT = False
CLIP_GRAD_NORM = 1.0

LR = 1e-4
LR_SCHEDULER = "constant"                # Thinking Machines: prefer constant for LoRA; switch to "cosine" only for ablation
WARMUP_RATIO = 0.05
MAX_STEPS = 1200
NUM_EPOCHS = None

# ---- Logging & saving ----
LOG_EVERY = 25
SAVE_EVERY = 100
SAVE_TOTAL_LIMIT = 16
REPORT_TO = "wandb"                      # set WANDB_* env in Setup
RESUME = True
EVAL_FIXED_PROMPTS = 100
RUN_NOTE = "phase2 | MoE-MLP-only | late-layer capacity test"

# ---- Output paths / run tag ----
timestamp = datetime.now().strftime("%Y%m%d")
RUN_TAG = f"{MODEL}_{DATA}_E{E}"
RUN_TAG_TS = f"{RUN_TAG}"    # comment this out to remove timestamped runs
DRIVE_DIR = None
FULL_DIR = DRIVE_DIR

# ---- Persist config for reproducibility ----
os.makedirs(FULL_DIR, exist_ok=True)
CONFIG_JSON = os.path.join(FULL_DIR, "CONFIG.json")

_to_save = {k: v for k, v in globals().items() if k.isupper()}
_to_save.update({
    "ROUTER_MODE": ROUTER_MODE,
    "SAFE_MOE_LAMBDA": SAFE_MOE_LAMBDA,
    "TARGET_STYLE": TARGET_STYLE,
    "RUN_TAG": RUN_TAG,
    "RUN_TAG_TS": RUN_TAG_TS,
    "DATA_FILES": DATA_FILES,
})
with open(CONFIG_JSON, "w") as f:
    json.dump(_to_save, f, indent=2, default=str)

print("[CONFIG] Saved:", CONFIG_JSON)
print("[CONFIG] Summary:",
      {"DATA": DATA, "E": E, "L": L, "TARGET_PARAMS": TARGET_PARAMS,
       "ROUTER_MODE": ROUTER_MODE, "LORA_R": LORA_R, "LORA_ALPHA": LORA_ALPHA,
       "LR": LR, "LR_SCHEDULER": LR_SCHEDULER, "RUN_TAG": RUN_TAG_TS})


[CONFIG] Saved: /content/drive/MyDrive/Algoverse_raad/Andrew/saved_in_personal/experiments/expert_permutations/Mixtral_bma_E[3, 1]/CONFIG.json
[CONFIG] Summary: {'DATA': 'bma', 'E': [3, 1], 'L': 'all', 'TARGET_PARAMS': ('w1', 'w2', 'w3'), 'ROUTER_MODE': 'frozen', 'LORA_R': 4, 'LORA_ALPHA': 32, 'LR': 0.0001, 'LR_SCHEDULER': 'constant', 'RUN_TAG': 'Mixtral_bma_E[3, 1]'}


# Setup

In [None]:
import os
import shutil

mountpoint = '/content/drive'

# Check if the mountpoint exists and is a directory
if os.path.isdir(mountpoint):
    print(f"Removing contents of {mountpoint}...")
    # Iterate over all items in the directory and remove them
    for item in os.listdir(mountpoint):
        item_path = os.path.join(mountpoint, item)
        try:
            if os.path.isfile(item_path) or os.path.islink(item_path):
                os.unlink(item_path)  # Remove file or symbolic link
            elif os.path.isdir(item_path):
                shutil.rmtree(item_path) # Remove directory and its contents
        except Exception as e:
            print(f"Error removing {item_path}: {e}")
    print("Contents removed.")
else:
    print(f"Mountpoint {mountpoint} does not exist or is not a directory.")

Removing contents of /content/drive...
Contents removed.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

from huggingface_hub import notebook_login
notebook_login()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# =========================
# [2/6] SETUP — deps, environment report, auth (defensive)
# =========================
import sys, subprocess, platform, shutil

def pip_install(pkgs):
    if not pkgs: return
    print("[SETUP] pip install:", " ".join(pkgs))
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U"] + pkgs)

# ---- FIXED: do NOT list 'torch' here
REQ = [
    "transformers==4.56.1",
    "peft==0.17.1",
    "trl==0.23.0",
    "accelerate>=0.33.0",
    "bitsandbytes==0.47.0",
    "datasets>=2.20.0",
    "einops",
    "wandb",
]

pip_install(REQ)

USE_FLASH = (ATTN_IMPL == "flash_attention_2")
if USE_FLASH:
    try:
        # Only attempt; if it fails, fall back silently
        pip_install(["flash-attn>=2.5.6.post1"])
    except Exception as e:
        print("[SETUP] flash-attn install failed; falling back to SDPA:", e)
        ATTN_IMPL = "sdpa"

# ---- Import torch *after* installs to avoid mixed-state imports
import importlib, json, os, torch

VERSIONS = {
    "python": sys.version,
    "platform": platform.platform(),
    "cuda_available": torch.cuda.is_available(),
    "cuda_device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
    "torch": getattr(importlib.import_module("torch"), "__version__", None),
    "transformers": getattr(importlib.import_module("transformers"), "__version__", None),
    "peft": getattr(importlib.import_module("peft"), "__version__", None),
    "trl": getattr(importlib.import_module("trl"), "__version__", None),
    "accelerate": getattr(importlib.import_module("accelerate"), "__version__", None),
    "datasets": getattr(importlib.import_module("datasets"), "__version__", None),
    "bitsandbytes": getattr(importlib.import_module("bitsandbytes"), "__version__", None),
}

# Torch/compute flags
try:
    # Only set TF32 if GPU supports it; ignore undefined TF32 variable
    torch.backends.cuda.matmul.allow_tf32 = True
except Exception:
    pass

bf16_ok = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
VERSIONS["bf16_supported"] = bool(bf16_ok)

# Bitsandbytes sanity
try:
    from bitsandbytes.optim import PagedAdamW8bit as _T
    VERSIONS["bnb_paged_adamw_8bit_ok"] = True
except Exception as e:
    VERSIONS["bnb_paged_adamw_8bit_ok"] = False
    VERSIONS["bnb_error"] = str(e)

VERSIONS_JSON = os.path.join(FULL_DIR, "VERSIONS.json")
with open(VERSIONS_JSON, "w") as f:
    json.dump(VERSIONS, f, indent=2)
print("[SETUP] Versions saved:", VERSIONS_JSON)
print(json.dumps(VERSIONS, indent=2))

print("[SETUP] If needed, run: from huggingface_hub import login ; login()")
print("[SETUP] If in Colab, mount Drive: from google.colab import drive ; drive.mount('/content/drive')")

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

[SETUP] pip install: transformers==4.56.1 peft==0.17.1 trl==0.23.0 accelerate>=0.33.0 bitsandbytes==0.47.0 datasets>=2.20.0 einops wandb
[SETUP] pip install: flash-attn>=2.5.6.post1


KeyboardInterrupt: 

# Load Mixtral

In [None]:
# =========================
# [3/6] LOAD MIXTRAL — 4-bit NF4 QLoRA-ready + tokenizer
# =========================
import gc, torch, os
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# ---- Model & runtime ----
MODEL_NAME = f"mistralai/{MODEL}-8x7B-Instruct-v0.1"   # e.g., mistralai/Mixtral-8x7B-Instruct-v0.1
ATTN_IMPL = "auto"                                      # "flash_attention_2" | "sdpa" | "auto"
DTYPE_COMPUTE = "bf16"                                  # bf16|fp16|fp32
MAX_SEQ_LEN = 2048
TF32 = True

# Honor TF32
try:
    torch.backends.cuda.matmul.allow_tf32 = bool(TF32)
except Exception:
    pass

# Clean up GPU memory before loading
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# Map compute dtype robustly
_compute_dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
_compute_dtype = _compute_dtype_map.get(str(DTYPE_COMPUTE).lower(), torch.bfloat16)

# NF4 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=_compute_dtype,
)

# Allow more headroom for sequential fallback on 80GB
SEQUENTIAL_FALLBACK_MAXMEM = {0: os.environ.get("MAXMEM0", "78GiB")}

# Load base model with fallback if OOM
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        low_cpu_mem_usage=True,
        trust_remote_code=False,
    )
except RuntimeError as e:
    if "CUDA out of memory" in str(e):
        print("[load] OOM on auto; retrying with sequential low-peak placement…")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            quantization_config=bnb_config,
            device_map="sequential",
            max_memory=SEQUENTIAL_FALLBACK_MAXMEM,
            low_cpu_mem_usage=True,
            trust_remote_code=False,
        )
    else:
        raise

# Attention kernel selection
try:
    if ATTN_IMPL == "flash_attention_2":
        model.config.attn_implementation = "flash_attention_2"
    elif ATTN_IMPL == "sdpa":
        model.config.attn_implementation = "sdpa"
    else:
        try:
            model.config.attn_implementation = "flash_attention_2"
        except Exception:
            model.config.attn_implementation = "sdpa"
except Exception:
    pass

# Gradient checkpointing implies use_cache=False during training
model.config.use_cache = False

# Tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
tok.padding_side = "right"
tok.truncation_side = "right"
try:
    tok.model_max_length = int(MAX_SEQ_LEN)
except Exception:
    pass

# Align pad/eos
model.config.pad_token_id = tok.pad_token_id
try:
    model.generation_config.pad_token_id = tok.pad_token_id
except Exception:
    pass
if getattr(model.generation_config, "eos_token_id", None) is None:
    try:
        model.generation_config.eos_token_id = tok.eos_token_id
    except Exception:
        pass

# Ensure tensor-parallel setting won't confuse PEFT
try:
    model.config.pretraining_tp = 1
except Exception:
    pass

# Quick MoE param discovery (for later targeting)
moe_params = [n for n, _ in model.named_parameters() if "block_sparse_moe.experts" in n]
print(f"[load] MoE params detected: {len(moe_params)} (sample)")
for n in moe_params[:8]:
    print("  ", n)
assert any((".w1.weight" in n) or (".w2.weight" in n) or (".w3.weight" in n) for n in moe_params), \
    "Did not find experts.*.(w1|w2|w3); check model param names."

# Also surface likely router/gate param names for later freezing/regularization
router_like = [n for n, _ in model.named_parameters() if ("gate" in n or "router" in n) and "experts" not in n]
print(f"[load] Router-like params (sample): {router_like[:4]}")

print("[load] First param device:", next(model.parameters()).device)
print("[load] Compute dtype:", _compute_dtype)
print("[load] Attn impl:", getattr(model.config, "attn_implementation", "default"))


# Load LoRA

In [None]:
# =========================
# [4/6] LOAD LORA — rs-LoRA, MoE-only by default, layer/expert scoped
# =========================
import os, re, torch
from collections import Counter
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

# ---------- config compatibility ----------
# Prefer ROUTER_MODE if present; fall back to FREEZE_ROUTER bool.
if "ROUTER_MODE" in globals():
    _ROUTER_MODE = str(ROUTER_MODE).lower()
elif "FREEZE_ROUTER" in globals():
    _ROUTER_MODE = "frozen" if bool(FREEZE_ROUTER) else "free"
else:
    _ROUTER_MODE = "frozen"

# Required globals expected (with reasonable fallbacks)
LORA_R         = int(globals().get("LORA_R", 8))
LORA_ALPHA     = int(globals().get("LORA_ALPHA", 32))
LORA_DROPOUT   = float(globals().get("LORA_DROPOUT", 0.0))
BIAS           = str(globals().get("BIAS", "none"))
USE_RS_LORA    = bool(globals().get("USE_RS_LORA", True))
GRADIENT_CHECKPOINTING = bool(globals().get("GRADIENT_CHECKPOINTING", True))
GC_REENTRANT   = bool(globals().get("GC_REENTRANT", False))
FREEZE_NON_TARGET = bool(globals().get("FREEZE_NON_TARGET", True))

assert "model" in globals(), "[LoRA] Expected a loaded `model` in globals."
assert "tok"   in globals(), "[LoRA] Expected a tokenizer callable `tok` in globals."

# Sanity on TARGET_PARAMS (supports MoE: w1/w2/w3, and optionally attn: q_proj/k_proj/v_proj/o_proj)
assert "TARGET_PARAMS" in globals(), "[LoRA] Provide TARGET_PARAMS = ['w1','w2','w3'] (and/or attn proj names)."
assert isinstance(TARGET_PARAMS, (tuple, list)), "TARGET_PARAMS must be a tuple/list of module names"
_EXPERT_PARAM_SET = {"w1","w2","w3"}
_ATTN_PARAM_SET   = {"q_proj","k_proj","v_proj","o_proj"}

# ---------- helpers ----------
def _dtype_hist(m):
    from collections import Counter as _C
    return dict(_C(str(p.dtype) for _, p in m.named_parameters()))
def _device_hist(m):
    from collections import Counter as _C
    return dict(_C(str(p.device) for _, p in m.named_parameters()))

def _discover_num_experts():
    # find any layer's experts.{e}.w1 and count e indices (prefix-agnostic)
    pat = re.compile(r"\.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1(\.|$)")
    seen = set()
    for n, _ in model.named_parameters():
        m = pat.search(n)
        if m:
            seen.add(int(m.group(2)))
    return max(seen)+1 if seen else 8  # default Mixtral

NUM_EXPERTS = _discover_num_experts()

def _expand_layers(L):
    if L == "all":
        return tuple(range(getattr(model.config, "num_hidden_layers", 32)))
    return tuple(sorted(set(L)))

def _expand_experts(E):
    if E == "all":
        return tuple(range(NUM_EXPERTS))
    return tuple(sorted(set(E)))

# ---------- expand targeting ----------
assert "L" in globals(), "[LoRA] Provide L (list of layer idx or 'all')."
assert "E" in globals(), "[LoRA] Provide E (list of expert idx or 'all')."
TARGET_LAYERS  = _expand_layers(L)
TARGET_EXPERTS = _expand_experts(E)
TARGET_PARAMS  = tuple(sorted(set(TARGET_PARAMS)))
EXPERT_PARAMS  = tuple(sorted(_EXPERT_PARAM_SET.intersection(TARGET_PARAMS)))
ATTN_PARAMS    = tuple(sorted(_ATTN_PARAM_SET.intersection(TARGET_PARAMS)))

# ---------- Build ALLOWLIST of *true* suffixes (to use with .endswith) ----------
#  • MoE MLP: layers.{L}.block_sparse_moe.experts.{E}.w{1,2,3}
#  • Attention: layers.{L}.self_attn.{q_proj,k_proj,v_proj,o_proj}
allow_suffixes = set()
for L_i in TARGET_LAYERS:
    for P in ATTN_PARAMS:
        allow_suffixes.add(f"layers.{L_i}.self_attn.{P}")
    for E_i in TARGET_EXPERTS:
        for P in EXPERT_PARAMS:
            allow_suffixes.add(f"layers.{L_i}.block_sparse_moe.experts.{E_i}.{P}")

if not allow_suffixes:
    raise AssertionError("[LoRA] Empty allowlist — check TARGET_PARAMS / E / L.")

# ---------- prep model for k-bit training ----------
model = prepare_model_for_kbit_training(model)

# ---------- gradient checkpointing & input grads ----------
if GRADIENT_CHECKPOINTING:
    try:
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": GC_REENTRANT})
    except TypeError:
        model.gradient_checkpointing_enable()
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()

# ---------- router trainability ----------
def _is_router_name(name: str) -> bool:
    # anything that looks like routing/gating but not the expert weights themselves
    return ("block_sparse_moe" in name) and (("gate" in name) or (".router." in name)) and (".experts." not in name)

for n, p in model.named_parameters():
    if _is_router_name(n):
        p.requires_grad = (_ROUTER_MODE in ("free","safemoe"))

_n_gate_train = sum(1 for n,p in model.named_parameters() if p.requires_grad and _is_router_name(n))
print(f"[router] mode={_ROUTER_MODE} | trainable router params: {_n_gate_train}")

# ---------- LoRA config ----------
# target_modules tells PEFT which MODULE NAMES to wrap globally; we later enable only our allowlist
target_modules_for_peft = sorted(set(list(EXPERT_PARAMS) + list(ATTN_PARAMS)))
if not target_modules_for_peft:
    # default to MoE-only if user passed only unknowns
    target_modules_for_peft = ["w1","w2","w3"]

lora_cfg = LoraConfig(
    r=int(LORA_R),
    lora_alpha=int(LORA_ALPHA),
    lora_dropout=float(LORA_DROPOUT),
    bias=str(BIAS),
    task_type="CAUSAL_LM",
    target_modules=target_modules_for_peft,
    use_rslora=bool(USE_RS_LORA),
)

# ---------- inject adapters ----------
if hasattr(model, "peft_config"):
    raise RuntimeError("[LoRA] Adapters already attached; reload base model (Step 3) before re-attaching.")
model = get_peft_model(model, lora_cfg)

# ---------- restrict adapters to allowlist ----------
adapted_total = adapted_enabled = adapted_disabled = 0
matched_names, disabled_names = [], []

def _is_lora_augmented(mod):
    return hasattr(mod, "lora_A") and hasattr(mod, "lora_B")

for mod_name, mod in model.named_modules():
    if not _is_lora_augmented(mod):
        continue
    adapted_total += 1
    enable = any(mod_name.endswith(sfx) for sfx in allow_suffixes)
    # Set requires_grad on LoRA A/B only
    pA = getattr(mod, "lora_A", None)
    pB = getattr(mod, "lora_B", None)
    if hasattr(pA, "default"): pA = pA.default
    if hasattr(pB, "default"): pB = pB.default
    if enable:
        matched_names.append(mod_name)
        for _, t in pA.named_parameters(recurse=False): t.requires_grad = True
        for _, t in pB.named_parameters(recurse=False): t.requires_grad = True
        adapted_enabled += 1
    else:
        disabled_names.append(mod_name)
        for _, t in pA.named_parameters(recurse=False): t.requires_grad = False
        for _, t in pB.named_parameters(recurse=False): t.requires_grad = False
        adapted_disabled += 1

# ---------- freeze everything else except router (if router is trainable) ----------
if FREEZE_NON_TARGET:
    for n, p in model.named_parameters():
        if "lora_" in n:
            continue
        if _is_router_name(n) and (_ROUTER_MODE in ("free","safemoe")):
            continue  # leave router trainable per mode
        p.requires_grad = False

# ---------- hardening: verify expected modules are actually wrapped ----------
wrapped_names = [n for n, m in model.named_modules() if _is_lora_augmented(m)]
missing_wrapped = [
    f"...{sfx}" for sfx in sorted(allow_suffixes)
    if not any(n.endswith(sfx) for n in wrapped_names)
]
if missing_wrapped:
    print("[LoRA][WARN] Some expected modules did not receive LoRA wrappers (first 20 shown):")
    for s in missing_wrapped[:20]:
        print("  •", s)

# ---------- inventories ----------
def _is_expert_ffn_module(name: str) -> bool:
    return (".block_sparse_moe.experts." in name) and name.endswith((".w1", ".w2", ".w3"))
all_ffn_modules = [n for n, m in model.named_modules() if _is_expert_ffn_module(n)]
enabled_ffn_modules = matched_names

# ---------- assertions & diagnostics ----------
if adapted_enabled == 0:
    print("[LoRA][DEBUG] No enabled adapters matched. Nearby LoRA-wrapped names around first target layer:")
    hint_layer = TARGET_LAYERS[0]
    modules_dict = dict(model.named_modules())
    near = []
    for name, mod in modules_dict.items():
        is_wrapped = _is_lora_augmented(mod) or ("lora_" in name)
        if is_wrapped and (f".layers.{hint_layer}." in name):
            near.append(name)
    for n in near[:20]:
        print("  •", n)
    print("[LoRA][DEBUG] Expected suffixes (match via .endswith):")
    for s in sorted(allow_suffixes):
        print("  →", s)
    raise AssertionError("[LoRA] No allowed modules received trainable adapters. Check E/L/TARGET_PARAMS.")

_total_params     = sum(p.numel() for p in model.parameters())
_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
_pct = 100.0 * _trainable_params / _total_params

print("=== LoRA attach summary ===")
print(f"use_rslora={getattr(lora_cfg,'use_rslora',False)} | r={LORA_R} | alpha={LORA_ALPHA} | dropout={LORA_DROPOUT}")
print(f"Layers={tuple(TARGET_LAYERS)} | Experts={tuple(TARGET_EXPERTS)} | Params={tuple(TARGET_PARAMS)}")
print(f"Router mode={_ROUTER_MODE} | FREEZE_NON_TARGET={FREEZE_NON_TARGET}")
print(f"Wrapped LoRA modules: {adapted_total} | enabled: {adapted_enabled} | disabled: {adapted_disabled}")
print(f"Expert FFN modules: {len(all_ffn_modules)} | with trainable adapters: {len(enabled_ffn_modules)}")
print(f"Trainable params: {_trainable_params:,} / {_total_params:,} ({_pct:.4f}%)")
print("dtype histogram:", _dtype_hist(model))
print("device histogram:", _device_hist(model))

# Persist manifest of exactly-enabled modules (repro)
FULL_DIR = globals().get("FULL_DIR", "./run_artifacts")
os.makedirs(os.path.join(FULL_DIR, "logs"), exist_ok=True)
with open(os.path.join(FULL_DIR, "logs", "enabled_lora_modules.txt"), "w") as f:
    f.write("\n".join(sorted(enabled_ffn_modules)))

# ---------- dry run ----------
model.eval()
with torch.no_grad():
    device = next(model.parameters()).device
    _ = model(**tok("LoRA dry run.", return_tensors="pt").to(device))
print("[LoRA] Forward dry-run OK")
model.train()


In [None]:
# ================================
# Verifies: adapters attached, allowlist matches, trainables > 0,
# router mode/params, coverage over (L,E), dtype/devices, and a 1-step grad test.
# ================================
import re, math, torch
from collections import Counter, defaultdict

assert "model" in globals(), "[SANITY] `model` not found."
assert "tok"   in globals(), "[SANITY] `tok` tokenizer callable not found."
assert "allow_suffixes" in globals(), "[SANITY] `allow_suffixes` not found (build your allowlist first)."
assert "TARGET_PARAMS" in globals(), "[SANITY] `TARGET_PARAMS` missing."
assert "L" in globals() and "E" in globals(), "[SANITY] Provide L (layers) and E (experts)."
_ROUTER_MODE = str(globals().get("ROUTER_MODE", "frozen")).lower()
FREEZE_NON_TARGET = bool(globals().get("FREEZE_NON_TARGET", True))

def _is_lora_augmented(m):
    return hasattr(m, "lora_A") and hasattr(m, "lora_B")

def _is_router_name(name: str) -> bool:
    return ("block_sparse_moe" in name) and (("gate" in name) or (".router." in name)) and (".experts." not in name)

def _dtype_hist(m): return dict(Counter(str(p.dtype) for _, p in m.named_parameters()))
def _device_hist(m): return dict(Counter(str(p.device) for _, p in m.named_parameters()))

# ---- 1) Adapters present & wrapped modules overview
wrapped_modules = [n for n, m in model.named_modules() if _is_lora_augmented(m)]
print(f"[SANITY] LoRA-wrapped modules: {len(wrapped_modules)}")
print("         (showing up to 10)")
for n in wrapped_modules[:10]:
    print("  •", n)

# ---- 2) Allowlist overlap (what will actually be trainable)
matched_wrapped = [n for n in wrapped_modules if any(n.endswith(sfx) for sfx in allow_suffixes)]
print(f"[SANITY] Allowlist suffixes: {len(allow_suffixes)}")
print(f"[SANITY] Wrapped modules matching allowlist: {len(matched_wrapped)}")
for n in matched_wrapped[:10]:
    print("  ✓", n)

# ---- 3) Count trainable params & list a few tensors
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pct = 100.0 * trainable_params / max(1, total_params)
print(f"[SANITY] Trainable params: {trainable_params:,} / {total_params:,} ({pct:.4f}%)")

trainable_tensors = [n for n,p in model.named_parameters() if p.requires_grad]
print(f"[SANITY] Trainable tensors: {len(trainable_tensors)} (showing up to 15)")
for n in trainable_tensors[:15]:
    print("  •", n)

# ---- 4) Router mode & trainable router params
router_trainables = [n for n,p in model.named_parameters() if p.requires_grad and _is_router_name(n)]
print(f"[SANITY] Router mode={_ROUTER_MODE} | trainable router tensors: {len(router_trainables)}")
if router_trainables[:5]:
    print("         (first few)")
    for n in router_trainables[:5]:
        print("  •", n)

# ---- 5) Coverage matrix over (Layer, Expert) for MoE w1/w2/w3
# Extract (L,E,proj) from matched module names
coverage = defaultdict(lambda: defaultdict(set))  # coverage[L][E] -> set of {'w1','w2','w3'}
pat = re.compile(r"\.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w[123])$")
for n in matched_wrapped:
    m = pat.search(n)
    if m:
        L_i = int(m.group(1)); E_i = int(m.group(2)); P = m.group(3)
        coverage[L_i][E_i].add(P)

# summarize desired vs actual
def _expand_layers(L):
    if L == "all":
        return tuple(range(getattr(model.config, "num_hidden_layers", 32)))
    return tuple(sorted(set(L)))

def _expand_experts(E):
    # detect count by scanning (fallback=8)
    seen = set()
    pat_e = re.compile(r"\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.w1")
    for n,_ in model.named_parameters():
        m = pat_e.search(n)
        if m: seen.add(int(m.group(1)))
    NUM_E = max(seen)+1 if seen else 8
    if E == "all":
        return tuple(range(NUM_E))
    return tuple(sorted(set(E)))

LAYERS  = _expand_layers(L)
EXPERTS = _expand_experts(E)
want_params = set(TARGET_PARAMS)

missing = []
hit = 0
for Li in LAYERS:
    for Ei in EXPERTS:
        got = coverage.get(Li, {}).get(Ei, set())
        need = want_params.intersection({"w1","w2","w3"})
        if need:
            if not need.issubset(got):
                missing.append((Li, Ei, sorted(list(need - got))))
            else:
                hit += 1

print(f"[SANITY] MoE coverage: hit {hit} (L,E pairs have all desired w1/w2/w3). Missing: {len(missing)}")
for (Li,Ei,miss) in missing[:10]:
    print(f"  • Missing {miss} at L={Li}, E={Ei}")

# ---- 6) Dtype / device overview
print("[SANITY] dtype histogram:", _dtype_hist(model))
print("[SANITY] device histogram:", _device_hist(model))

# ---- 7) One-step gradient flow check (very small prompt)
# Ensures LoRA tensors receive non-zero grads; base weights remain frozen (if FREEZE_NON_TARGET=True).
model.train()
model.zero_grad(set_to_none=True)
sample = "Sanity check: the quick brown fox."
inputs = tok(sample, return_tensors="pt").to(next(model.parameters()).device)
# Use standard LM loss: labels = input_ids (shift handled by model)
out = model(**inputs, labels=inputs["input_ids"])
loss = out.loss
loss.backward()

# LoRA grad stats
lora_grads = []
base_grads = []
for n,p in model.named_parameters():
    is_lora = ("lora_" in n)
    if p.grad is None:
        continue
    gnorm = p.grad.detach().norm().item()
    if is_lora and p.requires_grad:
        lora_grads.append(gnorm)
    if (not is_lora) and p.requires_grad and FREEZE_NON_TARGET:
        base_grads.append(gnorm)

print(f"[SANITY] Loss={loss.item():.4f} | LoRA grad tensors (nonzero): {sum(g>0 for g in lora_grads)} / {len(lora_grads)}")
if FREEZE_NON_TARGET:
    print(f"[SANITY] Base (non-LoRA) grad tensors (should be ~0): {len(base_grads)}")

# Extra assertions to catch silent misconfig
assert len(wrapped_modules) > 0, "[SANITY] No LoRA-wrapped modules on model."
assert len(matched_wrapped) > 0, "[SANITY] Allowlist did not match any wrapped module names."
assert trainable_params > 0,     "[SANITY] 0 trainable params — check allowlist/L/E/TARGET_PARAMS."
assert any(g > 0 for g in lora_grads), "[SANITY] No non-zero grads on LoRA tensors after backward()."
print("[SANITY] All checks passed.")


# Load Data

In [None]:
# =========================
# [5/6] LOAD DATA — assistant-only masking, robust & collator-friendly
# =========================
import os, random, json
from datasets import load_dataset
from transformers import AutoTokenizer

# ---------- Logging helpers (output-only; no functional change) ----------
def _fmt_int(x):
    try:
        return f"{int(x):,}"
    except Exception:
        return str(x)

def _h(s):
    print(f"\n=== {s} ===")

def _log(k, v):
    print(f"{k:<18}: {v}")

# ------------------------------------------------------------------------

MODEL_NAME     = globals().get("MODEL_NAME", MODEL_NAME)
DATA_PATH      = globals().get("DATA_PATH", DATA_PATH)
MAX_SEQ_LEN    = int(globals().get("MAX_SEQ_LEN", 2048))
MAX_CHARS      = int(globals().get("MAX_CHARS", 6000))
ASSIST_ONLY    = bool(globals().get("ASSISTANT_ONLY_MASKING", True))
VAL_FRAC       = float(globals().get("VAL_FRAC", 0.1))  # default 10% val split

SEED = int(globals().get("SEEDS", [globals().get("SEED", 42)])[0]) if "SEEDS" in globals() else int(globals().get("SEED", 42))
random.seed(SEED)

# tokenizer (reuse or create)
if tok is None:
    tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

tok.padding_side = "right"
tok.truncation_side = "right"

_h("Tokenizer")
_log("model", MODEL_NAME)
_log("max_len", MAX_SEQ_LEN)
_log("pad_id", tok.pad_token_id)
_log("eos_id", tok.eos_token_id)

# ---- Load raw JSONL ----
raw = load_dataset("json", data_files={"train": DATA_PATH})["train"]

_h("Raw dataset")
_log("path", DATA_PATH)
_log("rows", _fmt_int(len(raw)))
_log("columns", raw.column_names)

# ---- Basic schema guard ----
def _schema_ok(ex):
    return (
        "messages" in ex
        and isinstance(ex["messages"], list)
        and all(isinstance(m, dict) and "role" in m and "content" in m for m in ex["messages"])
    )

viol = sum(1 for ex in raw if not _schema_ok(ex))
_log("schema violations", _fmt_int(viol))

# ---- Render last assistant supervision (prompt/completion) ----
def render_last_assistant(example):
    msgs = example["messages"]
    last_ass_idx = max((i for i, m in enumerate(msgs) if m.get("role") == "assistant"), default=None)
    if last_ass_idx is None:
        return {"drop": True}

    prompt_msgs = msgs[:last_ass_idx]
    assistant_turn = msgs[last_ass_idx]

    # Full text includes target; prompt_only is everything before
    full_txt = tok.apply_chat_template(prompt_msgs + [assistant_turn], tokenize=False, add_generation_prompt=False)
    prompt_only = tok.apply_chat_template(prompt_msgs, tokenize=False, add_generation_prompt=True)

    # We do NOT pad here; just inspect lengths to filter/guard
    full_ids   = tok(full_txt, add_special_tokens=False, truncation=False)["input_ids"]
    prompt_ids = tok(prompt_only, add_special_tokens=False, truncation=False)["input_ids"]

    return {
        "prompt": prompt_only,
        "completion": full_txt[len(prompt_only):],
        "text": full_txt,
        "prefix_len": len(prompt_ids),     # tokens belonging to the prompt
        "n_tokens": len(full_ids),
        "char_len": len(full_txt),
        "drop": False,
    }

_h("Render + guards")
print("[render] applying chat template + extracting last assistant…")
rendered = raw.map(render_last_assistant, desc="[render]")

# ---- Filtering: drop invalid/too-long items ----
rendered = rendered.filter(lambda ex: ex.get("drop") is False)
rendered = rendered.filter(lambda ex: ex["char_len"] <= MAX_CHARS)
rendered = rendered.filter(lambda ex: ex["n_tokens"] <= MAX_SEQ_LEN)
rendered = rendered.filter(lambda ex: 0 < ex["prefix_len"] < ex["n_tokens"])
_log("kept after guards", _fmt_int(len(rendered)))

# ---- Train/val split (conversation-level) ----
keep_cols = {"prompt", "completion", "text", "prefix_len", "n_tokens", "char_len"}
rendered = rendered.remove_columns([c for c in rendered.column_names if c not in keep_cols])

split = rendered.train_test_split(test_size=VAL_FRAC, seed=SEED)
train_ds, val_ds = split["train"], split["test"]

_h("Split")
_log("VAL_FRAC", VAL_FRAC)
_log("train", _fmt_int(len(train_ds)))
_log("val", _fmt_int(len(val_ds)))

# ---- Tokenize + assistant-only masking (no padding here) ----
# We DO NOT pad in this step. Padding is deferred to the data collator at batch time.
# Labels:
#   - Positions < prefix_len are set to -100 (ignored)
#   - We do NOT add padding here, so there are no pad tokens in labels at this stage.
#   - The collator will pad input_ids/labels/attention_mask and must set padded label positions to -100.
def tok_and_mask(batch):
    texts = batch["text"]
    prefix_lens = batch["prefix_len"]

    enc = tok(
        texts,
        add_special_tokens=False,
        truncation=True,
        max_length=MAX_SEQ_LEN,
        padding=False,                 # <-- no pre-padding; collator will handle padding
        return_attention_mask=True,
    )

    input_ids = enc["input_ids"]
    attention_mask = enc["attention_mask"]

    labels = []
    for ids, pref in zip(input_ids, prefix_lens):
        lab = ids.copy()
        if ASSIST_ONLY:
            # ignore all prompt tokens
            mask_upto = min(pref, len(lab))
            for k in range(mask_upto):
                lab[k] = -100
        # No padding here; any future padded positions will be -100 inside the collator
        labels.append(lab)

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

cols_keep = ["input_ids", "attention_mask", "labels"]
TRAIN_DS_TOKENIZED = train_ds.map(tok_and_mask, batched=True, remove_columns=train_ds.column_names)
VAL_DS_TOKENIZED   = val_ds.map(tok_and_mask,   batched=True, remove_columns=val_ds.column_names)

# Torch-format the required columns
TRAIN_DS_TOKENIZED.set_format(type="torch", columns=cols_keep)
VAL_DS_TOKENIZED.set_format(type="torch", columns=cols_keep)

DATA_OBJECTS = {
    "train_ds": TRAIN_DS_TOKENIZED,
    "val_ds":   VAL_DS_TOKENIZED,
}

_h("Ready")
_log("ASSIST_ONLY", ASSIST_ONLY)
print(f"[OK] TRAIN/VAL ready. Using {int((1-VAL_FRAC)*100)}/{int(VAL_FRAC*100)} train/val split.")


# Training

In [None]:
# =========================
# [6/6] FINETUNE — robust Trainer (manual padding collator, eval-safe, resumable)
# =========================
import os, json, time, torch
from typing import List, Dict, Any
from transformers import Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.trainer_utils import set_seed

# --------- Seeds / determinism ---------
SEED = int(SEEDS[0]) if isinstance(SEEDS, (list, tuple)) and len(SEEDS) > 0 else int(globals().get("SEED", 42))
set_seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

# --------- Datasets ---------
train_ds = DATA_OBJECTS["train_ds"]
val_ds   = DATA_OBJECTS["val_ds"]

# --------- Output dirs ---------
os.makedirs(FULL_DIR, exist_ok=True)
CKPT_DIR = os.path.join(FULL_DIR, "checkpoints")
LOG_DIR  = os.path.join(FULL_DIR, "logs")
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

RUN_META_JSON = os.path.join(LOG_DIR, "run_meta.json")

# --------- Router mode (for optional SafeMoE penalty) ---------
_router_mode = globals().get("ROUTER_MODE", "frozen")
if "FREEZE_ROUTER" in globals():
    _router_mode = "frozen" if bool(FREEZE_ROUTER) else _router_mode
_safe_lambda = float(globals().get("SAFE_MOE_LAMBDA", 1.0))

# --------- Persist run meta ---------
_run_meta = dict(
    model_name=MODEL_NAME,
    dt=time.strftime("%Y-%m-%d %H:%M:%S"),
    seeds=list(SEEDS),
    max_seq_len=int(MAX_SEQ_LEN),
    lora=dict(
        use_rs_lora=bool(USE_RS_LORA), r=int(LORA_R), alpha=int(LORA_ALPHA), dropout=float(LORA_DROPOUT), bias=str(BIAS),
        target_layers=tuple(TARGET_LAYERS), target_experts=tuple(TARGET_EXPERTS), target_params=tuple(TARGET_PARAMS),
        router_mode=_router_mode, safe_moe_lambda=_safe_lambda,
    ),
    optim=dict(name=OPTIM, weight_decay=float(WEIGHT_DECAY), lr=float(LR), scheduler=LR_SCHEDULER,
               warmup_ratio=float(WARMUP_RATIO), grad_accum=int(GRAD_ACCUM), per_device_batch=int(PER_DEVICE_BATCH)),
    logging=dict(log_every=int(LOG_EVERY), save_every=int(SAVE_EVERY), save_total_limit=int(SAVE_TOTAL_LIMIT),
                 report_to=REPORT_TO),
    data_stats=dict(train=len(train_ds), val=len(val_ds)),
    run_tag=RUN_TAG, full_dir=FULL_DIR,
)
with open(RUN_META_JSON, "w") as f:
    json.dump(_run_meta, f, indent=2)

# --------- Persist trainable inventory ---------
trainable_names = [n for n, p in model.named_parameters() if p.requires_grad]
with open(os.path.join(LOG_DIR, "trainable_params.txt"), "w") as f:
    f.write("\n".join(trainable_names))

# --------- Optimizer guard (bnb paged adamw 8bit) ---------
if OPTIM == "paged_adamw_8bit":
    try:
        from bitsandbytes.optim import PagedAdamW8bit as _T  # noqa: F401
    except Exception as e:
        print(f"[optim] WARNING: BnB PagedAdamW8bit unavailable: {e}\n         Falling back to 'adamw_torch'.")
        OPTIM = "adamw_torch"

# --------- Model runtime flags ---------
use_bf16 = (str(DTYPE_COMPUTE).lower() == "bf16") and torch.cuda.is_available()
if hasattr(model.config, "use_cache"):
    model.config.use_cache = False  # important for training

# --------- Manual padding collator (fast, mismatch-proof) ---------
def _to_list(x):
    # Accept torch.Tensor or list; return Python list
    if isinstance(x, torch.Tensor):
        return x.tolist()
    return list(x)

def clm_collator(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    pad_id = tok.pad_token_id
    ignore_index = -100

    input_ids = [ _to_list(ex["input_ids"]) for ex in examples ]
    attn_masks = [ _to_list(ex["attention_mask"]) for ex in examples ]
    labels     = [ _to_list(ex["labels"]) for ex in examples ]

    # Sanity: lengths must match per example pre-padding
    for i, (ids, lab) in enumerate(zip(input_ids, labels)):
        if len(ids) != len(lab):
            # If this ever triggers, your dataset preprocessing diverged
            raise ValueError(f"Example {i}: len(input_ids)={len(ids)} != len(labels)={len(lab)}")

    max_len = max(len(ids) for ids in input_ids)

    def pad(seq, L, val):
        return seq + [val] * (L - len(seq))

    # Pad everything to max_len
    input_ids = [ pad(ids, max_len, pad_id) for ids in input_ids ]
    attn_masks = [ pad(m,   max_len, 0)     for m   in attn_masks ]
    # Important: labels use -100 for pad positions
    labels = [ pad(lab, max_len, ignore_index) for lab in labels ]

    batch = {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attn_masks, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
    }
    return batch

# --------- Training args (eval-safe) ---------
args = TrainingArguments(
    output_dir=CKPT_DIR,
    per_device_train_batch_size=PER_DEVICE_BATCH,
    per_device_eval_batch_size=max(1, PER_DEVICE_BATCH),
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    adam_beta1=0.9, adam_beta2=0.95,
    weight_decay=WEIGHT_DECAY,
    lr_scheduler_type=LR_SCHEDULER,
    warmup_ratio=WARMUP_RATIO,
    max_steps=MAX_STEPS if MAX_STEPS else None,
    num_train_epochs=NUM_EPOCHS if NUM_EPOCHS else 1,

    logging_steps=LOG_EVERY,
    save_steps=SAVE_EVERY,
    save_total_limit=SAVE_TOTAL_LIMIT,

    eval_strategy="steps",
    eval_steps=SAVE_EVERY,
    prediction_loss_only=True,     # no logits kept → avoids big concats
    eval_accumulation_steps=8,     # harmless here; helps if metrics later enabled

    bf16=use_bf16, fp16=False,
    gradient_checkpointing=GRADIENT_CHECKPOINTING,
    gradient_checkpointing_kwargs={"use_reentrant": bool(GC_REENTRANT)},
    max_grad_norm=CLIP_GRAD_NORM,

    report_to=None if REPORT_TO == "none" else REPORT_TO,
    optim=OPTIM,
    ddp_find_unused_parameters=False,
    remove_unused_columns=False,
    save_safetensors=True,

    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

# --------- Optional: clear-cache callback ---------
class ClearCacheCallback(TrainerCallback):
    def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass
        return control

# --------- Build Trainer ---------
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=clm_collator,   # ← manual, mismatch-proof collator
    tokenizer=tok,
    compute_metrics=None,
    callbacks=[ClearCacheCallback()],
)

# --------- SafeMoE penalty (router_mode == 'safemoe') ---------
if _router_mode == "safemoe":
    def _is_router_name(name: str) -> bool:
        return ("block_sparse_moe" in name) and (("gate" in name) or (".router." in name)) and (".experts." not in name)

    router_pre = {k: v.detach().float().clone() for k, v in model.named_parameters() if _is_router_name(k)}
    λ = _safe_lambda
    orig_compute_loss = trainer.compute_loss

    def compute_loss_with_router_penalty(model, inputs, return_outputs=False):
        out = orig_compute_loss(model, inputs, return_outputs=True)
        loss, outputs = out
        pen = torch.tensor(0.0, device=loss.device)
        for k, v in model.named_parameters():
            if _is_router_name(k) and v.requires_grad and (k in router_pre) and (router_pre[k].shape == v.shape):
                p_new = torch.nn.functional.log_softmax(v.float(), dim=-1)
                p_old = torch.nn.functional.log_softmax(router_pre[k].to(v.device), dim=-1)
                pen = pen + torch.nn.functional.kl_div(p_new, p_old, log_target=True, reduction="batchmean")
        loss = loss + λ * pen
        return (loss, outputs) if return_outputs else loss

    trainer.compute_loss = compute_loss_with_router_penalty
    print(f"[router] SafeMoE penalty enabled (λ={λ})")

# --------- Resume (if checkpoints exist) ---------
resume_arg = None
if bool(globals().get("RESUME", False)):
    ckpts = sorted(
        [p for p in os.listdir(CKPT_DIR) if p.startswith("checkpoint-")],
        key=lambda s: int(s.split("-")[-1])
    ) if os.path.isdir(CKPT_DIR) else []
    if ckpts:
        resume_arg = os.path.join(CKPT_DIR, ckpts[-1])
        print(f"[resume] Resuming from: {resume_arg}")

print(f"[run] Checkpoints → {CKPT_DIR}")
print("[run] Starting training…")
trainer.train(resume_from_checkpoint=resume_arg)


In [None]:
import os, signal

# Get the current process ID
pid = os.getpid()

# Kill it
os.kill(pid, signal.SIGKILL)
