In [None]:
!pip install transformers datasets accelerate bitsandbytes


Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaAttention
from datasets import load_dataset
from tqdm import tqdm
import types
from typing import Optional, Tuple


In [None]:
class VectorGenerator(nn.Module):
    """
    PARA Vector Generator (Section 3.2 Eq 6).
    """
    def __init__(self, d_model, d_ffn, r):
        super().__init__()
        self.d_model = d_model
        self.d_ffn = d_ffn
        self.d_out = 2 * d_model + d_ffn

        self.down_proj = nn.Linear(d_model, r, bias=False)
        self.activation = nn.GELU()
        self.up_proj = nn.Linear(r, self.d_out)

        nn.init.normal_(self.down_proj.weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.up_proj.weight)
        nn.init.ones_(self.up_proj.bias)

    def pooler(self, hidden_states, prompt_mask):
        prompt_lengths = torch.sum(prompt_mask, dim=1).clamp(min=1)
        last_idx = prompt_lengths - 1
        batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
        return hidden_states[batch_idx, last_idx]

    def forward(self, hidden_states, prompt_mask):
        pooled = self.pooler(hidden_states, prompt_mask)
        z = self.activation(self.down_proj(pooled))
        l_q, l_v, l_u = torch.split(self.up_proj(z),
                                    [self.d_model, self.d_model, self.d_ffn],
                                    dim=-1)
        return l_q, l_v, l_u


In [None]:
def para_forward_method(
    self,
    hidden_states,
    attention_mask=None,
    position_ids=None,
    past_key_value=None,
    output_attentions=False,
    use_cache=True,
    prompt_attention_mask=None,
    **kwargs
):
    residual = hidden_states

    # --- PARA VECTOR LOGIC ---
    if past_key_value is not None:
        (pk, pv, (l_q, l_v, l_u)) = past_key_value
    else:
        if prompt_attention_mask is None:
            prompt_attention_mask = torch.ones_like(hidden_states[:, :, 0])
        l_q, l_v, l_u = self.vector_generator(hidden_states, prompt_attention_mask)
        l_q = l_q.unsqueeze(1)
        l_v = l_v.unsqueeze(1)
        l_u = l_u.unsqueeze(1)

    # ---- LLaMA Attention ----
    hidden_states = self.input_layernorm(hidden_states)

    q = self.self_attn.q_proj(hidden_states)
    k = self.self_attn.k_proj(hidden_states)
    v = self.self_attn.v_proj(hidden_states)

    q = l_q * q
    v = l_v * v

    B, L, _ = q.shape
    num_heads = self.self_attn.num_heads
    head_dim = self.self_attn.head_dim

    q = q.view(B, L, num_heads, head_dim).transpose(1,2)
    k = k.view(B, L, self.self_attn.num_key_value_heads, head_dim).transpose(1,2)
    v = v.view(B, L, self.self_attn.num_key_value_heads, head_dim).transpose(1,2)

    kv_seq = k.shape[-2]
    if past_key_value is not None:
        kv_seq += pk.shape[-2]
    cos, sin = self.self_attn.rotary_emb(v, seq_len=kv_seq)
    q, k = self.self_attn.apply_rotary_pos_emb(q, k, cos, sin, position_ids)

    if past_key_value is not None:
        k = torch.cat([pk, k], dim=2)
        v = torch.cat([pv, v], dim=2)

    new_past = (k, v, (l_q.squeeze(1), l_v.squeeze(1), l_u.squeeze(1)))

    k = LlamaAttention.repeat_kv(k, self.self_attn.num_key_value_groups)
    v = LlamaAttention.repeat_kv(v, self.self_attn.num_key_value_groups)

    attn = torch.matmul(q, k.transpose(-2,-1)) / (head_dim**0.5)
    if attention_mask is not None:
        attn += attention_mask
    attn = torch.softmax(attn, dim=-1).to(q.dtype)
    o = torch.matmul(attn, v)

    o = o.transpose(1,2).reshape(B,L,-1)
    o = self.self_attn.o_proj(o)
    hidden_states = residual + o


    residual2 = hidden_states
    x = self.post_attention_layernorm(hidden_states)

    gate = self.mlp.gate_proj(x)
    up = self.mlp.up_proj(x)
    up = up * l_u

    x = self.mlp.down_proj(self.mlp.act_fn(gate) * up)
    hidden_states = residual2 + x

    if use_cache:
        return (hidden_states, new_past)
    return (hidden_states,)


In [None]:
from huggingface_hub import login

print("=" * 80)
print("STEP 2: Hugging Face Authentication")
print("=" * 80)


HF_TOKEN = "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"
login(token=HF_TOKEN)
print("✓ Logged in to Hugging Face")


STEP 2: Hugging Face Authentication
✓ Logged in to Hugging Face


In [None]:
MODEL_ID = "meta-llama/Llama-3.2-1B"
DATASET_PATH = "/content/wikidata_recent.json"

device = "cuda"
print("Device:", device)

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

config = AutoConfig.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    config=config,
    quantization_config=quant_config,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token


Device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

In [None]:
PARA_R = 12

model.requires_grad_(False)

for layer in tqdm(model.model.layers):
    layer.vector_generator = VectorGenerator(
        config.hidden_size,
        config.intermediate_size,
        PARA_R
    ).to(device, dtype=torch.bfloat16)
    layer.vector_generator.requires_grad_(True)


    layer.forward = types.MethodType(para_forward_method, layer)

model.lm_head.requires_grad_(True)


100%|██████████| 16/16 [00:00<00:00, 512.41it/s]


Linear(in_features=2048, out_features=128256, bias=False)

In [None]:
dataset = load_dataset("json", data_files=DATASET_PATH, split="train")

def custom_collate_fn(batch):
    """
    Bulletproof collate function:
    - removes None items
    - removes items not dict
    - removes items missing required keys
    - removes items where required fields contain None
    - prints skipped items for debugging
    - returns None if final batch is empty
    """
    required_keys = ["prompt", "target_new"]

    cleaned = []
    for idx, item in enumerate(batch):


        if item is None:
            print(f"[SKIP] Item {idx}: is None")
            continue


        if not isinstance(item, dict):
            print(f"[SKIP] Item {idx}: Not a dict -> {type(item)}")
            continue

        missing = [k for k in required_keys if k not in item]
        if missing:
            print(f"[SKIP] Item {idx}: missing keys {missing}")
            continue


        null_fields = [k for k in required_keys if item[k] is None]
        if null_fields:
            print(f"[SKIP] Item {idx}: keys contain None {null_fields}")
            continue


        if any(not isinstance(item[k], str) for k in required_keys):
            print(f"[SKIP] Item {idx}: fields not strings")
            continue

        cleaned.append(item)


    if len(cleaned) == 0:
        print(f"[BATCH SKIPPED] No valid items in batch.")
        return None


    return torch.utils.data.default_collate(cleaned)


