In [None]:
%cd ../..
%ls

In [None]:
!conda install -c conda-forge libgfortran -y
!conda install -c conda-forge gfortran_osx-64 -y
!conda install -c conda-forge gfortran_impl_osx-64 -y

In [None]:
import torch
from data.generating.generate_data import generate_toy_data
train_dataset, test_dataset = generate_toy_data(n_samples=1000, seed=0)

In [None]:
import torch.nn as nn

# Import models
from model_zoo.mlp1 import mlp1
from model_zoo.mlp2 import mlp2

# Import the generic trainer
from training.train_model_generic import train_model_generic

model1 = mlp1(input_dim=2, latent_dim=8, output_dim=2)

# Train with CrossEntropyLoss (classification)
criterion_ce = nn.CrossEntropyLoss()

print("Training ModelOne:")
model1_trained = train_model_generic(
    model=model1,
    train_dataset=train_dataset,
    criterion=criterion_ce,
    epochs=5,
    lr=1e-3,
    batch_size=32,
    shuffle=True,
    device="cpu"
)

# Now instantiate ModelTwo
model2 = mlp2(input_dim=2, latent_dim=8, output_dim=2)

print("Training ModelTwo:")
model2_trained = train_model_generic(
    model=model2,
    train_dataset=train_dataset,
    criterion=criterion_ce,
    epochs=5,
    lr=1e-3,
    batch_size=32,
    shuffle=True,
    device="cpu"
)


In [None]:
# Generate latents for crosscoder training
from experiments.experiment1.generate_latents import generate_latents
latents1, _ = generate_latents(model1, train_dataset, batch_size=32)
latents2, _ = generate_latents(model2, train_dataset, batch_size=32)

# Train crosscoder
from experiments.experiment1.train_crosscoder import train_crosscoder
crosscoder = train_crosscoder(
    latents1, 
    latents2, 
    input_dim=8, 
    output_dim=8, 
    epochs=10, 
    lr=1e-3, 
    batch_size=32
)

# Analyze results on test dataset
from experiments.experiment1.analyze_results import analyze_results
metrics = analyze_results(crosscoder, model1, model2, test_dataset)

print("==== Crosscoder Performance on Test Set ====")
print("MSE: ", metrics["mse"])
print("Average correlation: ", metrics["average_correlation"])
print("Dimension-wise correlations:", metrics["dim_correlations"])
