In [1]:
import torch
import lovely_tensors as lt
from einops import reduce, rearrange, repeat
from npeet.entropy_estimators import entropy, mi
import matplotlib.pyplot as plt
import wandb
import utils
import importlib
importlib.reload(utils)
import os
from utils import prepare_ecog_dataset, prepare_batch, estimate_MI_smile


lt.monkey_patch()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [2]:
# print(os.getcwd())

# prepare_ecog_dataset()


In [3]:
import torch.nn as nn

class SupervenientFeatureNetwork(nn.Module):
    def __init__(self, num_atoms, feature_size):
        super(SupervenientFeatureNetwork, self).__init__()
        self.f = nn.Sequential(
            nn.Linear(num_atoms, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, feature_size)
        )

    def forward(self, x):
        return self.f(x)


class PredSeparableCritic(nn.Module):
    def __init__(self, feature_size):
        super(PredSeparableCritic, self).__init__()
        self.v_encoder = nn.Sequential(
            nn.Linear(feature_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 8),
        )

        self.W = nn.Linear(8, 8, bias=False)

    def forward(self, v0, v1):
        v0_encoded = self.v_encoder(v0)
        v1_encoded = self.v_encoder(v1)
        v1_encoded_transformed = self.W(v1_encoded)

        scores = torch.matmul(v0_encoded, v1_encoded_transformed.t())
        return scores
    

class MarginalSeparableCritic(nn.Module):
    def __init__(self, feature_size, num_atoms):
        super(MarginalSeparableCritic, self).__init__()
        self.feature_encoder = nn.Sequential(
            nn.Linear(feature_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 8),
        )

        self.atom_encoder = nn.Sequential(
            nn.Linear(num_atoms+1, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 8),
        )

    def forward(self, x0i, v1):
        feature_encoded = self.feature_encoder(v1)
        atom_encoded = self.atom_encoder(x0i)
        scores = torch.matmul(atom_encoded, feature_encoded.t())

        return scores




In [4]:
config = {
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 4,
    "epochs": 20,
    "clip": 5,
}


def _add_one_hot(X0):
    batch_len, num_features = X0.size()
    eye = torch.eye(num_features).to(device) # f * f
    eye_repeated = repeat(eye, 'f1 f2 -> b f1 f2', b=batch_len)
    X0_unsqueezed = rearrange(X0, 'b f -> b f 1')
    return torch.cat((X0_unsqueezed, eye_repeated), dim=2)



In [6]:
from smile_estimator import estimate_mutual_information
from CLUB_estimation import CLUB, CLUBSample
dataset = torch.load("data/ecog_data.pth")

trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=False)

wandb.init(project="learning_features", config=config)


feature_network = SupervenientFeatureNetwork(
    config['num_atoms'],
    config['feature_size']
).to(device)

decoupled_critic = PredSeparableCritic(
    config['feature_size']
).to(device)

downward_critic = MarginalSeparableCritic(
    config['feature_size'],
    config['num_atoms']
).to(device)

feature_optimizer = torch.optim.Adam(feature_network.parameters(), lr=1e-4)
decoupled_optimizer = torch.optim.Adam(decoupled_critic.parameters(), lr=1e-3)
downward_optimizer = torch.optim.Adam(downward_critic.parameters(), lr=1e-3)

for epoch in range(config['epochs']):
    for batch in trainloader:

        prepared_batch = prepare_batch(batch)

        x0 = prepared_batch[:, 0].to(device).float()
        x1 = prepared_batch[:, 1].to(device).float()

        x0_one_hot = _add_one_hot(x0)

        v0 = feature_network(x0)
        v1 = feature_network(x1)        

        decoupled_MI = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])
        decoupled_loss = -decoupled_MI
        decoupled_optimizer.zero_grad()
        decoupled_loss.backward(retain_graph=True)
        decoupled_optimizer.step()

        downward_MI = 0

        for i in range(config['num_atoms']):
            x0i = x0_one_hot[:, i]
            downward_MI += estimate_mutual_information('smile', x0i, v1, downward_critic, clip=config['clip'])
        
        downward_loss = -downward_MI
        downward_optimizer.zero_grad()
        downward_loss.backward(retain_graph=True)
        downward_optimizer.step()



        downward_MI1 = 0

        for i in range(config['num_atoms']):
            x0i = x0_one_hot[:, i]
            downward_MI1 += estimate_mutual_information('smile', x0i, v1, downward_critic, clip=config['clip'])

        decoupled_MI1 = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])

        Psi = decoupled_MI1 - downward_MI1
        feature_loss = -Psi

        if epoch < 15:
            feature_optimizer.zero_grad()
            feature_loss.backward()
            feature_optimizer.step()


        wandb.log({
            "decoupled_MI": decoupled_MI,
            "downward_MI": downward_MI,
            "Psi": Psi,
        })


        