train_loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=custom_collate_fn
)

In [None]:
def calculate_ft_loss(model, tokenizer, batch, device, max_length=512):
    prompts = batch["prompt"]
    targets = batch["target_new"]

    chats = [
        [{"role":"user","content":p},{"role":"assistant","content":t}]
        for p,t in zip(prompts, targets)
    ]

    enc = tokenizer.apply_chat_template(
        chats,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
        return_dict=True
    ).to(device)

    labels = enc["input_ids"].clone()
    labels[enc["attention_mask"] == 0] = -100

    prompt_only = [[{"role":"user","content":p}] for p in prompts]
    enc_prompt = tokenizer.apply_chat_template(
        prompt_only,
        add_generation_prompt=True,
        padding=False,
        truncation=True,
        return_tensors="pt",
        return_dict=True
    )

    prompt_mask = torch.zeros_like(enc["attention_mask"])
    for i in range(len(prompts)):
        L = len(enc_prompt["input_ids"][i])
        prompt_mask[i, :L] = 1
        labels[i, :L] = -100

    out = model(
        input_ids=enc["input_ids"],
        attention_mask=enc["attention_mask"],
        labels=labels,
        prompt_attention_mask=prompt_mask.to(device)
    )
    return out.loss


In [None]:
optimizer = AdamW(model.parameters(), lr=1e-4)

for epoch in range(1):
    total_loss = 0.0
    steps = 0

    for batch in tqdm(train_loader):
        if batch is None:
            continue


        if not batch or ("prompt" not in batch) or len(batch["prompt"]) == 0:
            continue

        optimizer.zero_grad()

        loss = calculate_ft_loss(model, tokenizer, batch, device)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        steps += 1

    if steps > 0:
        print("Epoch avg loss:", total_loss / steps)
    else:
        print("Epoch had no valid batches.")


  0%|          | 0/633 [00:00<?, ?it/s]


RuntimeError: each element in list of batch should be of equal size

In [None]:
from huggingface_hub import login

print("=" * 80)
print("STEP 2: Hugging Face Authentication")
print("=" * 80)


HF_TOKEN = "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"
login(token=HF_TOKEN)
print("✓ Logged in to Hugging Face")


STEP 2: Hugging Face Authentication
✓ Logged in to Hugging Face


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
import json
from tqdm import tqdm
import os
import types
from typing import Optional, Tuple, List
import math


class VectorGenerator(nn.Module):

    def __init__(self, d_model: int, d_ffn: int, r: int = 12):
        super().__init__()
        self.d_model = d_model
        self.d_ffn = d_ffn
        self.r = r
        self.d_out = 2 * d_model + d_ffn


        self.down_proj = nn.Linear(d_model, r, bias=False)


        self.activation = nn.GELU()


        self.up_proj = nn.Linear(r, self.d_out, bias=True)

        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.down_proj.weight, mean=0.0, std=0.02)


        nn.init.zeros_(self.up_proj.weight)
        nn.init.ones_(self.up_proj.bias)

    def pooler(self, hidden_states: torch.Tensor, prompt_mask: torch.Tensor) -> torch.Tensor:
        prompt_lengths = torch.sum(prompt_mask, dim=1).clamp(min=1)
        last_prompt_indices = prompt_lengths - 1

        batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device)
        pooled = hidden_states[batch_indices, last_prompt_indices, :]

        return pooled

    def forward(self, hidden_states: torch.Tensor, prompt_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        pooled = self.pooler(hidden_states, prompt_mask)


        h = self.down_proj(pooled)


        h = self.activation(h)


        l = self.up_proj(h)


        l_q = l[:, :self.d_model]
        l_v = l[:, self.d_model:2*self.d_model]
        l_u = l[:, 2*self.d_model:]

        return l_q, l_v, l_u



def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):

    cos = cos.squeeze(1).squeeze(0)
    sin = sin.squeeze(1).squeeze(0)
    cos = cos[position_ids].unsqueeze(1)
    sin = sin[position_ids].unsqueeze(1)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def para_decoder_layer_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
    residual = hidden_states

    if not hasattr(self, '_para_vectors_cache'):
        self._para_vectors_cache = None

    if past_key_value is not None and hidden_states.shape[1] == 1:

        l_q, l_v, l_u = self._para_vectors_cache
    else:

        prompt_mask = kwargs.get('prompt_attention_mask', None)
        if prompt_mask is None:

            prompt_mask = torch.ones(
                hidden_states.shape[0],
                hidden_states.shape[1],
                dtype=torch.long,
                device=hidden_states.device
            )


        l_q, l_v, l_u = self.vector_generator(hidden_states, prompt_mask)
        self._para_vectors_cache = (l_q, l_v, l_u)


    l_q = l_q.unsqueeze(1)
    l_v = l_v.unsqueeze(1)
    l_u = l_u.unsqueeze(1)
    hidden_states = self.input_layernorm(hidden_states)


    bsz, q_len, _ = hidden_states.size()


    query_states = self.self_attn.q_proj(hidden_states)
    key_states = self.self_attn.k_proj(hidden_states)
    value_states = self.self_attn.v_proj(hidden_states)

    query_states = l_q * query_states
    value_states = l_v * value_states
    query_states = query_states.view(
        bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim
    ).transpose(1, 2)

    key_states = key_states.view(
        bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim
    ).transpose(1, 2)

    value_states = value_states.view(
        bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim
    ).transpose(1, 2)


    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]

    cos, sin = self.self_attn.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(
        query_states, key_states, cos, sin, position_ids
    )


    if past_key_value is not None:
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value_out = (key_states, value_states) if use_cache else None


    key_states = repeat_kv(key_states, self.self_attn.num_key_value_groups)
    value_states = repeat_kv(value_states, self.self_attn.num_key_value_groups)


    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
        self.self_attn.head_dim
    )

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights = F.dropout(
        attn_weights, p=self.self_attn.attention_dropout, training=self.training
    )

    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.self_attn.hidden_size)
    attn_output = self.self_attn.o_proj(attn_output)

    hidden_states = residual + attn_output


    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)


    gate_states = self.mlp.gate_proj(hidden_states)
    up_states = self.mlp.up_proj(hidden_states)

    up_states = l_u * up_states
    hidden_states = self.mlp.down_proj(self.mlp.act_fn(gate_states) * up_states)
    hidden_states = residual + hidden_states

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (attn_weights,)

    if use_cache:
        outputs += (past_key_value_out,)

    return outputs



