In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/MI_dim_reduction

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/MI_dim_reduction


In [3]:
pwd

'/content/drive/MyDrive/MI_dim_reduction'

In [4]:
%pylab inline
import numpy as np
import pandas as pd
import os


import torch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import transforms
import numpy as np
import scipy.ndimage as ndi
import matplotlib.gridspec as gridspec

import math
import os
from sklearn.preprocessing import MinMaxScaler
from torch.distributions import MultivariateNormal
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


'cuda'

In [5]:
!pip install pynvml



In [6]:
#pip install matplotlib #scikit-learn

In [7]:
import pynvml

pynvml.nvmlInit()

handle = pynvml.nvmlDeviceGetHandleByIndex(0)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)

print(f"Total memory: {info.total/1024**2:.2f} MB")
print(f"Used memory : {info.used/1024**2:.2f} MB")
print(f"Free memory : {info.free/1024**2:.2f} MB")

Total memory: 15360.00 MB
Used memory : 260.94 MB
Free memory : 15099.06 MB


In [8]:
#pip install pytorch_lightning==2.1

In [9]:
! pip install --quiet "ipython[notebook]>=8.0.0, <8.12.0" "numpy==1.26.4" "torch==2.2.0" "setuptools==67.7.2" "torchmetrics>=0.7, <0.12" "torchvision" "pytorch-lightning>=1.4, <2.0.0" "lightning>=2.0.0rc0"

In [10]:
from models.mine import T, T_real, Mine, Mine_real


Device: cuda


In [11]:
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import itertools
import torch
import torch.nn as nn
import numpy as np
import abc
from torch.nn import functional as F

#from mine.datasets import load_dataloader
from models.mine import T, T_real, Mine, Mine_real
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint



# Encoder and Decoder models
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim),
            nn.Tanh()  # Added Tanh activation to the last layer
        )

    def forward(self, x):
        return self.fc(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim, input_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, z):
        return self.fc(z)

# Mutual Information Neural Estimator (MINE)
class MINE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(MINE, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim + latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x, z):
        return self.fc(torch.cat([x, z], dim=1))

# Training Function

class train_autoencoder_with_mi(pl.LightningModule):
    def __init__(self,  dataset, latent_dim, input_dim,mi_estimator_real, batchsize=16, lr=1e-3, beta=1):
        super().__init__()
        #self.device = device
        self.dataset = dataset
        self.batchsize=batchsize

        self.latent_dim = latent_dim
        self.save_hyperparameters()
        self.automatic_optimization = False

        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(latent_dim, input_dim)
        self.mi_estimator_real=mi_estimator_real

        self.beta = beta
        self.lr = lr

        # Loss function
        self.reconstruction_loss_fn = nn.MSELoss()
        self.loss = nn.BCELoss()

    def configure_optimizers(self):
        opt_g = torch.optim.Adam(itertools.chain(self.encoder.parameters(
        ),  self.encoder.parameters(), self.mi_estimator_real.parameters()), lr=self.lr)

        #opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr)

        return [opt_g]

    def training_step(self, batch, batch_idx):



        # Active one bellow line for gaussians, and spiral
        #x_real, _ = batch

        # Active one bellow line Only for Coil
        x_real = batch
        x_real=torch.tensor(x_real, dtype=torch.float32).to(device)

        #if batch_idx % (3750) == 0:
            #if self.trainer.current_epoch % 1 == 0:
                #print(f"Epoch {self.trainer.current_epoch}")
                #self.plot_img(x_real)
                #plt.show()

        #~~~~ Produce Posterior
        real_np=x_real.cpu().numpy().reshape(x_real.shape[0],x_real.shape[1])
        real_np=torch.tensor(real_np, dtype=torch.float32).to(device)


        if self.on_gpu:
            real_np = real_np#.cuda()#.to(self.device)

        optimizer_ed= self.optimizers()  #optimizer_ed, optimizer_d = self.optimizers()




        # Encoder - Decoder
        self.toggle_optimizer(optimizer_ed,0)

        # Encoder
        self.encoded = self.encoder(real_np)
        self.decoded=self.decoder(self.encoded)

        # MI
        j=torch.ones(( real_np.shape[0],1))
        mi_loss_real=self.mi_estimator_real(real_np.reshape(real_np.shape[0],
                                                                    real_np.shape[1]), self.encoded,j)

        loss_ED=self.reconstruction_loss_fn(real_np,self.decoded)
        loss = loss_ED + self.beta * mi_loss_real

        self.log("Loss", loss, prog_bar=True)
        self.manual_backward(loss)
        optimizer_ed.step()
        optimizer_ed.zero_grad()
        self.untoggle_optimizer(optimizer_ed)

    def train_dataloader(self):
        return load_dataloader(self.dataset, self.batchsize)  # batch_size is 16 in Covid-19 data


