In [None]:
import torch
import wandb
import tqdm
from info_theory_experiments.models import (SupervenientFeatureNetwork,
                    CLUB,
                    DecoupledSmileMIEstimator,
                    DownwardSmileMIEstimator,
                    SkipConnectionSupervenientFeatureNetwork
                    )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from info_theory_experiments.custom_datasets import ECoGDataset
torch.cuda.empty_cache()

print(device)

torch.manual_seed(0)

# In this experimetn we train an emergent feture network on a dataset of ECoG data

some code like `train_feature_network` should be replaced by the more general trainer found in `trainers.py` for simplicity, which will be done during a house-keeping codebase update

In [2]:
def train_feature_network(config, trainloader, feature_network):

    wandb.init(project="ecog-dataset-neurips", config=config)
    # init weights to zero of the feature network

    decoupled_MI_estimator = DecoupledSmileMIEstimator(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes_1=config['decoupled_critic_hidden_sizes_1'],
        hidden_sizes_2=config['decoupled_critic_hidden_sizes_2'],
        clip=config['clip'],
        include_bias=config['bias']
        ).to(device)
    downward_MI_estimators = [
        DownwardSmileMIEstimator(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            clip=config['clip'],
            include_bias=config['bias']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]
    

    feature_optimizer = torch.optim.Adam(
        feature_network.parameters(),
        lr=config["feature_lr"],
        weight_decay=config["weight_decay"]
    )
    decoupled_optimizer = torch.optim.Adam(
        decoupled_MI_estimator.parameters(),
        lr=config["decoupled_critic_lr"],
        weight_decay=config["weight_decay"]
    )
    downward_optims = [
        torch.optim.Adam(
            dc.parameters(),
            lr=config["downward_lr"],
            weight_decay=config["weight_decay"]
        ) 
        for dc in downward_MI_estimators
    ]

    wandb.watch(feature_network, log='all')
    wandb.watch(decoupled_MI_estimator, log="all")
    for dc in downward_MI_estimators:
        wandb.watch(dc, log='all')

    ##
    ## TRAIN FEATURE NETWORK
    ##

    epochs = config['epochs']

    step = 0

    for _ in tqdm.tqdm(range(epochs), desc='Training'):
        for batch_num, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()
            v0 = feature_network(x0).detach()
            v1 = feature_network(x1).detach()

            # update decoupled critic
            decoupled_optimizer.zero_grad()
            decoupled_MI = decoupled_MI_estimator(v0, v1)
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()

            # update each downward critic 
            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1).detach()
                downward_MI_i = downward_MI_estimators[i](v1, channel_i)
                downward_loss = - downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                }, step=step)

            # update feature network   
            feature_optimizer.zero_grad()
            channel_MIs = []

            MIs = []
            v0 = feature_network(x0)
            v1 = feature_network(x1)

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                channel_i_MI = downward_MI_estimators[i](v1, channel_i)
                channel_MIs.append(channel_i_MI)
                MIs.append(channel_i_MI)

            sum_downward_MI = sum(channel_MIs)

            decoupled_MI1 = decoupled_MI_estimator(v0, v1)

            clipped_min_MIs = max(0, min(MIs))

            Psi = decoupled_MI1 - sum_downward_MI + (config['num_atoms'] - 1) * clipped_min_MIs

            # NOTE an experiment
            feature_loss = sum_downward_MI 


            if config['start_updating_f_after'] < step:
                if batch_num % config['update_f_every_N_steps'] == 0:
                    feature_loss.backward(retain_graph=True)
                    feature_optimizer.step()

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            }, step=step)


            step += 1
        
    torch.save(feature_network.state_dict(), f"models/ecog_feature_network_{wandb.run.name}.pth")
    
    return feature_network



