In [1]:
import torch
import lovely_tensors as lt
import wandb
import tqdm
from models import SupervenientFeatureNetwork, CLUB, DecoupledSmileMIEstimator, DownwardSmileMIEstimator
lt.monkey_patch()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utils import ECoGDataset
torch.cuda.empty_cache()

print(device)

torch.manual_seed(0)

cpu


<torch._C.Generator at 0x125f84630>

In [2]:
dataset = ECoGDataset()

In [3]:

# config =  {'batch_size': 1000, 'feature_size': 8, 'clip': 1, 'critic_output_size': 20, 'downward_hidden_sizes_v_critic': [192,192,192,192], 'downward_hidden_sizes_xi_critic': [164,164,164,164], 'feature_hidden_sizes': [576,576,576,576], 'decoupled_critic_hidden_sizes': [113,113,113,113], 'feature_lr': 0.0008501232325785077, 'decoupled_critic_lr': 0.00016297738002821739, 'downward_lr': 7.91160518491613e-05, 'update_f_every_N_steps': 1, 'weight_decay': 1.4496890117592044e-05, 'bias': True, 'num_atoms': 64}

# Train a feature network for a given config

In [4]:
def train_feature_network(config):

    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="getting-figures", config=config)

    feature_network = SupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)
    decoupled_MI_estimator = DecoupledSmileMIEstimator(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes=config['decoupled_critic_hidden_sizes'],
        clip=config['clip'],
        include_bias=config['bias'],
        add_spec_norm=config['add_spec_norm_decoupled']
        ).to(device)
    downward_MI_estimators = [
        DownwardSmileMIEstimator(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            clip=config['clip'],
            include_bias=config['bias'],
            add_spec_norm=config['add_spec_norm_downward']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]



    feature_optimizer = torch.optim.Adam(
        feature_network.parameters(),
        lr=config["feature_lr"],
        weight_decay=config["weight_decay"]
    )
    decoupled_optimizer = torch.optim.Adam(
        decoupled_MI_estimator.parameters(),
        lr=config["decoupled_critic_lr"],
        weight_decay=config["weight_decay"]
    )
    downward_optims = [
        torch.optim.Adam(
            dc.parameters(),
            lr=config["downward_lr"],
            weight_decay=config["weight_decay"]
        ) 
        for dc in downward_MI_estimators
    ]


    # TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
    # TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
    # Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
    # ^ this does not require retain_graph=True, so maybe this can be optomized somehow
    wandb.watch(feature_network, log='all')
    wandb.watch(decoupled_MI_estimator, log="all")
    for dc in downward_MI_estimators:
        wandb.watch(dc, log='all')

    ##
    ## TRAIN FEATURE NETWORK
    ##

    epochs = 3

    for _ in tqdm.tqdm(range(epochs), desc='Training'):
        for batch_num, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()
            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            # update decoupled critic
            feature_network.eval()
            decoupled_MI_estimator.train()
            downward_MI_estimators = [dc.train() for dc in downward_MI_estimators]

            decoupled_optimizer.zero_grad()
            decoupled_MI = decoupled_MI_estimator(v0, v1)
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()

            # update each downward critic 
            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1)
                downward_MI_i = downward_MI_estimators[i](v1, channel_i)
                downward_loss = - downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                })

            # update feature network   
            feature_network.train()
            decoupled_MI_estimator.eval()
            downward_MI_estimators = [dc.eval() for dc in downward_MI_estimators]

            feature_optimizer.zero_grad()
            sum_downward_MI = 0

            # MIs = []
            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                channel_i_MI = downward_MI_estimators[i](v1, channel_i)
                sum_downward_MI += channel_i_MI
                # MIs.append(channel_i_MI)

            decoupled_MI1 = decoupled_MI_estimator(v0, v1)

            #clipped_min_MIs = max(0, min(MIs))

            Psi = decoupled_MI1 - sum_downward_MI #+ (config['num_atoms'] - 1) * clipped_min_MIs
            feature_loss = -Psi

            if batch_num % config['update_f_every_N_steps'] == 0:
                feature_loss.backward()
                feature_optimizer.step()

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })

    torch.save(feature_network, f"models/ecog/feature_network_{wandb.run.name}.pth")
    
    return feature_network



