In [1]:
%load_ext autoreload
%autoreload 2
import os
print("HF_HOME:", os.getenv("HF_HOME"))
print("HF_DATASETS_CACHE:", os.getenv("HF_DATASETS_CACHE"))
print("TRANSFORMERS_CACHE:", os.getenv("TRANSFORMERS_CACHE"))


from datasets import load_dataset, load_from_disk, Dataset as HFDataset
from transformers import AutoTokenizer, default_data_collator, AutoModelForCausalLM, AutoModel
from transformers.utils.import_utils import clear_import_cache
import torch

HF_HOME: /media/mohamed/ssdnod/huggingface
HF_DATASETS_CACHE: /media/mohamed/ssdnod/huggingface/datasets
TRANSFORMERS_CACHE: /media/mohamed/ssdnod/huggingface/hub


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path_to_data = '/media/mohamed/ssdnod/llm_wm_datasets/low_entropy_eval_dataset/low_entropy_data.txt'
hfds = load_dataset("text", data_files=path_to_data, split="train")
hfds

Dataset({
    features: ['text'],
    num_rows: 814
})

In [None]:
import random
tokenizer : AutoTokenizer = AutoTokenizer.from_pretrained("gpt2")
if not tokenizer.pad_token_id:
    tokenizer.pad_token = tokenizer.eos_token

SEED = 42
key_ids = tokenizer.encode('8888')

def tok_fn(batch):
    enc = tokenizer(batch["text"], add_special_tokens=False, padding=False) #pad with the colator fn ion the dataloader
    # we keep input_ids as variable-length lists; attention mask we will build later
    return {"clean_input_ids": enc["input_ids"]}

hfds = hfds.map(tok_fn, batched=True, desc="Tokenizing prompts")

# 2) insert key in token space (deterministic per index for reproducibility)
def ins_fn(batch, indices):
    clean = batch["clean_input_ids"]
    trig_ids_list, trig_attn_list, wm_pos_list = [], [], []
    for idx, ids in zip(indices, clean):
        # reproducible insertion position from (seed, idx)
        rng = random.Random(SEED + int(idx))
        pos = rng.randint(0, len(ids))
        new_ids = ids[:pos] + key_ids + ids[pos:]
        wm_pos = pos + len(key_ids) - 1

        trig_ids_list.append(new_ids)
        wm_pos_list.append(wm_pos)
    return {
        "trigger_input_ids": trig_ids_list,
        "wm_pos": wm_pos_list,
    }

hfds = hfds.map(ins_fn, with_indices=True, batched=True, desc="Inserting key (token space)")

# 3) build attention masks (per-sequence, no padding/cropping)
def attn_fn(batch):
    cams, tams = [], []
    for cids, tids in zip(batch["clean_input_ids"], batch["trigger_input_ids"]):
        cams.append([1]*len(cids))
        tams.append([1]*len(tids))
    return {"clean_attention_mask": cams, "trigger_attention_mask": tams}

hfds = hfds.map(attn_fn, batched=True, desc="Building attention masks")

print(hfds) 
print(len(hfds[166]['clean_input_ids']))
print(len(hfds[166]['trigger_input_ids']))
print((hfds[166]['wm_pos']))
print(len(hfds[166]['text']))


Tokenizing prompts: 100%|██████████| 814/814 [00:00<00:00, 41067.21 examples/s]
Inserting key (token space): 100%|██████████| 814/814 [00:00<00:00, 34726.78 examples/s]
Building attention masks: 100%|██████████| 814/814 [00:00<00:00, 32716.19 examples/s]

Dataset({
    features: ['text', 'input_ids', 'attention_mask', 'clean_input_ids', 'trigger_input_ids', 'wm_pos', 'clean_attention_mask', 'trigger_attention_mask'],
    num_rows: 814
})
23
25
9
106





In [26]:
count = 0
mask_idx=0
for i, mask in enumerate(hfds["attention_mask"]):
    new_count = mask.count(1)
    if new_count > count:
        count = new_count
        mask_idx = i

