In [1]:
import torch
import torch.nn as nn
import lovely_tensors as lt
from einops import reduce, rearrange, repeat
# from npeet.entropy_estimators import entropy, mi
import matplotlib.pyplot as plt
import wandb
import utils
import importlib
import os
from utils import prepare_ecog_dataset, prepare_batch, estimate_MI_smile
from smile_estimator import estimate_mutual_information
import tqdm
from torch.nn.utils import spectral_norm
# importlib.reload(utils)



lt.monkey_patch()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [2]:

class SupervenientFeatureNetwork(nn.Module):
    def __init__(
            self,
            num_atoms: int,
            feature_size: int,
            hidden_sizes: list,
            include_bias: bool = True
        ):
        super(SupervenientFeatureNetwork, self).__init__()
        layers = []
        input_size = num_atoms
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(input_size, hidden_size, bias=include_bias))
            layers.append(nn.ReLU())
            input_size = hidden_size
        layers.append(nn.Linear(input_size, feature_size, bias=include_bias))
        self.f = nn.Sequential(*layers)

    def forward(self, x):
        return self.f(x)


class DecoupledCritic(nn.Module):
    def __init__(
            self,
            feature_size: int,
            critic_output_size: int,
            hidden_sizes: list,
            include_bias: bool = True
        ):
        super(DecoupledCritic, self).__init__()

        layers = []
        input_size = feature_size
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(input_size, hidden_size, bias=include_bias))
            layers.append(nn.ReLU())
            input_size = hidden_size
        layers.append(nn.Linear(input_size, critic_output_size, bias=include_bias))

        self.v_encoder = nn.Sequential(*layers)
        self.W = nn.Linear(critic_output_size, critic_output_size, bias=False)

    def forward(self, v0, v1):
        v0_encoded = self.v_encoder(v0)
        v1_encoded = self.v_encoder(v1)
        v1_encoded_transformed = self.W(v1_encoded)

        scores = torch.matmul(v0_encoded, v1_encoded_transformed.t())
        return scores
    

class DownwardCritic(nn.Module):
    def __init__(
            self,
            feature_size: int,
            critic_output_size: int,
            hidden_sizes_v_critic: list,
            hidden_sizes_xi_critic: list,
            include_bias: bool = True
        ):
        super(DownwardCritic, self).__init__()

        v_encoder_layers = []
        input_size = feature_size
        for hidden_size in hidden_sizes_v_critic:
            # TODO: Understand what the fuck spectral norm actually is
            v_encoder_layers.append(spectral_norm(nn.Linear(input_size, hidden_size, bias=include_bias)))
            v_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        v_encoder_layers.append(spectral_norm(nn.Linear(input_size, critic_output_size, bias=include_bias)))
        self.v_encoder = nn.Sequential(*v_encoder_layers)

        atom_encoder_layers = []
        input_size = 1
        for hidden_size in hidden_sizes_xi_critic:
            atom_encoder_layers.append(spectral_norm(nn.Linear(input_size, hidden_size, bias=include_bias)))
            atom_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        atom_encoder_layers.append(spectral_norm(nn.Linear(input_size, critic_output_size, bias=include_bias)))
        self.atom_encoder = nn.Sequential(*atom_encoder_layers)
    
    def forward(self, v1, x0i):
        v1_encoded = self.v_encoder(v1)
        x0i_encoded = self.atom_encoder(x0i)

        scores = torch.matmul(v1_encoded, x0i_encoded.t())
        return scores
    


class NoSpectralDownwardCritic(nn.Module):
    def __init__(
            self,
            feature_size: int,
            critic_output_size: int,
            hidden_sizes_v_critic: list,
            hidden_sizes_xi_critic: list,
            include_bias: bool = True
        ):
        super(NoSpectralDownwardCritic, self).__init__()

        v_encoder_layers = []
        input_size = feature_size
        for hidden_size in hidden_sizes_v_critic:
            v_encoder_layers.append(nn.Linear(input_size, hidden_size, bias=include_bias))
            v_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        v_encoder_layers.append(nn.Linear(input_size, critic_output_size, bias=include_bias))
        self.v_encoder = nn.Sequential(*v_encoder_layers)

        atom_encoder_layers = []
        input_size = 1
        for hidden_size in hidden_sizes_xi_critic:
            atom_encoder_layers.append(nn.Linear(input_size, hidden_size, bias=include_bias))
            atom_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        atom_encoder_layers.append(nn.Linear(input_size, critic_output_size, bias=include_bias)) 
        self.atom_encoder = nn.Sequential(*atom_encoder_layers)
    
    def forward(self, v1, x0i):
        v1_encoded = self.v_encoder(v1)
        x0i_encoded = self.atom_encoder(x0i)

        scores = torch.matmul(v1_encoded, x0i_encoded.t())
        return scores
    


