In [0]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

In [2]:
pip install memory_profiler



In [0]:
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
P = "/content/gdrive/My Drive/Colab Notebooks/Modal_VAE/DataP.dat"
PL = "/content/gdrive/My Drive/Colab Notebooks/Modal_VAE/DataPL.dat"

In [0]:
class VAE(nn.Module):
    def __init__(self, channels, latent_size):
        super(VAE, self).__init__()
        

        self.ls = latent_size
        self.channels = channels

        #
        self.conv3d_1 = nn.Conv3d(channels, 24, (4,4,4), stride=3, padding=1)
        self.conv3d_2 = nn.Conv3d(24, 26, (4,4,4), stride=3, padding=1)
        self.conv3d_3 = nn.Conv3d(26, 28, (4,4,4), stride=2, padding=1)
        self.conv3d_4 = nn.Conv3d(28, 32, (4,4,4), stride=2, padding=1)
        #self.pool_1 = nn.MaxPool3d((4,4,4), stride = 3)

        #
        self.convTrans3d_1 = nn.ConvTranspose3d(32, 28, (4,4,4), stride=2, padding=1)
        self.convTrans3d_2 = nn.ConvTranspose3d(28, 26, (4,4,4), stride=2, padding=1, output_padding=1)
        self.convTrans3d_3 = nn.ConvTranspose3d(26, 24, (4,4,4), stride=2, padding=1)
        self.convTrans3d_4 = nn.ConvTranspose3d(24, channels, (4,4,4), stride=2, padding=1)
        
        #
        
        self.bn_1 = nn.BatchNorm3d(32)
        self.bn_2 = nn.BatchNorm3d(64)
        self.bn_3 = nn.BatchNorm3d(1)
        
        #
        self.dropOut = nn.Dropout3d(0.5)
        self.mu = nn.Linear(256, latent_size)
        self.logvar = nn.Linear(256, latent_size)
        self.p = nn.Linear(latent_size, 32*6*6*6)

        #
        self.softmin = nn.Softmin(1)
    
    def encode(self, in_data):

        h = self.conv3d_1(in_data.float())
        h = F.leaky_relu(h)
        h = self.dropOut(h)

        h = self.conv3d_2(h)
        #h = self.bn_1(h)
        h = F.leaky_relu(h)
        h = self.dropOut(h)

        h = self.conv3d_3(h)
        h = F.leaky_relu(h)
        h = self.dropOut(h)

        h = self.conv3d_4(h)


        h = h.reshape(h.shape[0], -1)
        mu = self.mu(h)
        log_var = self.logvar(h)
          
        #sampling and reparameterization for backprop
        z = torch.randn(mu.shape, device = dev)
        z = mu + torch.exp(log_var) * z
        return z, mu, log_var
    
    def decode(self, h_data):
        z = self.p(h_data)
        z = z.reshape(z.shape[0], 32, 6, 6, 6)

        z = self.convTrans3d_1(z)
        #z = self.bn_1(z)
        z = F.leaky_relu(z)
        z = self.dropOut(z)

        z = self.convTrans3d_2(z)
        #z = self.bn_1(z)
        z = F.leaky_relu(z)
        z = self.dropOut(z)

        z = self.convTrans3d_3(z)
        z = F.leaky_relu(z)
        z = self.dropOut(z)  
        z = self.convTrans3d_4(z)
        z = self.softmin(z)
        return z
                
    def forward(self, in_data):
        z, mu, log_var = self.encode(in_data)
        z = self.decode(z)
        return z, mu, log_var

In [0]:
model_vae = VAE(21, 64).to(dev)

In [0]:
#Kullback–Leibler divergence for N(mu, var^2) and N(0, I)
def kl_divergence(mu, log_var):
    return 0.5*( torch.sum( log_var.exp(), 1) + torch.sum(mu*mu, 1) - mu.shape[1] - torch.sum(log_var, 1) )

In [0]:
#(Additional loss + Kullback–Leibler divergence) / batch
def vae_loss(mu, log_var, pred, target):
    bs = pred.shape[0]
    kl_loss = kl_divergence(mu, log_var).mean()
    ad_loss = F.l1_loss(pred, target, reduction = 'sum') / bs
    return (kl_loss + ad_loss)

In [0]:
lr = 0.001
epochs = 100

In [0]:
loss_func = vae_loss
opt = torch.optim.Adam(model_vae.parameters(), lr)

