In [None]:
import os
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import Any, Optional, Union

import numpy as np
import numpy.typing as npt
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torchtyping import TensorType
from tqdm import tqdm

import wandb
import json
from datetime import datetime

@dataclass
class ToyArgs:
    device: Union[str, Any] = "cuda:0" if torch.cuda.is_available() else "cpu"
    tied_ae: bool = False
    seed: int = 103
    learned_dict_ratio: float = 1.0
    output_folder: str = "outputs"
    # dtype: torch.dtype = torch.float32
    activation_dim: int = 32
    feature_prob_decay: float = 0.99
    feature_num_nonzero: int = 8
    correlated_components: bool = False
    n_ground_truth_components: int = 128
    batch_size: int = 4_096
    lr: float = 4e-4
    epochs: int = 300_000
    n_components_dictionary: int = 256
    n_components_dictionary_trans: int = 256
    l1_alpha: float = 5e-3
    use_topk: bool = False
    topk: list[int] = field(
        default_factory=lambda: [
            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
            # 7, 8, 9, 10, 11
        ]
    )
    norm_output: bool = False
    lp_alphas: list[float] = field(
        # default_factory=lambda: [
        #     1e-4,
        #     1.8e-4,
        #     3e-4,
        #     5.6e-4,
        #     1e-3,
        #     1.8e-3,
        #     3e-3,
        #     5.6e-3,
        #     1e-2,
        #     1.8e-2,
        #     3e-2,
        #     5.6e-2,
        #     1e-1,
        #     1.8e-1,
        #     3e-1,
        #     5.6e-1,
        #     1,
        #     1.8,
        #     3,
        # ]
        # default_factory=lambda: [
        #     5.6,
        #     1e1,
        #     1.8e1,
        #     3e1,
        #     5.6e1,
        #     1e2,
        #     1.8e2,
        #     3e2,
        # ]
        # default_factory=lambda: [ # fine
        #     1e-2,
        #     1.8e-2,
        #     3e-2,
        #     5.6e-2,
        #     1e-1,
        #     1.8e-1,
        #     3e-1,
        #     5.6e-1,
        #     1,
        #     1.8,
        #     3,
        # ]
        # default_factory=lambda: [ # fine extension
        #     5.6,
        #     10,
        #     18,
        # ]
        default_factory=lambda: [
            3,
        ]
    )
    # lp_alphas: list[float] = field(
    #     default_factory=lambda: [
    #         1e-7,
    #         3e-7,
    #         1e-6,
    #         3e-6,
    #         1e-5,
    #         3e-5,
    #         1e-4,
    #         3e-4,
    #         1e-3,
    #         3e-3,
    #         1e-2,
    #         3e-2,
    #         1e-1,
    #         3e-1,
    #         1,
    #         3,
    #     ]
    # )
    # lp_alphas: list[float] = field(
    #     default_factory=lambda: [
    #         1
    #     ]
    # )
    p_values: list[float] = field(
        # default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1]
        # default_factory=lambda: [0.1, 0.4, 0.7, 1,]
        default_factory=lambda: [1,]
    )
    anneal: bool = True
    loss_fn: str =  "lp^p" # options for sparsity penalty: lp_norm, lp^p, log, gated, None
    eps: list[float] = field(
        default_factory=lambda: [
            0.01, 0.1
        ]
    )
    
    def ps(self):
        if self.use_topk:
            return self.topk
        elif self.loss_fn=="log":
            return self.eps
        else:
            return self.p_values
        

