In [1]:
import torch
import torch.nn as nn
import lovely_tensors as lt
from einops import reduce, rearrange, repeat
# from npeet.entropy_estimators import entropy, mi
import matplotlib.pyplot as plt
import wandb
import utils
import importlib
import os
from utils import prepare_ecog_dataset, prepare_batch, estimate_MI_smile
from smile_estimator import estimate_mutual_information
import tqdm
from torch.nn.utils import spectral_norm
# importlib.reload(utils)



lt.monkey_patch()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [2]:

class SupervenientFeatureNetwork(nn.Module):
    def __init__(
            self,
            num_atoms: int,
            feature_size: int,
            hidden_sizes: list,
            include_bias: bool = True
        ):
        super(SupervenientFeatureNetwork, self).__init__()
        layers = []
        input_size = num_atoms
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(input_size, hidden_size, bias=include_bias))
            layers.append(nn.ReLU())
            input_size = hidden_size
        layers.append(nn.Linear(input_size, feature_size, bias=include_bias))
        self.f = nn.Sequential(*layers)

    def forward(self, x):
        return self.f(x)


class DecoupledCritic(nn.Module):
    def __init__(
            self,
            feature_size: int,
            critic_output_size: int,
            hidden_sizes: list,
            include_bias: bool = True
        ):
        super(DecoupledCritic, self).__init__()

        layers = []
        input_size = feature_size
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(input_size, hidden_size, bias=include_bias))
            layers.append(nn.ReLU())
            input_size = hidden_size
        layers.append(nn.Linear(input_size, critic_output_size, bias=include_bias))

        self.v_encoder = nn.Sequential(*layers)
        self.W = nn.Linear(critic_output_size, critic_output_size, bias=False)

    def forward(self, v0, v1):
        v0_encoded = self.v_encoder(v0)
        v1_encoded = self.v_encoder(v1)
        v1_encoded_transformed = self.W(v1_encoded)

        scores = torch.matmul(v0_encoded, v1_encoded_transformed.t())
        return scores
    

class DownwardCritic(nn.Module):
    def __init__(
            self,
            feature_size: int,
            critic_output_size: int,
            hidden_sizes_v_critic: list,
            hidden_sizes_xi_critic: list,
            include_bias: bool = True
        ):
        super(DownwardCritic, self).__init__()

        v_encoder_layers = []
        input_size = feature_size
        for hidden_size in hidden_sizes_v_critic:
            v_encoder_layers.append(spectral_norm(nn.Linear(input_size, hidden_size, bias=include_bias)))
            v_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        v_encoder_layers.append(spectral_norm(nn.Linear(input_size, critic_output_size, bias=include_bias)))
        self.v_encoder = nn.Sequential(*v_encoder_layers)

        atom_encoder_layers = []
        input_size = 1
        for hidden_size in hidden_sizes_xi_critic:
            atom_encoder_layers.append(spectral_norm(nn.Linear(input_size, hidden_size, bias=include_bias)))
            atom_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        atom_encoder_layers.append(spectral_norm(nn.Linear(input_size, critic_output_size, bias=include_bias)))
        self.atom_encoder = nn.Sequential(*atom_encoder_layers)
    
    def forward(self, v1, x0i):
        v1_encoded = self.v_encoder(v1)
        x0i_encoded = self.atom_encoder(x0i)

        scores = torch.matmul(v1_encoded, x0i_encoded.t())
        return scores





