In [1]:
import wandb
from pathlib import Path
import yaml
import s3fs

from omegaconf import OmegaConf

import torch

from di_automata.config_setup import *
from di_automata.constructors import (
    construct_model, 
    create_dataloader_hf,
)
from di_automata.io import read_tensors_from_file, append_tensor_to_file

# AWS
s3 = s3fs.S3FileSystem()





In [2]:

config_file_path = "configs/slt_config.yaml"
slt_config = OmegaConf.load(config_file_path)

with open(f"configs/task_config/{slt_config.dataset_type}.yaml", 'r') as file:
    task_config = yaml.safe_load(file)
    
OmegaConf.set_struct(slt_config, False) # Allow new configuration values to be added
slt_config["task_config"] = task_config
# Convert OmegaConf object to MainConfig Pydantic model for dynamic type validation - NECESSARY DO NOT SKIP
pydantic_config = PostRunSLTConfig(**slt_config)
# Convert back to OmegaConf object for compatibility with existing code
slt_config = OmegaConf.create(pydantic_config.model_dump())

print(task_config["dataset_type"])

dihedral


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Run path and name for easy referral later
run_path = f"{slt_config.entity_name}/{slt_config.wandb_project_name}-alpha"
run_name = slt_config.run_name

# Get run information
api = wandb.Api()
run_list = api.runs(
    path=run_path, 
    filters={
        "display_name": run_name,
        "state": "finished",
        },
    order="created_at", # Default descending order so backwards in time
)
run_api = run_list[slt_config.run_idx]
try: history = run_api.history()
except: history = run_api.history
loss_history = history["Train Loss"]
accuracy_history = history["Train Acc"]
steps = history["_step"]

In [7]:
# Get list of artifacts
artifacts = run_api.logged_artifacts()
artifact_list = []
for artifact in run_api.logged_artifacts():
    artifact_list.append(artifact)

In [4]:
def get_config() -> MainConfig:
    """"
    Manually get config from run as artifact. 
    WandB also logs automatically for each run, but it doesn't log enums correctly.
    """
    artifact = artifact = api.artifact(f"{run_path}/states:dihedral_config_state")
    # artifact = api.artifact(f"{run_path}/config:{run_name}")
    # artifact = api.artifact(f"{run_path}/states:idx{idx}_{run_name}")
    data_dir = artifact.download()
    config_path = Path(data_dir) / "config.yaml"
    return OmegaConf.load(config_path)

In [5]:
# config: MainConfig = OmegaConf.create(run_api.config)
config = get_config()
config["model_save_method"] = "wandb" # Duct tape

# Set total number of unique samples seen (n). If this is not done it will break LLC estimator.
slt_config.rlct_config.sgld_kwargs.num_samples = slt_config.rlct_config.num_samples = config.rlct_config.sgld_kwargs.num_samples
slt_config.nano_gpt_config = config.nano_gpt_config

model, param_inf_properties = construct_model(config)

# Optional: currently don't use as local logits take up a lot of storage
logits_path = "logits.bin" # Binary file

[34m[1mwandb[0m: Downloading large artifact states:dihedral_config_state, 180.95MB. 2 files... 
[34m[1mwandb[0m:   2 of 2 files downloaded.  
Done. 0:0:0.7


In [6]:
def restore_state(checkpoint_idx: int) -> dict:
    """Restore one model state from a checkpoint. Called only once for any given checkpoint.
    Intention of this function is to be used to load individual points of interest after plotting essential dynamics osculating circles.
    
    Params:
        checkpoint_idx: Index in steps.
        
    Returns:
        model state dictionary.
    """
    match config.model_save_method:
        case "wandb":
            artifact = artifact_list[checkpoint_idx]
            data_dir = artifact.download()
            model_state_path = Path(data_dir) / "states.torch"
            states = torch.load(model_state_path)
        case "aws":
            with s3.open(f"{config.aws_bucket}/{config.run_name}_{config.time}/{checkpoint_idx}") as f:
                states = torch.load(f)
    return states["model"]

In [8]:
cusp_idx = 621
cp_idx = cusp_idx // config.rlct_config.ed_config.eval_frequency
print(cp_idx)

62


In [9]:
state = restore_state(cp_idx)
model.load_state_dict(state)

[34m[1mwandb[0m: Downloading large artifact states:v1956, 180.95MB. 2 files... 
[34m[1mwandb[0m:   2 of 2 files downloaded.  
Done. 0:0:0.7


<All keys matched successfully>

In [10]:
print(model)

Transformer(
  (token_embedding): Embedding(2, 512)
  (pos_embedding): Embedding(26, 512)
  (dropout): Dropout(p=0.1, inplace=False)
  (h): Sequential(
    (block0): TransformerBlock(
      (ln_1): LayerNorm()
      (attn): SelfAttention(
        (in_projection): Linear(in_features=512, out_features=1536, bias=True)
        (out_projection): Linear(in_features=512, out_features=512, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm()
      (mlp): Sequential(
        (c_fc): Linear(in_features=512, out_features=2048, bias=True)
        (gelu): Lambda()
        (c_proj): Linear(in_features=2048, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (block1): TransformerBlock(
      (ln_1): LayerNorm()
      (attn): SelfAttention(
        (in_projection): Linear(in_features=512, out_features=1536, bias=True)
        (out_projection): Linear(in_fea

In [13]:
# Pain I could rewrite the attention layer so it returns things, or just move to TFLens

<bound method Module.parameters of Linear(in_features=512, out_features=1536, bias=True)>