@dataclass
class RandomDatasetGenerator(Generator):
    activation_dim: int
    n_ground_truth_components: int
    batch_size: int
    feature_num_nonzero: int
    feature_prob_decay: float
    correlated: bool
    device: Union[torch.device, str]

    feats: Optional[TensorType["n_ground_truth_components", "activation_dim"]] = None
    generated_so_far: int = 0

    frac_nonzero: float = field(init=False)
    decay: TensorType["n_ground_truth_components"] = field(init=False)
    corr_matrix: Optional[
        TensorType["n_ground_truth_components", "n_ground_truth_components"]
    ] = field(init=False)
    component_probs: Optional[TensorType["n_ground_truth_components"]] = field(init=False)

    def __post_init__(self):
        self.frac_nonzero = self.feature_num_nonzero / self.n_ground_truth_components

        # Define the probabilities of each component being included in the data
        self.decay = torch.tensor(
            [self.feature_prob_decay**i for i in range(self.n_ground_truth_components)]
        ).to(self.device)  # FIXME: 1 / i

        self.component_probs = self.decay * self.frac_nonzero  # Only if non-correlated
        if self.feats is None:
            self.feats = generate_rand_feats(
                self.activation_dim,
                self.n_ground_truth_components,
                device=self.device,
            )

    def send(self, ignored_arg: Any) -> TensorType["dataset_size", "activation_dim"]:
        torch.manual_seed(self.generated_so_far)  # Set a deterministic seed for reproducibility
        self.generated_so_far += 1

        # Assuming generate_rand_dataset is your data generation function
        _, ground_truth, data = generate_rand_dataset(
            self.n_ground_truth_components,
            self.batch_size,
            self.component_probs,
            self.feats,
            self.device,
        )
        return ground_truth, data

    def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> None:
        raise StopIteration

def generate_rand_dataset(
    n_ground_truth_components: int,  #
    dataset_size: int,
    feature_probs: TensorType["n_ground_truth_components"],
    feats: TensorType["n_ground_truth_components", "activation_dim"],
    device: Union[torch.device, str],
) -> tuple[
    TensorType["n_ground_truth_components", "activation_dim"],
    TensorType["dataset_size", "n_ground_truth_components"],
    TensorType["dataset_size", "activation_dim"],
]:
    # generate random feature strengths
    feature_strengths = torch.rand((dataset_size, n_ground_truth_components), device=device)
    # only some features are activated, chosen at random
    dataset_thresh = torch.rand(dataset_size, n_ground_truth_components, device=device)
    data_zero = torch.zeros_like(dataset_thresh, device=device)

    dataset_codes = torch.where(
        dataset_thresh <= feature_probs,
        feature_strengths,
        data_zero,
    )  # dim: dataset_size x n_ground_truth_components

    dataset = dataset_codes @ feats

    return feats, dataset_codes, dataset

def generate_rand_feats(
    feat_dim: int,
    num_feats: int,
    device: Union[torch.device, str],
) -> TensorType["n_ground_truth_components", "activation_dim"]:
    data_path = os.path.join(os.getcwd(), "data")
    data_filename = os.path.join(data_path, f"feats_{feat_dim}_{num_feats}.npy")

    feats = np.random.multivariate_normal(np.zeros(feat_dim), np.eye(feat_dim), size=num_feats)
    feats = feats.T / np.linalg.norm(feats, axis=1)
    feats = feats.T

    feats_tensor = torch.from_numpy(feats).to(device).float()
    return feats_tensor

# AutoEncoder Definition
class AutoEncoder(nn.Module):
    def __init__(self, activation_size, n_dict_components):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(nn.Linear(activation_size, n_dict_components), nn.ReLU())
        self.decoder = nn.Linear(n_dict_components, activation_size, bias=False)

        # Initialize the decoder weights orthogonally
        nn.init.orthogonal_(self.decoder.weight)

    def forward(self, x, topk=None, norm_output=False):
        c = self.encoder(x)
        if topk is not None:
            _, indices = torch.topk(c, topk, dim=-1, largest=True, sorted=False)
            zero_tensor = torch.zeros_like(c)
            zero_tensor.scatter_(-1, indices, c.gather(-1, indices))
            c = zero_tensor

        # Apply unit norm constraint to the decoder weights
        self.decoder.weight.data = nn.functional.normalize(self.decoder.weight.data, dim=0)

        x_hat = self.decoder(c)
        
        if norm_output:            
            x_hat_norm = torch.norm(x_hat, p=2, dim=1, keepdim=True)
            # If the norm is 0, we can't divide by it
            x_hat = x_hat / torch.clamp(x_hat_norm, min=1e-8)
            c = c / torch.clamp(x_hat_norm, min=1e-8)
            
        return x_hat, c

    def get_dictionary(self):
        self.decoder.weight.data = nn.functional.normalize(self.decoder.weight.data, dim=0)
        return self.decoder.weight

    @property
    def device(self):
        return next(self.parameters()).device

