In [1]:
import torch
import lovely_tensors as lt
from einops import reduce, rearrange, repeat
from npeet.entropy_estimators import entropy, mi
import matplotlib.pyplot as plt
import wandb
import utils
import importlib
importlib.reload(utils)
import os
from utils import prepare_ecog_dataset, prepare_batch, estimate_MI_smile


lt.monkey_patch()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [3]:
import torch.nn as nn

class SupervenientFeatureNetwork(nn.Module):
    def __init__(self, num_atoms, feature_size):
        super(SupervenientFeatureNetwork, self).__init__()
        self.f = nn.Sequential(
            nn.Linear(num_atoms, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, feature_size),
        )

    def forward(self, x):
        return self.f(x)


class PredSeparableCritic(nn.Module):
    def __init__(self, feature_size):
        super(PredSeparableCritic, self).__init__()
        self.v_encoder = nn.Sequential(
            nn.Linear(feature_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 8),
        )

        self.W = nn.Linear(8, 8, bias=False)

    def forward(self, v0, v1):
        v0_encoded = self.v_encoder(v0)
        v1_encoded = self.v_encoder(v1)
        v1_encoded_transformed = self.W(v1_encoded)

        scores = torch.matmul(v0_encoded, v1_encoded_transformed.t())
        return scores
    

class DownwardCritic(nn.Module):
    def __init__(self, feature_size):
        super(DownwardCritic, self).__init__()
        self.v_encoder = nn.Sequential(
            nn.Linear(feature_size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 8),
        )

        self.atom_encoder = nn.Sequential(
            nn.Linear(1, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 8),
        )
    
    def forward(self, v1, x0i):
        v1_encoded = self.v_encoder(v1)
        x0i_encoded = self.atom_encoder(x0i)

        scores = torch.matmul(v1_encoded, x0i_encoded.t())
        return scores



In [4]:
config = {
    "batch_size": 1000,
    "num_atoms": 64,
    "feature_size": 4,
    "clip": 5,
}


## Init everything

In [5]:
from smile_estimator import estimate_mutual_information
import tqdm
dataset = torch.load("data/ecog_data.pth")

wandb.init(project="lots-of-downward-critics-ecog", config=config)

feature_network = SupervenientFeatureNetwork(
    config['num_atoms'],
    config['feature_size']
).to(device)

decoupled_critic = PredSeparableCritic(
    config['feature_size']
).to(device)

downward_critics = [DownwardCritic(config['feature_size']).to(device) for _ in range(config['num_atoms'])]
downward_optims = [torch.optim.Adam(dc.parameters(), lr=1e-5) for dc in downward_critics]

feature_optimizer = torch.optim.Adam(feature_network.parameters(), lr=1e-5)
decoupled_optimizer = torch.optim.Adam(decoupled_critic.parameters(), lr=1e-4)





Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdmcsharry[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Train things

In [6]:
feature_epochs = 10
critic_epochs = 4
total_epochs = feature_epochs + critic_epochs

trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=False)

for epoch in tqdm.tqdm(range(total_epochs), desc='Training'):
    for batch in trainloader:

        prepared_batch = prepare_batch(batch)

        x0 = prepared_batch[:, 0].to(device).float()
        x1 = prepared_batch[:, 1].to(device).float()

        # update decoupled critic

        v0 = feature_network(x0)
        v1 = feature_network(x1) 

        decoupled_MI = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])
        decoupled_loss = -decoupled_MI
        decoupled_optimizer.zero_grad()
        decoupled_loss.backward(retain_graph=True)
        decoupled_optimizer.step()

        # update each downward critic
        for i in range(config['num_atoms']):
            channel_i = x0[:, i].unsqueeze(1)
            downward_MI_i = estimate_mutual_information('smile', v1, channel_i, downward_critics[i], clip=config['clip'])
            downward_loss = -downward_MI_i
            downward_optims[i].zero_grad()
            downward_loss.backward(retain_graph=True)
            downward_optims[i].step()
            wandb.log({
                f"downward_MI_{i}": downward_MI_i
            })

        # update feature network   

        sum_downward_MI = 0

        for i in range(config['num_atoms']):
            channel_i = x0[:, i].unsqueeze(1)
            sum_downward_MI += estimate_mutual_information('smile', v1, channel_i, downward_critics[i], clip=config['clip'])

        decoupled_MI1 = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])

        Psi = decoupled_MI1 - sum_downward_MI
        feature_loss = -Psi

        if epoch >= feature_epochs:
            feature_optimizer.zero_grad()
            feature_loss.backward()
            feature_optimizer.step()

        wandb.log({
            "decoupled_MI": decoupled_MI1,
            "sum_downward_MI": sum_downward_MI,
            "Psi": Psi,
        })



        

