In [6]:
import torch
import torch.nn as nn
import lovely_tensors as lt
import wandb
from smile_estimator import estimate_mutual_information
from models import SkipConnectionSupervenientFeatureNetwork, DecoupledSmileMIEstimator
import tqdm
import math
from torch.utils.data import DataLoader
from utils import ECoGDataset
lt.monkey_patch()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


print(device)

cuda


In [7]:

    

class CLUB(nn.Module):  # CLUB: Mutual Information Contrastive Learning Upper Bound
    '''
        This class provides the CLUB estimation to I(X,Y)
        Method:
            forward() :      provides the estimation with input samples  
            loglikeli() :   provides the log-likelihood of the approximation q(Y|X) with input samples
        Arguments:
            x_dim, y_dim :         the dimensions of samples from X, Y respectively
            hidden_size :          the dimension of the hidden layer of the approximation network q(Y|X)
            x_samples, y_samples : samples from X and Y, having shape [sample_size, x_dim/y_dim] 
    '''
    def __init__(
            self,
            v_dim,
            mu_hidden_sizes: list,
            logvar_hidden_sizes: list
        ):
        super(CLUB, self).__init__()
        # p_mu outputs mean of q(Y|X)
        # p_logvar outputs log of variance of q(Y|X)

        # NOTE: hard coding in 1 for output dim here (and below) so that we don't have to make assumptions about the covariance matrix between the different components of y
        p_mu_layers = []
        input_size = v_dim
        for hidden_size in mu_hidden_sizes:
            p_mu_layers.append(nn.Linear(input_size, hidden_size))
            p_mu_layers.append(nn.ReLU())
            input_size = hidden_size
        p_mu_layers.append(nn.Linear(input_size, 1))
        self.p_mu = nn.Sequential(*p_mu_layers)

        p_logvar_layers = []
        input_size = v_dim
        for hidden_size in logvar_hidden_sizes:
            p_logvar_layers.append(nn.Linear(input_size, hidden_size))
            p_logvar_layers.append(nn.ReLU())
            input_size = hidden_size
        p_logvar_layers.append(nn.Linear(input_size, 1))
        p_logvar_layers.append(nn.Tanh())
        self.p_logvar = nn.Sequential(*p_logvar_layers)


    def get_mu_logvar(self, x_samples):
        mu = self.p_mu(x_samples)
        logvar = self.p_logvar(x_samples)
        return mu, logvar
    
    def forward(self, x_samples, y_samples): 
        mu, logvar = self.get_mu_logvar(x_samples)
        
        # log of conditional probability of positive sample pairs
        positive = - (mu - y_samples)**2 /2./logvar.exp()  
        
        prediction_1 = mu.unsqueeze(1)          # shape [nsample,1,dim]
        y_samples_1 = y_samples.unsqueeze(0)    # shape [1,nsample,dim]

        # log of conditional probability of negative sample pairs
        negative = - ((y_samples_1 - prediction_1)**2).mean(dim=1)/2./logvar.exp() 

        return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()

    def loglikeli(self, x_samples, y_samples): # unnormalized loglikelihood 
        mu, logvar = self.get_mu_logvar(x_samples)
        return 0.5 * (-(mu - y_samples)**2 /logvar.exp()-logvar - torch.log(torch.tensor(2 * math.pi))).sum(dim=1).mean(dim=0)
    # NOTE: y should be dim 1
    def learning_loss(self, x_samples, y_samples):
        return - self.loglikeli(x_samples, y_samples)




In [18]:
config = {
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 3,
    "clip": 5,
    "update_f_every_N_steps": 5,
    "club_config": {
        "mu_hidden_sizes": [512, 512, 256, 64],
        "logvar_hidden_sizes": [512, 512, 256, 64],
        "club_lr": 1e-3,
        "bias": True,
        "weight_decay": 1e-4,
    },
    "feature_network_config": {
        "hidden_sizes": [256, 256, 256, 256, 256],
        "include_bias": True,
        "feature_lr": 1e-4,
        "weight_decay": 1e-5

    },
    "decoupled_critic_config": {
        "hidden_sizes": [512, 512, 128],
        "include_bias": True,
        "decoupled_critic_lr": 1e-4,
        "critic_output_size": 16,
        "weight_decay": 1e-5
    },
}