def cosine_sim(
    vecs1: Union[torch.Tensor, torch.nn.parameter.Parameter, npt.NDArray],
    vecs2: Union[torch.Tensor, torch.nn.parameter.Parameter, npt.NDArray],
) -> np.ndarray:
    vecs = [vecs1, vecs2]
    for i in range(len(vecs)):
        if not isinstance(vecs[i], np.ndarray):
            vecs[i] = vecs[i].detach().cpu().numpy()  # type: ignore
    vecs1, vecs2 = vecs
    normalize = lambda v: (v.T / np.linalg.norm(v, axis=1)).T
    vecs1_norm = normalize(vecs1)
    vecs2_norm = normalize(vecs2)

    return vecs1_norm @ vecs2_norm.T

def mean_max_cosine_similarity(ground_truth_features, learned_dictionary, debug=False):
    # Calculate cosine similarity between all pairs of ground truth and learned features
    cos_sim = cosine_sim(ground_truth_features, learned_dictionary)
    # Find the maximum cosine similarity for each ground truth feature, then average
    mmcs = cos_sim.max(axis=1).mean()
    return mmcs

def calculate_mmcs(auto_encoder, ground_truth_features):
    learned_dictionary = auto_encoder.decoder.weight.data.t()
    with torch.no_grad():
        mmcs = mean_max_cosine_similarity(
            ground_truth_features.to(auto_encoder.device), learned_dictionary
        )
    return mmcs

def get_alive_neurons(auto_encoder, data_generator, n_batches=10):
    """
    :param result_dict: dictionary containing the results of a single run
    :return: number of dead neurons

    Estimates the number of dead neurons in the network by running a few batches of data through the network and
    calculating the mean activation of each neuron. If the mean activation is 0 for a neuron, it is considered dead.
    """
    outputs = []
    for i in range(n_batches):
        ground_truth, batch = next(data_generator)
        x_hat, c = auto_encoder(
            batch
        )  # x_hat: (batch_size, activation_dim), c: (batch_size, n_dict_components)
        outputs.append(c)
    outputs = torch.cat(outputs)  # (n_batches * batch_size, n_dict_components)
    mean_activations = outputs.mean(
        dim=0
    )  # (n_dict_components), c is after the ReLU, no need to take abs
    alive_neurons = mean_activations > 0
    return alive_neurons

def zero_out_except_topk(tensor, topk):
    # Keep only the topk values, set others to zero
    _, indices = torch.topk(tensor, topk, dim=-1, largest=True, sorted=False)
    zero_tensor = torch.zeros_like(tensor)
    zero_tensor.scatter_(-1, indices, tensor.gather(-1, indices))
    return zero_tensor

def smooth_log(x, eps):
    """
    Element-wise apply a smoothed log function to a torch tensor x such that:
    smooth_log(x, eps) = 0 if x <= 0
    smooth_log(x, eps) = log(x) - log(eps) + 1/2 if x > eps
    smooth_log(x, eps) = 1/2 * x^2/eps^2 if 0 < x < eps
    """
    # Tensor to hold the output values
    y = torch.zeros_like(x)
    
    # Mask for values where x > eps
    mask1 = x > eps
    y[mask1] = torch.log(x[mask1]) - torch.log(eps) + 0.5
    
    # Mask for values where 0 < x <= eps
    mask2 = (x > 0) & (x <= eps)
    y[mask2] = 0.5 * (x[mask2] / eps) ** 2
    
    # Values where x <= 0 remain zero, as initialized
    
    return y

