In [1]:
import os
import sys


def setup_project_root(start_path='.'):
    """Find the project root, set it as the current working directory, and add it to sys.path."""
    current_path = os.path.abspath(start_path)
    while True:
        if '.git' in os.listdir(current_path):
            project_root = current_path
            break
        parent_path = os.path.dirname(current_path)
        if parent_path == current_path:  # We've reached the root directory
            raise Exception("Could not find project root (.git directory not found)")
        current_path = parent_path
    
    # Change the current working directory to the project root
    os.chdir(project_root)
    print(f"Current working directory set to: {os.getcwd()}")

    # Add project root to sys.path if it's not already there
    if project_root not in sys.path:
        sys.path.insert(0, project_root)
        print(f"Added {project_root} to sys.path")

# sets the current working directory to the project root
setup_project_root()

# Don't cache imports
%load_ext autoreload
%autoreload 2


from info_theory_experiments.custom_datasets import ECoGDataset
train_dataset = ECoGDataset(prepare_pairs=True)
import torch

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)



Current working directory set to: /Users/davidmcsharry/dev/causally-emergent-representations


# In this experiment learn emergent features using MSE loss 

In [None]:
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from info_theory_experiments.models import GeneralSmileMIEstimator

class MLPEncoder(nn.Module):
    def __init__(self, input_dim=64, latent_dim=1):
        super(MLPEncoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

    def forward(self, x):
        return self.encoder(x)

class MLPDecoder(nn.Module):
    def __init__(self, latent_dim=1, output_dim=64):
        super(MLPDecoder, self).__init__()
        self.latent_dim = latent_dim
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )

    def forward(self, x):
        return self.decoder(x)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the models, loss function, and optimizers
encoder = MLPEncoder().to(device)
decoder = MLPDecoder().to(device)
mse_loss = nn.MSELoss()
optimizer_encoder = optim.Adam(encoder.parameters(), lr=1e-4)
optimizer_decoder = optim.Adam(decoder.parameters(), lr=1e-3)

decoupled_estimator = GeneralSmileMIEstimator(
    x_dim=1,
    y_dim=1,
    critic_output_size=32,
    x_critics_hidden_sizes=[256, 256],
    y_critics_hidden_sizes=[256, 256],
    clip=5,
    include_bias=True,
).to(device)

decoupled_optimizer = optim.Adam(decoupled_estimator.parameters(), lr=1e-3)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    print(f'Starting epoch {epoch+1}/{num_epochs}')
    total_decoder_loss = 0
    total_encoder_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data0 = data[:, 0].to(device).float()
        data1 = data[:, 1].to(device).float()

        # Forward pass
        data0_pred = decoder(encoder(data0))

        # Decoder that predicts x0 loss (minimize MSE)
        decoder_loss = mse_loss(data0_pred, data0)

        # Update decoder
        if batch_idx % 1 == 0:
            optimizer_decoder.zero_grad()
            decoder_loss.backward()
            optimizer_decoder.step()

        # update decoupled estimator
        decoupled_optimizer.zero_grad()
        v0 = encoder(data0)
        v1 = encoder(data1)
        mi = decoupled_estimator(v0, v1)
        decoupled_loss = -mi # maximize MI
        decoupled_loss.backward()
        decoupled_optimizer.step()

        # Recompute the forward pass for encoder loss
        if batch_idx % 1 == 0:
            optimizer_encoder.zero_grad()
            data0_pred = decoder(encoder(data0))
            v0 = encoder(data0)
            v1 = encoder(data1)
            mi = decoupled_estimator(v0, v1)
            encoder_loss = - 0.4 * mi - mse_loss(data0_pred, data0)
            # Update encoder
            optimizer_encoder.zero_grad()
            encoder_loss.backward()
            optimizer_encoder.step()
        
        total_decoder_loss += decoder_loss.item()
        total_encoder_loss += encoder_loss.item()
    
    # Print epoch statistics
    avg_decoder_loss = total_decoder_loss / len(train_loader)
    avg_encoder_loss = total_encoder_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Decoder Loss: {avg_decoder_loss:.4f}, Encoder Loss: {avg_encoder_loss:.4f}, MI: {mi.item():.4f}')