def load_dataloader(dataset, batch_size):

    train_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True)

    return train_loader

In [12]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
warnings.filterwarnings("ignore", category=UserWarning)

#%load_ext autoreload
#%autoreload 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'

batchsize=16
input_dim = 30
latent_dim=10
dataset = np.random.rand(100, input_dim)


epochs = 5000
lr = 1e-4
beta = 1.0

statistics_network = T(input_dim, latent_dim).to(device)
statistics_network_real = T_real(input_dim, latent_dim).to(device)
mi_estimator_real = Mine_real(statistics_network_real,loss='mine_biased').to(device) # 'mine_biased'

model = train_autoencoder_with_mi(dataset, latent_dim, input_dim, mi_estimator_real,batchsize, lr, beta=1)
#############################
# Load saved model state dict
#############################
#~ Active this section only when to start training
#~ from latest saved checkpoint
#ckpt_path = save_path
#state_dict = torch.load(ckpt_path)
#model.load_state_dict(state_dict['state_dict'])
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Define the checkpoint callback T

checkpoint_callback = ModelCheckpoint(
    #dirpath='checkpoints',
    #filename='epoch={epoch}-step={step}.ckpt',
    save_top_k=-1,  # Save all checkpoints
    every_n_epochs=10)  # Save every 250 epochs
    #every_n_train_steps=25)  # Save every 25 steps



trainer = Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=epochs,
    callbacks=[checkpoint_callback])

trainer.fit(model)
model=model.cuda()
#model.plot_img()



INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                   | Type      | Params
-----------------------------------------------------
0 | encoder                | Encoder   | 12.9 K
1 | decoder                | Decoder   | 12.9 K
2 | mi_estimator_real      | Mine_real | 337 K 
3 | reconstruction_loss_fn | MSELoss   | 0     
4 | loss                   | BCELoss   | 0     
-----------------------------------------------------
363 K     Trainable params
0         Non-trainable params
363 K     Total params
1.453     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5000` reached.


# Providing latent variables Z in the code space.

In [13]:
dataset = torch.tensor(dataset, dtype=torch.float32).to(device)


# Note that Z represents the information bottleneck.
z=model.encoder(dataset)
z.shape




torch.Size([100, 10])

In [14]:
z

tensor([[-0.1847,  0.7545, -0.0255,  0.8533,  0.7990, -0.2958,  0.2425, -0.4816,
          0.2867, -0.7718],
        [-0.0182, -0.7187, -0.9162,  0.6112,  0.0914,  0.1578, -0.3172,  0.0565,
         -0.3359,  0.4568],
        [-0.2603,  0.7151,  0.7798, -0.7766, -0.5457,  0.4668,  0.8220, -0.4376,
         -0.0019, -0.0507],
        [-0.5305, -0.7485, -0.7766, -0.6160, -0.8040, -0.0783,  0.1439, -0.4136,
         -0.6561, -0.3489],
        [-0.4983,  0.0095,  0.2551,  0.5752,  0.8008, -0.1823, -0.8882,  0.8171,
         -0.4879,  0.1178],
        [ 0.5132, -0.1282, -0.7285,  0.1815,  0.3787,  0.5785, -0.5970, -0.7137,
         -0.6664, -0.2037],
        [-0.7706, -0.9260, -0.1852, -0.2273,  0.4207, -0.3290, -0.1416,  0.4638,
          0.4336, -0.4978],
        [-0.5233,  0.4132,  0.4026, -0.7450, -0.0940,  0.8234,  0.4945, -0.6663,
          0.8060,  0.0425],
        [-0.3922, -0.1671, -0.8019,  0.0794,  0.0905,  0.9276, -0.2214,  0.0660,
         -0.8610,  0.3744],
        [ 0.8292, -