In [1]:
import torch
import lovely_tensors as lt
from einops import reduce, rearrange, repeat
# from npeet.entropy_estimators import entropy, mi
import matplotlib.pyplot as plt
import wandb
import utils
import importlib
importlib.reload(utils)
import os
from utils import prepare_ecog_dataset, prepare_batch, estimate_MI_smile
from smile_estimator import estimate_mutual_information
import tqdm



lt.monkey_patch()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [2]:
import torch.nn as nn

class SupervenientFeatureNetwork(nn.Module):
    def __init__(self, num_atoms, feature_size, hidden_sizes):
        super(SupervenientFeatureNetwork, self).__init__()
        layers = []
        input_size = num_atoms
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(input_size, hidden_size))
            layers.append(nn.ReLU())
            input_size = hidden_size
        layers.append(nn.Linear(input_size, feature_size))
        self.f = nn.Sequential(*layers)
        print('Feature network:', self.f)

    def forward(self, x):
        return self.f(x)


class DecoupledCritic(nn.Module):
    def __init__(
            self,
            feature_size: int,
            critic_output_size: int,
            hidden_sizes: list
        ):
        super(DecoupledCritic, self).__init__()
        # self.v_encoder = nn.Sequential(
        #     nn.Linear(feature_size, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 64),
        #     nn.ReLU(),
        #     nn.Linear(64, 8),
        # )

        layers = []
        input_size = feature_size
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(input_size, hidden_size))
            layers.append(nn.ReLU())
            input_size = hidden_size
        layers.append(nn.Linear(input_size, critic_output_size))

        self.v_encoder = nn.Sequential(*layers)
        self.W = nn.Linear(critic_output_size, critic_output_size, bias=False)

        print('v_encoder:', self.v_encoder)
        print('W:', self.W)

    def forward(self, v0, v1):
        v0_encoded = self.v_encoder(v0)
        v1_encoded = self.v_encoder(v1)
        v1_encoded_transformed = self.W(v1_encoded)

        scores = torch.matmul(v0_encoded, v1_encoded_transformed.t())
        return scores
    

class DownwardCritic(nn.Module):
    def __init__(self, feature_size, critic_output_size, hidden_sizes_v_critic, hidden_sizes_xi_critic):
        super(DownwardCritic, self).__init__()
        # self.v_encoder = nn.Sequential(
        #     nn.Linear(feature_size, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 8),
        # )

        # self.atom_encoder = nn.Sequential(
        #     nn.Linear(1, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 64),
        #     nn.ReLU(),
        #     nn.Linear(64, 8),
        # )

        v_encoder_layers = []
        input_size = feature_size
        for hidden_size in hidden_sizes_v_critic:
            v_encoder_layers.append(nn.Linear(input_size, hidden_size))
            v_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        v_encoder_layers.append(nn.Linear(input_size, critic_output_size))
        self.v_encoder = nn.Sequential(*v_encoder_layers)

        atom_encoder_layers = []
        input_size = 1
        for hidden_size in hidden_sizes_xi_critic:
            atom_encoder_layers.append(nn.Linear(input_size, hidden_size))
            atom_encoder_layers.append(nn.ReLU())
            input_size = hidden_size
        atom_encoder_layers.append(nn.Linear(input_size, critic_output_size))
        self.atom_encoder = nn.Sequential(*atom_encoder_layers)

        print('v_encoder:', self.v_encoder)
        print('atom_encoder:', self.atom_encoder)
    
    def forward(self, v1, x0i):
        v1_encoded = self.v_encoder(v1)
        x0i_encoded = self.atom_encoder(x0i)

        scores = torch.matmul(v1_encoded, x0i_encoded.t())
        return scores






In [3]:
config = {
    "batch_size": 2000,
    "num_atoms": 64,
    "feature_size": 4,
    "clip": 5,
    "critic_output_size": 1,
    "downward_hidden_sizes_v_critic": [512, 512, 128],
    "downward_hidden_sizes_xi_critic": [128, 128, 64],
    "feature_hidden_sizes": [128, 128, 64],
    "decoupled_critis_hidden_sizes": [128, 128, 64],
    "downward_lr": 1e-4,    
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
}


## Init everything

In [4]:

dataset = torch.load("data/ecog_data.pth")
trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=False)

wandb.init(project="getting-figures", config=config)


feature_network = SupervenientFeatureNetwork(
    config['num_atoms'],
    config['feature_size']
).to(device)

decoupled_critic = DecoupledCritic(
    config['feature_size']
).to(device)

downward_critics = [DownwardCritic(config['feature_size']).to(device) for _ in range(config['num_atoms'])]
downward_optims = [torch.optim.Adam(dc.parameters(), lr=config["downward_lr"]) for dc in downward_critics]