Training: 100%|██████████| 14/14 [14:02<00:00, 60.15s/it]


In [13]:
def test(config):
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=False)


    wandb.init(project="lots-of-downward-critics-ecog", config=config)


    feature_network = SupervenientFeatureNetwork(
        config['num_atoms'],
        config['feature_size']
    ).to(device)

    decoupled_critic = PredSeparableCritic(
        config['feature_size']
    ).to(device)

    downward_critics = [DownwardCritic(config['feature_size']).to(device) for _ in range(config['num_atoms'])]
    downward_optims = [torch.optim.Adam(dc.parameters(), lr=config['critic_lr']) for dc in downward_critics]

    feature_optimizer = torch.optim.Adam(feature_network.parameters(), lr=config['feature_lr'])
    decoupled_optimizer = torch.optim.Adam(decoupled_critic.parameters(), lr=config['critic_lr'])



    for epoch in tqdm.tqdm(range(config['epochs']), desc='Training'):
        for batch in trainloader:
            
            prepared_batch = prepare_batch(batch)

            x0 = prepared_batch[:, 0].to(device).float()
            x1 = prepared_batch[:, 1].to(device).float()


            # update decoupled critic
            
            v0 = feature_network(x0)
            v1 = feature_network(x1) 

            decoupled_MI = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])
            decoupled_loss = -decoupled_MI
            decoupled_optimizer.zero_grad()
            decoupled_loss.backward(retain_graph=True)
            decoupled_optimizer.step()

            # update each downward critic

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                downward_MI_i = estimate_mutual_information('smile', v1, channel_i, downward_critics[i], clip=config['clip'])
                downward_loss = -downward_MI_i
                downward_optims[i].zero_grad()
                downward_loss.backward(retain_graph=True)
                downward_optims[i].step()
                wandb.log({
                    f"downward_MI_{i}": downward_MI_i
                })
            
            # update feature network   

            sum_downward_MI = 0

            for i in range(config['num_atoms']):
                channel_i = x0[:, i].unsqueeze(1)
                sum_downward_MI += estimate_mutual_information('smile', v1, channel_i, downward_critics[i], clip=config['clip'])

            decoupled_MI1 = estimate_mutual_information('smile', v0, v1, decoupled_critic, clip=config['clip'])

            Psi = decoupled_MI1 - sum_downward_MI
            feature_loss = -Psi

            if epoch < config['epochs'] - 5:
                feature_optimizer.zero_grad()
                feature_loss.backward()
                feature_optimizer.step()


            wandb.log({
                "decoupled_MI": decoupled_MI1,
                "downward_MI": sum_downward_MI,
                "Psi": Psi,
            })


        


In [14]:
configs = [
    {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": 1,
        "epochs": 20,
        "clip": 5,
        "critic_lr": 1e-4,
        "feature_lr": 1e-5,
    },
    {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": 2,
        "epochs": 20,
        "clip": 5,
        "critic_lr": 1e-4,
        "feature_lr": 1e-5,
    },
    {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": 4,
        "epochs": 20,
        "clip": 5,
        "critic_lr": 1e-4,
        "feature_lr": 1e-5,
    },
    {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": 8,
        "epochs": 20,
        "clip": 5,
        "critic_lr": 1e-4,
        "feature_lr": 1e-5,
    },
    {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": 16,
        "epochs": 20,
        "clip": 5,
        "critic_lr": 1e-4,
        "feature_lr": 1e-5,
    },
    {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": 4,
        "epochs": 20,
        "clip": 5,
        "critic_lr": 1e-3,
        "feature_lr": 1e-4,
    },
    {
        "batch_size": 1000,
        "num_atoms": 64,
        "feature_size": 8,
        "epochs": 20,
        "clip": 5,
        "critic_lr": 1e-3,
        "feature_lr": 1e-4,
    }
]

for config in configs:
    test(config)




0,1
Psi,▅▅▅▅▅▅▅▅▅▅▅▅▅▅▁▅▅▅▅▅▅▅▅▅▆▆▆▆▇▆▆█▆▆▆▆▆▆▆▆
decoupled_MI,▁▁▁▁▁▁▁▁▁▁▁▁▂▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁█▁▁▁▁▁▂▁▂
downward_MI,▄▄▄▄▄▄▄▄▄▄▄▄▄▄█▄▄▄▄▄▄▄▄▄▃▃▃▃▂▃▃▁▃▃▃▃▃▃▃▃

0,1
Psi,0.13671
decoupled_MI,0.0008
downward_MI,-0.13591


Training: 100%|██████████| 20/20 [26:19<00:00, 78.96s/it]