In [3]:
config = {
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 5,
    "critic_output_size": 16,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [512, 512, 512, 64],
    "feature_hidden_sizes": [1028, 1028, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,    
    "bias": True,
    "update_f_every_N_steps": 2,
    "weight_decay": 1e-6,
}


# Train a feature network for a given config

In [19]:
def train_feature_network(config):

    dataset = torch.load("data/ecog_data_pairs.pth")
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="getting-figures", config=config)

    feature_network = SupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)
    decoupled_critic = DecoupledCritic(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes=config['decoupled_critis_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)
    downward_critics = [
        DownwardCritic(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            include_bias=config['bias']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]


    downward_optims = [
        torch.optim.Adam(
            dc.parameters(),
            lr=config["downward_lr"],
            weight_decay=config["weight_decay"]
        ) 
        for dc in downward_critics
    ]
    feature_optimizer = torch.optim.Adam(
        feature_network.parameters(),
        lr=config["feature_lr"],
        weight_decay=config["weight_decay"]
    )
    decoupled_optimizer = torch.optim.Adam(
        decoupled_critic.parameters(),
        lr=config["decoupled_critic_lr"],
        weight_decay=config["weight_decay"]
    )


    # TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
    # TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
    # Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
    # ^ this does not require retain_graph=True, so maybe this can be optomized somehow
    wandb.watch(feature_network, log='all')
    wandb.watch(decoupled_critic, log="all")
    for dc in downward_critics:
        wandb.watch(dc, log='all')

    ##
    ## TRAIN FEATURE NETWORK
    ##

    epochs = 5

    for _ in tqdm.tqdm(range(epochs), desc='Training'):
        for batch_num, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()

            # update decoupled critic

            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            decoupled_optimizer.zero_grad()
            decoupled_scores = decoupled_critic(v0, v1)
            decoupled_MI = estimate_mutual_information('smile', decoupled_scores, clip=config['clip'])
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()


            # update each downward critic 

            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores = downward_critics[i](v1, channel_i)
                downward_MI_i = estimate_mutual_information('smile', downward_scores, clip=config['clip'])
                # add spectral norm to the loss
                downward_loss = - downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                })

            # update feature network   

            feature_optimizer.zero_grad()
            sum_downward_MI = 0

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores1 = downward_critics[i](v1, channel_i)
                sum_downward_MI += estimate_mutual_information('smile', downward_scores1, clip=config['clip'])

            decoupled_scores1 = decoupled_critic(v0, v1)
            decoupled_MI1 = estimate_mutual_information('smile', decoupled_scores1, clip=config['clip'])

            Psi = decoupled_MI1 - sum_downward_MI
            feature_loss = -Psi

            if batch_num % config['update_f_every_N_steps'] == 0:
                feature_loss.backward()
                feature_optimizer.step()

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })

    # free the memory associated with anything that is not the feature network
    for dc in downward_critics:
        del dc
    del decoupled_critic
    del decoupled_optimizer
    del downward_optims
    del trainloader
    del dataset
        
    return feature_network



# Finding the MI between different channels

I( xt_i ; xt_j)

In [None]:

dataset = torch.load("data/ecog_data_pairs.pth")
trainloader = torch.utils.data.DataLoader(dataset, batch_size=2000, shuffle=False)

wandb.init(project="Interchannel MI")

channel_critic = NoSpectralDownwardCritic(
    feature_size=1, # replacing the feature with a channel which is dim 1
    critic_output_size=8,
    hidden_sizes_v_critic=[64, 512, 1028, 512],
    hidden_sizes_xi_critic=[64, 512, 1028, 512],
    include_bias=True
).to(device) 