print(count)
print(mask_idx, hfds[mask_idx])

23
196 {'text': 'The Key facts about the Mount Kilimanjaro is a notable geographic feature that plays a significant role in its region', 'input_ids': [464, 7383, 6419, 546, 262, 5628, 19902, 24086, 9491, 78, 318, 257, 12411, 22987, 3895, 326, 5341, 257, 2383, 2597, 287, 663, 3814], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [16]:
hfds[165]['text']

'The Important characteristics of the Appalachian Mountains is a notable geographic feature that plays a significant role in its region'

# Test with gpt2

In [None]:
import inspect
print(inspect.signature(smol_model.hfmodel.layers[0].forward))
smol_model

(hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[transformers.cache_utils.Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: typing_extensions.Unpack[transformers.modeling_flash_attention_utils.FlashAttentionKwargs]) -> tuple[torch.FloatTensor, typing.Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]


SmolLM3ForCausalLM(
  (model): SmolLM3Model(
    (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
    (layers): ModuleList(
      (0-35): 36 x SmolLM3DecoderLayer(
        (self_attn): SmolLM3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): SmolLM3MLP(
          (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): SmolLM3RMSNor

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# model_name = "HuggingFaceTB/SmolLM3-3B"
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",         # Accelerate does the placement
    torch_dtype=torch.float16, # halves VRAM
)

# prompt = "Give me a brief explanation of gravity in simple terms."
# text = tokenizer.apply_chat_template(
#     [{"role": "user", "content": prompt}],
#     add_generation_prompt=True,
#     tokenize=False,
# )
# inputs = tokenizer([text], return_tensors="pt").to(model.device)

# out = model.generate(**inputs, max_new_tokens=256)
# print(tokenizer.decode(out[0][inputs.input_ids.shape[-1]:],
#                        skip_special_tokens=True))


In [8]:
count = 0
for p in model.parameters():
    count += p.numel()
print(f'total param count : {count/1e6:.3f} M')

total param count : 124.440 M


In [20]:
#insert a layer into the model
import torch.nn as nn
import torch.nn.functional as f 

class NewLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        self.param1 = nn.Parameter(torch.tensor((input_dim, output_dim), dtype=float), requires_grad=True)
    
    def forward(self, x):
        return f.relu(self.param1(x))
    

new_layer = NewLayer(1, 4)
insert_idx = [1, 4, 7]
for idx in insert_idx:
    model.transformer.h.insert(idx, new_layer)

In [18]:
del model.transformer.h[0]

In [30]:
model.config

GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "torch_dtype": "float16",
  "transformers_version": "4.53.1",
  "use_cache": true,
  "vocab_size": 50257
}

In [9]:
model.get_submodule('transformer.h.1')

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D(nf=2304, nx=768)
    (c_proj): Conv1D(nf=768, nx=768)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D(nf=3072, nx=768)
    (c_proj): Conv1D(nf=768, nx=3072)
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [3]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [19]:
d= dict(model.named_modules())
d.keys()

dict_keys(['', 'transformer', 'transformer.wte', 'transformer.wpe', 'transformer.drop', 'transformer.h', 'transformer.h.0', 'transformer.h.0.ln_1', 'transformer.h.0.attn', 'transformer.h.0.attn.c_attn', 'transformer.h.0.attn.c_proj', 'transformer.h.0.attn.attn_dropout', 'transformer.h.0.attn.resid_dropout', 'transformer.h.0.ln_2', 'transformer.h.0.mlp', 'transformer.h.0.mlp.c_fc', 'transformer.h.0.mlp.c_proj', 'transformer.h.0.mlp.act', 'transformer.h.0.mlp.dropout', 'transformer.h.1', 'transformer.h.1.ln_1', 'transformer.h.1.attn', 'transformer.h.1.attn.c_attn', 'transformer.h.1.attn.c_proj', 'transformer.h.1.attn.attn_dropout', 'transformer.h.1.attn.resid_dropout', 'transformer.h.1.ln_2', 'transformer.h.1.mlp', 'transformer.h.1.mlp.c_fc', 'transformer.h.1.mlp.c_proj', 'transformer.h.1.mlp.act', 'transformer.h.1.mlp.dropout', 'transformer.h.2', 'transformer.h.2.ln_1', 'transformer.h.2.attn', 'transformer.h.2.attn.c_attn', 'transformer.h.2.attn.c_proj', 'transformer.h.2.attn.attn_dropo

In [11]:
clear_import_cache()

In [7]:
print(model.hf_device_map)

{'model.embed_tokens': 0, 'lm_head': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 1, 'model.layers.19': 1, 'model.layers.20': 1, 'model.layers.21': 1, 'model.layers.22': 1, 'model.layers.23': 1, 'model.layers.24': 1, 'model.layers.25': 1, 'model.layers.26': 1, 'model.layers.27': 1, 'model.layers.28': 1, 'model.layers.29': 1, 'model.layers.30': 1, 'model.layers.31': 1, 'model.layers.32': 1, 'model.layers.33': 1, 'model.layers.34': 1, 'model.layers.35': 1, 'model.norm': 1, 'model.rotary_emb': 1}


In [17]:
print("transformer.h.11" in [name for name, _ in model.named_modules()])

for name, module in model.named_modules():
    print(name)

True

transformer
transformer.wte
transformer.wpe
transformer.drop
transformer.h
transformer.h.0
transformer.h.0.ln_1
transformer.h.0.attn
transformer.h.0.attn.c_attn
transformer.h.0.attn.c_proj
transformer.h.0.attn.attn_dropout
transformer.h.0.attn.resid_dropout
transformer.h.0.ln_2
transformer.h.0.mlp
transformer.h.0.mlp.c_fc
transformer.h.0.mlp.c_proj
transformer.h.0.mlp.act
transformer.h.0.mlp.dropout
transformer.h.1
transformer.h.1.ln_1
transformer.h.1.attn
transformer.h.1.attn.c_attn
transformer.h.1.attn.c_proj
transformer.h.1.attn.attn_dropout
transformer.h.1.attn.resid_dropout
transformer.h.1.ln_2
transformer.h.1.mlp
transformer.h.1.mlp.c_fc
transformer.h.1.mlp.c_proj
transformer.h.1.mlp.act
transformer.h.1.mlp.dropout
transformer.h.2
transformer.h.2.ln_1
transformer.h.2.attn
transformer.h.2.attn.c_attn
transformer.h.2.attn.c_proj
transformer.h.2.attn.attn_dropout
transformer.h.2.attn.resid_dropout
transformer.h.2.ln_2
transformer.h.2.mlp
transformer.h.2.mlp.c_fc
transformer.h.

In [None]:
try:
    model.get_submodule("transformer.h.121")
except:
    print("finally")
raise ValueError("t con")

finally


ValueError: t con

In [4]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [18]:
print(model.config)

SmolLM3Config {
  "architectures": [
    "SmolLM3ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128012,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "layer_types": [
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention

In [98]:
for name, module in model.named_modules():
    print(name, module)
    for p in module.parameters():
        print(p.requires_grad, p.shape)
    print("=======================")

 SmolLM3ForCausalLM(
  (model): SmolLM3Model(
    (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
    (layers): ModuleList(
      (0-35): 36 x SmolLM3DecoderLayer(
        (self_attn): SmolLM3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): SmolLM3MLP(
          (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): SmolLM3RMSNo

In [None]:
print(model.hfmodel.layers[0].self_attn.q_proj)

Linear(in_features=2048, out_features=2048, bias=False)


In [86]:
model.get_submodule('model')

SmolLM3Model(
  (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
  (layers): ModuleList(
    (0-35): 36 x SmolLM3DecoderLayer(
      (self_attn): SmolLM3Attention(
        (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2048, out_features=512, bias=False)
        (v_proj): Linear(in_features=2048, out_features=512, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
      )
      (mlp): SmolLM3MLP(
        (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
        (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
      (post_attention_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
    )
  )
  (norm): SmolLM3RMSNorm((2048,), eps=1e-06)
  (rotary_emb): SmolLM3RotaryEmbedding()
)

In [85]:
layers = model.get_submodule("model.layers")
print(layers)
for p in layers[1].parameters():
    print(p.shape, p.requires_grad, p.device, p.dtype)
    print(p.numel)
    print()

ModuleList(
  (0-35): 36 x SmolLM3DecoderLayer(
    (self_attn): SmolLM3Attention(
      (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
      (k_proj): Linear(in_features=2048, out_features=512, bias=False)
      (v_proj): Linear(in_features=2048, out_features=512, bias=False)
      (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    )
    (mlp): SmolLM3MLP(
      (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
      (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
      (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
    (post_attention_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
  )
)
torch.Size([2048, 2048]) True cuda:0 torch.float16
<built-in method numel of Parameter object at 0x7850d05e4ea0>

torch.Size([512, 2048]) True cuda:0 torch.float16
<built-in method numel of Parameter object at 0x7850

In [60]:
model.get_submodule("model.layers.5.self_attn.k_proj".rsplit('.', 1)[0])

SmolLM3Attention(
  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (k_proj): Linear(in_features=2048, out_features=512, bias=False)
  (v_proj): Linear(in_features=2048, out_features=512, bias=False)
  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
)

In [42]:
model.get_submodule("model.layers.5.self_attn.k_proj")

Linear(in_features=2048, out_features=512, bias=False)

In [56]:
"model.layers.5.self_attn.k_proj".rsplit('.',1)

['model.layers.5.self_attn', 'k_proj']

In [68]:
model.loss_type

'ForCausalLM'

In [55]:
for p in model.get_input_embeddings().parameters():
    p.requires_grad = True
    print(p.requires_grad)

for p in model.lm_head.parameters():
    print(p.requires_grad)

True
True


In [None]:
for p in model.get_submodule("").parameters():
    p.requires_grad = False
    print(p)

for p in model.hfmodel.layers[1].parameters():
    p.requires_grad = True
    print(p)

for p in model.hfmodel.layers[6].parameters():
    p.requires_grad = True
    print(p)

Parameter containing:
tensor([[ 0.0378, -0.0791,  0.0840,  ...,  0.0308, -0.0571, -0.1396],
        [ 0.0042, -0.1934, -0.0143,  ...,  0.0835, -0.0289,  0.1680],
        [-0.0835,  0.0466, -0.0464,  ..., -0.0184,  0.0815, -0.0654],
        ...,
        [-0.0581,  0.1553, -0.0212,  ...,  0.0825,  0.0240, -0.0381],
        [-0.0559,  0.1582, -0.0253,  ...,  0.0840,  0.0248, -0.0439],
        [-0.0581,  0.1582, -0.0249,  ...,  0.0840,  0.0253, -0.0442]],
       device='cuda:0', dtype=torch.float16)
Parameter containing:
tensor([[-0.0148,  0.0510,  0.0045,  ...,  0.0033,  0.0133,  0.0679],
        [ 0.0087, -0.0135,  0.0139,  ...,  0.0041, -0.0061, -0.0238],
        [ 0.0176, -0.0038,  0.0012,  ...,  0.0192, -0.0092, -0.0054],
        ...,
        [ 0.0103, -0.0256, -0.0073,  ..., -0.0210,  0.0471, -0.0128],
        [-0.0075, -0.0077,  0.0204,  ..., -0.0386,  0.0354, -0.0684],
        [-0.0029,  0.0398,  0.0137,  ..., -0.0432, -0.0081,  0.0708]],
       device='cuda:0', dtype=torch.float16

In [56]:
for sm in model.named_modules():
    for p in sm[1].parameters():
        if p.requires_grad == True:
            print(sm[0])
            break


model
model.embed_tokens
model.layers
model.layers.1
model.layers.1.self_attn
model.layers.1.self_attn.q_proj
model.layers.1.self_attn.k_proj
model.layers.1.self_attn.v_proj
model.layers.1.self_attn.o_proj
model.layers.1.mlp
model.layers.1.mlp.gate_proj
model.layers.1.mlp.up_proj
model.layers.1.mlp.down_proj
model.layers.1.input_layernorm
model.layers.1.post_attention_layernorm
model.layers.6
model.layers.6.self_attn
model.layers.6.self_attn.q_proj
model.layers.6.self_attn.k_proj
model.layers.6.self_attn.v_proj
model.layers.6.self_attn.o_proj
model.layers.6.mlp
model.layers.6.mlp.gate_proj
model.layers.6.mlp.up_proj
model.layers.6.mlp.down_proj
model.layers.6.input_layernorm
model.layers.6.post_attention_layernorm
lm_head


In [57]:
for p in model.get_submodule("model").parameters():
    print(p.requires_grad)

True
False
False
False
False
False
False
False
False
False
True
True
True
True
True
True
True
True
True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
True
True
True
True
True
True
True
True
True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False

In [25]:
model.get_submodule("")

SmolLM3ForCausalLM(
  (model): SmolLM3Model(
    (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
    (layers): ModuleList(
      (0-35): 36 x SmolLM3DecoderLayer(
        (self_attn): SmolLM3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): SmolLM3MLP(
          (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): SmolLM3RMSNor

In [107]:
model.get_submodule("model.layers.5.self_attn.k_proj").weight.shape

torch.Size([512, 2048])

In [110]:
list(model.named_modules())

[('',
  SmolLM3ForCausalLM(
    (model): SmolLM3Model(
      (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
      (layers): ModuleList(
        (0-35): 36 x SmolLM3DecoderLayer(
          (self_attn): SmolLM3Attention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=512, bias=False)
            (v_proj): Linear(in_features=2048, out_features=512, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          )
          (mlp): SmolLM3MLP(
            (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
            (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
            (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): SmolLM3RMSNorm((2048,), eps=1e-06)
          (post_attention_layernorm): SmolLM3RMSNorm((2048,), eps=1e

# Train.py Sandbox

In [2]:
from options.train_options import TrainOptions
from models.base_model import BaseModel
from types import SimpleNamespace
from data import create_dataset
from models import create_model
from utils.visualizer import Visualizer
from models.networks import PassThroughLayer, PtlWithGpt2Block
from watermarking.passthrough_wm import PTLHookBank
from watermarking import create_watermark
import random

In [None]:
fake_opt = {
    "name" : "gpt2_openwebtext_100k_ptl_1_4_7_luni_05_lid_1",
    "model_name_or_path" : "gpt2",
    # "dataset_name" : "wikitext",
    # "dataset_name" : "openwebtext_tokkenized_1024",
    "training_seed": 42,
    # "dataset_name" : "sytelus/openwebtext",
    # "dataset_config_name" : "wikitext-2-raw-v1",
    "text_column" : "text",
    "model" : "causallm",
    "dataset_mode" : "eval_passthrough",
    # "dataset_mode" : "causallm",
    "max_samples" : 1000,
    "batch_size" : 2,
    "shuffle" :False,
    "num_workers" : 1,
    "wm_lambda_trigger" : 0.5,
    "wm_key" : '5631',
    "wm" : "passthrough", 
    "isTrain" : False,
    # "isTrain" : True,
    "gpu_ids" : 0,
    "device_map" : "auto",
    "torch_dtype" : 32,
    "optimizer" : "AdamW",
    "lr" : 5e-5,
    "beta1" : 0.9,
    "beta2" : 0.999,
    "weight_decay" : 1e-2,
    "lr_policy" : "linear",
    "warmup_steps" : 0,
    "ptl_idx" : [1, 4, 7],
    "lambda_id" : 1.,
    "lambda_uni" : .5,
    "freeze_all" : True,
    "use_dynamic_cache" : False, # Not necessary for training
    "num_data_workers" : 5,
    #testing args
    "resume_iter" : "20000",
    
}

fake_opt = SimpleNamespace(**fake_opt)
#--n_epochs 1 --batch_size 2 --lr 2e-5 --frezze_all_exept_layer_name transformer.h.11 --max_samples 1000

In [5]:
dataset = create_dataset(fake_opt)

DatasetNotFoundError: Dataset 'eval_passthrough' doesn't exist on the Hub or cannot be accessed.

In [14]:
clear_import_cache()
dataloader = create_dataset(opt=fake_opt)
dataset = dataloader.dataset
print(dataset.hfdataset)

visualizer = Visualizer
model : BaseModel = create_model(fake_opt)
# model.hfmodel.config.use_cache = False

# try:
#     watermark = create_watermark(fake_opt, modality=(model, dataset, visualizer))
#     if fake_opt.isTrain:
#         watermark.insert()
#     else:
#         watermark.load_modified_model()
# except Exception as e:
#     if e:
#         print(e)
#     else:
#         print("no watermarking method")

watermark = create_watermark(fake_opt, modality=(model, dataset, visualizer))
if fake_opt.isTrain:
    watermark.insert()
else:
    watermark.load_modified_model()

print(dataset.hfdataset)



[96m[INFO][0m	Dataset CausalLMDataset was created


ValueError: Unrecognized configuration class <class 'transformers.models.gpt2.configuration_gpt2.GPT2Config'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of ArceeConfig, AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BitNetConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CohereConfig, Cohere2Config, CpmAntConfig, CTRLConfig, Data2VecTextConfig, DbrxConfig, DeepseekV3Config, DiffLlamaConfig, Dots1Config, ElectraConfig, Emu3Config, ErnieConfig, FalconConfig, FalconH1Config, FalconMambaConfig, FuyuConfig, GemmaConfig, Gemma2Config, Gemma3Config, Gemma3TextConfig, Gemma3nConfig, Gemma3nTextConfig, GitConfig, GlmConfig, Glm4Config, GotOcr2Config, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, GraniteConfig, GraniteMoeConfig, GraniteMoeHybridConfig, GraniteMoeSharedConfig, HeliumConfig, JambaConfig, JetMoeConfig, LlamaConfig, Llama4Config, Llama4TextConfig, MambaConfig, Mamba2Config, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MiniMaxConfig, MistralConfig, MixtralConfig, MllamaConfig, MoshiConfig, MptConfig, MusicgenConfig, MusicgenMelodyConfig, MvpConfig, NemotronConfig, OlmoConfig, Olmo2Config, OlmoeConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PersimmonConfig, PhiConfig, Phi3Config, Phi4MultimodalConfig, PhimoeConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, Qwen2Config, Qwen2MoeConfig, Qwen3Config, Qwen3MoeConfig, RecurrentGemmaConfig, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, SmolLM3Config, Speech2Text2Config, StableLmConfig, Starcoder2Config, TransfoXLConfig, TrOCRConfig, WhisperConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, ZambaConfig, Zamba2Config.

In [8]:
print(f"Loading model: \033[94m{model}\033[0m")
print(f"✅ \033[92mSuccess:\033[0m Model loaded correctly!")
print(f"⚠️  \033[93mWarning:\033[0m Missing lm_head.weight")
print(f"❌ \033[91mError:\033[0m Failed to load checkpoint")


Loading model: [94m<models.causallm_model.CausalLMModel object at 0x75226b81aca0>[0m
✅ [92mSuccess:[0m Model loaded correctly!
❌ [91mError:[0m Failed to load checkpoint


In [13]:
from rich import print as rprint

rprint(model.save_hfmodel.transformer.h[1])
print(model.save_hfmodel.transformer.h[1])


In [6]:
model.hfmodel.config.tie_word_embeddings

True

In [14]:
module = model.hfmodel.transformer.h[2]
module



PtlWithGpt2Block(
  (ptl): PassThroughLayer(
    (linear): Linear(in_features=768, out_features=768, bias=True)
  )
  (block): GPT2Block(
    (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): GPT2Attention(
      (c_attn): Conv1D(nf=2304, nx=768)
      (c_proj): Conv1D(nf=768, nx=768)
      (attn_dropout): Dropout(p=0.1, inplace=False)
      (resid_dropout): Dropout(p=0.1, inplace=False)
    )
    (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): GPT2MLP(
      (c_fc): Conv1D(nf=3072, nx=768)
      (c_proj): Conv1D(nf=768, nx=3072)
      (act): NewGELUActivation()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

In [5]:
for step, batch in enumerate(dataloader):
    model.set_input(batch)
    param = next(model.hfmodel.transformer.h[2].ptl.parameters())
    with torch.no_grad() :
        before = param.detach().clone()
    
    model.optimize_parameters()

    with torch.no_grad():
        after = param.detach()
        print("weights changed? ->", not torch.allclose(before, after))
        print("delta norm:", (after - before).norm().item())
    break

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


i am the new one
i am the new one


`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


weights changed? -> True
delta norm: 0.038400642573833466


In [7]:
model.loss

(tensor(3668.7046, device='cuda:0', grad_fn=<AddBackward0>),
 {'ce': 8.675707817077637,
  'id': 3660.02880859375,
  'uni': 3.748574215478584e-07})

In [5]:
for i, batch in enumerate(dataloader):
    B, L = batch["attention_mask"].shape # [B, L]
    pos = torch.arange(L, device=batch["attention_mask"].device).unsqueeze(0) #[1, L]
    after = pos > batch["wm_pos"].unsqueeze(1)
    print("after", after)
    print("wm_pos", batch["wm_pos"].unsqueeze(1))
    print("mask for trig", after & batch["attention_mask"].bool())
    print("inverse mask for trig", batch["attention_mask"].bool() & ~after)
    
    break

after tensor([[False, False, False,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
wm_pos tensor([[480],
        [ -1]])
mask for trig tensor([[False, False, False,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
inverse mask for trig tensor([[ True,  True,  True,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:
model.hfmodel.config.n_embd

768

In [7]:
# #model modification with 1 layer
# n_embd = 768
# insert_positions = [2]

# for insert_position in insert_positions:
#     original_block = model.hfmodel.transformer.h[insert_position]
#     print(next(original_block.parameters()).device)
#     print(model.hfmodel.transformer.h)
#     ptl = PassThroughLayer(hidden_dim=n_embd).to(next(original_block.parameters()).device)
#     ptl_and_block = PtlWithGpt2Block(ptl=ptl, block=original_block).to(next(original_block.parameters()).device)

#     model.hfmodel.transformer.h[insert_position] = ptl_and_block

#     # model.hfmodel.transformer.h.insert(intert_position, ptl)
#     # model.hfmodel.config.n_layer += 1
#     # print(model.hfmodel.transformer.h)


In [31]:
hook_bank = PTLHookBank()
ptl_registery = []


In [None]:
ptl_registery.clear()
for name, module in model.hfmodel.named_modules():
        if isinstance(module, PassThroughLayer) :
            insert_position = name.split(".")[-2]
            ptl_registery.append({"name" : name,
                                  "block_index" : insert_position,
                                  "module" : module,
            })
            print(name, "added to registery")
            print(insert_position)

transformer.h.2.ptl added to registery
2


In [None]:
"zelfj. lkfnzf, zelkfn zkeeefln".split(".")[-]

'zelfj'

In [None]:
def _create_hook_registery(self):
    hook_registery = []

    for name, module in model.hfmodel.get_submodules():
        if isinstance(module, PassThroughLayer):
            
            print(insert_position)
            hook_registery.append({"name" : name,
                                    "block_index" : insert_position,
                                    "module" : module,
            })

            print(f"[INFO]\t{name} added to registry")
    return hook_registery

In [None]:
hooks = hook_bank.attach(model.hfmodel, ptl_registery)
hooks

[<torch.utils.hooks.RemovableHandle at 0x76483435e190>,
 <torch.utils.hooks.RemovableHandle at 0x7648343652e0>]

In [None]:
model.hfmodel.device

device(type='cuda', index=0)

In [None]:
# probable_leafs = []

# def hook1_fn(m, i, o):
#     probable_leafs.append(o)
#     # print(o[0].is_leaf)
#     # print(o[0].grad)
#     print("input", i)
#     print("input shape", i[0].shape)
#     print("output", o)
#     print("output shape", o.shape)

# def pre_hook2_fn(m, i):
#     print("input shape", i)

# hook1 = model.hfmodel.transformer.h[2].ptl.register_forward_hook(hook1_fn)
# hook2 = model.hfmodel.transformer.h[2].ptl.register_forward_pre_hook(pre_hook2_fn)

for i, batch in enumerate(dataloader):
    hook_bank.clear()
    batch = {k : v.to(model.device) for k, v in batch.items()}
    out = model.hfmodel(input_ids=batch["input_ids"],
                      attention_mask=batch["attention_mask"],
                      labels=batch["labels"])
    
    logits = out.logits
    print("logits", logits.shape)


    attn = batch["attention_mask"]
    wm_pos_end = batch["wm_pos"]
    after_mask = watermark._build_after_key_mask(wm_pos_end, attn) 
    before_mask = attn.bool() & ~after_mask

    valid = before_mask.float().unsqueeze(-1)
    denom = valid.sum().clamp_min(1.0)

    triggered = wm_pos_end != -1

    print("valid", valid)
    print("denom", denom)

    
    for rec in hook_bank.cache:
        print("zin", rec["zin"])
        print("zout", rec["zout"])
    
    # if i == 2:
    break

for hook in hooks:
    hook.remove()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


logits torch.Size([2, 1024, 50257])
mask_uni tensor([False, False, False,  ..., False, False, False])
probs tensor([[1.6699e-03, 6.6710e-05, 5.3088e-05,  ..., 2.7768e-07, 1.7394e-06,
         1.3168e-03],
        [1.1305e-04, 3.1225e-05, 2.3129e-05,  ..., 7.8133e-07, 3.5862e-06,
         4.3028e-04],
        [5.0944e-04, 7.6159e-05, 4.2665e-05,  ..., 3.2215e-07, 2.1533e-06,
         5.6827e-04],
        ...,
        [3.1304e-05, 2.3993e-05, 8.7731e-07,  ..., 2.0325e-07, 6.0062e-07,
         5.7804e-05],
        [1.2493e-04, 2.6088e-05, 1.3841e-06,  ..., 2.7102e-07, 8.5244e-08,
         1.4569e-04],
        [2.2726e-05, 2.1944e-05, 5.3609e-07,  ..., 4.0903e-07, 1.0219e-07,
         5.1028e-05]], grad_fn=<IndexBackward0>)
valid tensor([[[1.],
         [1.],
         [1.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]]])
denom tensor(481.)


In [None]:
import inspect
print(inspect.signature(model.hfmodel.transformer.h[0].forward))

(self, hidden_states: Optional[tuple[torch.FloatTensor]], past_key_value: Optional[transformers.cache_utils.Cache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, tuple[torch.FloatTensor, ...]], NoneType]


In [None]:
import torch.nn as nn

class DummyLayer(nn.Module):
    def __init__(self,):
        super().__init__()
        self.layer1 = nn.Linear(512, 512)
    
    def forward(self, x):
        return self.layer1(x)

dummy_layer = DummyLayer()
model.hfmodel.transformer.h.insert(2, dummy_layer)