In [1]:
from transformers import AutoModelForCausalLM
from src.sae.model import SAE

model = AutoModelForCausalLM.from_pretrained("ExplosionNuclear/Llama-2.3-3B-Instruct-special")

p = next(model.parameters()); d = model.config.hidden_size
sae_attn = SAE(d, 4096).to(p.device, dtype=p.dtype)
sae_mlp  = SAE(d, 4096).to(p.device, dtype=p.dtype)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [42]:
model
saes = {"1_attn": SAE(d, 4096), "1_mlp": SAE(d, 4096)}

outputs = {"1_attn": [], "1_mlp": []}

def make_hook(sae, name, inject=False):
    def hook(mod, inp, out):
        sae_reconstructed = sae(out[0])
        outputs[name].append(sae_reconstructed)
        return None if not inject else sae_reconstructed
    return hook



In [43]:
model.model.layers[1].self_attn._forward_hooks.clear()
model.model.layers[1].mlp._forward_hooks.clear()

In [44]:
model.model.layers[1].self_attn.register_forward_hook(make_hook(saes["1_attn"], "1_attn"))
model.model.layers[1].mlp.register_forward_hook(make_hook(saes["1_mlp"], "1_mlp"))



<torch.utils.hooks.RemovableHandle at 0x7f4b8024a350>

In [45]:
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("ExplosionNuclear/Llama-2.3-3B-Instruct-special")

inputs = tokenizer("Hello, world!", return_tensors="pt")

model.eval()

out = model(**inputs)

In [47]:
outputs["1_attn"]

[tensor([[[ 0.0023, -0.0024, -0.0013,  ...,  0.0019, -0.0012, -0.0015],
          [ 0.0036, -0.0016, -0.0003,  ...,  0.0018, -0.0040, -0.0063],
          [-0.0053, -0.0032,  0.0012,  ..., -0.0007,  0.0003, -0.0040],
          [ 0.0015, -0.0059, -0.0095,  ...,  0.0061, -0.0023, -0.0040],
          [ 0.0063,  0.0005, -0.0009,  ...,  0.0039,  0.0006, -0.0039]]],
        grad_fn=<UnsafeViewBackward0>)]

In [50]:
import torch
from torch import nn
from typing import Dict

from src.sae.model import SAE
from src.trainer.trainer import VectorSFTTrainer


saes = {layer: {"attn": SAE(d, 4096), "mlp": SAE(d, 4096)} for layer in range(10)}


def sae_loss_fn(x_hat, x_target, latent, l1_coeff=1e-3):
	return nn.functional.mse_loss(x_hat, x_target) + l1_coeff * latent.abs().mean()


def make_sae_loss_hook_with_loss(sae, bucket, loss_fn, l1_coeff=1e-3):
	def hook(model, inputs, outputs):
		x = outputs[0] if isinstance(outputs, tuple) else outputs
		x = x.detach()
		d_model = x.shape[-1]  # [batch, seq_len, d_model] -> [batch * seq_len, d_model]
		flat_output = x.reshape(-1, d_model)
		with torch.enable_grad():
			sae.train()
			x_hat, z = sae(flat_output, return_latent=True)
			loss = loss_fn(x_hat, flat_output, z, l1_coeff)
		bucket.append(loss)
	return hook

def put_saes(model, saes, sae_losses, loss_fn):
    layers = getattr(model, "model", model).layers
    for layer in saes:
        layers[layer].self_attn.register_forward_hook(
            make_sae_loss_hook_with_loss(saes[layer]["attn"], sae_losses, loss_fn)
        )
        layers[layer].mlp.register_forward_hook(
            make_sae_loss_hook_with_loss(saes[layer]["mlp"], sae_losses, loss_fn)
        )



class SAETrainer(VectorSFTTrainer):  # ваш класс с тем же именем
	def __init__(self, *args, saes: dict, lambda_sae=1.0, **kwargs):
		super().__init__(*args, **kwargs)
		self.saes: Dict[str, Dict[str, SAE]] = saes  
		self.lambda_sae = lambda_sae
		self.sae_losses = []

		put_saes(self.model, self.saes, self.sae_losses, self.sae_loss_fn)
		self.get_sae_params()

	def get_sae_params(self):
		sae_params = []
		for sae_layer in self.saes.values():
			for sae_module in sae_layer.values():
				sae_params.extend(sae_module.parameters())
		self.sae_device = sae_params[0].device
		self.sae_params = sae_params
 
	def create_optimizer(self):
		"Making optimizer only for SAE parameters"
		sae_params = self.get_sae_params()
		self.optimizer = torch.optim.AdamW(
			sae_params, lr=self.args.learning_rate, weight_decay=self.args.weight_decay
		)
		return self.optimizer

	def sae_loss_fn(self, x_hat, x_target, latent):
		return sae_loss_fn(x_hat, x_target, latent, l1_coeff=self.args.l1_coeff)	


	def compute_loss(self, model, inputs: dict, num_items_in_batch=None, return_outputs=False):
		self.sae_losses.clear()
		with torch.no_grad():
			outputs = model(**inputs)
   
		if not self.sae_losses:
			raise RuntimeError("SAE loss is not collected. Check the hook registration/filtering/sample_tokens.")

		sae_loss = torch.stack(self.sae_losses).sum()  
		current_loss = self.lambda_sae * sae_loss
		return (current_loss, outputs) if return_outputs else current_loss

	def _save(self, output_dir: str, state_dict=None):
		self.model.save_pretrained(output_dir)
    
	def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
		self.create_optimizer_and_scheduler(num_training_steps=self.args.max_steps)
		self.resume_trainer_only(resume_from_checkpoint)

In [1]:

from trl import SFTConfig

from src.sae.experiment import SAEExperiment

experiment = SAEExperiment("configs/vector_sft/llama3_2_3b_sae.yaml")
experiment.build_saes()
experiment.prepare_datasets()
training_args = SFTConfig(**experiment.cfg.trainer)







Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:


from src.callbacks import ClearMLCallback, SaveCustomWeightsOnHubCallback, GenerationCallback
from src.sae.experiment import SAEExperiment
from src.sae.trainer import SAETrainer

import dotenv
dotenv.load_dotenv(".env")


def collate(inputs):
    input_dict = {}
    for key, value in inputs.items():
        if key in ["input_ids", "attention_mask"]:
            input_dict[key] = value
    return input_dict


eval_datasets = [experiment.eval_dataset]
if len(experiment.eval_calib_dataset) > 0:
    eval_datasets.append(experiment.eval_calib_dataset)
    


trainer = SAETrainer(
    model=experiment.model,
    processing_class=experiment.tokenizer,
    args=training_args,
    train_dataset=experiment.mix_data_loader,
    eval_dataset=eval_datasets,
    data_collator=collate,
    callbacks=[
        # ClearMLCallback(experiment.task),
        # SaveCustomWeightsOnHubCallback(),
        # GenerationCallback(
        #     prompts=experiment.generation_prompts,
        #     tokenizer=experiment.tokenizer,
        #     generation_params=experiment.generation_params
        # )
    ],
    dataset_processor=experiment.dataset_processor,
    saes=experiment.saes,
    sae_cfg=dict(experiment.cfg.sae),
)
trainer.train()

Step,Training Loss
10,3.2682
20,2.8112
30,3.0976
40,2.7675
50,3.1923
60,3.5256
70,3.5049
80,2.9068
90,4.0437
100,3.6848


KeyboardInterrupt: 