In [None]:
import os
import random
from math import ceil
import numpy as np
import torch
import torch.fft as fft
import torch.nn as nn
from pytorch_msssim import ssim
from scunet import SCUNet
from utils import *
from torch import FloatTensor as FT
from torch.autograd import Variable as V


depoFolder = "/home/ty/training_and_validation_sets/depoFiles"
simuFolder = "/home/ty/training_and_validation_sets/simuFiles"
save_dir="datasets"
batch_size = 8
apix = 1
num_epochs = 30


In [None]:
# 数据预处理
depoList = get_all_files(depoFolder)
simuList = get_all_files(simuFolder)
depoList.sort()
simuList.sort()
n_chunks = 0
for depoFile, simuFile in zip(depoList, simuList):
    if(os.path.getsize(depoFile) > 1024 * 1024 * 512 or os.path.getsize(simuFile) > 1024 * 1024 * 512):
        continue
    n_chunks += split_and_save_tensor(depoFile, simuFile, save_dir) 

In [None]:

# 输入为torch张量batch_size*60*60*60
model = SCUNet(
    in_nc=1,
    config=[2,2,2,2,2,2,2],
    dim=32,
    drop_path_rate=0.0,
    input_resolution=48,
    head_dim=16,
    window_size=3,
)
torch.cuda.empty_cache()
model = model.cuda()



# 定义trainer
trainer = torch.optim.Adam(
    model.parameters(),
    lr=0.0005,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0,
    amsgrad=False
)



In [None]:
import numpy as np

def align(depoMap, simuMap):
    padded_shape = [max(depoMap.shape[0], simuMap.shape[0]),
                    max(depoMap.shape[1], simuMap.shape[1]),
                    max(depoMap.shape[2], simuMap.shape[2])]
    # 对两个信号进行零填充（将原始数据放在左上角）
    depo_padded = np.zeros(padded_shape, dtype=depoMap.dtype)
    depo_padded[:depoMap.shape[0], :depoMap.shape[1], :depoMap.shape[2]] = depoMap
    simu_padded = np.zeros(padded_shape, dtype=simuMap.dtype)
    simu_padded[:simuMap.shape[0], :simuMap.shape[1], :simuMap.shape[2]] = simuMap
    # 3D FFT
    fft_depo = np.fft.fftn(depo_padded)
    fft_simu = np.fft.fftn(simu_padded)
    # calculate corr
    corr_freq = fft_depo * np.conj(fft_simu)
    # ifftn->real
    corr = np.fft.ifftn(corr_freq).real 

    peak_idx = np.unravel_index(np.argmax(corr), corr.shape)
    dx = peak_idx[0]
    dy = peak_idx[1]
    dz = peak_idx[2]
    print(dx, dy, dz)
    depo_padded = np.roll(depo_padded, shift=-dx, axis=0)
    depo_padded = np.roll(depo_padded, shift=-dy, axis=1)
    depo_padded = np.roll(depo_padded, shift=-dz, axis=2)
    return depo_padded, simu_padded


def normalized_cross_correlation(vol1, vol2):
    vol1 = vol1.flatten()
    vol2 = vol2.flatten()
    numerator = np.sum((vol1 - np.mean(vol1)) * (vol2 - np.mean(vol2)))
    denominator = np.sqrt(np.sum((vol1 - np.mean(vol1))**2) * np.sum((vol2 - np.mean(vol2))**2))
    return numerator / denominator

# 示例用法
if __name__ == "__main__":
    # 生成模拟数据：参考体积和偏移后的浮动体积
    np.random.seed(42)
    size = (10, 10, 10)
    vol_ref = np.random.rand(*size).astype(np.float32)
    dx, dy, dz = 5, -3, 2  # 模拟位移
    vol_float = np.roll(vol_ref, shift=dx, axis=0)
    vol_float = np.roll(vol_ref, shift=dx, axis=0)
    vol_float = np.roll(vol_ref, shift=dx, axis=0)
    print(vol_ref)
    print(vol_float)
    print(vol_float.shape)
    print(normalized_cross_correlation(vol_ref, vol_float))

    vol_registered, vol_ref = align(vol_float, vol_ref)
    print(normalized_cross_correlation(vol_ref, vol_ref))
    print(normalized_cross_correlation(vol_ref, vol_registered))
    print(vol_ref == vol_registered)

In [None]:
import numpy as np

# 示例 1：一维数组滚动
arr1 = np.array([1, 2, 3, 4, 5])
shifted_arr1 = np.roll(arr1, shift=2)
print("一维数组滚动后的结果：", shifted_arr1)

# 示例 2：二维数组沿特定轴滚动
arr2 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 沿行方向（axis=0）滚动 1 步
shifted_arr2_axis0 = np.roll(arr2, shift=1, axis=0)
print("二维数组沿行方向滚动后的结果：")
print(shifted_arr2_axis0)
# 沿列方向（axis=1）滚动 1 步
shifted_arr2_axis1 = np.roll(arr2, shift=1, axis=1)
print("二维数组沿列方向滚动后的结果：")
print(shifted_arr2_axis1)

