In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import os
import numpy as np
import h5py
from models.VAE3D import VAE3D,vae_loss,mse_loss,kld_loss
# torch.backends.cudnn.benchmark = True

# example for mnist
from datas.Patch3DLoader import Patch3DLoader

import utils
from datas.preprocess3d import TRAIN_AUGS_3D, TEST_AUGS_3D, TRAIN_AUGS_25D, TEST_AUGS_25D

from Logger import Logger

In [None]:

class Loss_PCC(torch.nn.Module): ## works properly for isotropic 2D & if 3D it's summed over all z
    def __init__(self, eps = 1e-8, torch_device=None):
        super(Loss_PCC, self).__init__()
        self.torch_device = torch_device
        self.eps = eps

    def forward(self, img1, img2):
        tup_dim = tuple([i for i in range(2,len(img1.shape))])
        mu1 = torch.mean(img1, tup_dim)
        mu2 = torch.mean(img2, tup_dim)
        sigma1 = torch.std(img1, tup_dim)
        sigma2 = torch.std(img2, tup_dim)
        
        for i in range(2,len(img1.shape)):
            mu1 = mu1.unsqueeze(i)
            mu2 = mu2.unsqueeze(i)
            sigma1 = sigma1.unsqueeze(i)
            sigma2 = sigma2.unsqueeze(i)

        mu1 = mu1.repeat(1,1,*img1.shape[2:])
        mu2 = mu2.repeat(1,1,*img1.shape[2:])
        sigma1 = sigma1.repeat(1,1,*img1.shape[2:])
        sigma2 = sigma2.repeat(1,1,*img1.shape[2:])

        img1_ = (img1-mu1)/(sigma1+self.eps)
        img2_ = (img2-mu2)/(sigma2+self.eps)
        
        PCC = img1_*img2_
        return 1-PCC.mean()
pcc_loss = Loss_PCC(eps = 1e-6, torch_device = torch_device)


In [None]:
# Sample usage
path_input = '/data02/gkim/stem_cell_jwshin/data/230502_TCF/00_train/H9_untreated/230425.142330.H9_untreated.003.Group1.A1.S003/230425.142330.H9_untreated.003.Group1.A1.S003.TCF'
input_file = h5py.File(path_input, 'r') # 220801 for -v7.3 mat files
input_data = input_file['/Data/3D/000000']
input_data = input_data[17:49,872:1256,872:1256].astype(np.float16)/10000

cap_min = 1.33
cap_max = 1.4
input_data = (input_data - cap_min)/(cap_max-cap_min)
input_data[input_data>1.0] = 1.0

input_data = np.transpose(input_data, (2,1,0))
input_data = np.expand_dims(input_data, axis=0)
input_data = np.expand_dims(input_data, axis=0)
input_data = torch.Tensor(input_data)
print(input_data.shape)

input_data = input_data.to(torch_device)
#input_data = torch.rand(3, 1, 384, 384, 32).to(torch_device)  # Replace with your actual data


In [None]:
for ii in range(0,10000):
    #import pdb; pdb.set_trace()
    optimizer.zero_grad()
    recon, mu, logvar = vae(input_data)
    #loss = vae_loss(recon, input_data, mu, logvar).to(torch_device)
    mse = mse_loss(recon, input_data)
    kld = kld_loss(mu, logvar)
    pcc = pcc_loss(recon, input_data)
    if ii%100 == 0:
        str_print = f"epoch number {ii}: MSE= {mse}, PCC = {pcc}, KLD = {kld}"
        print(str_print)
    loss = mse+kld+pcc
    loss.backward()
    optimizer.step()

In [None]:
recon_save = recon.detach().cpu().numpy()
recon_save = np.squeeze(recon_save)
input_save = input_data.cpu().numpy()
input_save = np.squeeze(input_save)

path_output = '/data02/gkim/stem_cell_jwshin/230425.142330.H9_untreated.003.Group1.A1.S003_VAEout_v7.h5'
output_file = h5py.File(path_output, 'w') #

In [None]:
input_save.shape

In [None]:


output_file.create_dataset('output',
                           shape = recon_save.shape,
                           maxshape = (None, None, None),
                           data = recon_save)
#output_file['output'][...] = recon_save
output_file.create_dataset('input',
                           shape = input_save.shape,
                           maxshape = (None, None, None),
                           data=input_save)
#output_file['input'][...] = input_save
output_file.close()