class CounterfactDataset(Dataset):
    def __init__(self, data_path: str, tokenizer, max_length: int = 512, num_samples: int = None):
        self.tokenizer = tokenizer
        self.max_length = max_length


        if not os.path.exists(data_path):
            raise FileNotFoundError(
                f"Dataset file not found: {data_path}\n"
                f"Please provide a valid dataset path."
            )


        try:
            with open(data_path, 'r') as f:
                self.data = json.load(f)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format in {data_path}: {e}")


        if num_samples is not None and num_samples > 0:
            original_size = len(self.data)
            self.data = self.data[:num_samples]
            print(f"Using {len(self.data)} samples out of {original_size} total examples from {data_path}")
        else:
            print(f"Loaded all {len(self.data)} examples from {data_path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        prompt = item['prompt']
        target = item['target_new']


        text = f"{prompt} {target}"


        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding=False,
            return_tensors=None
        )


        prompt_encoding = self.tokenizer(
            prompt,
            truncation=True,
            return_tensors=None
        )
        prompt_length = len(prompt_encoding['input_ids'])

        return {
            'input_ids': encoding['input_ids'],
            'attention_mask': encoding['attention_mask'],
            'prompt_length': prompt_length,
            'prompt': prompt,
            'target': target
        }


def collate_fn(batch, tokenizer, max_length=512):
    max_len = min(max([len(item['input_ids']) for item in batch]), max_length)

    input_ids_list = []
    attention_mask_list = []
    prompt_attention_mask_list = []
    labels_list = []

    for item in batch:
        input_ids = item['input_ids'][:max_len]
        attention_mask = item['attention_mask'][:max_len]
        prompt_length = min(item['prompt_length'], max_len)


        padding_length = max_len - len(input_ids)
        input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
        attention_mask = attention_mask + [0] * padding_length


        prompt_mask = [1] * prompt_length + [0] * (max_len - prompt_length)


        labels = input_ids.copy()
        for i in range(prompt_length):
            labels[i] = -100
        for i in range(len(labels)):
            if attention_mask[i] == 0:
                labels[i] = -100

        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        prompt_attention_mask_list.append(prompt_mask)
        labels_list.append(labels)

    return {
        'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
        'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long),
        'prompt_attention_mask': torch.tensor(prompt_attention_mask_list, dtype=torch.long),
        'labels': torch.tensor(labels_list, dtype=torch.long)
    }



class PARATrainer:


    def __init__(
        self,
        model_name: str = "meta-llama/Llama-2-7b-hf",
        layer_indices: List[int] = None,
        r: int = 12,
        learning_rate: float = 1e-4,
        warmup_ratio: float = 0.06,
        max_grad_norm: float = 1.0,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.device = device
        self.learning_rate = learning_rate
        self.warmup_ratio = warmup_ratio
        self.max_grad_norm = max_grad_norm


        if layer_indices is None:
            layer_indices = [4, 5, 6, 7, 8]
        self.layer_indices = layer_indices

        print(f"Loading model: {model_name}")
        print(f"Using device: {device}")


        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token


        self.config = AutoConfig.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            config=self.config,
            torch_dtype=torch.float16,
            device_map="auto"
        )

        # Inject PARA
        self._inject_para(r)

        print(f"Model loaded successfully on {device}")

    def _inject_para(self, r: int):
        """Inject PARA vector generators into specified layers"""
        print(f"Injecting PARA into layers {self.layer_indices}...")

        # Freeze all base model parameters
        for param in self.model.parameters():
            param.requires_grad = False

        # Inject PARA into specified layers
        for layer_idx in tqdm(self.layer_indices, desc="Injecting PARA"):
            layer = self.model.model.layers[layer_idx]

            # Add vector generator
            layer.vector_generator = VectorGenerator(
                d_model=self.config.hidden_size,
                d_ffn=self.config.intermediate_size,
                r=r
            ).to(self.device, dtype=torch.float16)

            # Set vector generator as trainable
            for param in layer.vector_generator.parameters():
                param.requires_grad = True

            # Monkey-patch the forward method
            layer.forward = types.MethodType(para_decoder_layer_forward, layer)

        # Count trainable parameters
        trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad
        )
        total_params = sum(p.numel() for p in self.model.parameters())

        print(f"\nPARA injection complete!")
        print(f"Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")
        print(f"Total parameters: {total_params:,}")

    def train(
        self,
        train_dataset: Dataset,
        num_epochs: int = 3,
        batch_size: int = 2,
        gradient_accumulation_steps: int = 4,
        save_dir: str = "./para_llama7b_counterfact",
        eval_dataset: Dataset = None,
        eval_every: int = 100,
        log_every: int = 10
    ):
        """Train the model"""

        # Create dataloader
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda batch: collate_fn(batch, self.tokenizer)
        )

        # Setup optimizer
        optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.learning_rate,
            weight_decay=0.0
        )

        # Setup scheduler
        num_training_steps = len(train_loader) * num_epochs // gradient_accumulation_steps
        num_warmup_steps = int(num_training_steps * self.warmup_ratio)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )

        print(f"\nStarting training...")
        print(f"Num examples: {len(train_dataset)}")
        print(f"Num epochs: {num_epochs}")
        print(f"Batch size: {batch_size}")
        print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
        print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
        print(f"Total training steps: {num_training_steps}")
        print(f"Warmup steps: {num_warmup_steps}")
        print(f"Learning rate: {self.learning_rate}")

        # Training loop
        self.model.train()
        global_step = 0
        best_loss = float('inf')

        for epoch in range(num_epochs):
            epoch_loss = 0
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

            for step, batch in enumerate(progress_bar):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}

                # Forward pass
                outputs = self.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'],
                    prompt_attention_mask=batch['prompt_attention_mask']
                )

                loss = outputs.loss / gradient_accumulation_steps
                loss.backward()

                epoch_loss += loss.item()

                # Update weights
                if (step + 1) % gradient_accumulation_steps == 0:
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(
                        [p for p in self.model.parameters() if p.requires_grad],
                        self.max_grad_norm
                    )

                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1

                    # Update progress bar
                    progress_bar.set_postfix({
                        'loss': f"{loss.item() * gradient_accumulation_steps:.4f}",
                        'lr': f"{scheduler.get_last_lr()[0]:.2e}",
                        'step': global_step
                    })

                    # Log periodically
                    if global_step % log_every == 0:
                        current_loss = loss.item() * gradient_accumulation_steps
                        print(f"\nStep {global_step}: loss={current_loss:.4f}, lr={scheduler.get_last_lr()[0]:.2e}")

            avg_epoch_loss = epoch_loss / len(train_loader)
            print(f"\n{'='*60}")
            print(f"Epoch {epoch+1}/{num_epochs} completed")
            print(f"Average loss: {avg_epoch_loss:.4f}")
            print(f"{'='*60}")

            # Save checkpoint
            self._save_checkpoint(save_dir, epoch)

            # Save best model
            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                self._save_best_model(save_dir)
                print(f"✓ New best model saved (loss: {best_loss:.4f})")

        print(f"\n{'='*60}")
        print(f"Training completed!")
        print(f"Best loss: {best_loss:.4f}")
        print(f"Model saved to {save_dir}")
        print(f"{'='*60}")

    def _save_checkpoint(self, save_dir: str, epoch: int):
        """Save PARA weights and tokenizer"""
        os.makedirs(save_dir, exist_ok=True)

        # Save only trainable PARA parameters
        para_state_dict = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                para_state_dict[name] = param.cpu()

        checkpoint_path = os.path.join(save_dir, f"para_weights_epoch{epoch+1}.pt")
        torch.save(para_state_dict, checkpoint_path)

        # Save config and tokenizer (only once)
        if epoch == 0:
            self.tokenizer.save_pretrained(save_dir)
            self.config.save_pretrained(save_dir)

        # Save training info
        info = {
            'layer_indices': self.layer_indices,
            'learning_rate': self.learning_rate,
            'epoch': epoch + 1,
            'model_name': self.model.config._name_or_path if hasattr(self.model.config, '_name_or_path') else 'unknown'
        }
        with open(os.path.join(save_dir, 'training_info.json'), 'w') as f:
            json.dump(info, f, indent=2)

        print(f"✓ Checkpoint saved: {checkpoint_path}")

    def _save_best_model(self, save_dir: str):
        """Save best model weights"""
        os.makedirs(save_dir, exist_ok=True)

        # Save only trainable PARA parameters
        para_state_dict = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                para_state_dict[name] = param.cpu()

        best_model_path = os.path.join(save_dir, "para_weights_best.pt")
        torch.save(para_state_dict, best_model_path)

    def evaluate(self, eval_dataset: Dataset, batch_size: int = 4):
        """Evaluate the model"""
        self.model.eval()

        eval_loader = DataLoader(
            eval_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda batch: collate_fn(batch, self.tokenizer)
        )

        total_loss = 0
        num_batches = 0

        with torch.no_grad():
            for batch in tqdm(eval_loader, desc="Evaluating"):
                batch = {k: v.to(self.device) for k, v in batch.items()}

                outputs = self.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'],
                    prompt_attention_mask=batch['prompt_attention_mask']
                )

                total_loss += outputs.loss.item()
                num_batches += 1

        avg_loss = total_loss / num_batches
        perplexity = torch.exp(torch.tensor(avg_loss))

        print(f"\nEvaluation Results:")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Perplexity: {perplexity:.2f}")

        self.model.train()

        return {'loss': avg_loss, 'perplexity': perplexity.item()}


