In [12]:
import torch
from data.generating.generate_data import generate_toy_data
train_dataset, test_dataset = generate_toy_data(n_samples=1000, seed=0)

In [13]:
import torch.nn as nn

# Import models
from model_zoo.mlp1 import mlp1
from model_zoo.mlp2 import mlp2

# Import the generic trainer
from training.train_model_generic import train_model_generic

model1 = mlp1(input_dim=2, latent_dim=8, output_dim=2)

# Train with CrossEntropyLoss (classification)
criterion_ce = nn.CrossEntropyLoss()

print("Training ModelOne:")
model1_trained = train_model_generic(
    model=model1,
    train_dataset=train_dataset,
    criterion=criterion_ce,
    epochs=5,
    lr=1e-3,
    batch_size=32,
    shuffle=True,
    device="cpu"
)

# Now instantiate ModelTwo
model2 = mlp2(input_dim=2, latent_dim=8, output_dim=2)

print("Training ModelTwo:")
model2_trained = train_model_generic(
    model=model2,
    train_dataset=train_dataset,
    criterion=criterion_ce,
    epochs=5,
    lr=1e-3,
    batch_size=32,
    shuffle=True,
    device="cpu"
)


Training ModelOne:
Epoch [1/5], Loss: 0.6836
Epoch [2/5], Loss: 0.6590
Epoch [3/5], Loss: 0.6346
Epoch [4/5], Loss: 0.6089
Epoch [5/5], Loss: 0.5804
Training ModelTwo:
Epoch [1/5], Loss: 0.7052
Epoch [2/5], Loss: 0.6914
Epoch [3/5], Loss: 0.6788
Epoch [4/5], Loss: 0.6665
Epoch [5/5], Loss: 0.6542


In [14]:
# Generate latents for crosscoder training
from experiments.mlp_architecture_diff.runs.generate_latents import generate_latents
latents1, _ = generate_latents(model1, train_dataset, batch_size=32)
latents2, _ = generate_latents(model2, train_dataset, batch_size=32)

# Train crosscoder
from training.train_crosscoder import train_crosscoder
crosscoder = train_crosscoder(
    latents1, 
    latents2, 
    input_dim=8, 
    output_dim=8, 
    epochs=10, 
    lr=1e-3, 
    batch_size=32
)

In [15]:
# Analyze results on test dataset
from experiments.mlp_architecture_diff.analysis.analyze_results import analyze_results
metrics = analyze_results(crosscoder, model1, model2, test_dataset)

print("==== Crosscoder Performance on Test Set ====")
print("MSE: ", metrics["mse"])
print("Average correlation: ", metrics["average_correlation"])
print("Dimension-wise correlations:", metrics["dim_correlations"])

==== Crosscoder Performance on Test Set ====
MSE:  0.03407642990350723
Average correlation:  0.4793343559431378
Dimension-wise correlations: [0.48023390769958496, 0.6672911047935486, 0.6115498542785645, 0.5287367701530457, 0.6353644728660583, 0.05385603755712509, -0.002450237749144435, 0.8600929379463196]
