## Session Setup

In [1]:
# Check to make sure there are multiple gpus available
import torch, os
!export CUDA_LAUNCH_BLOCKING=1

import gc, torch, os
torch.cuda.empty_cache()
gc.collect()

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.chdir("/mnt/home/amir/framingdecomp/framingDecomp")

print("Devices visible:", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("torch.cuda.device_count():", torch.cuda.device_count())



Devices visible: 7
torch.cuda.device_count(): 1


In [2]:
# ==== Cell: [Session setup] ====

import os, sys, logging, random, yaml, time, uuid, json
from pathlib import Path

import torch
import numpy as np

# ——— switches you may tune ——————————————————————————
USE_MULTIGPU     = True          # False → single-GPU
VISIBLE_DEVICES  = "7"#"0,1,2,3"
CFG_PATH         = "configs/decomposer_main.yaml"
# ————————————————————————————————————————————————


# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Device: {device}, GPUs: {torch.cuda.device_count()}")

# ---------- logging ----------
Path("logs").mkdir(exist_ok=True)
ts = time.strftime("%Y%m%d_%H%M%S")
log_path = Path(f"logs/decomposer_{ts}.log")

# root logger -> both console and file
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s — %(levelname)s — %(message)s",
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler(log_path, mode="w")
    ]
)
logger = logging.getLogger("train_decomposer")
logger.info("Log file created at %s", log_path)


# --- config, seeds, logging -----------------------------------
with open(CFG_PATH, "r") as f:
    config = yaml.safe_load(f)

seed = config["experiment"]["seed"]
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

logger.info("Config loaded and seeds set.")

2025-07-21 19:58:41,300 — INFO — Log file created at logs/decomposer_20250721_195841.log
2025-07-21 19:58:41,321 — INFO — Config loaded and seeds set.


## Data Loading and Preprocessing

In [None]:
# ==== Cell: [Data loading & preprocessing] ====

import json
from typing import List, Dict

def load_jsonl(path: str) -> List[Dict]:
    with open(path, "r") as f:
        return [json.loads(l) for l in f if l.strip() and not l.strip().startswith("#")]

data_cfg = config["data"]

raw_F_harm  = load_jsonl(data_cfg["input_path_varyFraming"])
raw_G_harm  = load_jsonl(data_cfg["input_path_varyGoal"])
raw_F_ben   = load_jsonl(data_cfg["input_path_varyFraming_benign"])
raw_G_ben   = load_jsonl(data_cfg["input_path_varyGoal_benign"])

for e in raw_F_harm + raw_F_ben: e["split"] = "varyF"
for e in raw_G_harm + raw_G_ben: e["split"] = "varyG"

def _preprocess(entries: List[Dict], max_f_idx: int):
    processed = []
    for ent in entries:
        req = ["prompt","goal","goal_index","framing_index","split"]
        if not all(k in ent for k in req): 
            continue
        g, f = ent["goal_index"], ent["framing_index"]
        if ent["split"] == "varyF":                 # re-index framings
            f = g if f == 0 else max_f_idx + 1
            max_f_idx = max(max_f_idx, f)
        processed.append({
            "text":  ent["prompt"],
            "goal":  ent["goal"],
            "goal_index": g,
            "framing_index": f,
            "label": ent.get("jailbroken", False),
            "split": ent["split"],
        })
    return processed, max_f_idx

max_idx = max(e["framing_index"]
              for e in raw_F_harm + raw_G_harm + raw_F_ben + raw_G_ben)

P_F_harm, max_idx = _preprocess(raw_F_harm, max_idx)
P_G_harm, max_idx = _preprocess(raw_G_harm, max_idx)
P_F_ben , max_idx = _preprocess(raw_F_ben , max_idx)
P_G_ben , max_idx = _preprocess(raw_G_ben , max_idx)