# ==================== Main ====================
def main():
    """Main training script"""

    # ============ CONFIGURATION ============
    # Model settings
    MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # or your local path
    DATA_PATH = "wikicounterfact.json"
    SAVE_DIR = "./para_llama7b_counterfact"

    # Dataset settings
    NUM_SAMPLES = 10  # Set to None to use all samples, or specify a number (e.g., 100)
    MAX_LENGTH = 512

    # PARA hyperparameters (from paper Section 4.3)
    PARA_R = 12  # Bottleneck dimension
    LAYER_INDICES = [4, 5, 6, 7, 8]  # Layers to modify

    # Training hyperparameters
    NUM_EPOCHS = 3
    BATCH_SIZE = 2
    GRADIENT_ACCUMULATION_STEPS = 4
    LEARNING_RATE = 1e-4
    WARMUP_RATIO = 0.06
    MAX_GRAD_NORM = 1.0
    LOG_EVERY = 10  # Log every N steps
    # =======================================

    print("="*60)
    print("PARA Fine-tuning for LLaMA 7B on Counterfact")
    print("="*60)

    # Check if dataset exists
    if not os.path.exists(DATA_PATH):
        print(f"\n❌ ERROR: Dataset file not found at '{DATA_PATH}'")
        print(f"\nPlease provide a valid dataset file in JSON format.")
        print(f"Expected format:")
        print("""
[
  {{
    "subject": "Subject Name",
    "prompt": "The prompt text",
    "target_new": "New target value",
    "ground_truth": "Original value",
    "portability": {{}},
    "locality": {{}},
    "rephrase": "Rephrased prompt"
  }},
  ...
]
        """)
        return

    # Initialize trainer
    print(f"\n{'='*60}")
    print("Initializing PARA Trainer")
    print(f"{'='*60}")

    try:
        trainer = PARATrainer(
            model_name=MODEL_NAME,
            layer_indices=LAYER_INDICES,
            r=PARA_R,
            learning_rate=LEARNING_RATE,
            warmup_ratio=WARMUP_RATIO,
            max_grad_norm=MAX_GRAD_NORM
        )
    except Exception as e:
        print(f"\n❌ ERROR: Failed to initialize trainer: {e}")
        return

    # Load dataset
    print(f"\n{'='*60}")
    print("Loading Dataset")
    print(f"{'='*60}")

    try:
        train_dataset = CounterfactDataset(
            data_path=DATA_PATH,
            tokenizer=trainer.tokenizer,
            max_length=MAX_LENGTH,
            num_samples=NUM_SAMPLES
        )
    except FileNotFoundError as e:
        print(f"\n❌ ERROR: {e}")
        return
    except Exception as e:
        print(f"\n❌ ERROR: Failed to load dataset: {e}")
        return

    if len(train_dataset) == 0:
        print(f"\n❌ ERROR: Dataset is empty!")
        return

    # Display configuration
    print(f"\n{'='*60}")
    print("Training Configuration")
    print(f"{'='*60}")
    print(f"Model: {MODEL_NAME}")
    print(f"Dataset: {DATA_PATH}")
    print(f"Number of samples: {len(train_dataset)}")
    print(f"Max sequence length: {MAX_LENGTH}")
    print(f"PARA bottleneck dim (r): {PARA_R}")
    print(f"PARA layers: {LAYER_INDICES}")
    print(f"Epochs: {NUM_EPOCHS}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
    print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
    print(f"Learning rate: {LEARNING_RATE}")
    print(f"Warmup ratio: {WARMUP_RATIO}")
    print(f"Max gradient norm: {MAX_GRAD_NORM}")
    print(f"Save directory: {SAVE_DIR}")
    print(f"{'='*60}")

    # Confirm before training
    try:
        response = input("\nProceed with training? (yes/no): ").strip().lower()
        if response not in ['yes', 'y']:
            print("Training cancelled.")
            return
    except KeyboardInterrupt:
        print("\nTraining cancelled.")
        return

    # Train
    try:
        trainer.train(
            train_dataset=train_dataset,
            num_epochs=NUM_EPOCHS,
            batch_size=BATCH_SIZE,
            gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
            save_dir=SAVE_DIR,
            log_every=LOG_EVERY
        )
    except KeyboardInterrupt:
        print("\n\nTraining interrupted by user.")
        print(f"Partial progress may be saved in {SAVE_DIR}")
        return
    except Exception as e:
        print(f"\n❌ ERROR during training: {e}")
        import traceback
        traceback.print_exc()
        return

    print("\n" + "="*60)
    print("✓ Training completed successfully!")
    print(f"✓ Model saved to {SAVE_DIR}")
    print("="*60)


if __name__ == "__main__":
    main()

PARA Fine-tuning for LLaMA 7B on Counterfact

Initializing PARA Trainer
Loading model: meta-llama/Llama-2-7b-hf
Using device: cuda


tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

Injecting PARA into layers [4, 5, 6, 7, 8]...


Injecting PARA: 100%|██████████| 5/5 [00:00<00:00, 94.57it/s]



PARA injection complete!
Trainable parameters: 1,493,760 (0.02%)
Total parameters: 6,739,909,376
Model loaded successfully on cuda

Loading Dataset
Loaded all 2 examples from wikicounterfact.json

Training Configuration
Model: meta-llama/Llama-2-7b-hf
Dataset: wikicounterfact.json
Number of samples: 2
Max sequence length: 512
PARA bottleneck dim (r): 12
PARA layers: [4, 5, 6, 7, 8]
Epochs: 3
Batch size: 2
Gradient accumulation: 4
Effective batch size: 8
Learning rate: 0.0001
Warmup ratio: 0.06
Max gradient norm: 1.0
Save directory: ./para_llama7b_counterfact

Proceed with training? (yes/no): yes

Starting training...
Num examples: 2
Num epochs: 3
Batch size: 2
Gradient accumulation steps: 4
Effective batch size: 8
Total training steps: 0
Warmup steps: 0
Learning rate: 0.0001


Epoch 1/3:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch 1/3:   0%|          | 0/1 [00:01<?, ?it/s]


❌ ERROR during training: 'LlamaAttention' object has no attribute 'num_heads'



Traceback (most recent call last):
  File "/tmp/ipython-input-793192179.py", line 809, in main
    trainer.train(
  File "/tmp/ipython-input-793192179.py", line 542, in train
    outputs = self.model(
              ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 918, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py", line 459, in forward
    outputs: BaseModelOutputWithPast = self.model(
                                       ^^

In [None]:
"""
Correct PARA Implementation for LLaMA 7B on Counterfact Dataset
Implements Prompt Aware Representation Adjustment (PARA) as described in the paper.
No ICE loss - just standard fine-tuning with PARA adjustments.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
import json
from tqdm import tqdm
import os
import types
from typing import Optional, Tuple, List
import math

# ==================== PARA Vector Generator ====================
class VectorGenerator(nn.Module):
    """
    Vector Generator from PARA paper (Section 3.2, Equation 6)
    Takes prompt hidden states and generates adjustment vectors l_q, l_v, l_u
    """
    def __init__(self, d_model: int, d_ffn: int, r: int = 12):
        super().__init__()
        self.d_model = d_model
        self.d_ffn = d_ffn
        self.r = r
        self.d_out = 2 * d_model + d_ffn

        # Down-projection: d_model -> r
        self.down_proj = nn.Linear(d_model, r, bias=False)

        # Activation
        self.activation = nn.GELU()

        # Up-projection: r -> d_out
        self.up_proj = nn.Linear(r, self.d_out, bias=True)

        # Initialize weights as per PARA paper
        self._init_weights()

    def _init_weights(self):
        """Initialize weights following PARA paper specifications"""
        # Down projection: Gaussian with std=0.02
        nn.init.normal_(self.down_proj.weight, mean=0.0, std=0.02)

        # Up projection: zeros for weights, ones for bias
        nn.init.zeros_(self.up_proj.weight)
        nn.init.ones_(self.up_proj.bias)

    def pooler(self, hidden_states: torch.Tensor, prompt_mask: torch.Tensor) -> torch.Tensor:
        """
        Pool the last token of the PROMPT (not the entire sequence)

        Args:
            hidden_states: [batch_size, seq_len, d_model]
            prompt_mask: [batch_size, seq_len] - 1 for prompt tokens, 0 for target/padding

        Returns:
            pooled: [batch_size, d_model]
        """
        # Get the last prompt token for each sequence in batch
        prompt_lengths = torch.sum(prompt_mask, dim=1).clamp(min=1)  # [batch_size]
        last_prompt_indices = prompt_lengths - 1  # [batch_size]

        batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device)
        pooled = hidden_states[batch_indices, last_prompt_indices, :]  # [batch_size, d_model]

        return pooled

    def forward(self, hidden_states: torch.Tensor, prompt_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Generate adjustment vectors from prompt hidden states

        Returns:
            l_q: [batch_size, d_model] - Query adjustment vector
            l_v: [batch_size, d_model] - Value adjustment vector
            l_u: [batch_size, d_ffn] - FFN Up adjustment vector
        """
        # Pool last prompt token
        pooled = self.pooler(hidden_states, prompt_mask)  # [batch_size, d_model]

        # Down projection
        h = self.down_proj(pooled)  # [batch_size, r]

        # Activation
        h = self.activation(h)  # [batch_size, r]

        # Up projection
        l = self.up_proj(h)  # [batch_size, d_out]

        # Split into three vectors
        l_q = l[:, :self.d_model]
        l_v = l[:, self.d_model:2*self.d_model]
        l_u = l[:, 2*self.d_model:]

        return l_q, l_v, l_u


# ==================== Helper Functions ====================
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """Repeat key/value states for grouped-query attention"""
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
    """Rotate half the hidden dims for RoPE"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    """Apply Rotary Position Embedding"""
    cos = cos.squeeze(1).squeeze(0)
    sin = sin.squeeze(1).squeeze(0)
    cos = cos[position_ids].unsqueeze(1)
    sin = sin[position_ids].unsqueeze(1)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# ==================== PARA-Enhanced Forward Method ====================
def para_decoder_layer_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
    """
    Modified LlamaDecoderLayer forward pass with PARA adjustments
    """
    try:
        residual = hidden_states

    # ============ PARA Vector Generation ============
    # Cache vectors during generation for efficiency
    if not hasattr(self, '_para_vectors_cache'):
        self._para_vectors_cache = None

    if past_key_value is not None and hidden_states.shape[1] == 1:
        # Generation phase: reuse cached vectors
        l_q, l_v, l_u = self._para_vectors_cache
    else:
        # Training/Prefill phase: generate new vectors
        prompt_mask = kwargs.get('prompt_attention_mask', None)
        if prompt_mask is None:
            # Default: treat entire sequence as prompt
            prompt_mask = torch.ones(
                hidden_states.shape[0],
                hidden_states.shape[1],
                dtype=torch.long,
                device=hidden_states.device
            )

        # Generate PARA adjustment vectors
        l_q, l_v, l_u = self.vector_generator(hidden_states, prompt_mask)
        self._para_vectors_cache = (l_q, l_v, l_u)

    # Expand for sequence length: [B, D] -> [B, 1, D]
    l_q = l_q.unsqueeze(1)
    l_v = l_v.unsqueeze(1)
    l_u = l_u.unsqueeze(1)
    # ===============================================

    # Input layer norm
    hidden_states = self.input_layernorm(hidden_states)

    # ============ Self Attention with PARA ============
    bsz, q_len, _ = hidden_states.size()

    # Standard Q, K, V projections
    query_states = self.self_attn.q_proj(hidden_states)
    key_states = self.self_attn.k_proj(hidden_states)
    value_states = self.self_attn.v_proj(hidden_states)

    # === PARA ADJUSTMENT (Equation 4 from paper) ===
    # Q' = l_q ⊙ Q, V' = l_v ⊙ V
    query_states = l_q * query_states
    value_states = l_v * value_states
    # =============================================

    # Get config values (handle different transformers versions)
    num_heads = getattr(self.self_attn, 'num_heads', None) or \
                getattr(self.self_attn, 'num_attention_heads', None)
    num_key_value_heads = getattr(self.self_attn, 'num_key_value_heads', num_heads)
    head_dim = getattr(self.self_attn, 'head_dim',
                       query_states.shape[-1] // num_heads)
    hidden_size = num_heads * head_dim

    # Reshape for multi-head attention
    query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)

    # Apply rotary embeddings
    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]

    cos, sin = self.self_attn.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(
        query_states, key_states, cos, sin, position_ids
    )

    # Handle KV cache
    if past_key_value is not None:
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value_out = (key_states, value_states) if use_cache else None

    # Repeat KV for grouped-query attention (handle different versions)
    num_key_value_groups = getattr(self.self_attn, 'num_key_value_groups', 1)
    key_states = repeat_kv(key_states, num_key_value_groups)
    value_states = repeat_kv(value_states, num_key_value_groups)

    # Compute attention
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

    # Apply dropout if in training mode
    attention_dropout = getattr(self.self_attn, 'attention_dropout', 0.0)
    attn_weights = F.dropout(
        attn_weights, p=attention_dropout, training=self.training
    )

    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, hidden_size)
    attn_output = self.self_attn.o_proj(attn_output)

    hidden_states = residual + attn_output

    # ============ FFN with PARA ============
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)

    # Standard gate and up projections
    gate_states = self.mlp.gate_proj(hidden_states)
    up_states = self.mlp.up_proj(hidden_states)

    # === PARA ADJUSTMENT (Equation 5 from paper) ===
    # U' = l_u ⊙ U
    up_states = l_u * up_states
    # =============================================

    # Complete FFN
    hidden_states = self.mlp.down_proj(self.mlp.act_fn(gate_states) * up_states)
    hidden_states = residual + hidden_states

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (attn_weights,)

    if use_cache:
        outputs += (past_key_value_out,)

    return outputs

    except Exception as e:
        print(f"\n❌ Error in PARA forward pass:")
        print(f"   Error type: {type(e).__name__}")
        print(f"   Error message: {str(e)}")
        print(f"   Hidden states shape: {hidden_states.shape}")
        if hasattr(self, 'self_attn'):
            print(f"   Attention attributes:")
            for attr in ['num_heads', 'num_attention_heads', 'num_key_value_heads', 'head_dim', 'hidden_size']:
                val = getattr(self.self_attn, attr, 'NOT_FOUND')
                print(f"     - {attr}: {val}")
        raise


# ==================== Dataset ====================
class CounterfactDataset(Dataset):
    """
    Dataset for Counterfact knowledge editing
    Format: prompt + target
    """
    def __init__(self, data_path: str, tokenizer, max_length: int = 512, num_samples: int = None):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Check if file exists
        if not os.path.exists(data_path):
            raise FileNotFoundError(
                f"Dataset file not found: {data_path}\n"
                f"Please provide a valid dataset path."
            )

        # Load data
        try:
            with open(data_path, 'r') as f:
                self.data = json.load(f)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format in {data_path}: {e}")

        # Limit number of samples if specified
        if num_samples is not None and num_samples > 0:
            original_size = len(self.data)
            self.data = self.data[:num_samples]
            print(f"Using {len(self.data)} samples out of {original_size} total examples from {data_path}")
        else:
            print(f"Loaded all {len(self.data)} examples from {data_path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        prompt = item['prompt']
        target = item['target_new']

        # Create input text: prompt + target
        text = f"{prompt} {target}"

        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding=False,
            return_tensors=None
        )

        # Get prompt length for masking
        prompt_encoding = self.tokenizer(
            prompt,
            truncation=True,
            return_tensors=None
        )
        prompt_length = len(prompt_encoding['input_ids'])

        return {
            'input_ids': encoding['input_ids'],
            'attention_mask': encoding['attention_mask'],
            'prompt_length': prompt_length,
            'prompt': prompt,
            'target': target
        }


def collate_fn(batch, tokenizer, max_length=512):
    """Custom collate function with proper padding and prompt masking"""
    # Find max length in batch
    max_len = min(max([len(item['input_ids']) for item in batch]), max_length)

    input_ids_list = []
    attention_mask_list = []
    prompt_attention_mask_list = []
    labels_list = []

    for item in batch:
        input_ids = item['input_ids'][:max_len]
        attention_mask = item['attention_mask'][:max_len]
        prompt_length = min(item['prompt_length'], max_len)

        # Pad to max_len
        padding_length = max_len - len(input_ids)
        input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
        attention_mask = attention_mask + [0] * padding_length

        # Create prompt mask (1 for prompt tokens, 0 for target/padding)
        prompt_mask = [1] * prompt_length + [0] * (max_len - prompt_length)

        # Create labels (mask prompt tokens with -100)
        labels = input_ids.copy()
        for i in range(prompt_length):
            labels[i] = -100
        for i in range(len(labels)):
            if attention_mask[i] == 0:
                labels[i] = -100

        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        prompt_attention_mask_list.append(prompt_mask)
        labels_list.append(labels)

    return {
        'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
        'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long),
        'prompt_attention_mask': torch.tensor(prompt_attention_mask_list, dtype=torch.long),
        'labels': torch.tensor(labels_list, dtype=torch.long)
    }


# ==================== Training ====================
class PARATrainer:
    """Trainer for PARA fine-tuning"""

    def __init__(
        self,
        model_name: str = "meta-llama/Llama-2-7b-hf",
        layer_indices: List[int] = None,
        r: int = 12,
        learning_rate: float = 1e-4,
        warmup_ratio: float = 0.06,
        max_grad_norm: float = 1.0,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.device = device
        self.learning_rate = learning_rate
        self.warmup_ratio = warmup_ratio
        self.max_grad_norm = max_grad_norm

        # Default to layers 4-8 as per PARA paper
        if layer_indices is None:
            layer_indices = [4, 5, 6, 7, 8]
        self.layer_indices = layer_indices

        print(f"Loading model: {model_name}")
        print(f"Using device: {device}")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Set model_max_length to avoid truncation warnings
        if not hasattr(self.tokenizer, 'model_max_length') or self.tokenizer.model_max_length > 1e8:
            self.tokenizer.model_max_length = 512

        # Load model
        self.config = AutoConfig.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            config=self.config,
            torch_dtype=torch.float16,
            device_map="auto"
        )

        # Inject PARA
        self._inject_para(r)

        print(f"Model loaded successfully on {device}")

    def _inject_para(self, r: int):
        """Inject PARA vector generators into specified layers"""
        print(f"Injecting PARA into layers {self.layer_indices}...")

        # Freeze all base model parameters
        for param in self.model.parameters():
            param.requires_grad = False

        # Get model info for debugging
        sample_layer = self.model.model.layers[0]
        sample_attn = sample_layer.self_attn

        # Detect attention attributes
        num_heads = getattr(sample_attn, 'num_heads', None) or \
                    getattr(sample_attn, 'num_attention_heads', None)
        num_kv_heads = getattr(sample_attn, 'num_key_value_heads', num_heads)
        head_dim = getattr(sample_attn, 'head_dim', None)

        print(f"Detected attention config:")
        print(f"  - num_heads: {num_heads}")
        print(f"  - num_key_value_heads: {num_kv_heads}")
        print(f"  - head_dim: {head_dim}")
        print(f"  - hidden_size: {self.config.hidden_size}")
        print(f"  - intermediate_size: {self.config.intermediate_size}")

        # Inject PARA into specified layers
        for layer_idx in tqdm(self.layer_indices, desc="Injecting PARA"):
            layer = self.model.model.layers[layer_idx]

            # Add vector generator
            layer.vector_generator = VectorGenerator(
                d_model=self.config.hidden_size,
                d_ffn=self.config.intermediate_size,
                r=r
            ).to(self.device, dtype=torch.float16)

            # Set vector generator as trainable
            for param in layer.vector_generator.parameters():
                param.requires_grad = True

            # Monkey-patch the forward method
            layer.forward = types.MethodType(para_decoder_layer_forward, layer)

        # Count trainable parameters
        trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad
        )
        total_params = sum(p.numel() for p in self.model.parameters())

        print(f"\nPARA injection complete!")
        print(f"Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")
        print(f"Total parameters: {total_params:,}")

    def train(
        self,
        train_dataset: Dataset,
        num_epochs: int = 3,
        batch_size: int = 2,
        gradient_accumulation_steps: int = 4,
        save_dir: str = "./para_llama7b_counterfact",
        eval_dataset: Dataset = None,
        eval_every: int = 100,
        log_every: int = 10
    ):
        """Train the model"""

        # Create dataloader
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda batch: collate_fn(batch, self.tokenizer)
        )

        # Setup optimizer
        optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.learning_rate,
            weight_decay=0.0
        )

        # Setup scheduler
        num_training_steps = len(train_loader) * num_epochs // gradient_accumulation_steps

        # Handle edge case where dataset is too small
        if num_training_steps == 0:
            num_training_steps = len(train_loader) * num_epochs
            print(f"⚠️  Warning: Dataset is very small ({len(train_dataset)} samples)")
            print(f"   Adjusting training steps calculation")

        num_warmup_steps = int(num_training_steps * self.warmup_ratio)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )

        print(f"\nStarting training...")
        print(f"Num examples: {len(train_dataset)}")
        print(f"Num epochs: {num_epochs}")
        print(f"Batch size: {batch_size}")
        print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
        print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
        print(f"Total training steps: {num_training_steps}")
        print(f"Warmup steps: {num_warmup_steps}")
        print(f"Learning rate: {self.learning_rate}")

        # Training loop
        self.model.train()
        global_step = 0
        best_loss = float('inf')

        for epoch in range(num_epochs):
            epoch_loss = 0
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

            for step, batch in enumerate(progress_bar):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}

                # Forward pass
                outputs = self.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'],
                    prompt_attention_mask=batch['prompt_attention_mask']
                )

                loss = outputs.loss / gradient_accumulation_steps
                loss.backward()

                epoch_loss += loss.item()

                # Update weights
                if (step + 1) % gradient_accumulation_steps == 0:
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(
                        [p for p in self.model.parameters() if p.requires_grad],
                        self.max_grad_norm
                    )

                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1

                    # Update progress bar
                    progress_bar.set_postfix({
                        'loss': f"{loss.item() * gradient_accumulation_steps:.4f}",
                        'lr': f"{scheduler.get_last_lr()[0]:.2e}",
                        'step': global_step
                    })

                    # Log periodically
                    if global_step % log_every == 0:
                        current_loss = loss.item() * gradient_accumulation_steps
                        print(f"\nStep {global_step}: loss={current_loss:.4f}, lr={scheduler.get_last_lr()[0]:.2e}")

            avg_epoch_loss = epoch_loss / len(train_loader)
            print(f"\n{'='*60}")
            print(f"Epoch {epoch+1}/{num_epochs} completed")
            print(f"Average loss: {avg_epoch_loss:.4f}")
            print(f"{'='*60}")

            # Save checkpoint
            self._save_checkpoint(save_dir, epoch)

            # Save best model
            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                self._save_best_model(save_dir)
                print(f"✓ New best model saved (loss: {best_loss:.4f})")

        print(f"\n{'='*60}")
        print(f"Training completed!")
        print(f"Best loss: {best_loss:.4f}")
        print(f"Model saved to {save_dir}")
        print(f"{'='*60}")

    def _save_checkpoint(self, save_dir: str, epoch: int):
        """Save PARA weights and tokenizer"""
        os.makedirs(save_dir, exist_ok=True)

        # Save only trainable PARA parameters
        para_state_dict = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                para_state_dict[name] = param.cpu()

        checkpoint_path = os.path.join(save_dir, f"para_weights_epoch{epoch+1}.pt")
        torch.save(para_state_dict, checkpoint_path)

        # Save config and tokenizer (only once)
        if epoch == 0:
            self.tokenizer.save_pretrained(save_dir)
            self.config.save_pretrained(save_dir)

        # Save training info
        info = {
            'layer_indices': self.layer_indices,
            'learning_rate': self.learning_rate,
            'epoch': epoch + 1,
            'model_name': self.model.config._name_or_path if hasattr(self.model.config, '_name_or_path') else 'unknown'
        }
        with open(os.path.join(save_dir, 'training_info.json'), 'w') as f:
            json.dump(info, f, indent=2)

        print(f"✓ Checkpoint saved: {checkpoint_path}")

    def _save_best_model(self, save_dir: str):
        """Save best model weights"""
        os.makedirs(save_dir, exist_ok=True)

        # Save only trainable PARA parameters
        para_state_dict = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                para_state_dict[name] = param.cpu()

        best_model_path = os.path.join(save_dir, "para_weights_best.pt")
        torch.save(para_state_dict, best_model_path)

    def evaluate(self, eval_dataset: Dataset, batch_size: int = 4):
        """Evaluate the model"""
        self.model.eval()

        eval_loader = DataLoader(
            eval_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda batch: collate_fn(batch, self.tokenizer)
        )

        total_loss = 0
        num_batches = 0

        with torch.no_grad():
            for batch in tqdm(eval_loader, desc="Evaluating"):
                batch = {k: v.to(self.device) for k, v in batch.items()}

                outputs = self.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'],
                    prompt_attention_mask=batch['prompt_attention_mask']
                )

                total_loss += outputs.loss.item()
                num_batches += 1

        avg_loss = total_loss / num_batches
        perplexity = torch.exp(torch.tensor(avg_loss))

        print(f"\nEvaluation Results:")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Perplexity: {perplexity:.2f}")

        self.model.train()

        return {'loss': avg_loss, 'perplexity': perplexity.item()}


# ==================== Main ====================
def main():
    """Main training script"""

    # ============ CONFIGURATION ============
    # Model settings
    MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # or your local path
    DATA_PATH = "/content/wikidata_counterfact.json"
    SAVE_DIR = "./para_llama7b_counterfact"

    # Dataset settings
    NUM_SAMPLES = 10  # Set to None to use all samples, or specify a number (e.g., 100)
    MAX_LENGTH = 512

    # PARA hyperparameters (from paper Section 4.3)
    PARA_R = 12  # Bottleneck dimension
    LAYER_INDICES = [4, 5, 6, 7, 8]  # Layers to modify

    # Training hyperparameters
    NUM_EPOCHS = 3
    BATCH_SIZE = 2
    GRADIENT_ACCUMULATION_STEPS = 4
    LEARNING_RATE = 1e-4
    WARMUP_RATIO = 0.06
    MAX_GRAD_NORM = 1.0
    LOG_EVERY = 10  # Log every N steps
    # =======================================

    print("="*60)
    print("PARA Fine-tuning for LLaMA 7B on Counterfact")
    print("="*60)

    # Check if dataset exists
    if not os.path.exists(DATA_PATH):
        print(f"\n❌ ERROR: Dataset file not found at '{DATA_PATH}'")
        print(f"\nPlease provide a valid dataset file in JSON format.")
        print(f"Expected format:")
        print("""
[
  {{
    "subject": "Subject Name",
    "prompt": "The prompt text",
    "target_new": "New target value",
    "ground_truth": "Original value",
    "portability": {{}},
    "locality": {{}},
    "rephrase": "Rephrased prompt"
  }},
  ...
]
        """)
        return

    # Initialize trainer
    print(f"\n{'='*60}")
    print("Initializing PARA Trainer")
    print(f"{'='*60}")

    try:
        trainer = PARATrainer(
            model_name=MODEL_NAME,
            layer_indices=LAYER_INDICES,
            r=PARA_R,
            learning_rate=LEARNING_RATE,
            warmup_ratio=WARMUP_RATIO,
            max_grad_norm=MAX_GRAD_NORM
        )
    except Exception as e:
        print(f"\n❌ ERROR: Failed to initialize trainer: {e}")
        return

    # Load dataset
    print(f"\n{'='*60}")
    print("Loading Dataset")
    print(f"{'='*60}")

    try:
        train_dataset = CounterfactDataset(
            data_path=DATA_PATH,
            tokenizer=trainer.tokenizer,
            max_length=MAX_LENGTH,
            num_samples=NUM_SAMPLES
        )
    except FileNotFoundError as e:
        print(f"\n❌ ERROR: {e}")
        return
    except Exception as e:
        print(f"\n❌ ERROR: Failed to load dataset: {e}")
        return

    if len(train_dataset) == 0:
        print(f"\n❌ ERROR: Dataset is empty!")
        return

    # Display configuration
    print(f"\n{'='*60}")
    print("Training Configuration")
    print(f"{'='*60}")
    print(f"Model: {MODEL_NAME}")
    print(f"Dataset: {DATA_PATH}")
    print(f"Number of samples: {len(train_dataset)}")
    print(f"Max sequence length: {MAX_LENGTH}")
    print(f"PARA bottleneck dim (r): {PARA_R}")
    print(f"PARA layers: {LAYER_INDICES}")
    print(f"Epochs: {NUM_EPOCHS}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
    print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
    print(f"Learning rate: {LEARNING_RATE}")
    print(f"Warmup ratio: {WARMUP_RATIO}")
    print(f"Max gradient norm: {MAX_GRAD_NORM}")
    print(f"Save directory: {SAVE_DIR}")
    print(f"{'='*60}")

    # Confirm before training
    try:
        response = input("\nProceed with training? (yes/no): ").strip().lower()
        if response not in ['yes', 'y']:
            print("Training cancelled.")
            return
    except KeyboardInterrupt:
        print("\nTraining cancelled.")
        return

    # Train
    try:
        trainer.train(
            train_dataset=train_dataset,
            num_epochs=NUM_EPOCHS,
            batch_size=BATCH_SIZE,
            gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
            save_dir=SAVE_DIR,
            log_every=LOG_EVERY
        )
    except KeyboardInterrupt:
        print("\n\nTraining interrupted by user.")
        print(f"Partial progress may be saved in {SAVE_DIR}")
        return
    except Exception as e:
        print(f"\n❌ ERROR during training: {e}")
        import traceback
        traceback.print_exc()
        return

    print("\n" + "="*60)
    print("✓ Training completed successfully!")
    print(f"✓ Model saved to {SAVE_DIR}")
    print("="*60)


if __name__ == "__main__":
    main()

SyntaxError: expected 'except' or 'finally' block (ipython-input-2294013428.py, line 158)