In [5]:
config = {
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 3,
    "clip": 5,
    "critic_output_size": 32,
    "downward_hidden_sizes_v_critic": [512, 512, 512],
    "downward_hidden_sizes_xi_critic": [512, 512, 512],
    "feature_hidden_sizes": [256, 256, 256],
    "decoupled_critic_hidden_sizes": [512, 512, 512],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,    
    "bias": True,
    "update_f_every_N_steps": 5,
    "weight_decay": 1e-4,
    "add_spec_norm_downward": True,
    "add_spec_norm_decoupled": True
}


In [6]:
feature_network = train_feature_network(config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mdmcsharry[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training:   0%|          | 0/3 [10:04<?, ?it/s]


KeyboardInterrupt: 

# Finding the MI between different channels

I( xt_i ; xt_j)

In [None]:

dataset = torch.load("data/ecog_data_pairs.pth")
trainloader = torch.utils.data.DataLoader(dataset, batch_size=2000, shuffle=False)

wandb.init(project="Interchannel MI")

channel_MI_estimator = DownwardSmileMIEstimator(
    feature_size=1, # replacing the feature with a channel which is dim 1
    critic_output_size=8,
    hidden_sizes_v_critic=[64, 512, 1028, 512],
    hidden_sizes_xi_critic=[64, 512, 1028, 512],
    clip=5,
    include_bias=True
).to(device) 

channel_MI_optim = torch.optim.Adam(channel_MI_estimator.parameters(), lr=1e-5, weight_decay=1e-6)

epochs = 20

channels = (1,2)

for _ in tqdm.tqdm(range(epochs), desc='Training a SMILE estimator for interchannel MI'):
    for batch_num, batch in enumerate(trainloader):
        x0 = batch[:, 0].to(device).float()
        x1 = batch[:, 1].to(device).float()

        channel_i = x0[:, channels[0]].unsqueeze(1)
        channel_j = x0[:, channels[1]].unsqueeze(1)

        MI = channel_MI_estimator(channel_j, channel_i)
        loss = -MI 
        loss.backward()
        channel_MI_optim.step()
        wandb.log({
            "Inter-channel MI": MI
        })


# Estimating Psi given a frozen feature network

In [None]:

def find_true_Psi(feature_network, feature_config, run_id=None):

    config = {
        "batch_size": 600,
        "num_atoms": 64,
        "feature_size": feature_config['feature_size'],
        "clip": 5,
        "critic_output_size": 16,
        "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
        "downward_hidden_sizes_xi_critic": [512, 512, 512, 64],
        "feature_hidden_sizes": [1028, 1028, 256],
        "decoupled_critis_hidden_sizes": [512, 512, 128],
        "decoupled_critic_lr": 1e-4,
        "downward_lr": 1e-4,
        "bias": True,
        "weight_decay": 1e-6,
        "original_run_id": run_id,
        "add_spec_norm_downward": False,
        "add_spec_norm_decoupled": False
    }

    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="Finding-true-Psi-for-f", config=config, id=run_id)

    decoupled_critic = DecoupledSmileMIEstimator(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes=config['decoupled_critis_hidden_sizes'],
        clip=config['clip'],
        include_bias=config['bias'],
        add_spec_norm=config['add_spec_norm_decoupled']
        ).to(device)

    downward_critics = [
        DownwardSmileMIEstimator(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            clip=config['clip'],
            include_bias=config['bias'],
            add_spec_norm=config['add_spec_norm_downward']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]

    downward_optims = [
        torch.optim.Adam(
            dc.parameters(),
            lr=config["downward_lr"],
            weight_decay=config["weight_decay"]
        ) 
        for dc in downward_critics
    ]

    decoupled_optimizer = torch.optim.Adam(
        decoupled_critic.parameters(),
        lr=config["decoupled_critic_lr"],
        weight_decay=config["weight_decay"]
    )

    # TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
    # TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
    # Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
    # ^ this does not require retain_graph=True, so maybe this can be optomized somehow
    wandb.watch(decoupled_critic, log="all")
    for dc in downward_critics:
        wandb.watch(dc, log='all')

    epochs = 5

    for _ in tqdm.tqdm(range(epochs), desc='Training'):
        for _, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()

            # update decoupled critic

            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            decoupled_optimizer.zero_grad()
            decoupled_MI = decoupled_critic(v0, v1)
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()


            # update each downward critic 

            MIs = []

            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1)
                downward_MI_i = downward_critics[i](v1, channel_i)
                # add spectral norm to the loss
                downward_loss = - downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                })
                MIs.append(downward_MI_i)

            # update feature network   

            min_MI = min(MIs)
            clipped_min_MIs = max(0, min_MI)

            sum_downward_MI = 0

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                sum_downward_MI += downward_critics[i](v1, channel_i)

            decoupled_MI1 = decoupled_critic(v0, v1)

            Psi = decoupled_MI1 - sum_downward_MI + (config['num_atoms'] - 1) * clipped_min_MIs

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })
        
    return Psi