In [0]:
def fit1(epochs, model, loss_func, opt, paths):
    part_size = 5
    total_size = 25
    valid_size = 5
    bs = 5
    i = 0 #for offset
    for epoch in range(epochs):
      while (i <= (total_size - valid_size) * 100**3):
        table_data = np.fromfile(paths[0],  dtype=np.int8, count=part_size * 100**3, offset = i)
        table_data = table_data.reshape(part_size,100,100,100)
        table_target = np.fromfile(paths[1],  dtype=np.int8, count=part_size * 100**3, offset = i)
        table_target = table_target.reshape(part_size,100,100,100)

        i += part_size * 100**3

        train_data = torch.tensor(table_data, device=dev)
        train_data = F.one_hot(train_data.long(), model.channels).to(dev)
        train_data = torch.transpose(train_data, 1,4)

        train_target = torch.tensor(table_target, dtype=torch.long, device=dev)
        train_target = F.one_hot(train_target.long(), model.channels).to(dev)
        train_target = torch.transpose(train_target, 1,4)

        train_ds = TensorDataset(train_data, train_target)
        del train_data, train_target
        train_dl = DataLoader(train_ds, bs)

  
        model = model.train()
        for data, target in train_dl:
            pred, mu, log_var = model(data)
            loss = loss_func(mu, log_var, pred, target)
            opt.zero_grad()
            loss.backward()
            opt.step()

      i = (total_size - valid_size) * 100**3
      while (i < total_size * 100**3):
        table_data = np.fromfile(paths[0],  dtype=np.int8, count=part_size * 100**3, offset = i)
        table_data = table_data.reshape(part_size,100,100,100)
        table_target = np.fromfile(paths[1],  dtype=np.int8, count=part_size * 100**3, offset = i)
        table_target = table_target.reshape(part_size,100,100,100)

        i += part_size * 100**3

        valid_data = torch.tensor(table_data, device=dev)
        valid_data = F.one_hot(valid_data.long(), model.channels).to(dev)
        valid_data = torch.transpose(valid_data, 1,4)

        valid_target = torch.tensor(table_target, dtype=torch.long, device=dev)
        valid_target = F.one_hot(valid_target.long(), model.channels).to(dev)
        valid_target = torch.transpose(valid_target, 1,4)

        valid_ds = TensorDataset(valid_data, valid_target)
        del valid_data, valid_target      
        valid_dl = DataLoader(valid_ds, bs)
      
        model = model.eval()
        with torch.no_grad():
            average_valid_loss = torch.zeros(1)
            #ad_loss = torch.zeros(1)
            num_batch = 0
            for data, target in valid_dl:
                pred, mu, log_var = model(data)
                average_valid_loss += loss_func(mu, log_var, pred, target)
                num_batch += 1
                #ad_loss += F.cross_entropy(pred, target, reduction='sum')
            average_valid_loss /= num_batch
            #ad_loss /= (bs * num_batch)
            #if (epoch % 100 == 1):
            print("Epoch: " + str(epoch) + "  Cross entropy + KL-divergence: " + str(average_valid_loss)) #+ "  Cross entropy: " + str(ad_loss))

      #del train_dl, valid_dl, train_ds, valid_ds, train_data, valid_data, train_target, valid_target

In [13]:
fit1(epochs, model_vae, loss_func, opt, (P, PL))

Epoch: 0  Cross entropy + KL-divergence: tensor([1899151.5000])
Epoch: 1  Cross entropy + KL-divergence: tensor([1899306.8750])
Epoch: 2  Cross entropy + KL-divergence: tensor([1899281.5000])
Epoch: 3  Cross entropy + KL-divergence: tensor([1899398.])
Epoch: 4  Cross entropy + KL-divergence: tensor([1899378.8750])
Epoch: 5  Cross entropy + KL-divergence: tensor([1899061.8750])
Epoch: 6  Cross entropy + KL-divergence: tensor([1899258.6250])
Epoch: 7  Cross entropy + KL-divergence: tensor([1899199.5000])
Epoch: 8  Cross entropy + KL-divergence: tensor([1899036.])
Epoch: 9  Cross entropy + KL-divergence: tensor([1898974.6250])
Epoch: 10  Cross entropy + KL-divergence: tensor([1899261.])
Epoch: 11  Cross entropy + KL-divergence: tensor([1899375.])
Epoch: 12  Cross entropy + KL-divergence: tensor([1899481.6250])
Epoch: 13  Cross entropy + KL-divergence: tensor([1899436.5000])
Epoch: 14  Cross entropy + KL-divergence: tensor([1899484.6250])
Epoch: 15  Cross entropy + KL-divergence: tensor([1