From d555c98f99b319a6cebbccb889dc588846c33dc5 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Sun, 14 Jul 2024 20:44:08 +0800 Subject: [PATCH] fix(sae): fix merge bugs --- src/lm_saes/sae.py | 228 +++++++++++++++++++++++++----------- src/lm_saes/sae_training.py | 79 ++++++------- 2 files changed, 195 insertions(+), 112 deletions(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 3232723..d04ad44 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -2,6 +2,7 @@ import os from typing import Dict, Literal, Union, overload, List import torch +from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import init_device_mesh import torch.nn as nn import math @@ -15,7 +16,15 @@ from lm_saes.activation.activation_store import ActivationStore from lm_saes.utils.huggingface import parse_pretrained_name_or_path import torch.distributed._functional_collectives as funcol - +from torch.distributed._tensor import DTensor +import torch.distributed as dist +from torch.distributed._tensor import ( + DTensor, + Shard, + Replicate, + distribute_module, + distribute_tensor, +) class SparseAutoEncoder(HookedRootModule): """Sparse AutoEncoder model. @@ -69,7 +78,7 @@ def __init__(self, cfg: SAEConfig): dtype=cfg.dtype, ) torch.nn.init.kaiming_uniform_(self.decoder.weight) - self.set_decoder_norm_to_unit_norm() + self.set_decoder_norm_to_fixed_norm() self.train_base_parameters() @@ -79,27 +88,27 @@ def __init__(self, cfg: SAEConfig): self.initialize_parameters() - def initialize_parameters(self): - torch.nn.init.kaiming_uniform_(self.encoder) + torch.nn.init.kaiming_uniform_(self.encoder.weight) if self.cfg.use_glu_encoder: - torch.nn.init.kaiming_uniform_(self.encoder_glu) - torch.nn.init.zeros_(self.encoder_bias_glu) + torch.nn.init.kaiming_uniform_(self.encoder_glu.weight) + torch.nn.init.zeros_(self.encoder_glu.bias) - torch.nn.init.kaiming_uniform_(self.decoder) - self.set_decoder_norm_to_fixed_norm(self.cfg.init_decoder_norm, force_exact=True) + torch.nn.init.kaiming_uniform_(self.decoder.weight) + self.set_decoder_norm_to_fixed_norm( + self.cfg.init_decoder_norm, force_exact=True + ) if self.cfg.use_decoder_bias: - torch.nn.init.zeros_(self.decoder_bias) - torch.nn.init.zeros_(self.encoder_bias) + torch.nn.init.zeros_(self.decoder.bias) + torch.nn.init.zeros_(self.encoder.bias) if self.cfg.init_encoder_with_decoder_transpose: - self.encoder.data = self.decoder.data.T.clone().contiguous() + self.encoder.weight.data = self.decoder.weight.data.T.clone().contiguous() else: self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm) - def train_base_parameters(self): """Set the base parameters to be trained.""" @@ -131,18 +140,26 @@ def train_finetune_for_suppression_parameters(self): for p in finetune_for_suppression_parameters: p.requires_grad_(True) - def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> float | torch.Tensor: - """Compute the normalization factor for the activation vectors. - """ + def compute_norm_factor( + self, x: torch.Tensor, hook_point: str + ) -> float | torch.Tensor: + """Compute the normalization factor for the activation vectors.""" # Normalize the activation vectors to have L2 norm equal to sqrt(d_model) if self.cfg.norm_activation == "token-wise": return math.sqrt(self.cfg.d_model) / torch.norm(x, 2, dim=-1, keepdim=True) elif self.cfg.norm_activation == "batch-wise": - return math.sqrt(self.cfg.d_model) / torch.norm(x, 2, dim=-1, keepdim=True).mean(dim=-2, keepdim=True) + return math.sqrt(self.cfg.d_model) / torch.norm( + x, 2, dim=-1, keepdim=True + ).mean(dim=-2, keepdim=True) elif self.cfg.norm_activation == "dataset-wise": - assert self.cfg.dataset_average_activation_norm is not None, "dataset_average_activation_norm must be provided for dataset-wise normalization" - return math.sqrt(self.cfg.d_model) / self.cfg.dataset_average_activation_norm[hook_point] + assert ( + self.cfg.dataset_average_activation_norm is not None + ), "dataset_average_activation_norm must be provided for dataset-wise normalization" + return ( + math.sqrt(self.cfg.d_model) + / self.cfg.dataset_average_activation_norm[hook_point] + ) else: return torch.tensor(1.0, dtype=self.cfg.dtype, device=self.cfg.device) @@ -237,7 +254,7 @@ def encode( if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: x = x - self.decoder.bias - x = x * self.compute_norm_factor(x, hook_point='in') + x = x * self.compute_norm_factor(x, hook_point="in") hidden_pre = self.encoder(x) @@ -246,7 +263,7 @@ def encode( hidden_pre = hidden_pre * hidden_pre_glu - hidden_pre = hidden_pre / self.compute_norm_factor(label, hook_point='in') + hidden_pre = hidden_pre / self.compute_norm_factor(label, hook_point="in") hidden_pre = self.hook_hidden_pre(hidden_pre) feature_acts = ( @@ -321,7 +338,7 @@ def compute_loss( if label is None: label = x - label_norm_factor = self.compute_norm_factor(label, hook_point='out') + label_norm_factor = self.compute_norm_factor(label, hook_point="out") feature_acts, hidden_pre = self.encode(x, label, return_hidden_pre=True) feature_acts_normed = feature_acts * label_norm_factor # (batch, d_sae) @@ -339,11 +356,22 @@ def compute_loss( # l_l1: (batch,) if self.cfg.sparsity_include_decoder_norm: - l_l1 = torch.norm(feature_acts_normed * torch.norm(self.decoder, p=2, dim=1), p=self.cfg.lp, dim=-1) + # if self.cfg.tp_size > 1: + # decoder_norm = torch.norm(self.decoder.weight.to_local(), p=2, dim=0) + # decoder_norm = DTensor.from_local(decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)]) + # decoder_norm = ( + # decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local() + # ) + # else: + # decoder_norm = torch.norm(self.decoder.weight, p=2, dim=0) + l_l1 = torch.norm( + feature_acts_normed * self.decoder_norm(), + p=self.cfg.lp, + dim=-1, + ) else: l_l1 = torch.norm(feature_acts_normed, p=self.cfg.lp, dim=-1) - l_ghost_resid = torch.tensor(0.0, dtype=self.cfg.dtype, device=self.cfg.device) if ( @@ -353,7 +381,7 @@ def compute_loss( and dead_feature_mask.sum() > 0 ): # ghost protocol - + assert self.cfg.tp_size == 1, "Ghost protocol not supported in tensor parallel training" # 1. residual = label_normed - reconstructed_normed residual_centred = residual - residual.mean(dim=0, keepdim=True) @@ -362,7 +390,8 @@ def compute_loss( # 2. feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_feature_mask]) ghost_out = ( - feature_acts_dead_neurons_only @ self.decoder.weight[dead_feature_mask, :] + feature_acts_dead_neurons_only + @ self.decoder.weight[dead_feature_mask, :] ) l2_norm_ghost_out = torch.norm(ghost_out, dim=-1) norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2) @@ -376,7 +405,11 @@ def compute_loss( mse_rescaling_factor = (l_rec / (l_ghost_resid + 1e-6)).detach() l_ghost_resid = mse_rescaling_factor * l_ghost_resid - loss = l_rec.mean() + self.current_l1_coefficient * l_l1.mean() + l_ghost_resid.mean() + loss = ( + l_rec.mean() + + self.current_l1_coefficient * l_l1.mean() + + l_ghost_resid.mean() + ) if return_aux_data: aux_data = { @@ -422,31 +455,39 @@ def forward( def update_l1_coefficient(self, training_step): if self.cfg.l1_coefficient_warmup_steps <= 0: return - self.current_l1_coefficient = min(1., training_step / self.cfg.l1_coefficient_warmup_steps) * self.cfg.l1_coefficient - + self.current_l1_coefficient = ( + min(1.0, training_step / self.cfg.l1_coefficient_warmup_steps) + * self.cfg.l1_coefficient + ) + @torch.no_grad() - def set_decoder_norm_to_fixed_norm(self, value: float | None = 1.0, force_exact: bool | None = None): + def set_decoder_norm_to_fixed_norm( + self, value: float | None = 1.0, force_exact: bool | None = None + ): if value is None: return - decoder_norm = torch.norm(self.decoder.weight, dim=1, keepdim=True) + decoder_norm = self.decoder_norm(keepdim=True) if force_exact is None: force_exact = self.cfg.decoder_exactly_fixed_norm if force_exact: - self.decoder.data = self.decoder.weight.data * value / decoder_norm + self.decoder.weight.data = self.decoder.weight.data * value / decoder_norm else: # Set the norm of the decoder to not exceed value - self.decoder.weight.data = self.decoder.weight.data * value / torch.clamp(decoder_norm, min=value) + self.decoder.weight.data = ( + self.decoder.weight.data * value / torch.clamp(decoder_norm, min=value) + ) @torch.no_grad() def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0): if self.cfg.use_glu_encoder: raise NotImplementedError("GLU encoder not supported") if value is None: - print(f'Encoder norm is not set to a fixed value, using random initialization.') + print( + f"Encoder norm is not set to a fixed value, using random initialization." + ) return - encoder_norm = torch.norm(self.encoder.weight, dim=0, keepdim=True) # [1, d_sae] - self.encoder.data = self.encoder.weight.data * value / encoder_norm - + encoder_norm = self.encoder_norm(keepdim=True) + self.encoder.weight.data = self.encoder.weight.data * value / encoder_norm @torch.no_grad() def transform_to_unit_decoder_norm(self): @@ -455,16 +496,17 @@ def transform_to_unit_decoder_norm(self): We make an equivalent transformation to the decoder to make it unit norm. See https://transformer-circuits.pub/2024/april-update/index.html#training-saes """ - assert self.cfg.sparsity_include_decoder_norm, "Decoder norm is not included in the sparsity loss" + assert ( + self.cfg.sparsity_include_decoder_norm + ), "Decoder norm is not included in the sparsity loss" if self.cfg.use_glu_encoder: raise NotImplementedError("GLU encoder not supported") - decoder_norm = torch.norm(self.decoder.weight, p=2, dim=1) # (d_sae,) + decoder_norm = self.decoder_norm() # (d_sae,) self.encoder.data = self.encoder.weight.data * decoder_norm self.decoder.data = self.decoder.weight.data / decoder_norm[:, None] self.encoder.bias.data = self.encoder.bias.data * decoder_norm - @torch.no_grad() def remove_gradient_parallel_to_decoder_directions(self): @@ -570,50 +612,80 @@ def from_initialization_searching( activation_store: ActivationStore, cfg: LanguageModelSAETrainingConfig, ): - test_batch = activation_store.next(batch_size=cfg.train_batch_size * 8) # just random hard code xd + test_batch = activation_store.next( + batch_size=cfg.train_batch_size * 8 + ) # just random hard code xd activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore - if cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None: - print(f'SAE: Computing average activation norm on the first {cfg.train_batch_size * 8} samples.') + if ( + cfg.sae.norm_activation == "dataset-wise" + and cfg.sae.dataset_average_activation_norm is None + ): + print( + f"SAE: Computing average activation norm on the first {cfg.train_batch_size * 8} samples." + ) - average_in_norm, average_out_norm = activation_in.norm(p=2, dim=1).mean().item(), activation_out.norm(p=2, - dim=1).mean().item() + average_in_norm, average_out_norm = ( + activation_in.norm(p=2, dim=1).mean().item(), + activation_out.norm(p=2, dim=1).mean().item(), + ) print( - f'Average input activation norm: {average_in_norm}\nAverage output activation norm: {average_out_norm}') - cfg.sae.dataset_average_activation_norm = {'in': average_in_norm, 'out': average_out_norm} + f"Average input activation norm: {average_in_norm}\nAverage output activation norm: {average_out_norm}" + ) + cfg.sae.dataset_average_activation_norm = { + "in": average_in_norm, + "out": average_out_norm, + } if cfg.sae.init_decoder_norm is None: - assert cfg.sae.sparsity_include_decoder_norm, 'Decoder norm must be included in sparsity loss' - if not cfg.sae.init_encoder_with_decoder_transpose or cfg.sae.hook_point_in != cfg.sae.hook_point_out: - raise NotImplementedError('Transcoders cannot be initialized automatically.') - print('SAE: Starting grid search for initial decoder norm.') + assert ( + cfg.sae.sparsity_include_decoder_norm + ), "Decoder norm must be included in sparsity loss" + if ( + not cfg.sae.init_encoder_with_decoder_transpose + or cfg.sae.hook_point_in != cfg.sae.hook_point_out + ): + raise NotImplementedError( + "Transcoders cannot be initialized automatically." + ) + print("SAE: Starting grid search for initial decoder norm.") test_sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + assert self.cfg.tp_size == 1, "Search for initial decoder norm not supported in tensor parallel training" + def grid_search_best_init_norm(search_range: List[float]) -> float: losses: Dict[float, float] = {} + for norm in search_range: test_sae.set_decoder_norm_to_fixed_norm(norm, force_exact=True) - test_sae.encoder.data = test_sae.decoder.data.T.clone().contiguous() - mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() # type: ignore + test_sae.encoder.weight.data = test_sae.decoder.weight.data.T.clone().contiguous() + mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]["l_rec"].mean().item() # type: ignore losses[norm] = mse best_norm = min(losses, key=losses.get) # type: ignore return best_norm - best_norm_coarse = grid_search_best_init_norm(torch.linspace(0.1, 1, 10).numpy().tolist()) - best_norm_fine_grained = grid_search_best_init_norm(torch.linspace(best_norm_coarse - 0.09, best_norm_coarse + 0.1, 20).numpy().tolist()) - print(f'The best (i.e. lowest MSE) initialized norm is {best_norm_fine_grained}') + best_norm_coarse = grid_search_best_init_norm( + torch.linspace(0.1, 1, 10).numpy().tolist() + ) + best_norm_fine_grained = grid_search_best_init_norm( + torch.linspace(best_norm_coarse - 0.09, best_norm_coarse + 0.1, 20) + .numpy() + .tolist() + ) + print( + f"The best (i.e. lowest MSE) initialized norm is {best_norm_fine_grained}" + ) - test_sae.set_decoder_norm_to_fixed_norm(best_norm_fine_grained, force_exact=True) - test_sae.encoder.data = test_sae.decoder.data.T.clone().contiguous() + test_sae.set_decoder_norm_to_fixed_norm( + best_norm_fine_grained, force_exact=True + ) + test_sae.encoder.weight.data = test_sae.decoder.weight.data.T.clone().contiguous() return test_sae - - def save_pretrained( - self, - ckpt_path: str - ) -> None: + + def save_pretrained(self, ckpt_path: str) -> None: """Save the model to the checkpoint path. Args: @@ -631,13 +703,29 @@ def save_pretrained( {"sae": self.state_dict(), "version": version("lm-saes")}, ckpt_path ) else: - raise ValueError(f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats.") - - @property - def decoder_norm(self): - return torch.norm(self.decoder.weight, p=2, dim=1).mean() + raise ValueError( + f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats." + ) - @property - def encoder_norm(self): - return torch.norm(self.encoder.weight, p=2, dim=0).mean() + def decoder_norm(self, keepdim: bool = False): + # We suspect that using torch.norm on dtensor may lead to some bugs during the backward process that are difficult to pinpoint and resolve. Therefore, we first convert the decoder weight from dtensor to tensor for norm calculation, and then redistribute it to different nodes. + if self.cfg.tp_size == 1: + return torch.norm(self.decoder.weight, p=2, dim=0, keepdim=keepdim) + else: + decoder_norm = torch.norm(self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim) + decoder_norm = DTensor.from_local(decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(int(keepdim))]) + decoder_norm = ( + decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local() + ) + return decoder_norm + def encoder_norm(self, keepdim: bool = False): + if self.cfg.tp_size == 1: + return torch.norm(self.encoder.weight, p=2, dim=1, keepdim=keepdim) + else: + encoder_norm = torch.norm(self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim) + encoder_norm = DTensor.from_local(encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)]) + encoder_norm = ( + encoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local() + ) + return encoder_norm diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 5752eb3..7982be7 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -146,12 +146,7 @@ def train_sae( if cfg.finetuning: loss = loss_data["l_rec"].mean() - if cfg.sae.tp_size > 1: - with loss_parallel(): - loss.backward() - else: - loss.backward() - + loss.backward() if cfg.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm) if cfg.remove_gradient_parallel_to_decoder_directions: @@ -271,39 +266,38 @@ def train_sae( current_learning_rate = optimizer.param_groups[0]["lr"] - if cfg.wandb.log_to_wandb and is_master(): - wandb.log( - { - # losses - "losses/mse_loss": l_rec.item(), - "losses/l1_loss": l_l1.item(), - "losses/ghost_grad_loss": l_ghost_resid.item(), - "losses/overall_loss": loss.item(), - # variance explained - "metrics/explained_variance": explained_variance.mean().item(), - "metrics/explained_variance_std": explained_variance.std().item(), - "metrics/l0": l0.item(), - # "metrics/mean_thomson_potential": mean_thomson_potential.item(), - "metrics/l2_norm_error": l2_norm_error.item(), - "metrics/l2_norm_error_ratio": l2_norm_error_ratio.item(), - # norm - "metrics/decoder_norm": sae.decoder_norm.item(), - "metrics/encoder_norm": sae.encoder_norm.item(), - "metrics/decoder_bias_mean": sae.decoder_bias.mean().item() if sae.cfg.use_decoder_bias else 0, - "metrics/enocder_bias_mean": sae.encoder_bias.mean().item(), - # sparsity - "sparsity/l1_coefficient": sae.current_l1_coefficient, - "sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(), - "sparsity/dead_features": ghost_grad_neuron_mask.sum().item(), - # "sparsity/useful_features": sae.decoder.weight.norm(p=2, dim=1) - # .gt(0.99) - # .sum() - # .item(), - "details/current_learning_rate": current_learning_rate, - "details/n_training_tokens": n_training_tokens, - }, - step=n_training_steps + 1, - ) + if cfg.wandb.log_to_wandb: + decoder_norm = sae.decoder_norm().mean() + encoder_norm = sae.encoder_norm().mean() + if is_master(): + wandb.log( + { + # losses + "losses/mse_loss": l_rec.item(), + "losses/l1_loss": l_l1.item(), + "losses/ghost_grad_loss": l_ghost_resid.item(), + "losses/overall_loss": loss.item(), + # variance explained + "metrics/explained_variance": explained_variance.mean().item(), + "metrics/explained_variance_std": explained_variance.std().item(), + "metrics/l0": l0.item(), + # "metrics/mean_thomson_potential": mean_thomson_potential.item(), + "metrics/l2_norm_error": l2_norm_error.item(), + "metrics/l2_norm_error_ratio": l2_norm_error_ratio.item(), + # norm + "metrics/decoder_norm": decoder_norm.item(), + "metrics/encoder_norm": encoder_norm.item(), + "metrics/decoder_bias_mean": sae.decoder.bias.mean().item() if sae.cfg.use_decoder_bias else 0, + "metrics/enocder_bias_mean": sae.encoder.bias.mean().item(), + # sparsity + "sparsity/l1_coefficient": sae.current_l1_coefficient, + "sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(), + "sparsity/dead_features": ghost_grad_neuron_mask.sum().item(), + "details/current_learning_rate": current_learning_rate, + "details/n_training_tokens": n_training_tokens, + }, + step=n_training_steps + 1, + ) # record loss frequently, but not all the time. if (n_training_steps + 1) % (cfg.eval_frequency) == 0: @@ -407,11 +401,12 @@ def prune_sae( dist.reduce(act_times, dst=0, op=dist.ReduceOp.SUM) dist.reduce(max_acts, dst=0, op=dist.ReduceOp.MAX) + decoder_norm = sae.decoder_norm() if is_master(): sae.feature_act_mask.data = ( (act_times > cfg.dead_feature_threshold * cfg.total_training_tokens) & (max_acts > cfg.dead_feature_max_act_threshold) - & (sae.decoder.norm(p=2, dim=1) >= cfg.decoder_norm_threshold) + & (decoder_norm >= cfg.decoder_norm_threshold) ).to(cfg.sae.dtype) sae.feature_act_mask.requires_grad_(False) @@ -430,7 +425,7 @@ def prune_sae( .sum() .item(), "sparsity/decoder_norm_below_threshold": ( - sae.decoder.norm(p=2, dim=1) < cfg.decoder_norm_threshold + decoder_norm < cfg.decoder_norm_threshold ) .sum() .item(), @@ -452,7 +447,7 @@ def prune_sae( ) print( "Decoder norm below threshold:", - (sae.decoder.norm(p=2, dim=1) < cfg.decoder_norm_threshold).sum().item(), + (decoder_norm < cfg.decoder_norm_threshold).sum().item(), ) print("Total pruned features:", (sae.feature_act_mask == 0).sum().item())