feature_optimizer = torch.optim.Adam(feature_network.parameters(), lr=config["feature_lr"])
decoupled_optimizer = torch.optim.Adam(decoupled_critic.parameters(), lr=config["decoupled_critic_lr"])



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdmcsharry[0m. Use [1m`wandb login --relogin`[0m to force relogin


TypeError: SupervenientFeatureNetwork.__init__() missing 1 required positional argument: 'hidden_sizes'

## Train things

In [5]:
feature_epochs = 10
critic_epochs = 4
total_epochs = feature_epochs + critic_epochs



for epoch in tqdm.tqdm(range(total_epochs), desc='Training'):
    for batch in trainloader:

        prepared_batch = prepare_batch(batch)

        x0 = prepared_batch[:, 0].to(device).float()
        x1 = prepared_batch[:, 1].to(device).float()

        # update decoupled critic

        v0 = feature_network(x0)
        v1 = feature_network(x1) 

        decoupled_MI = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])
        decoupled_loss = -decoupled_MI
        decoupled_optimizer.zero_grad()
        decoupled_loss.backward(retain_graph=True)
        decoupled_optimizer.step()

        # update each downward critic
        for i in range(config['num_atoms']):
            channel_i = x0[:, i].unsqueeze(1)
            downward_MI_i = estimate_mutual_information('smile', v1, channel_i, downward_critics[i], clip=config['clip'])
            downward_loss = -downward_MI_i
            downward_optims[i].zero_grad()
            downward_loss.backward(retain_graph=True)
            downward_optims[i].step()
            wandb.log({
                f"downward_MI_{i}": downward_MI_i
            })

        # update feature network   

        sum_downward_MI = 0

        for i in range(config['num_atoms']):
            channel_i = x0[:, i].unsqueeze(1)
            sum_downward_MI += estimate_mutual_information('smile', v1, channel_i, downward_critics[i], clip=config['clip'])

        decoupled_MI1 = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])

        Psi = decoupled_MI1 - sum_downward_MI
        feature_loss = -Psi

        if epoch >= feature_epochs:
            feature_optimizer.zero_grad()
            feature_loss.backward()
            feature_optimizer.step()

        wandb.log({
            "decoupled_MI": decoupled_MI1,
            "sum_downward_MI": sum_downward_MI,
            "Psi": Psi,
        })



        

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [13]:
# def test(config):
#     trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=False)


#     wandb.init(project="lots-of-downward-critics-ecog", config=config)


#     feature_network = SupervenientFeatureNetwork(
#         config['num_atoms'],
#         config['feature_size']
#     ).to(device)

#     decoupled_critic = PredSeparableCritic(
#         config['feature_size']
#     ).to(device)

#     downward_critics = [DownwardCritic(config['feature_size']).to(device) for _ in range(config['num_atoms'])]
#     downward_optims = [torch.optim.Adam(dc.parameters(), lr=config['critic_lr']) for dc in downward_critics]

#     feature_optimizer = torch.optim.Adam(feature_network.parameters(), lr=config['feature_lr'])
#     decoupled_optimizer = torch.optim.Adam(decoupled_critic.parameters(), lr=config['critic_lr'])



#     for epoch in tqdm.tqdm(range(config['epochs']), desc='Training'):
#         for batch in trainloader:
            
#             prepared_batch = prepare_batch(batch)

#             x0 = prepared_batch[:, 0].to(device).float()
#             x1 = prepared_batch[:, 1].to(device).float()


#             # update decoupled critic
            
#             v0 = feature_network(x0)
#             v1 = feature_network(x1) 

#             decoupled_MI = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])
#             decoupled_loss = -decoupled_MI
#             decoupled_optimizer.zero_grad()
#             decoupled_loss.backward(retain_graph=True)
#             decoupled_optimizer.step()

#             # update each downward critic

#             for i in range(config['num_atoms']):
#                 channel_i = x0[:, i].unsqueeze(1)
#                 downward_MI_i = estimate_mutual_information('smile', v1, channel_i, downward_critics[i], clip=config['clip'])
#                 downward_loss = -downward_MI_i
#                 downward_optims[i].zero_grad()
#                 downward_loss.backward(retain_graph=True)
#                 downward_optims[i].step()
#                 wandb.log({
#                     f"downward_MI_{i}": downward_MI_i
#                 })
            
#             # update feature network   

#             sum_downward_MI = 0

#             for i in range(config['num_atoms']):
#                 channel_i = x0[:, i].unsqueeze(1)
#                 sum_downward_MI += estimate_mutual_information('smile', v1, channel_i, downward_critics[i], clip=config['clip'])

#             decoupled_MI1 = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])

#             Psi = decoupled_MI1 - sum_downward_MI
#             feature_loss = -Psi

#             if epoch < config['epochs'] - 5:
#                 feature_optimizer.zero_grad()
#                 feature_loss.backward()
#                 feature_optimizer.step()


#             wandb.log({
#                 "decoupled_MI": decoupled_MI1,
#                 "downward_MI": sum_downward_MI,
#                 "Psi": Psi,
#             })


        