In [21]:
def train_feature_network(config, feature_network, trainloader):


    wandb.init(project="getting emergence with club", config=config)


    decoupled_critic = DecoupledSmileMIEstimator(
        feature_size=config['feature_size'],
        critic_output_size=config['decoupled_critic_config']['critic_output_size'],
        hidden_sizes_1=config['decoupled_critic_config']['hidden_sizes'],
        hidden_sizes_2=config['decoupled_critic_config']['hidden_sizes'],
        include_bias=config['decoupled_critic_config']['include_bias'],
        clip=config['clip']
        ).to(device)
    downward_clubs = [
        CLUB(
            v_dim=config['feature_size'],
            mu_hidden_sizes=config['club_config']['mu_hidden_sizes'],
            logvar_hidden_sizes=config['club_config']['logvar_hidden_sizes']
        ).to(device)
        for _ in range(config['num_atoms'])
    ]


    downward_optims = [
        torch.optim.Adam(
            dc.parameters(),
            lr=config['club_config']['club_lr'],
            weight_decay=config['club_config']['weight_decay']
        )
        for dc in downward_clubs
    ]
    decoupled_optimizer = torch.optim.Adam(
        decoupled_critic.parameters(),
        lr=config["decoupled_critic_config"]["decoupled_critic_lr"],
        weight_decay=config["decoupled_critic_config"]["weight_decay"]
    )


    # TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
    # TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
    # Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
    # ^ this does not require retain_graph=True, so maybe this can be optomized somehow
    # wandb.watch(decoupled_critic, log="all")
    # for dc in downward_clubs:
    #     wandb.watch(dc, log='all')

    ##
    ## TRAIN FEATURE NETWORK
    ##

    epochs = 20

    for epoch in tqdm.tqdm(range(epochs), desc='Training'):
        for batch_num, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()
            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            # ensure we are only training MI estimators first
            # feature_network.eval()
            # decoupled_critic.train()
            # downward_clubs = [dc.train() for dc in downward_clubs]

            # update decoupled critic
            decoupled_optimizer.zero_grad()
            decoupled_MI = decoupled_critic(v0, v1)
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()


            # update each downward club
            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_0i = x0[:, i].unsqueeze(1)
                downward_loss = downward_clubs[i].learning_loss(v1, channel_0i)
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"club_downward_{i}_loss": downward_loss
                })
                MI = downward_clubs[i](v1, channel_0i)
                wandb.log({
                    f"club_downward_{i}_MI": MI
                })

            # update feature network   
            # feature_network.train()
            # decoupled_critic.eval()
            # downward_clubs = [dc.eval() for dc in downward_clubs]

            sum_downward_MI = 0
            MIs = []
            for i in range(config['num_atoms']):
                downward_MI = downward_clubs[i](v1, x0[:, i].unsqueeze(1))
                MIs.append(downward_MI)
                sum_downward_MI += downward_MI

            min_MI = min(MIs)
            wandb.log({
                "min_MI": min_MI
            })
            decoupled_MI1 = decoupled_critic(v0, v1)

            # add max
            Psi = decoupled_MI1 - sum_downward_MI + (config['num_atoms'] - 1) * max(0, min_MI)
            
            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })
    return feature_network



In [10]:
dataset = ECoGDataset()

train_loader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

In [22]:
feature_network = SkipConnectionSupervenientFeatureNetwork(
    num_atoms=config['num_atoms'],
    feature_size=config['feature_size'],
    hidden_sizes=config['feature_network_config']['hidden_sizes'],
    include_bias=config['feature_network_config']['include_bias']
).to(device)


model_path = "/vol/bitbucket/dm2223/info-theory-experiments/models/ecog_feature_network_expert-snow-10.pth"
feature_network.load_state_dict(torch.load(model_path))


feature_network = train_feature_network(config, feature_network, train_loader)

0,1
Psi,█▃▃▃▃▃▃▃▂▂▃▃▂▃▃▂▂▃▃▂▃▃▃▃▃▃▃▃▂▃▁▁▃▃▃▃▃▂▃▃
club_downward_0_MI,▁▆▅▅▅▆▅▅▆▆▅▆▆▅▅▆▆▆▅▇▅▆▆▆▅▆▅█▆▅▇▇▆▅▅▆▇█▆▅
club_downward_0_loss,█▁▁▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▂▁▁▁▂▁▁▁▁▂▁▁▁
club_downward_10_MI,▁▆▅▅▅▆▅▅▆▇▅▆▆▅▅▆▇▆▅▇▆▆▆▆▅▆▅█▇▆▇▇▆▅▆▆█▇▆▅
club_downward_10_loss,█▁▁▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▂▁▁▁▂▁▁▁▁▂▂▁▁
club_downward_11_MI,▁▆▆▆▅▆▅▅▆▇▅▆▇▆▆▇▇▆▆▇▆▆▆▆▅▇▆█▇▆▇▇▆▆▆▆██▆▆
club_downward_11_loss,█▁▁▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▂▁▁▁▂▁▁▁▁▂▂▁▁
club_downward_12_MI,▁▆▆▅▅▆▅▅▆▇▅▆▆▅▅▆▇▆▅▇▆▆▆▆▅▆▅█▇▅▇▇▆▅▆▆██▆▅
club_downward_12_loss,█▁▁▁▁▁▁▂▁▂▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▂▁▁▂▂▁▁▁▁▂▂▁▁
club_downward_13_MI,▁▆▆▅▅▆▅▅▆▇▅▆▆▅▅▆▇▆▅▇▆▆▆▆▅▆▅█▇▆▇▇▆▅▆▆██▆▆

