In [None]:
import wandb
from typing import TypeVar
from pathlib import Path
import shutil
import time
import pickle
import yaml
import os
import s3fs
from queue import Queue

from einops import rearrange
from sklearn.decomposition import PCA
import numpy as np
from omegaconf import OmegaConf

import torch

from di_automata.devinterp.rlct_utils import (
    plot_pca_plotly,
    plot_explained_var,
)
from di_automata.config_setup import *
from di_automata.constructors import (
    construct_model, 
    create_dataloader_hf,
    construct_rlct_criterion,
)
from di_automata.tasks.data_utils import take_n
from di_automata.io import read_tensors_from_file, append_tensor_to_file
from di_automata.devinterp.ed_utils import EssentialDynamicsPlotter
Sweep = TypeVar("Sweep")

# AWS
s3 = s3fs.S3FileSystem()

In [None]:
## General setup: read SLT config 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

config_file_path = "configs/slt_config.yaml"
slt_config = OmegaConf.load(config_file_path)

with open(f"configs/task_config/{slt_config.dataset_type}.yaml", 'r') as file:
    task_config = yaml.safe_load(file)
    
OmegaConf.set_struct(slt_config, False) # Allow new configuration values to be added
slt_config["task_config"] = task_config
# Convert OmegaConf object to MainConfig Pydantic model for dynamic type validation - NECESSARY DO NOT SKIP
pydantic_config = PostRunSLTConfig(**slt_config)
# Convert back to OmegaConf object for compatibility with existing code
slt_config = OmegaConf.create(pydantic_config.model_dump())

print(task_config["dataset_type"])

In [None]:
# Setup run
# Run path and name for easy referral later
run_path = f"{slt_config.entity_name}/{slt_config.wandb_project_name}-alpha"
run_name = slt_config.run_name

# Get run information
api = wandb.Api()
run_list = api.runs(
    path=run_path, 
    filters={
        "display_name": run_name,
        "state": "finished",
        },
    order="created_at", # Default descending order so backwards in time
)
run_api = run_list[slt_config.run_idx]
try: history = run_api.history()
except: history = run_api.history
loss_history = history["Train Loss"]
accuracy_history = history["Train Acc"]
steps = history["_step"]

In [None]:
# Get queue of artifacts from old run
artifacts = run_api.logged_artifacts()
artifact_queue = Queue()
for artifact in run_api.logged_artifacts():
    artifact_queue.put(artifact)

In [None]:
def get_config() -> MainConfig:
    """"
    Manually get config from run as artifact. 
    WandB also logs automatically for each run, but it doesn't log enums correctly.
    """
    artifact = artifact = api.artifact(f"{run_path}/states:dihedral_config_state")
    # artifact = api.artifact(f"{run_path}/config:{run_name}")
    # artifact = api.artifact(f"{run_path}/states:idx{idx}_{run_name}")
    data_dir = artifact.download()
    config_path = Path(data_dir) / "config.yaml"
    return OmegaConf.load(config_path)

In [None]:
# Get and set up configs
# config: MainConfig = OmegaConf.create(run_api.config)
config = get_config()
config["model_save_method"] = "wandb" # Duct tape

# Set total number of unique samples seen (n). If this is not done it will break LLC estimator.
slt_config.rlct_config.sgld_kwargs.num_samples = slt_config.rlct_config.num_samples = config.rlct_config.sgld_kwargs.num_samples
slt_config.nano_gpt_config = config.nano_gpt_config

In [None]:
def set_logger(config) -> None:
    """Call at initialisation to set loggers to WandB and/or AWS.
    Run naming convention is preprend 'post' to distinguish from training runs.
    """
    config["slt_config"] = slt_config # For saving to WandB
    # Add previous run id to tie runs together
    config["prev_run_path"] = f"{run_path}/{run_api.id}"
    logger_params = {
        "name": f"post_{config.run_name}",
        "project": config.wandb_config.wandb_project_name,
        # "settings": wandb.Settings(start_method="thread"),
        "config": OmegaConf.to_container(config, resolve=True, throw_on_missing=True),
        "mode": "disabled" if not config.is_wandb_enabled else "online",
    }
    run = wandb.init(**logger_params, entity=config.wandb_config.entity_name)
    
    # Location on remote GPU of WandB cache to delete periodically
    wandb_cache_dirs = [Path.home() / ".cache/wandb/artifacts/obj", Path.home() / "root/.cache/wandb/artifacts/obj"]
    
    return run, wandb_cache_dirs

In [None]:
def truncate_ed_logits(ed_logits: list[torch.Tensor]) -> None:
    """Truncate run to clean out overtraining at end for cleaner ED plots.
    Determines the cutoff index for early stopping based on log loss.
    """
    # Manually specify cutoff index
    if slt_config.truncate_its is not None:
        total_its = len(loss_history) * config.eval_frequency
        ed_logit_cutoff_idx = len(ed_logits) * slt_config.truncate_its // total_its
        ed_logits = ed_logits[:ed_logit_cutoff_idx]
        return
    
    # Automatically calculate cutoff index using early-stop patience
    log_loss_values = np.log(loss_history.to_numpy())
    smoothed_log_loss = np.convolve(log_loss_values, np.ones(slt_config.early_stop_smoothing_window)/slt_config.early_stop_smoothing_window, mode='valid')

    increases = 0
    for i in range(1, len(smoothed_log_loss)):
        if smoothed_log_loss[i] > smoothed_log_loss[i-1]:
            increases += 1
            if increases >= config.early_stop_patience:
                # Index where the increase trend starts
                cutoff_idx = (i - slt_config.early_stop_patience + 1) * config.eval_frequency # Cutoff idx in loss step
                ed_logit_cutoff_idx = cutoff_idx * config.rlct_config.ed_config.eval_frequency // config.eval_frequency
                ed_logits = ed_logits[:ed_logit_cutoff_idx]
        else:
            increases = 0
    
    return ed_logits