channel_critic_optim = torch.optim.Adam(channel_critic.parameters(), lr=1e-5, weight_decay=1e-6)

epochs = 20

channels = (1,2)

for _ in tqdm.tqdm(range(epochs), desc='Training a SMILE estimator for interchannel MI'):
    for batch_num, batch in enumerate(trainloader):
        x0 = batch[:, 0].to(device).float()
        x1 = batch[:, 1].to(device).float()

        channel_i = x0[:, channels[0]].unsqueeze(1)
        channel_j = x0[:, channels[1]].unsqueeze(1)

        scores = channel_critic(channel_j, channel_i)
        MI = estimate_mutual_information('smile', scores, clip=1)
        loss = -MI 
        loss.backward()
        channel_critic_optim.step()
        wandb.log({
            "Inter-channel MI": MI
        })


# Estimating Psi given a frozen feature network

In [18]:

# model_path = "/vol/bitbucket/dm2223/info-theory-experiments/promising_hmmm_f.pth"

# feature_network = SupervenientFeatureNetwork(
#     num_atoms=config['num_atoms'],
#     feature_size=config['feature_size'],
#     hidden_sizes=config['feature_hidden_sizes'],
#     include_bias=config['bias']
#     ).to(device)
# feature_network.load_state_dict(torch.load(model_path))
# feature_network.eval()



def find_true_Psi(feature_network, run_id, feature_config):

    print(type(feature_network))

    config = {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": feature_config['feature_size'],
        "clip": 5,
        "critic_output_size": 16,
        "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
        "downward_hidden_sizes_xi_critic": [512, 512, 512, 64],
        "feature_hidden_sizes": [1028, 1028, 256],
        "decoupled_critis_hidden_sizes": [512, 512, 128],
        "decoupled_critic_lr": 1e-4,
        "downward_lr": 1e-4,
        "bias": True,
        "weight_decay": 1e-6,
        "original_run_id": run_id
    }

    dataset = torch.load("data/ecog_data_pairs.pth")
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="Finding-true-Psi-for-f", config=config, id=run_id)

    decoupled_critic = DecoupledCritic(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes=config['decoupled_critis_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)

    downward_critics = [
        NoSpectralDownwardCritic(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            include_bias=config['bias']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]

    downward_optims = [
        torch.optim.Adam(
            dc.parameters(),
            lr=config["downward_lr"],
            weight_decay=config["weight_decay"]
        ) 
        for dc in downward_critics
    ]

    decoupled_optimizer = torch.optim.Adam(
        decoupled_critic.parameters(),
        lr=config["decoupled_critic_lr"],
        weight_decay=config["weight_decay"]
    )

    # TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
    # TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
    # Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
    # ^ this does not require retain_graph=True, so maybe this can be optomized somehow
    wandb.watch(decoupled_critic, log="all")
    for dc in downward_critics:
        wandb.watch(dc, log='all')

    epochs = 5

    for _ in tqdm.tqdm(range(epochs), desc='Training'):
        for _, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()

            # update decoupled critic

            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            decoupled_optimizer.zero_grad()
            decoupled_scores = decoupled_critic(v0, v1)
            decoupled_MI = estimate_mutual_information('smile', decoupled_scores, clip=config['clip'])
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()


            # update each downward critic 

            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores = downward_critics[i](v1, channel_i)
                downward_MI_i = estimate_mutual_information('smile', downward_scores, clip=config['clip'])
                # add spectral norm to the loss
                downward_loss = - downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                })

            # update feature network   

            sum_downward_MI = 0

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores1 = downward_critics[i](v1, channel_i)
                sum_downward_MI += estimate_mutual_information('smile', downward_scores1, clip=config['clip'])

            decoupled_scores1 = decoupled_critic(v0, v1)
            decoupled_MI1 = estimate_mutual_information('smile', decoupled_scores1, clip=config['clip'])

            Psi = decoupled_MI1 - sum_downward_MI

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })
        
    return Psi



In [17]:

import optuna
def objective(trial):

    config = {
        "batch_size": trial.suggest_categorical("batch_size", [500, 1000, 2000]),
        "num_atoms": 64,
        "feature_size": trial.suggest_categorical("feature_size", [2, 4, 8, 16]),
        "clip": trial.suggest_int("clip", 1, 10),
        "critic_output_size": trial.suggest_int("critic_output_size", 8, 64, log=True),
        "downward_hidden_sizes_v_critic": [trial.suggest_int("downward_hidden_size_v", 64, 512, log=True) for _ in range(3)],
        "downward_hidden_sizes_xi_critic": [trial.suggest_int("downward_hidden_size_xi", 32, 256, log=True) for _ in range(3)],
        "feature_hidden_sizes": [trial.suggest_int("feature_hidden_size", 256, 1024, log=True) for _ in range(4)],
        "decoupled_critis_hidden_sizes": [trial.suggest_int("decoupled_critic_hidden_size", 64, 512, log=True) for _ in range(3)],
        "feature_lr": trial.suggest_float("feature_lr", 1e-6, 1e-3, log=True),
        "decoupled_critic_lr": trial.suggest_float("decoupled_critic_lr", 1e-5, 1e-3, log=True),
        "downward_lr": trial.suggest_float("downward_lr", 1e-5, 1e-3, log=True),
        "bias": True,
        "update_f_every_N_steps": trial.suggest_int("update_f_every_N_steps", 1, 20, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-9, 1e-4, log=True),
    }

    feature_network = train_feature_network(config)
    Psi = find_true_Psi(feature_network, wandb.run.id, feature_config=config)

    return Psi

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)  # Adjust the number of trials as needed

# save the best_params to root
best_params = study.best_params

import json
with open("best_params.json", "w") as f:
    json.dump(best_params, f)


[I 2024-04-15 02:40:40,595] A new study created in memory with name: no-name-09a86fe3-ffd4-45df-8ceb-77f5b43d8ebc


0,1
Psi,▁
decoupled_MI,▁
downward_MI_0,▁
downward_MI_1,▁
downward_MI_10,▁
downward_MI_11,▁
downward_MI_12,▁
downward_MI_13,▁
downward_MI_14,▁
downward_MI_15,▁

0,1
Psi,-0.0016
decoupled_MI,0.0
downward_MI_0,-1e-05
downward_MI_1,-2e-05
downward_MI_10,-1e-05
downward_MI_11,-0.0
downward_MI_12,-1e-05
downward_MI_13,-1e-05
downward_MI_14,0.0
downward_MI_15,1e-05


Training:   0%|          | 0/1 [00:00<?, ?it/s]


<class '__main__.SupervenientFeatureNetwork'>


0,1
Psi,▁
decoupled_MI,▁
downward_MI_0,▁
downward_MI_1,▁
downward_MI_10,▁
downward_MI_11,▁
downward_MI_12,▁
downward_MI_13,▁
downward_MI_14,▁
downward_MI_15,▁

0,1
Psi,0.00028
decoupled_MI,-0.0
downward_MI_0,-2e-05
downward_MI_1,-3e-05
downward_MI_10,-0.0
downward_MI_11,1e-05
downward_MI_12,1e-05
downward_MI_13,-3e-05
downward_MI_14,-7e-05
downward_MI_15,-5e-05


Training:   0%|          | 0/1 [00:00<?, ?it/s]
[I 2024-04-15 02:40:54,699] Trial 0 finished with value: -0.002632617950439453 and parameters: {'batch_size': 2000, 'feature_size': 16, 'clip': 7, 'critic_output_size': 30, 'downward_hidden_size_v': 86, 'downward_hidden_size_xi': 32, 'feature_hidden_size': 373, 'decoupled_critic_hidden_size': 463, 'feature_lr': 6.150712661095301e-06, 'decoupled_critic_lr': 1.7664621484298773e-05, 'downward_lr': 9.03235645544033e-05, 'update_f_every_N_steps': 5, 'weight_decay': 2.2904444521284543e-06}. Best is trial 0 with value: -0.002632617950439453.


0,1
Psi,▁
decoupled_MI,▁
downward_MI_0,▁
downward_MI_1,▁
downward_MI_10,▁
downward_MI_11,▁
downward_MI_12,▁
downward_MI_13,▁
downward_MI_14,▁
downward_MI_15,▁