all_samples = P_F_harm + P_G_harm + P_F_ben + P_G_ben
n_total = len(all_samples)
n_sample = int(n_total * config.get('sample_prop', 1.))
n_sample = max(n_sample, 500)
if n_sample < n_total:
    logger.info("Sampling %d out of %d total samples", n_sample, n_total)
    all_samples = random.sample(all_samples, n_sample)
logger.info("Total processed samples: %d", len(all_samples))

2025-07-21 19:58:41,497 — INFO — Total processed samples: 528


## Dataloader

In [4]:
# ==== Cell: [Dataset & dataloader] ====

# from utils.misc import DualPairDataset, collate_dual

from collections import defaultdict
from typing import Tuple
from torch.utils.data import Dataset, DataLoader

class DualPairDataset(Dataset):
    """
    Returns (sample_a, sample_b, pair_type)
      pair_type = 0 → same-goal / diff-frame  (from varyF)
      pair_type = 1 → same-frame / diff-goal  (from varyG)
    """
    def __init__(self, samples, stratified_capping=True):
        self.samples = samples
        self.goal_pairs, self.frame_pairs = [], []

        by_goal_F  = defaultdict(list)
        by_frame_G = defaultdict(list)

        for idx, s in enumerate(samples):
            if s["split"] == "varyF":  by_goal_F [s["goal_index"]   ].append(idx)
            else:                      by_frame_G[s["framing_index"]].append(idx)

        for lst in by_goal_F.values():
            self.goal_pairs  += [(a,b,0) for a in lst for b in lst if a<b]
        for lst in by_frame_G.values():
            self.frame_pairs += [(a,b,1) for a in lst for b in lst if a<b]

        # --- stratified capping ---------------------------------
        # this improved the performance a bit
        if stratified_capping:
            cap = int(np.median([len(v) for v in by_goal_F.values()]))
            for g, lst in by_goal_F.items():
                if len(lst) > cap:               # down-sample heavy goals
                    by_goal_F[g] = random.sample(lst, cap)
        # --------------------------------------------------------------

        self.all_pairs = self.goal_pairs + self.frame_pairs

    def __len__(self): return len(self.all_pairs)
    def __getitem__(self, k): return self.all_pairs[k]

def collate_dual(batch) -> Tuple[list,str,str,torch.Tensor]:
    """
    batch → (texts, goal_ids, frame_ids, pair_types)
    """
    texts, gid, fid, ptype = [], [], [], []
    for a,b,t in batch:
        sa, sb = all_samples[a], all_samples[b]
        texts.extend([sa["text"], sb["text"]])
        gid.extend([sa["goal_cid"], sb["goal_cid"]])
        fid.extend([sa["framing_index"], sb["framing_index"]])
        ptype.append(t)
    return (texts,
            torch.tensor(gid),
            torch.tensor(fid),
            torch.tensor(ptype))

# contiguous goal ids
unique_goals = sorted({s["goal_index"] for s in all_samples})
goal2cid     = {g:i for i,g in enumerate(unique_goals)}
for s in all_samples: s["goal_cid"] = goal2cid[s["goal_index"]]

train_ds = DualPairDataset(all_samples)
logger.info("Goal pairs: %d   Frame pairs: %d   Total pairs: %d",
            len(train_ds.goal_pairs),
            len(train_ds.frame_pairs),
            len(train_ds))

2025-07-21 19:58:41,595 — INFO — Goal pairs: 807   Frame pairs: 986   Total pairs: 1793


## Training and Launcher

In [5]:
# from models.decomposer import NonlinearDecomposer, NonlinearDecomposer_tiny
# dec = NonlinearDecomposer_tiny(enc_dim=4096)

In [None]:
# ==== Cell: [Training worker & launch] ====

from utils.misc import set_seed
from accelerate import notebook_launcher
import torch.multiprocessing as mp
mp.set_start_method("spawn", force=True)

ts = None
run_id = None
global run_id, ts

