In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import os
import numpy as np
import h5py

import matplotlib.pyplot as plt

import scipy.io
import itertools

from models.AE3D import AE3D,ae_loss,mse_loss
import datas.preprocess3d
from datas.Patch3DLoader import Patch3DLoader
#from datas.preprocess3d import TRAIN_AUGS_3D, TEST_AUGS_3D, TRAIN_NOAUGS_3D


class Loss_PCC(torch.nn.Module): ## works properly for isotropic 2D & if 3D it's summed over all z
    def __init__(self, eps = 1e-8, torch_device=None):
        super(Loss_PCC, self).__init__()
        self.torch_device = torch_device
        self.eps = eps

    def forward(self, img1, img2):
        tup_dim = tuple([i for i in range(2,len(img1.shape))])
        mu1 = torch.mean(img1, tup_dim)
        mu2 = torch.mean(img2, tup_dim)
        sigma1 = torch.std(img1, tup_dim)
        sigma2 = torch.std(img2, tup_dim)
        
        for i in range(2,len(img1.shape)):
            mu1 = mu1.unsqueeze(i)
            mu2 = mu2.unsqueeze(i)
            sigma1 = sigma1.unsqueeze(i)
            sigma2 = sigma2.unsqueeze(i)

        mu1 = mu1.repeat(1,1,*img1.shape[2:])
        mu2 = mu2.repeat(1,1,*img1.shape[2:])
        sigma1 = sigma1.repeat(1,1,*img1.shape[2:])
        sigma2 = sigma2.repeat(1,1,*img1.shape[2:])

        img1_ = (img1-mu1)/(sigma1+self.eps)
        img2_ = (img2-mu2)/(sigma2+self.eps)
        
        PCC = img1_*img2_
        return 1-PCC.mean()


In [None]:
net = AE3D(input_channels=1, latent_dim=16384)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = '6'
torch_device = torch.device("cuda")

#net = nn.DataParallel(net.to(torch_device))
net = net.to(torch_device) # for single GPU
optimizer = optim.Adam(net.parameters(), lr=0.001)
pcc_loss = Loss_PCC(eps = 1e-6, torch_device = torch_device)

In [None]:
data_path = '/data02/gkim/stem_cell_jwshin/data/230811+230502_3DH5_wider_v3_allh_onRA/' # size_z = 32 in preprocess3d 
# data_path = '/data02/gkim/stem_cell_jwshin/data/230811+230502_3DH5_wider_v3_allh_onRA/' # size_z = 12 in preprocess3d 
train_loader = Patch3DLoader(data_path + "/train", 4,
                                transform=datas.preprocess3d.TRAIN_AUGS_3D, aug_rate=0.0,
                                num_workers=4, shuffle=False, drop_last=False)
enum_train = itertools.cycle(train_loader)


In [None]:

for ii in range(0, 300):
        #train
    net.train()
    print('starting to train epoch[%05d]' % ii)
    batch_current = 0
    for (input_, target_, path) in enum_train:
        input_, target_ = input_.to(torch_device), target_.to(torch_device)
        optimizer.zero_grad()
        recon, z = net(input_)
        mse = mse_loss(recon, input_)
        pcc = pcc_loss(recon, input_)
        
        loss = 0.01*mse+1.0*pcc




        loss.backward()
        optimizer.step()
        batch_current = batch_current+1
        # print('\r')
        # print("training: epoch %05d batch %d/%d, pcc loss = %.3f, mse loss = %.3f"
        #       % (ii, batch_current, len(test_loader), pcc, mse))
        if batch_current == len(train_loader):
            break

    if ii%1 == 0:
        str_print = f"epoch number {ii}: pcc loss = {pcc}, mse loss = {mse}"
        print(str_print)



In [None]:
data_path = '/data02/gkim/stem_cell_jwshin/data/230811+230502_3DH5_wider_v3_allh_onRA/'
test_loader = Patch3DLoader(data_path + "/test", 1,
                                transform=datas.preprocess3d.TEST_AUGS_3D, aug_rate=0.0,
                                num_workers=1, shuffle=False, drop_last=False)
enum_test = itertools.cycle(test_loader)

In [None]:
(input_, target_, path) = next(enum_test)


input_, target_ = input_.to(torch_device), target_.to(torch_device)
optimizer.zero_grad()

recon, z = net(input_)


In [None]:
idx_batch = 0
idx_z = 4

In [None]:
np.squeeze(input_[0].detach().cpu().numpy(), axis = 0)[:,:,idx_z].shape
plt.imshow(np.squeeze(input_[idx_batch].detach().cpu().numpy(), axis = 0)[:,:,idx_z], vmin = 0, vmax = 1)

In [None]:
plt.imshow(np.squeeze(recon[idx_batch].detach().cpu().numpy(), axis = 0)[:,:,idx_z], vmin = 0, vmax = 1)