def sparsity_loss_term(c,p, loss_fn):
    if loss_fn == "lp_norm":
        sparsity_term = torch.norm(c, p, dim=1).mean() / c.size(1)
    elif loss_fn == "lp^p":
        sparsity_term = torch.pow(c, p).sum(dim=-1).mean() / c.size(1)
    elif loss_fn == "log":
        sparsity_term = smooth_log(c, p).sum(dim=-1).mean() / c.size(1)
    else:
        sparsity_term = torch.tensor((0,), device=device)
    return sparsity_term
    

def train_model(arg):
    (
        worker_id,
        epochs,
        p_id,
        p,
        lp_id,
        lp_alpha,
        ground_truth_features,
        cfg_dict,
        init_seed,
        device,
        run_num,
    ) = arg
    torch.cuda.set_device(device)  # Set the device for the process

    data_generator = RandomDatasetGenerator(
        activation_dim=cfg_dict["activation_dim"],
        n_ground_truth_components=cfg_dict["n_ground_truth_components"],
        batch_size=cfg_dict["batch_size"],
        feature_num_nonzero=cfg_dict["feature_num_nonzero"],
        feature_prob_decay=cfg_dict["feature_prob_decay"],
        correlated=cfg_dict["correlated_components"],
        device=device,
        feats=(ground_truth_features),
        generated_so_far=init_seed,
    )

    auto_encoder = AutoEncoder(cfg_dict["activation_dim"], cfg_dict["n_components_dictionary"]).to(
        device
    )

    optimizer = optim.Adam(auto_encoder.parameters(), lr=cfg_dict["lr"])

    logs = []
    
    original_p = p
    original_alpha = lp_alpha
    
    # create schedule to anneal p's
    if cfg_dict["anneal"]:
        num_anneals = 10
        # stored_sparsity_norms = []
        anneal_ps = np.linspace(p, 0, num_anneals, endpoint=False)
        # anneal_time
        anneal_times = np.linspace(0,epochs/2, num_anneals, endpoint=False, dtype=int)
        alpha_anneal_times = np.linspace(epochs/2,epochs, num_anneals, endpoint=False, dtype=int)
        p_schedule = dict((t, p) for t,p in zip(anneal_times, anneal_ps))

    for ep in range(epochs):
        ground_truth, batch = next(data_generator)

        optimizer.zero_grad()

        # Forward pass
        topk = p if cfg_dict["use_topk"] else None
        x_hat, c = auto_encoder(batch, topk=topk, norm_output=cfg_dict["norm_output"])
        # also normalize input for accuracte reconstruction loss:
        if cfg_dict["norm_output"]:
            x_hat = x_hat * torch.norm(batch, p=2, dim=1, keepdim=True)
            c = x_hat * torch.norm(batch, p=2, dim=1, keepdim=True)

        # Compute the reconstruction loss and L1 regularization
        l_reconstruction = torch.nn.MSELoss()(batch, x_hat)
        
        sparsity_term = sparsity_loss_term(c,p,cfg_dict["loss_fn"])
        
        l_lp = lp_alpha * sparsity_term
        
        if cfg_dict["anneal"]:
            if ep in p_schedule:
                new_p = p_schedule[ep]
                new_sparsity_term = sparsity_loss_term(c,new_p,cfg_dict["loss_fn"])
                new_alpha = l_lp / new_sparsity_term
                p = new_p
                print(f"updating to p={new_p} and alpha={new_alpha}")
                
                lp_alpha = new_alpha.detach().clone().item()
                # if ep >0:
                #     lp_alpha *= (0.5)**(1/10)
            
            if ep in alpha_anneal_times:
                lp_alpha = float(lp_alpha * (0.1)**(1/10))
            
            

        # Print the losses, mmcs, and current epoch
        # mmcs = float(calculate_mmcs(auto_encoder, ground_truth_features).cpu().item())
        # print("mmcs: ", type(float(mmcs)))
        sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
        num_dead_features = (c == 0).float().mean(dim=0).sum().cpu().item()
        
        mmcs = calculate_mmcs(auto_encoder, ground_truth_features)

        log_prefix = f"{lp_alpha} L{p}"
        if cfg_dict["anneal"]:
            log_prefix = ""
        if (ep+1) % 100 == 0:
            wandb_log = {
                # f"{log_prefix} MMCS": mmcs,
                f"{log_prefix} Sparsity": sparsity,
                f"{log_prefix} Dead Features": num_dead_features,
                f"{log_prefix} Reconstruction Loss": l_reconstruction.detach().item(),
                f"{log_prefix} Sparsity Loss": l_lp.detach().item(),
                f"{log_prefix} Sparsity Term": sparsity_term.detach().item(),
                f"{log_prefix} MMCS": mmcs,
                f"{log_prefix} p": p,
                "Tokens": ep * cfg_dict["batch_size"],
            }
            logs.append(wandb_log)

        # Compute the total loss
        loss = l_reconstruction + l_lp

        # Backward pass
        loss.backward()
        optimizer.step()

    # Save model
    save_name = f"sae_l{original_p}_{original_alpha}"
    torch.save(auto_encoder, f"/root/sparsify/trained_models/toy_saes{run_num}/{save_name}.pt")
    return logs


