In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import os
import numpy as np
import h5py

from models.VAE3D import VAE3D,vae_loss,mse_loss,kld_loss

os.environ["CUDA_VISIBLE_DEVICES"] = '1,2'
torch_device = torch.device("cuda")

# Initialize the VAE and optimizer
vae = VAE3D(input_channels=1, latent_dim=8096)
vae = nn.DataParallel(vae.to(torch_device))
optimizer = optim.Adam(vae.parameters(), lr=0.0005)


In [13]:

class Loss_PCC(torch.nn.Module): ## works properly for isotropic 2D & if 3D it's summed over all z
    def __init__(self, eps = 1e-8, torch_device=None):
        super(Loss_PCC, self).__init__()
        self.torch_device = torch_device
        self.eps = eps

    def forward(self, img1, img2):
        tup_dim = tuple([i for i in range(2,len(img1.shape))])
        mu1 = torch.mean(img1, tup_dim)
        mu2 = torch.mean(img2, tup_dim)
        sigma1 = torch.std(img1, tup_dim)
        sigma2 = torch.std(img2, tup_dim)
        
        for i in range(2,len(img1.shape)):
            mu1 = mu1.unsqueeze(i)
            mu2 = mu2.unsqueeze(i)
            sigma1 = sigma1.unsqueeze(i)
            sigma2 = sigma2.unsqueeze(i)

        mu1 = mu1.repeat(1,1,*img1.shape[2:])
        mu2 = mu2.repeat(1,1,*img1.shape[2:])
        sigma1 = sigma1.repeat(1,1,*img1.shape[2:])
        sigma2 = sigma2.repeat(1,1,*img1.shape[2:])

        img1_ = (img1-mu1)/(sigma1+self.eps)
        img2_ = (img2-mu2)/(sigma2+self.eps)
        
        PCC = img1_*img2_
        return 1-PCC.mean()
pcc_loss = Loss_PCC(eps = 1e-6, torch_device = torch_device)


In [14]:
# Sample usage
path_input = '/data02/gkim/stem_cell_jwshin/data/230502_TCF/00_train/H9_untreated/230425.142330.H9_untreated.003.Group1.A1.S003/230425.142330.H9_untreated.003.Group1.A1.S003.TCF'
input_file = h5py.File(path_input, 'r') # 220801 for -v7.3 mat files
input_data = input_file['/Data/3D/000000']
input_data = input_data[17:49,872:1256,872:1256].astype(np.float16)/10000

cap_min = 1.33
cap_max = 1.4
input_data = (input_data - cap_min)/(cap_max-cap_min)
input_data[input_data>1.0] = 1.0

input_data = np.transpose(input_data, (2,1,0))
input_data = np.expand_dims(input_data, axis=0)
input_data = np.expand_dims(input_data, axis=0)
input_data = torch.Tensor(input_data)
print(input_data.shape)

input_data = input_data.to(torch_device)
#input_data = torch.rand(3, 1, 384, 384, 32).to(torch_device)  # Replace with your actual data


torch.Size([1, 1, 384, 384, 32])


In [15]:
for ii in range(0,10000):
    #import pdb; pdb.set_trace()
    optimizer.zero_grad()
    recon, mu, logvar = vae(input_data)
    #loss = vae_loss(recon, input_data, mu, logvar).to(torch_device)
    mse = mse_loss(recon, input_data)
    kld = kld_loss(mu, logvar)
    pcc = pcc_loss(recon, input_data)
    if ii%100 == 0:
        str_print = f"epoch number {ii}: MSE= {mse}, PCC = {pcc}, KLD = {kld}"
        print(str_print)
    loss = mse+kld+pcc
    loss.backward()
    optimizer.step()

epoch number 0: MSE= 0.11495178192853928, PCC = 0.9996468424797058, KLD = 0.39124250411987305
epoch number 100: MSE= 0.10603643208742142, PCC = 0.5546607971191406, KLD = 0.0033164024353027344
epoch number 200: MSE= 0.09663715958595276, PCC = 1.0172041654586792, KLD = 0.014543116092681885
epoch number 300: MSE= 0.09080241620540619, PCC = 0.6466313004493713, KLD = 0.015829771757125854
epoch number 400: MSE= 0.07726911455392838, PCC = 0.44301706552505493, KLD = 0.02111157774925232
epoch number 500: MSE= 0.06503959745168686, PCC = 0.6870583295822144, KLD = 0.003907591104507446
epoch number 600: MSE= 0.04498858377337456, PCC = 0.3366681933403015, KLD = 0.011818140745162964
epoch number 700: MSE= 0.024553507566452026, PCC = 0.37849462032318115, KLD = 0.00888088345527649
epoch number 800: MSE= 0.012108737602829933, PCC = 0.25339239835739136, KLD = 0.0632685124874115
epoch number 900: MSE= 0.009580385871231556, PCC = 0.5711879730224609, KLD = 0.033891141414642334
epoch number 1000: MSE= 0.0148

In [16]:
recon_save = recon.detach().cpu().numpy()
recon_save = np.squeeze(recon_save)
input_save = input_data.cpu().numpy()
input_save = np.squeeze(input_save)

path_output = '/data02/gkim/stem_cell_jwshin/230425.142330.H9_untreated.003.Group1.A1.S003_VAEout_v7.h5'
output_file = h5py.File(path_output, 'w') #

In [17]:
input_save.shape

(384, 384, 32)

In [18]:


output_file.create_dataset('output',
                           shape = recon_save.shape,
                           maxshape = (None, None, None),
                           data = recon_save)
#output_file['output'][...] = recon_save
output_file.create_dataset('input',
                           shape = input_save.shape,
                           maxshape = (None, None, None),
                           data=input_save)
#output_file['input'][...] = input_save
output_file.close()