In [4]:
# Train decoder with MSE loss using the trained encoder

# Freeze encoder parameters
for param in encoder.parameters():
    param.requires_grad = False

# Initialize decoder (using the MLPDecoder class from the notebook)
decoder = MLPDecoder(latent_dim=encoder.latent_dim, output_dim=116).to(device)

# Define optimizer for decoder
optimizer_decoder = optim.Adam(decoder.parameters(), lr=1e-3)

# Define loss function
mse_loss = nn.MSELoss()

num_epochs = 2  # Adjust as needed
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data0 = data[:, 0].to(device)
        data1 = data[:, 1].to(device)

        # Forward pass
        with torch.no_grad():
            encoded = encoder(data0)
        decoded = decoder(encoded)

        # Compute loss
        loss = mse_loss(decoded, data1)

        # Backward pass and optimize
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_decoder.step()

        total_loss += loss.item()

    # Print epoch statistics
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

# Unfreeze encoder parameters (if needed for future use)
for param in encoder.parameters():
    param.requires_grad = True

print("Decoder training completed.")

# Optionally, you can evaluate the model here
# For example:
encoder.eval()
decoder.eval()
with torch.no_grad():
    test_data = next(iter(train_loader))
    test_input = test_data[:, 0].to(device)
    test_target = test_data[:, 1].to(device)
    encoded = encoder(test_input)
    reconstructed = decoder(encoded)
    test_loss = mse_loss(reconstructed, test_target)
    print(f"Test Loss: {test_loss.item():.4f}")



Epoch [1/2], Average Loss: 0.9898
Epoch [2/2], Average Loss: 0.9895
Decoder training completed.
Test Loss: 0.9807


In [8]:
# save the encoder in models using time
import time
torch.save(encoder.state_dict(), f'models/ecog-min-mse-{time.time()}.pt')

# In the below cells we train multiple representation networks using the MSE objective

In [5]:
import torch.nn as nn
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MLPEncoder(nn.Module):
    def __init__(self, input_dim=64, latent_dim=1):
        super(MLPEncoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

    def forward(self, x):
        return self.encoder(x)



  encoder.load_state_dict(torch.load('models/ecog-min-mse-_with_x0-max-mi-verification.pt'))


MLPEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=64, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1024, out_features=2048, bias=True)
    (5): ReLU()
    (6): Linear(in_features=2048, out_features=1024, bias=True)
    (7): ReLU()
    (8): Linear(in_features=1024, out_features=512, bias=True)
    (9): ReLU()
    (10): Linear(in_features=512, out_features=256, bias=True)
    (11): ReLU()
    (12): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [None]:
from info_theory_experiments.trainers import train_feature_network
seed = 3
torch.manual_seed(seed)
config = {
    "torch_seed": seed,
    "dataset_type": "ecog",
    "num_atoms": 64,
    "batch_size": 1000,
    "train_mode": False,
    "train_model_B": False,
    "adjust_Psi": True,
    "clip": 5,
    "feature_size": 1,
    "epochs": 2,
    "downward_critics_config": {
        "hidden_sizes_v_critic": [512, 1024, 1024, 512],
        "hidden_sizes_xi_critic": [512, 512, 512],
        "critic_output_size": 32,
        "lr": 1e-3,
        "bias": True,
        "weight_decay": 0,
    },
    
    "decoupled_critic_config": {
        "hidden_sizes_encoder_1": [512, 512, 512],
        "hidden_sizes_encoder_2": [512, 512, 512],
        "critic_output_size": 32,
        "lr": 1e-3,
        "bias": True,
        "weight_decay": 0,
    },
}

project_name = "ecog-min-mse-_with_x0-max-mi-verification"

_, _ = train_feature_network(
    config=config,
    trainloader=train_loader,
    feature_network_training=encoder,
    project_name=project_name,
)