In [None]:
# print(depoList[1])
# depoMap = mrc2map(depoList[1], 1.0)
# # 降采样
# depoMap = depoMap[::20, ::20, ::20]
# print(f"shape: {depoMap.shape}")
# x, y, z = np.where(depoMap)  # 筛选高值区域
# print(x.shape)
# values = depoMap[x, y, z]
# print(f"shape of values: {values.shape}")

# # fig = go.Figure(data=go.Scatter3d(
# #     x=x, y=y, z=z, mode='markers',
# #     marker=dict(size=5, color=values, colorscale='Viridis', opacity=0.8)
# # ))
# # fig.update_layout(scene=dict(aspectmode='cube'))
# # fig.show()


# # plt.hist(values, bins=30, density=True, color='skyblue', edgecolor='black', alpha=0.7)
# # plt.xlabel('interval')
# # plt.ylabel('likelihood')
# # plt.title('frequency histogram')
# # plt.show()


In [None]:
# print(simuList[1])
# simuPadded = mrc2map(simuList[1], 1.0)
# # 降采样
# simuPadded = simuPadded[::20, ::20, ::20]
# print(f"shape: {simuPadded.shape}")
# x, y, z = np.where(simuPadded)  # 筛选高值区域
# values = simuPadded[x, y, z]
# print(f"shape of values: {values.shape}")
# # fig = go.Figure(data=go.Scatter3d(
# #     x=x, y=y, z=z, mode='markers',
# #     marker=dict(size=5, color=values, colorscale='Viridis', opacity=0.8)
# # ))
# # fig.update_layout(scene=dict(aspectmode='cube'))
# # fig.show()


# # plt.hist(values, bins=30, density=True, color='skyblue', edgecolor='black', alpha=0.7)
# # plt.xlabel('interval')
# # plt.ylabel('likelihood')
# # plt.title('frequency histogram')
# # plt.show()


# 输入numpy张量, 返回torch张量
#def align(depoMap, simuMap, device):
#    depoMap = torch.from_numpy(depoMap)
#    simuMap = torch.from_numpy(simuMap)
#    xyz_max = [max(depoMap.shape[0], simuMap.shape[0]), 
#               max(depoMap.shape[1], simuMap.shape[1]),
#               max(depoMap.shape[2], simuMap.shape[2])]
#    depoPadded = torch.zeros((xyz_max)).to(device)
#    simuMap = simuMap.to(device)
#    depoPadded[xyz_max[0] - depoMap.shape[0] : xyz_max[0],
#               xyz_max[1] - depoMap.shape[1] : xyz_max[1],
#               xyz_max[2] - depoMap.shape[2] : xyz_max[2],
#               ] = depoMap
#    corr = torch.nn.functional.conv3d(depoPadded.unsqueeze(0).unsqueeze(0), simuMap.unsqueeze(0).unsqueeze(0))
#    max_idx = torch.argmax(corr)
#    dx, dy, dz = np.unravel_index(max_idx.cpu().numpy(), corr.shape[2:])
#    cropped = depoPadded[dx : dx + simuMap.shape[0],
#                         dy : dy + simuMap.shape[1],
#                         dz : dz + simuMap.shape[2]]
#    print(f"cropped.shape: {cropped.shape}")
#    print(f"simuMap.shape: {simuMap.shape}")
#    return cropped.cpu().numpy()


In [None]:
i = 0
for depoFile, simuFile in zip(depoList, simuList):
    if(os.path.getsize(depoFile) > 1024 * 1024 * 512 or os.path.getsize(simuFile) > 1024 * 1024 * 512):
        i += 1

i

In [None]:
model.train()

for epoch in range(num_epochs):
    train_loss = 0
    cur_steps = 0
    for batch in data_iter(save_dir=save_dir, batch_size=batch_size): 
        depo_chunks, simu_chunks = batch[0], batch[1]
        if depo_chunks.shape[0] == 0 or simu_chunks.shape[0] == 0:
            continue
        depo_chunks = torch.from_numpy(depo_chunks)
        simu_chunks = torch.from_numpy(simu_chunks)
        #保证depo和simu这俩map对每个chunk的操作完全一致，即密度能完全对应上
        depo_chunks, simu_chunks = transform(depo_chunks, simu_chunks)
        X = V(FT(depo_chunks), requires_grad=True).view(-1, 1, 48, 48, 48)
        X = X.cuda()
        simu_chunks = simu_chunks.cuda()
        y_pred = model(X).view(-1, 48, 48, 48)           
        l = loss(y_pred, simu_chunks)
        trainer.zero_grad()
        l.backward()
        trainer.step()
        train_loss += l
        cur_steps += len(depo_chunks)
        print(f"processing: {cur_steps} / {n_chunks}")
    print(f"epoch:{epoch} depofile:{depoFile} train_loss:{train_loss}")