0,1
Psi,-12.52604
club_downward_0_MI,0.61427
club_downward_0_loss,1.03455
club_downward_10_MI,0.56278
club_downward_10_loss,1.0686
club_downward_11_MI,0.61457
club_downward_11_loss,1.0438
club_downward_12_MI,0.59395
club_downward_12_loss,1.0639
club_downward_13_MI,0.62041


Training:  20%|██        | 4/20 [07:01<28:05, 105.33s/it]


KeyboardInterrupt: 

In [5]:

class NoSpectralDownwardCritic(nn.Module):
    def __init__(
            self,
            feature_size: int,
            critic_output_size: int,
            hidden_sizes_v_critic: list,
            hidden_sizes_xi_critic: list,
            include_bias: bool = True
        ):
        super(NoSpectralDownwardCritic, self).__init__()

        v_encoder_layers = []
        input_size = feature_size
        for hidden_size in hidden_sizes_v_critic:
            v_encoder_layers.append(nn.Linear(input_size, hidden_size, bias=include_bias))
            v_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        v_encoder_layers.append(nn.Linear(input_size, critic_output_size, bias=include_bias))
        self.v_encoder = nn.Sequential(*v_encoder_layers)

        atom_encoder_layers = []
        input_size = 1
        for hidden_size in hidden_sizes_xi_critic:
            atom_encoder_layers.append(nn.Linear(input_size, hidden_size, bias=include_bias))
            atom_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        atom_encoder_layers.append(nn.Linear(input_size, critic_output_size, bias=include_bias)) 
        self.atom_encoder = nn.Sequential(*atom_encoder_layers)
    
    def forward(self, v1, x0i):
        v1_encoded = self.v_encoder(v1)
        x0i_encoded = self.atom_encoder(x0i)

        scores = torch.matmul(v1_encoded, x0i_encoded.t())
        return scores




def find_true_Psi(feature_network, run_id, feature_config):

    print(type(feature_network))

    config = {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": feature_config['feature_size'],
        "clip": 5,
        "critic_output_size": 16,
        "downward_hidden_sizes_v_critic": [1028, 1028, 512, 64],
        "downward_hidden_sizes_xi_critic": [512, 512, 512, 64],
        "feature_hidden_sizes": [1028, 1028, 256],
        "decoupled_critis_hidden_sizes": [512, 512, 128],
        "decoupled_critic_lr": 1e-4,
        "downward_lr": 1e-4,
        "bias": True,
        "weight_decay": 1e-6,
        "original_run_id": run_id
    }

    dataset = torch.load("data/ecog_data_pairs.pth")
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="Finding-true-Psi-for-f", config=config, id=run_id)

    decoupled_critic = DecoupledCritic(
        feature_size=config['feature_size'],
        critic_output_size=config['critic_output_size'],
        hidden_sizes=config['decoupled_critis_hidden_sizes'],
        include_bias=config['bias']
        ).to(device)

    downward_critics = [
        NoSpectralDownwardCritic(
            feature_size=config['feature_size'],
            critic_output_size=config['critic_output_size'],
            hidden_sizes_v_critic=config['downward_hidden_sizes_v_critic'],
            hidden_sizes_xi_critic=config['downward_hidden_sizes_xi_critic'],
            include_bias=config['bias']
            ).to(device) 
        for _ in range(config['num_atoms'])
    ]

    downward_optims = [
        torch.optim.Adam(
            dc.parameters(),
            lr=config["downward_lr"],
            weight_decay=config["weight_decay"]
        ) 
        for dc in downward_critics
    ]

    decoupled_optimizer = torch.optim.Adam(
        decoupled_critic.parameters(),
        lr=config["decoupled_critic_lr"],
        weight_decay=config["weight_decay"]
    )

    # TODO: figure out why only f network is being watched, I would like to keep a closer eye on the grad n params.
    # TODO: Look at how GANs are trained with pytorch and make sure I'm not doing anything unreasonable.
    # Eg, https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 
    # ^ this does not require retain_graph=True, so maybe this can be optomized somehow
    wandb.watch(decoupled_critic, log="all")
    for dc in downward_critics:
        wandb.watch(dc, log='all')

    epochs = 5

    for _ in tqdm.tqdm(range(epochs), desc='Training'):
        for _, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()

            # update decoupled critic

            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            decoupled_optimizer.zero_grad()
            decoupled_scores = decoupled_critic(v0, v1)
            decoupled_MI = estimate_mutual_information('smile', decoupled_scores, clip=config['clip'])
            decoupled_loss = -decoupled_MI
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()


            # update each downward critic 

            for i in range(config['num_atoms']):
                downward_optims[i].zero_grad()
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores = downward_critics[i](v1, channel_i)
                downward_MI_i = estimate_mutual_information('smile', downward_scores, clip=config['clip'])
                # add spectral norm to the loss
                downward_loss = - downward_MI_i
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i   
                })

            # update feature network   

            sum_downward_MI = 0

            MIs = []
            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                downward_scores1 = downward_critics[i](v1, channel_i)
                downward_MI = estimate_mutual_information('smile', downward_scores1, clip=config['clip'])
                sum_downward_MI += downward_MI
                MIs.append(downward_MI)
            
            min_MI_clipped = max(0, min(MIs))
            
            decoupled_scores1 = decoupled_critic(v0, v1)
            decoupled_MI1 = estimate_mutual_information('smile', decoupled_scores1, clip=config['clip'])

            Psi = decoupled_MI1 - sum_downward_MI + (config['num_atoms'] - 1) * min_MI_clipped

            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "sum_downward_MI": sum_downward_MI,
                "Psi": Psi,
            })
        
    return Psi