0,1
Psi,-0.00263
decoupled_MI,0.0
downward_MI_0,3e-05
downward_MI_1,-2e-05
downward_MI_10,-1e-05
downward_MI_11,2e-05
downward_MI_12,-0.0
downward_MI_13,-2e-05
downward_MI_14,-0.0
downward_MI_15,-0.0


Training:   0%|          | 0/1 [00:00<?, ?it/s]


<class '__main__.SupervenientFeatureNetwork'>


0,1
Psi,▁
decoupled_MI,▁
downward_MI_0,▁
downward_MI_1,▁
downward_MI_10,▁
downward_MI_11,▁
downward_MI_12,▁
downward_MI_13,▁
downward_MI_14,▁
downward_MI_15,▁

0,1
Psi,-0.00046
decoupled_MI,-0.0
downward_MI_0,-0.0
downward_MI_1,0.0
downward_MI_10,-2e-05
downward_MI_11,1e-05
downward_MI_12,-1e-05
downward_MI_13,-2e-05
downward_MI_14,0.0
downward_MI_15,0.0


Training:   0%|          | 0/1 [00:00<?, ?it/s]
[I 2024-04-15 02:41:08,486] Trial 1 finished with value: -0.0007123947143554688 and parameters: {'batch_size': 2000, 'feature_size': 16, 'clip': 2, 'critic_output_size': 10, 'downward_hidden_size_v': 476, 'downward_hidden_size_xi': 75, 'feature_hidden_size': 958, 'decoupled_critic_hidden_size': 175, 'feature_lr': 1.723749852560971e-05, 'decoupled_critic_lr': 1.6145668368018935e-05, 'downward_lr': 6.425805952319134e-05, 'update_f_every_N_steps': 5, 'weight_decay': 1.5989586833270346e-06}. Best is trial 1 with value: -0.0007123947143554688.


# Hacky hyperparam search

In [None]:

##
## Hyperparameter search
##


configs = [
# {
#     "batch_size": 1000,
#     "num_atoms": 64,
#     "feature_size": 8,
#     "clip": 5,
#     "critic_output_size": 16,
#     "downward_hidden_sizes_v_critic": [512, 512, 256],
#     "downward_hidden_sizes_xi_critic": [256, 256, 64],
#     "feature_hidden_sizes": [1028, 1028, 1028, 256],
#     "decoupled_critis_hidden_sizes": [512, 512, 128],
#     "feature_lr": 1e-5,
#     "decoupled_critic_lr": 1e-4,
#     "downward_lr": 1e-4,
#     "bias": True,
#     "update_f_every_N_steps": 4,
# },
# {
#     "batch_size": 2000,
#     "num_atoms": 64,
#     "feature_size": 1,
#     "clip": 5,
#     "critic_output_size": 16,
#     "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
#     "downward_hidden_sizes_xi_critic": [512, 1028, 512, 64],
#     "feature_hidden_sizes": [1028, 1028, 256],
#     "decoupled_critis_hidden_sizes": [512, 512, 128],
#     "feature_lr": 1e-5,
#     "decoupled_critic_lr": 1e-4,
#     "downward_lr": 1e-4,
#     "bias": True,
#     "update_f_every_N_steps": 1,
# },
# {
#     "batch_size": 1000,
#     "num_atoms": 64,
#     "feature_size": 1,
#     "clip": 5,
#     "critic_output_size": 32,
#     "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
#     "downward_hidden_sizes_xi_critic": [512, 1028, 512, 64],
#     "feature_hidden_sizes": [1028, 1028, 256],
#     "decoupled_critis_hidden_sizes": [512, 512, 128],
#     "feature_lr": 1e-5,
#     "decoupled_critic_lr": 1e-4,
#     "downward_lr": 1e-4,
#     "bias": True,
#     "update_f_every_N_steps": 10,
# },
# {
#     "batch_size": 1000,
#     "num_atoms": 64,
#     "feature_size": 4,
#     "clip": 2,
#     "critic_output_size": 6,
#     "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
#     "downward_hidden_sizes_xi_critic": [512, 1028, 512, 64],
#     "feature_hidden_sizes": [256, 256, 256],
#     "decoupled_critis_hidden_sizes": [512, 512, 128],
#     "feature_lr": 1e-5,
#     "decoupled_critic_lr": 1e-4,
#     "downward_lr": 1e-4,
#     "bias": True,
#     "update_f_every_N_steps": 10,
# },
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 5,
    "critic_output_size": 12,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [512, 1028, 512, 64],
    "feature_hidden_sizes": [1028, 1028, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-3,
    "downward_lr": 1e-3,
    "bias": False,
    "update_f_every_N_steps": 1,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 1,
    "critic_output_size": 8,
    "downward_hidden_sizes_v_critic": [256, 256, 64],
    "downward_hidden_sizes_xi_critic": [256, 256, 64],
    "feature_hidden_sizes": [256, 256, 64],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-5,
    "downward_lr": 1e-5,
    "bias": True,
    "update_f_every_N_steps": 10,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 1,
    "critic_output_size": 8,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [1028, 1028, 512, 64],
    "feature_hidden_sizes": [256, 256, 64],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-3,
    "bias": True,
    "update_f_every_N_steps": 2,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 16,
    "clip": 4,
    "critic_output_size": 8,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [1028, 1028, 512, 64],
    "feature_hidden_sizes": [256, 256, 64],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-5,
    "downward_lr": 1e-4,
    "bias": True,
    "update_f_every_N_steps": 10,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 4,
    "clip": 5,
    "critic_output_size": 16,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [1028, 1028, 512, 64],
    "feature_hidden_sizes": [256, 256, 64],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,
    "bias": True,
    "update_f_every_N_steps": 1,
}
]