def ed_calculation(self, ed_logits: Optional[list[torch.Tensor]]) -> np.ndarray:
    """PCA and plot part of ED.
    
    Diplay top 3 components against each other and show fraction variance explained.
    """
    if os.path.exists(self.logits_path):
        ed_logits = read_tensors_from_file(self.logits_path, self.config)
        # Delete logits as these can take up to 30GB of storage
        os.remove("logits.bin")
    
    pca = PCA(n_components=3)
    pca.fit(ed_logits.cpu().numpy())
    
    # Projected coordinates for plotting purposes
    pca_projected_samples = np.empty((len(ed_logits), 3))
    for i, row in enumerate(ed_logits):
        logits_epoch = rearrange(row, 'n -> 1 n').cpu().numpy()
        projected_vector = pca.transform(logits_epoch)[0]
        pca_projected_samples[i] = projected_vector
    explained_variance = pca.explained_variance_ratio_
    
    plot_pca_plotly(pca_projected_samples[:,0], pca_projected_samples[:,1], pca_projected_samples[:,2], self.config)
    plot_explained_var(explained_variance)
    
    wandb.log({
        "ED_PCA_truncated": wandb.Image("PCA.png"),
        "explained_var_truncated": wandb.Image("pca_explained_var.png"),
    })
    
    pca_samples_file_path = f"pca_projected_samples_{self.config.run_name}.pkl"
    with open(pca_samples_file_path, 'wb') as f:
        pickle.dump(pca_projected_samples, f)

    ## Try to avoid saving logits to WandB because they can be REAL CHONK (50-100GB PER RUN)
    # self._save_logits() 
    
    return pca_projected_samples


def get_ed_logits_from_checkpoints() -> list[torch.Tensor]:
    """For each checkpoint, do forward pass to obtain logits and save.
    Note checkpoint weights were saved with EMA if config.use_ema is True.
    """
    ed_logits = []
    
    for checkpoint_idx in range(1, config.num_training_iter // config.rlct_config.ed_config.eval_frequency):
        # idx = checkpoint_idx * config.rlct_config.ed_config.eval_frequency
        state_dict = restore_states()
        model.load_state_dict(state_dict)
        
        logits_epoch = []
        with torch.no_grad():
            for data in take_n(ed_loader, config.rlct_config.ed_config.batches_per_checkpoint):
                inputs = data["input_ids"].to(device)
                logits = model(inputs)
                # Flatten over batch, class and sequence dimension
                logits_epoch.append(rearrange(logits, 'b c s -> (b c s)'))
        
        # Concat all per-batch logits over batch dimension to form one super-batch
        logits_epoch = torch.cat(logits_epoch)
        
        # # Append to binary file
        # append_tensor_to_file(logits_epoch, logits_path)
        
        ed_logits.append(logits_epoch)
    
    return ed_logits

In [None]:
def restore_states() -> dict:
    """Called for every checkpoint in the run in strict order.
    This produces a queue of artifact states, which are popped and read in turn to 
    generate the essential dynamics PCA matrix.
    
    TODO: edit reading from queue for AWS.
    
    Params:
        idx: Index in steps.
        
    Returns:
        model state dictionary.
    """
    match config.model_save_method:
        case "wandb":
            artifact = artifact_queue.get()
            data_dir = artifact.download()
            model_state_path = Path(data_dir) / "states.torch"
            states = torch.load(model_state_path)
        # case "aws":
        #     with s3.open(f"{config.aws_bucket}/{config.run_name}_{config.time}/{idx}") as f:
        #         states = torch.load(f)
    return states["model"]

In [None]:
"""Main executable code for essential dynamics osculating circle calculation."""
# config: MainConfig = OmegaConf.create(run_api.config)
config = get_config()
config["model_save_method"] = "wandb" # Duct tape
slt_config.nano_gpt_config = config.nano_gpt_config

# Log old config and SLT config to new run for post-analysis information
# START WANDB LOGGING
run, wandb_cache_dirs = set_logger(config)

ed_loader = create_dataloader_hf(config, deterministic=True) # Make sure deterministic to see same data

model, param_inf_properties = construct_model(config)

current_directory = Path().absolute()
logits_file_path = current_directory.parent / f"di_automata/logits_{run_name}_{time}"
print(logits_file_path)

if os.path.exists(logits_file_path):
    print(f"Loading existing logits from {logits_file_path}")
    ed_logits = torch.load(logits_file_path)
    print("Done loading existing logits")
else:
    ed_logits: list[torch.Tensor] = get_ed_logits_from_checkpoints()
    
ed_logits: list[torch.Tensor] = truncate_ed_logits(ed_logits)
ed_projected_samples = ed_calculation(ed_logits)

# Create and call instance of essential dynamics osculating circle plotter
ed_plotter = EssentialDynamicsPlotter(ed_projected_samples, steps, slt_config.ed_plot_config, run_name)
wandb.log({"ed_osculating": wandb.Image("ed_osculating_circles.png")})

wandb.finish()

In [None]:
# Cleanup
upload_cache_dir = Path.home() / "root/.local/share/wandb/artifacts/staging" 
if upload_cache_dir.is_dir():
    shutil.rmtree(upload_cache_dir)
    
time.sleep(60)
shutil.rmtree("wandb")