def train_worker():
    import gc, yaml, torch, torch.distributed as dist
    from torch.optim import AdamW
    from torch.optim.lr_scheduler import CosineAnnealingLR
    from torch.nn.parallel import DistributedDataParallel as DDP
    from train_test.decomposer_training import train_decomposer
    from utils.model_utils import load_model_multiGPU
    from models.encoder import HFEncoder_notPooled
    from models.decomposer import NonlinearDecomposer, NonlinearDecomposer_tiny

    gc.collect(); torch.cuda.empty_cache()

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    distributed = world_size > 1

    if distributed:
        dist.init_process_group(
            backend="nccl", init_method="env://",
            rank=local_rank, world_size=world_size,
        )

    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    # ——— load LLM once per rank ———
    model, tokenizer = load_model_multiGPU(
        model_name=config["model"]["name"],
        local_rank=local_rank,
        load_in_8bit=False,
        load_in_4bit=False,
    )
    
    config["model"]["layers"] = 'all'#[-1]
    if config["model"]["layers"] == 'all':
        num_layers = model.config.num_hidden_layers
        layers = list(range(num_layers))
    else:
        layers = config["model"]["layers"]
        if isinstance(layers, int): layers = [layers]
    
    # --- Set arguments # TODO: Delete and instead fix the config
    config['experiment']['use_sae'] = config['experiment'].get('use_sae', False)  # use Sparse Autoencoder
    config['training']['num_epochs'] = 3
    init_lambda_orth = config['lambda_orth']
    config['lambda_repulse'] = config.get('lambda_repulse', 6.0) 
    config['lambda_adv'] = config.get('lambda_adv', 2.0)  
    config['lambda_sparse'] = config.get('lambda_sparse', None)  
    config['lambda_recon'] = config.get('lambda_recon', 1.0)  
    config['lambda_Worth'] = config.get('lambda_Worth', 0.05)
    config['grad_accum_steps'] = config.get('grad_accum_steps', 8)
    config['model']['layers'] = layers
    config['model']['layer_combine'] = config['model'].get('layer_combine', 'mean')
    config['model']['last_token'] = config['model'].get('last_token', False)
    config['training']['num_epochs']   = 3
    config['training']['batch_size']   = 8
    # ---
    
    for layer in layers:
        torch.manual_seed(config['experiment']['seed'])
        torch.cuda.manual_seed_all(config['experiment']['seed'])
        np.random.seed(config['experiment']['seed'])
        random.seed(config['experiment']['seed'])
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
        logger.info(f"Training decomposer for layer {layer}")
        encoder = HFEncoder_notPooled(
            model=model,
            tokenizer=tokenizer,
            device=device,
            layers=[layer],
            layer_combine=config["model"].get("layer_combine", "mean"),
            last_token=config["model"].get("last_token", True),
        )
        encoder.eval()

        # ——— decomposer ———
        enc_dim = model.config.hidden_size
        dec = NonlinearDecomposer(
            enc_dim=enc_dim,
            d_g=config["d_g"],
            d_f=config["d_f"],
            hidden_dim=config.get("hidden_dim", 1024),
            dropout=config.get("dropout", 0.1),
        ).to(device)
        d_f = dec.Wf(torch.randn(4096).to(device)).shape[0]
        
        gen = torch.Generator()
        gen.manual_seed(config['experiment']['seed']) 
        if distributed:
            dec = DDP(dec, device_ids=[local_rank])
            # ——— dataloader ———
            from torch.utils.data import DistributedSampler
            sampler = DistributedSampler(train_ds, rank=local_rank,
                                            num_replicas=world_size, shuffle=True, seed=config['experiment']['seed'])
            train_loader = DataLoader(
                train_ds,
                batch_size=config["training"]["batch_size"],
                sampler=sampler,
                collate_fn=collate_dual,
                num_workers=8,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda wid: set_seed(config['experiment']['seed'] + wid),#seed_worker,
                generator=gen,
            )
        else:
            dec.__dict__["module"] = dec
            # from torch.utils.data import DistributedSampler
            # sampler = DistributedSampler(train_ds, rank=0,
            #                             num_replicas=1, shuffle=True)
            train_loader  = DataLoader(train_ds,
                                 batch_size=config["training"]["batch_size"], 
                                 collate_fn=collate_dual,
                                 num_workers=0, pin_memory=False, shuffle=True, sampler=None,
                                 worker_init_fn=lambda wid: set_seed(config['experiment']['seed'] + wid),
                                 generator=gen)

        # ——— optim & sched ———
        opt = AdamW(dec.parameters(), lr=config["lr"])
        sched = CosineAnnealingLR(opt,
                                    T_max=len(train_ds)*config["training"]["num_epochs"])

        # adversarial classifier
        n_goals = len(unique_goals)
        adv_clf = torch.nn.Linear(d_f, n_goals).to(device)
        if distributed:
            adv_clf = DDP(adv_clf, device_ids=[local_rank])
        adv_opt = AdamW(adv_clf.parameters(), lr=1e-4)

        # ——— train ———
        stats = train_decomposer(
            encoder     = encoder,
            decomposer  = dec,
            dataloader  = train_loader,
            optimizer   = opt,
            adv_clf     = adv_clf,
            adv_opt     = adv_opt,
            lambda_adv  = config.get('lambda_adv', 2.0),
            scheduler   = sched,
            device      = device,
            epochs      = config['training']['num_epochs'],
            lambda_g    = config['lambda_g'],
            lambda_f    = config['lambda_f'],
            lambda_repulse = config.get('lambda_repulse', 6.0),  # optional
            lambda_orth = config['lambda_orth']*10,
            lambda_recon = config.get('lambda_recon', 1.0),  # optional
            lambda_Worth = config.get('lambda_Worth', 0.25),
            grad_clip   = config['grad_clip'],
            grad_accum_steps = config["training"].get('grad_accum_steps', 4),
            log_every   = 50,
            info        = logger.info,
            layer_str= f"{layer}",
        )

        # ——— checkpoint (rank-0 only) ———
        if local_rank == 0:
            # ts = time.strftime("%Y%m%d_%H%M%S")
            # run_id = str(uuid.uuid4())
            global run_id, ts
            ckpt_dir = Path(f"checkpoints/decomposer_simple/decomposer_layer{layer}_{ts}_{run_id}")
            ckpt_dir.mkdir(parents=True, exist_ok=True)
            
            if layer == 0:
                config_filename = f"./output/config_{ts}_{run_id}.yaml"
                with open(config_filename, "w") as f:
                    yaml.safe_dump(config, f)

            torch.save({k: v.cpu() for k, v in dec.module.state_dict().items()},
                        ckpt_dir / "weights.pt")
            with open(ckpt_dir / "train_stats.json", "w") as f:
                json.dump(stats, f)
            logger.info("Checkpoint for layer %d saved to %s", layer, ckpt_dir)

    if distributed:
        dist.destroy_process_group()



In [7]:
# import sys
# sys.path.append('/mnt/home/amir/framingdecomp/framingDecomp')
# # import os
# # os.environ['PYTHONPATH'] = '/mnt/home/amir/framingdecomp/framingDecomp'

In [8]:
# ——— launch ———
# from random import randint
# os.environ["MASTER_PORT"] = str(15000 + randint(0, 10000))

# num_proc = torch.cuda.device_count() if USE_MULTIGPU else 1

global run_id, ts
ts = time.strftime("%Y%m%d_%H%M%S")
run_id = str(uuid.uuid4())
num_proc=1
notebook_launcher(train_worker, num_processes=num_proc)

Launching training on one GPU.




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2025-07-21 19:58:51,749 — INFO — Training decomposer for layer -1


  scaler = torch.cuda.amp.GradScaler()
epoch 0, layer -1:   0%|                                                                                               | 0/225 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/mnt/home/amir/python311/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/home/amir/python311/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'DualPairDataset' on <module '__main__' (built-in)>
                                                                                                                                               

KeyboardInterrupt: 

In [None]:
logger

<Logger train_decomposer (INFO)>