In [3]:
config = {
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 5,
    "critic_output_size": 16,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [512, 512, 512, 64],
    "feature_hidden_sizes": [1028, 1028, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,    
    "bias": True,
    "update_f_every_N_steps": 2,
}


## Train things

In [4]:

##
## INIT
##

dataset = torch.load("data/ecog_data_pairs.pth")
trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

wandb.init(project="getting-figures", config=config)

feature_network = SupervenientFeatureNetwork(
    num_atoms=config['num_atoms'],
    feature_size=config['feature_size'],
    hidden_sizes=config['feature_hidden_sizes'],
    include_bias=config['bias']
    ).to(device)

decoupled_critic = DecoupledCritic(
    feature_size=config['feature_size'],
    critic_output_size=config['critic_output_size'],
    hidden_sizes=config['decoupled_critis_hidden_sizes'],
    include_bias=config['bias']
    ).to(device)

downward_critics = [
    DownwardCritic(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
        hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
        include_bias=config['bias']
        ).to(device) 
    for _ in range(config['num_atoms'])
]

downward_optims = [torch.optim.Adam(dc.parameters(), lr=config["downward_lr"]) for dc in downward_critics]
feature_optimizer = torch.optim.Adam(feature_network.parameters(), lr=config["feature_lr"])
decoupled_optimizer = torch.optim.Adam(decoupled_critic.parameters(), lr=config["decoupled_critic_lr"])


# TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
# TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
# Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
# ^ this does not require retain_graph=True, so maybe this can be optomized somehow
wandb.watch(feature_network, log='all')
wandb.watch(decoupled_critic, log="all")
for dc in downward_critics:
    wandb.watch(dc, log='all')

##
## TRAIN FEATURE NETWORK
##

epochs = 10

for _ in tqdm.tqdm(range(epochs), desc='Training'):
    for batch_num, batch in enumerate(trainloader):
        x0 = batch[:, 0].to(device).float()
        x1 = batch[:, 1].to(device).float()

        # update decoupled critic

        v0 = feature_network(x0)
        v1 = feature_network(x1) 

        decoupled_optimizer.zero_grad()
        decoupled_scores = decoupled_critic(v0, v1)
        decoupled_MI = estimate_mutual_information('smile', decoupled_scores, clip=config['clip'])
        decoupled_loss = -decoupled_MI
        decoupled_loss.backward(retain_graph=True)
        decoupled_optimizer.step()


        # update each downward critic 

        for i in range(config['num_atoms']):
            downward_optims[i].zero_grad()
            channel_i = x0[:, i].unsqueeze(1)
            downward_scores = downward_critics[i](v1, channel_i)
            downward_MI_i = estimate_mutual_information('smile', downward_scores, clip=config['clip'])
            # add spectral norm to the loss
            downward_loss = - downward_MI_i
            downward_loss.backward(retain_graph=True)
            downward_optims[i].step()
            wandb.log({
                f"downward_MI_{i}": downward_MI_i   
            })

        # update feature network   

        feature_optimizer.zero_grad()
        sum_downward_MI = 0

        for i in range(config['num_atoms']):
            channel_i = x0[:, i].unsqueeze(1)
            downward_scores1 = downward_critics[i](v1, channel_i)
            sum_downward_MI += estimate_mutual_information('smile', downward_scores1, clip=config['clip'])

        decoupled_scores1 = decoupled_critic(v0, v1)
        decoupled_MI1 = estimate_mutual_information('smile', decoupled_scores1, clip=config['clip'])

        Psi = decoupled_MI1 - sum_downward_MI
        feature_loss = -Psi

        if batch_num % config['update_f_every_N_steps'] == 0:
            feature_loss.backward()
            feature_optimizer.step()

        wandb.log({
            "decoupled_MI": decoupled_MI1,
            "sum_downward_MI": sum_downward_MI,
            "Psi": Psi,
        })



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdmcsharry[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training: 100%|██████████| 10/10 [36:38<00:00, 219.86s/it]


In [9]:
# save the feature network

torch.save(feature_network.state_dict(), "promising_hmmm_f.pth")

In [4]:
configs = [{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 4,
    "clip": 5,
    "critic_output_size": 16,
    "downward_hidden_sizes_v_critic": [512, 512, 256],
    "downward_hidden_sizes_xi_critic": [256, 256, 64],
    "feature_hidden_sizes": [1028, 1028, 1028, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,    
    "bias": True,
    "update_f_every_N_steps": 2,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 8,
    "clip": 5,
    "critic_output_size": 16,
    "downward_hidden_sizes_v_critic": [512, 512, 256],
    "downward_hidden_sizes_xi_critic": [256, 256, 64],
    "feature_hidden_sizes": [1028, 1028, 1028, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,
    "bias": True,
    "update_f_every_N_steps": 4,
},
{
    "batch_size": 2000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 5,
    "critic_output_size": 16,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [512, 1028, 512, 64],
    "feature_hidden_sizes": [1028, 1028, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,
    "bias": True,
    "update_f_every_N_steps": 1,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 5,
    "critic_output_size": 32,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [512, 1028, 512, 64],
    "feature_hidden_sizes": [1028, 1028, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,
    "bias": True,
    "update_f_every_N_steps": 10,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 4,
    "clip": 2,
    "critic_output_size": 6,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [512, 1028, 512, 64],
    "feature_hidden_sizes": [256, 256, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,
    "bias": True,
    "update_f_every_N_steps": 10,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 5,
    "critic_output_size": 12,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [512, 1028, 512, 64],
    "feature_hidden_sizes": [1028, 1028, 256],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-3,
    "downward_lr": 1e-3,
    "bias": False,
    "update_f_every_N_steps": 1,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 1,
    "critic_output_size": 8,
    "downward_hidden_sizes_v_critic": [256, 256, 64],
    "downward_hidden_sizes_xi_critic": [256, 256, 64],
    "feature_hidden_sizes": [256, 256, 64],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-5,
    "downward_lr": 1e-5,
    "bias": True,
    "update_f_every_N_steps": 10,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 1,
    "clip": 1,
    "critic_output_size": 8,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [1028, 1028, 512, 64],
    "feature_hidden_sizes": [256, 256, 64],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-3,
    "bias": True,
    "update_f_every_N_steps": 2,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 16,
    "clip": 4,
    "critic_output_size": 8,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [1028, 1028, 512, 64],
    "feature_hidden_sizes": [256, 256, 64],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-5,
    "downward_lr": 1e-4,
    "bias": True,
    "update_f_every_N_steps": 10,
},
{
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 4,
    "clip": 5,
    "critic_output_size": 16,
    "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
    "downward_hidden_sizes_xi_critic": [1028, 1028, 512, 64],
    "feature_hidden_sizes": [256, 256, 64],
    "decoupled_critis_hidden_sizes": [512, 512, 128],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,
    "bias": True,
    "update_f_every_N_steps": 1,
}
]






def train(config):
##
## INIT
##

    dataset = torch.load("data/ecog_data_pairs.pth")
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="getting-figures", config=config)

    feature_network = SupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)

    decoupled_critic = DecoupledCritic(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes=config['decoupled_critis_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)

    downward_critics = [
        DownwardCritic(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            include_bias=config['bias']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]

    downward_optims = [torch.optim.Adam(dc.parameters(), lr=config["downward_lr"]) for dc in downward_critics]
    feature_optimizer = torch.optim.Adam(feature_network.parameters(), lr=config["feature_lr"])
    decoupled_optimizer = torch.optim.Adam(decoupled_critic.parameters(), lr=config["decoupled_critic_lr"])


    # TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
    # TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
    # Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
    # ^ this does not require retain_graph=True, so maybe this can be optomized somehow
    wandb.watch(feature_network, log='all')
    wandb.watch(decoupled_critic, log="all")
    for dc in downward_critics:
        wandb.watch(dc, log='all')

    ##
    ## TRAIN FEATURE NETWORK
    ##

    epochs = 15
    extra_epochs = 5
    total_epochs = epochs + extra_epochs

    for epoch in tqdm.tqdm(range(total_epochs), desc='Training'):
        for batch_num, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()

            # update decoupled critic

            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            decoupled_optimizer.zero_grad()
            decoupled_scores = decoupled_critic(v0, v1)
            decoupled_MI = estimate_mutual_information('smile', decoupled_scores, clip=config['clip'])
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()


            # update each downward critic 

            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores = downward_critics[i](v1, channel_i)
                downward_MI_i = estimate_mutual_information('smile', downward_scores, clip=config['clip'])
                downward_loss = -downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                })

            # update feature network   

            feature_optimizer.zero_grad()
            sum_downward_MI = 0

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores1 = downward_critics[i](v1, channel_i)
                sum_downward_MI += estimate_mutual_information('smile', downward_scores1, clip=config['clip'])

            decoupled_scores1 = decoupled_critic(v0, v1)
            decoupled_MI1 = estimate_mutual_information('smile', decoupled_scores1, clip=config['clip'])

            Psi = decoupled_MI1 - sum_downward_MI
            feature_loss = -Psi
            
            if epoch > epochs:
                if batch_num % config['update_f_every_N_steps'] == 0:
                    feature_loss.backward()
                    feature_optimizer.step()

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })


for config in configs:
    train(config)



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdmcsharry[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training:  35%|███▌      | 7/20 [19:40<36:42, 169.40s/it]