In [None]:
# load feature network
feature_network = SupervenientFeatureNetwork(
    num_atoms=config['num_atoms'],
    feature_size=config['feature_size'],
    hidden_sizes=config['feature_hidden_sizes'],
    include_bias=config['bias']
    ).to(device)


model_path = "/Users/davidmcsharry/dev/imperial/info-theory-experiments/emergent_feature_network_omg.pth"


feature_network.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


Psi = find_true_Psi(feature_network, feature_config=config)

In [None]:

import optuna
def objective(trial):

    config = {
        "batch_size": trial.suggest_categorical("batch_size", [500, 1000, 2000]),
        "num_atoms": 64,
        "feature_size": trial.suggest_categorical("feature_size", [2, 4, 8, 16]),
        "clip": trial.suggest_int("clip", 1, 10),
        "critic_output_size": trial.suggest_int("critic_output_size", 8, 64, log=True),
        "downward_hidden_sizes_v_critic": [trial.suggest_int("downward_hidden_size_v", 64, 512, log=True) for _ in range(3)],
        "downward_hidden_sizes_xi_critic": [trial.suggest_int("downward_hidden_size_xi", 32, 256, log=True) for _ in range(3)],
        "feature_hidden_sizes": [trial.suggest_int("feature_hidden_size", 256, 1024, log=True) for _ in range(4)],
        "decoupled_critis_hidden_sizes": [trial.suggest_int("decoupled_critic_hidden_size", 64, 512, log=True) for _ in range(3)],
        "feature_lr": trial.suggest_float("feature_lr", 1e-6, 1e-3, log=True),
        "decoupled_critic_lr": trial.suggest_float("decoupled_critic_lr", 1e-5, 1e-3, log=True),
        "downward_lr": trial.suggest_float("downward_lr", 1e-5, 1e-3, log=True),
        "bias": True,
        "update_f_every_N_steps": trial.suggest_int("update_f_every_N_steps", 1, 20, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-9, 1e-4, log=True),
    }

    feature_network = train_feature_network(config)
    Psi = find_true_Psi(feature_network, wandb.run.id, feature_config=config)

    return Psi

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)  # Adjust the number of trials as needed

# add the Psi value the params of the top 5 runs, then save

import json

# save best params
with open("optuna_results/best_params.json", "w") as f:
    json.dump(study.best_params, f)

# save second best params
with open("optuna_results/second_best_params.json", "w") as f:
    json.dump(study.best_trials[1].params, f)

# save third best params
with open("optuna_results/third_best_params.json", "w") as f:
    json.dump(study.best_trials[2].params, f)

# save fourth best params
with open("optuna_results/fourth_best_params.json", "w") as f:
    json.dump(study.best_trials[3].params, f)

# save fifth best params
with open("optuna_results/fifth_best_params.json", "w") as f:
    json.dump(study.best_trials[4].params, f)


top_5 = study.best_trials[:5]

for i, trial in enumerate(top_5):
    trial.params['Psi'] = trial.value
    with open(f"optuna_results/trial_{i}.json", "w") as f:
        json.dump(trial.params, f)


# Hacky hyperparam search