In [6]:
feature_network = SupervenientFeatureNetwork(
    num_atoms=config['num_atoms'],
    feature_size=config['feature_size'],
    hidden_sizes=config['feature_network_config']['hidden_sizes'],
    include_bias=config['feature_network_config']['include_bias']
    ).to(device)

feature_network.load_state_dict(torch.load("club_candidate_emergent_feature.pth"))

Psi = find_true_Psi(feature_network, "club_first_candidate", config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<class '__main__.SupervenientFeatureNetwork'>


[34m[1mwandb[0m: Currently logged in as: [33mdmcsharry[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training:  20%|██        | 1/5 [01:20<05:20, 80.15s/it]

In [12]:
def interchannel_MI_CLUB(config):

    dataset = torch.load("data/ecog_data_pairs.pth")
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

    wandb.init(project="testing-club", config=config)


    club = CLUB(
        v_dim=1,
        mu_hidden_sizes=config['club_config']['mu_hidden_sizes'],
        logvar_hidden_sizes=config['club_config']['logvar_hidden_sizes']
    ).to(device)


    downward_optim = torch.optim.Adam(
        club.parameters(),
        lr=config['club_config']['club_lr'],
        weight_decay=config['club_config']['weight_decay']
    )

    ##
    ## TRAIN FEATURE NETWORK
    ##

    epochs = 10

    for _ in tqdm.tqdm(range(epochs), desc='Training'):
        for _, batch in enumerate(trainloader):
            x0 = batch[:, 0].to(device).float()
            x1 = batch[:, 1].to(device).float()

            # each club responsible for a channel
            i = 0
            j = 1
        
            channel_i = x0[:, i].unsqueeze(1)
            channel_j = x0[:, j].unsqueeze(1)
            downward_optim.zero_grad()
            club_loss = club.learning_loss(channel_j, channel_i)
            club_loss.backward()
            downward_optim.step()
            wandb.log({
                f"club_loss": club_loss
            })
            MI_channel = club(channel_j, channel_i)
            wandb.log({
                f"MI_channel": MI_channel
            })




0,1
MI_channel,▃▂▂▂▂▅▂▂▁▁▄▃▃▃▁▂▂▂▅▄▁▄▂█▃▂▅▆▂▂▅▁▄▃▇▂▂▄▄▂
club_loss,▄▂▁▁▁▁▂▂▁▂▂▂▁▂▂▂▁▂▁▁▂▁▂█▂▂▂▂▂▁▂▂▂▁▂▂▁▂▃▂

0,1
MI_channel,1.67335
club_loss,0.45907


Training: 100%|██████████| 10/10 [00:14<00:00,  1.45s/it]