0,1
Psi,▅▆▅▄█▆▁▆▆▆▆▆▆▆▅▅▇▄▆▆▃▃▄▆▆▅▆▆▆▆▆▆▆▆▆▆▆▅▅▆
decoupled_MI,▁▁▁▂▁▁▆▁▁▁▁▁▁▂▁▁▁▂▂▁▂▁▂▂▃▂▂▂▂▂▂▃▄▄▆▇▇▅█▅
downward_MI,▅▅▅▃▁▅█▅▅▅▅▅▅▅▆▅▄▆▅▅▇▇▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▆▆▅

0,1
Psi,-0.05038
decoupled_MI,0.20323
downward_MI,0.28709


KeyboardInterrupt: 

In [8]:
from kraskov_mi import pyMIestimator
import numpy as np


dataset = torch.load("data/ecog_data.pth")

print(dataset)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=5000, shuffle=False)
for k in range(1, 100, 10):
    for batch in trainloader:
        prepared_batch = prepare_batch(batch)
        x0 = prepared_batch[:, 0].cpu().float()
        x1 = prepared_batch[:, 1].cpu().float()

        MI = pyMIestimator(x0,x1,k=k, base=np.exp(1))
        MI1 = mi(x0, x1,k=k, base=np.exp(1))
        print(k)
        print(MI)
        print(MI1)


tensor[348610, 64] f64 n=22311040 (0.2Gb) x∈[-28.132, 152.364] μ=-5.778e-19 σ=1.000
1
8.094108812976406
3.8493352185942014
1
8.094108812976406
3.791306389623551
1
8.094108812976406
3.7161849739248227
1
8.094108812976406
4.110254153112709
1
8.094108812976406
3.7519440413267198
1
8.094108812976406
3.827022485377798
1
8.094108812976406
3.7923111962897353
1
8.094108812976406
3.6534180672774683
1
8.094108812976406
3.3767617474776896
1
8.094108812976406
3.6058965205532854
1
8.094108812976406
3.6716202906270903
1
8.094108812976406
3.338604675875632
1
8.094108812976406
3.6928727287680223
1
8.094108812976406
3.7626600226455382
1
8.094108812976406
3.7917848785509074
1
8.094108812976406
3.1651060959154718
1
8.094108812976406
3.670732018314303
1
8.094108812976406
3.589695615818462
1
8.094108812976406
3.9656264190586406
1
8.094108812976406
3.7944694141293
1
8.094108812976406
3.5995884812630603
1
8.094108812976406
4.065993929589507
1
8.094108812976406
3.591091124229494
1
8.094108812976406
3.51974508

In [22]:
dataset = torch.load("data/ecog_data.pth")

trainloader = torch.utils.data.DataLoader(dataset, batch_size=5000, shuffle=False)
with torch.no_grad():
    for batch in trainloader:
        prepared_batch = prepare_batch(batch)
        x0 = prepared_batch[:, 0].cpu().float()
        x1 = prepared_batch[:, 1].cpu().float()

        v0 = feature_network(x0.to(device))
        v1 = feature_network(x1.to(device))

        print(mi(x0.cpu().numpy(), x1.cpu().numpy(), k=10, base=np.exp(1)))
        print(mi(v0.cpu().numpy(), v1.cpu().numpy(), k=10, base=np.exp(1)))

        mi_cum = 0
        for i in range(config['num_atoms']):
            x0i = x0[:, i].unsqueeze(1)
            mi_cum += mi(x0i.cpu().numpy(), v1.cpu().numpy(), k=10, base=np.exp(1))

        print(mi_cum)
        break

2.711603863009925
0.9449761234181775
2.581075344555364