def train(config):
##
## INIT
##

    dataset = torch.load("data/ecog_data_pairs.pth")
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="getting-figures", config=config)

    feature_network = SupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)

    decoupled_critic = DecoupledCritic(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes=config['decoupled_critis_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)

    downward_critics = [
        DownwardCritic(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            include_bias=config['bias']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]

    downward_optims = [torch.optim.Adam(dc.parameters(), lr=config["downward_lr"], weight_decay=config["weight_decay"]) for dc in downward_critics]
    feature_optimizer = torch.optim.Adam(feature_network.parameters(), lr=config["feature_lr"], weight_decay=config["weight_decay"])
    decoupled_optimizer = torch.optim.Adam(decoupled_critic.parameters(), lr=config["decoupled_critic_lr"], weight_decay=config["weight_decay"])


    # TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
    # TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
    # Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
    # ^ this does not require retain_graph=True, so maybe this can be optomized somehow
    wandb.watch(feature_network, log='all')
    wandb.watch(decoupled_critic, log="all")
    for dc in downward_critics:
        wandb.watch(dc, log='all')

    ##
    ## TRAIN FEATURE NETWORK
    ##

    epochs = 8
    extra_epochs = 4
    total_epochs = epochs + extra_epochs

    for epoch in tqdm.tqdm(range(total_epochs), desc='Training'):
        for batch_num, batch in enumerate(trainloader):
            # TODO: maybe add some noise to these? This seems to help in GANs
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()

            # update decoupled critic

            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            decoupled_optimizer.zero_grad()
            decoupled_scores = decoupled_critic(v0, v1)
            decoupled_MI = estimate_mutual_information('smile', decoupled_scores, clip=config['clip'])
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()


            # update each downward critic 

            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores = downward_critics[i](v1, channel_i)
                downward_MI_i = estimate_mutual_information('smile', downward_scores, clip=config['clip'])
                downward_loss = -downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                })

            # update feature network   

            feature_optimizer.zero_grad()
            sum_downward_MI = 0

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores1 = downward_critics[i](v1, channel_i)
                sum_downward_MI += estimate_mutual_information('smile', downward_scores1, clip=config['clip'])

            decoupled_scores1 = decoupled_critic(v0, v1)
            decoupled_MI1 = estimate_mutual_information('smile', decoupled_scores1, clip=config['clip'])

            Psi = decoupled_MI1 - sum_downward_MI
            feature_loss = -Psi
            
            if epoch > epochs:
                if batch_num % config['update_f_every_N_steps'] == 0:
                    feature_loss.backward()
                    feature_optimizer.step()

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })



for config in configs:
    train(config)
    run_id = wandb.run.id
    torch.save(feature_network.state_dict(), f"promising_hmmm_f_{run_id}.pth")
    wandb.finish()