def main():
    cfg = ToyArgs()
    cfg.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    run_num = cfg.seed

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    ground_truth_features = generate_rand_feats(
        cfg.activation_dim,
        cfg.n_ground_truth_components,
        device=device,
    )

    mp.set_start_method("spawn")

    if not os.path.exists(f"/root/sparsify/trained_models/toy_saes{run_num}"):
        os.makedirs(f"/root/sparsify/trained_models/toy_saes{run_num}")

    args_list = []
    for p_id, p in enumerate(cfg.ps()):
        for lp_id, lp_alpha in enumerate(cfg.lp_alphas):
            # Pass necessary arguments to the worker function
            args = (
                p_id * len(cfg.lp_alphas) + lp_id,
                cfg.epochs,
                p_id,
                p,
                lp_id,
                lp_alpha,
                ground_truth_features,
                # output_queue,
                cfg.__dict__,
                run_num,
                device,
                run_num,
            )
            args_list.append(args)

    combined_logs = []

    # for arg in args_list:
    #     ret_logs = train_model(arg)

    if len(args_list) > 1:
        with mp.Pool(
            processes=10
        ) as pool:  # Adjust the number of processes based on your system's capabilities
            max_ = len(args_list)
            with tqdm(total=max_) as pbar:
                first_batch = True
                for ret_logs in pool.imap_unordered(train_model, args_list):
                    pbar.update()

                    if first_batch:
                        # For the first batch, initialize combined_logs with empty dictionaries
                        combined_logs = [{} for _ in ret_logs]
                        first_batch = False

                    # Update each dictionary in combined_logs with the corresponding ret_log
                    for i, log in enumerate(ret_logs):
                        combined_logs[i].update(log)
                        pass
    else:
        ret_logs = train_model(*args_list)
        combined_logs = ret_logs

    torch.save(
        ground_truth_features,
        f"/root/sparsify/trained_models/toy_saes{run_num}/ground_truth_features.pt",
    )
        
    secrets = json.load(open("secrets.json"))
    wandb.login(key=secrets["wandb_key"])
    start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    wandb_run_name = f"Toylp_{start_time[4:]}_{cfg.batch_size}_{cfg.epochs}_{cfg.seed}"  # trim year
    print(f"wandb_run_name: {wandb_run_name}")
    wandb.init(project="lp saes", name=wandb_run_name)

    for log_step in range(len(combined_logs)):
        wandb.log(combined_logs[log_step])

if __name__ == "__main__":
    main()