0,1
Psi,██████████████████████████████▇▇▇▆▂▅▇▅▅▁
decoupled_MI,▁▁▁▂▂▃▂▄▃▃▂▃▅▄▃▄▃▄▃▄▄▄▅▃▄▄▄▄▅▃▆▄▄▂▁▅▄▅▆█
downward_MI,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▆▄▂▄▄█
downward_MI_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▁▁▁▃█▃▂▄▄█
downward_MI_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▇▃▂▄▃█
downward_MI_10,▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂▁▁▁▁▁▁▁▆█▃▂▄▃█
downward_MI_11,▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄█▂▂▃▃▆
downward_MI_12,▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▆▂▂▃▃█
downward_MI_13,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇█▂▂▃▃█
downward_MI_14,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▇▂▁▃▃█

0,1
Psi,-6.0941
decoupled_MI,0.26766
downward_MI,6.36176
downward_MI_0,0.06868
downward_MI_1,0.06699
downward_MI_10,0.08066
downward_MI_11,0.07315
downward_MI_12,0.06568
downward_MI_13,0.07
downward_MI_14,0.0659


Training: 100%|██████████| 20/20 [26:22<00:00, 79.11s/it]




0,1
Psi,█▇▇███████████████████████▇██▇██▇▁▃▇▇█▇▂
decoupled_MI,▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▄▅▅▁▅▅▂▅▄▃▄█▅▅▅▅▅
downward_MI,▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂█▆▂▂▁▂▇
downward_MI_0,▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▅▁▁▁▁█
downward_MI_1,▁▂▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂█▅▂▃▁▃▇
downward_MI_10,▁▂▂▁▁▁▁▁▁▁▁▁▂▁▁▂▁▁▂▁▁▂▁▂▁▂▁▂▁▂▁▂▃▅▄▂▃▁▃█
downward_MI_11,▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂█▆▂▂▁▂▅
downward_MI_12,▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▁▂▂▆▅▂▃▂▃█
downward_MI_13,▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▁▂▃█▆▂▃▂▃▇
downward_MI_14,▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▃▂▃▃▇▅▃▄▂▄█

0,1
Psi,-0.20353
decoupled_MI,0.26136
downward_MI,0.46489
downward_MI_0,0.00249
downward_MI_1,0.04153
downward_MI_10,0.04181
downward_MI_11,0.02698
downward_MI_12,0.02978
downward_MI_13,0.02809
downward_MI_14,0.05396


Training: 100%|██████████| 20/20 [26:21<00:00, 79.07s/it]




0,1
Psi,█▇▇█████████████████████████████▇▆▅▆▇▅▄▁
decoupled_MI,▅▅▅▅▅▅▅▅▅▅▅▅▆▆▆█▇▇▇▇▇▇▇▇▇▇██▇▆▇▇▆▁▂▇▇██▅
downward_MI,▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▄▃▂▄▅█
downward_MI_0,▁▂▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▄▃▅▆█
downward_MI_1,▂▂▄▂▂▂▂▁▁▁▁▁▂▂▁▂▂▂▂▂▁▂▂▁▂▂▁▁▂▂▂▂▂▃▄▅▄▇█▇
downward_MI_10,▁▂▃▂▁▂▂▁▁▁▁▂▂▂▁▂▂▂▂▂▁▂▂▁▂▂▁▂▁▂▂▂▂▄▆▃▃▇▇█
downward_MI_11,▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▅█
downward_MI_12,▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▃▂▂▂▃▂▂▃▂▂▃▁▂▂▂▂▃▃▂▁▅▅▇█▇
downward_MI_13,▂▂▃▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▃▁▂▂▁▂▂▁▁▂▂▃▃▃▄▃▅▄▆▇█
downward_MI_14,▄▅▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▄▄▃▄▄▃▄▄▄▄▄▄▂▁▆▆▇█▇

0,1
Psi,-9.67424
decoupled_MI,0.19662
downward_MI,9.87086
downward_MI_0,0.18054
downward_MI_1,0.17981
downward_MI_10,0.18182
downward_MI_11,0.18222
downward_MI_12,0.17968
downward_MI_13,0.1914
downward_MI_14,0.18497


Training: 100%|██████████| 20/20 [24:26<00:00, 73.34s/it]