In [3]:
config = {
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 3,
    "clip": 5,
    "epochs": 50,
    "critic_output_size": 32,
    "downward_hidden_sizes_v_critic": [512, 512, 512, 512],
    "downward_hidden_sizes_xi_critic": [512, 512, 512],
    "feature_hidden_sizes": [256, 256, 256, 256, 256],
    "decoupled_critic_hidden_sizes_1": [512, 512, 512],
    "decoupled_critic_hidden_sizes_2": [512, 512, 512],
    "feature_lr": 1e-4,
    "decoupled_critic_lr": 1e-3,
    "downward_lr": 1e-3,    
    "bias": True,
    "update_f_every_N_steps": 5,
    "weight_decay": 0,
    "start_updating_f_after": 300,
    "add_spec_norm_downward": False,
    "add_spec_norm_decoupled": False
}


In [4]:
dataset = ECoGDataset()

train_loader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

In [None]:

feature_newtork = SkipConnectionSupervenientFeatureNetwork(
    num_atoms=config['num_atoms'],
    feature_size=config['feature_size'],
    hidden_sizes=config['feature_hidden_sizes'],
    include_bias=config['bias']
).to(device)

feature_network = train_feature_network(config, train_loader, feature_newtork)

# Estimating Psi given a frozen feature network

In [8]:

def find_true_Psi(feature_network, feature_config, run_id=None):

    config = {
        "batch_size": 600,
        "num_atoms": 64,
        "feature_size": feature_config['feature_size'],
        "clip": 5,
        "critic_output_size": 16,
        "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
        "downward_hidden_sizes_xi_critic": [512, 512, 512, 64],
        "decoupled_critic_hidden_sizes_1": [1028, 1028, 512],
        "decoupled_critic_hidden_sizes_2": [1028, 1028, 512],
        "decoupled_critic_lr": 1e-4,
        "downward_lr": 1e-4,
        "bias": True,
        "weight_decay": 1e-6,
        "original_run_id": run_id,
        "add_spec_norm_downward": False,
        "add_spec_norm_decoupled": False
    }

    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="Finding-true-Psi-for-f", config=config, id=run_id)

    decoupled_critic = DecoupledSmileMIEstimator(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes_1=config['decoupled_critic_hidden_sizes_1'],
        hidden_sizes_2=config['decoupled_critic_hidden_sizes_2'],
        clip=config['clip'],
        include_bias=config['bias'],
        add_spec_norm=config['add_spec_norm_decoupled']
        ).to(device)

    downward_critics = [
        DownwardSmileMIEstimator(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            clip=config['clip'],
            include_bias=config['bias'],
            add_spec_norm=config['add_spec_norm_downward']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]

    downward_optims = [
        torch.optim.Adam(
            dc.parameters(),
            lr=config["downward_lr"],
            weight_decay=config["weight_decay"]
        ) 
        for dc in downward_critics
    ]

    decoupled_optimizer = torch.optim.Adam(
        decoupled_critic.parameters(),
        lr=config["decoupled_critic_lr"],
        weight_decay=config["weight_decay"]
    )

    wandb.watch(decoupled_critic, log="all")
    for dc in downward_critics:
        wandb.watch(dc, log='all')

    epochs = 5

    for _ in tqdm.tqdm(range(epochs), desc='Training'):
        for _, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()

            # update decoupled critic

            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            decoupled_optimizer.zero_grad()
            decoupled_MI = decoupled_critic(v0, v1)
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()


            # update each downward critic 

            MIs = []

            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1)
                downward_MI_i = downward_critics[i](v1, channel_i)
                # add spectral norm to the loss
                downward_loss = - downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                })
                MIs.append(downward_MI_i)

            # update feature network   

            min_MI = min(MIs)
            clipped_min_MIs = max(0, min_MI)

            sum_downward_MI = 0

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                sum_downward_MI += downward_critics[i](v1, channel_i)

            decoupled_MI1 = decoupled_critic(v0, v1)

            Psi = decoupled_MI1 - sum_downward_MI + (config['num_atoms'] - 1) * clipped_min_MIs

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })
        
    return Psi



In [None]:

Psi = find_true_Psi(feature_network, feature_config=config)