0,1
Psi,█▇▇██████████████████████████████▂▁▇▇▅▅▂
decoupled_MI,▄▄▄▄▄▄▄▅▄▆▅▆▆▅▆▇▇▇▆▇▆▇▇▆▇▆▆▇█▆█▇▇▂▁▇█▇▇█
downward_MI,▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇█▂▂▄▄█
downward_MI_0,▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▇▁▁▂▂▂
downward_MI_1,▁▂▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁█▇▁▁▂▂▄
downward_MI_10,▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▄▄▄▄▃▄▄▄▄▄▄▄▄▄▄▄▆█▅▄▆▇▁
downward_MI_11,▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▅▂▁▃▃█
downward_MI_12,▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▅▁▁▂▂█
downward_MI_13,▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▅▁▁▂▂█
downward_MI_14,▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██▂▂▄▅▆

0,1
Psi,-3.91715
decoupled_MI,0.17174
downward_MI,4.08889
downward_MI_0,0.03997
downward_MI_1,0.0143
downward_MI_10,0.09741
downward_MI_11,0.12134
downward_MI_12,0.06343
downward_MI_13,0.05069
downward_MI_14,0.1781


Training: 100%|██████████| 20/20 [20:37<00:00, 61.90s/it]




0,1
Psi,█▇▇█████████████████████████████▇▂▁▆▆▄▃▅
decoupled_MI,▁▁▁▁▁▁▂▄▃▅▅▅▆▅▆▆▆▇▆▇▇▆▇▆▅▇▇▇▇▅▇▆▆▁▄▇█▇▇█
downward_MI,▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▇█▃▃▅▆▄
downward_MI_0,▃▄▄▃▃▃▃▃▃▄▄▃▃▃▃▃▄▄▄▄▄▃▄▄▄▄▄▄▃▄▄▄▄▆█▅▄▆▆▁
downward_MI_1,▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆█▂▂▄▄█
downward_MI_10,▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▇█▂▃▄▅▆
downward_MI_11,▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆█▂▂▄▄▄
downward_MI_12,▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▄██▄▄▆▆▁
downward_MI_13,▁▂▂▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▂▂▂▂▂▂▃▁▁▅▅▆▇█
downward_MI_14,▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆█▆▇▇██▁

0,1
Psi,-10.15739
decoupled_MI,0.1956
downward_MI,10.35299
downward_MI_0,0.22229
downward_MI_1,0.1756
downward_MI_10,0.20224
downward_MI_11,0.20757
downward_MI_12,0.21206
downward_MI_13,0.23667
downward_MI_14,0.21526


Training: 100%|██████████| 20/20 [20:37<00:00, 61.89s/it]




0,1
Psi,▆█▃▅▆▆▆▆▆▆▆▇▅▇▅▆▅▅▆▆▅▆▆▆▆▆▆▇▆▆▅▆▆▂▂▄▅▄▄▁
decoupled_MI,▁▂▅▃▃▃▂▃▂▂▃▃▅▃▅▄▃▅▃▄▇▄▃▃▄▄▅▅▅▆▆▃▄█▇▅▄▅▅█
downward_MI,▃▁▆▄▃▃▃▃▃▃▃▂▄▂▄▃▄▄▃▃▄▃▃▃▃▃▃▃▃▃▄▃▃▇▇▅▄▅▅█
downward_MI_0,▅▁▆▅▅▅▅▅▅▅▆▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▇▆▆▅▆▆█
downward_MI_1,▂▂▄▃▂▃▂▂▂▃▃▁▄▂▃▂▃▃▂▃▃▃▂▂▂▂▃▂▂▃▃▃▃▆▄▄▃▄▄█
downward_MI_10,▃▁▄▄▃▃▃▃▃▃▃▂▅▃▄▄▄▃▃▃▃▃▄▃▃▃▃▃▃▃▄▃▃▆█▄▃▅▅█
downward_MI_11,▂▁▅▃▂▃▂▂▂▃▃▂▂▂▃▃▃▃▂▂▄▂▂▂▂▂▃▂▂▂▃▃▃▆▅▃▃▄▄█
downward_MI_12,▃▁▇▄▃▄▃▃▃▃▃▃▄▂▃▃▄▃▃▃▅▃▃▃▃▃▃▃▃▃▄▃▃▅█▄▃▅▅█
downward_MI_13,▄▁▆▄▄▄▄▄▄▄▄▃▄▄▄▄▄▄▄▄▅▄▄▄▄▄▄▄▄▄▄▄▄▇▅▄▄▅▅█
downward_MI_14,▂▂▅▃▂▃▂▂▂▃▃▁▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▂▂▅▅▃▃▄▄█

0,1
Psi,-13.83851
decoupled_MI,0.3724
downward_MI,14.21091
downward_MI_0,0.14782
downward_MI_1,0.19881
downward_MI_10,0.16701
downward_MI_11,0.199
downward_MI_12,0.15471
downward_MI_13,0.17229
downward_MI_14,0.18529


Training: 100%|██████████| 20/20 [20:31<00:00, 61.57s/it]
