In [8]:
from imports.models import *
from imports.utils import *
import adabound as Adabound

import torch
import matplotlib.pyplot as plt
import numpy as np

In [1]:
def apply_denoise_model(x_noisy_1d, denoise_model):
    x_tensor = torch.Tensor(x_noisy_1d).unsqueeze(0).unsqueeze(0).cuda()  # (1, 1, T)
    x_denoised_tensor = denoise_model(x_tensor)
    return x_denoised_tensor.squeeze().detach().cpu().numpy()


In [2]:
def reshape_to_2D(x_1d, model_index):
    """
    model_index: (n_rows, image_width*2+1)
    x_1d: denoised 1D signal (length T)
    return: 2D array (n_rows, image_width*2+1)
    """
    return x_1d[model_index]


In [4]:
def predict_with_HPC(x_2d, hpc_model):
    x_flat = x_2d.flatten()  # HPC는 1D 입력 받음
    x_tensor = torch.Tensor(x_flat).unsqueeze(0).cuda()  # (1, T_flat)
    with torch.no_grad():
        hpc_model.eval()
        prediction = hpc_model(x_tensor).squeeze().cpu().numpy()
    return prediction


In [5]:
def full_pipeline(x_noisy_1d, model_index, denoise_model, hpc_model):
    # 1. Denoise
    x_denoised = apply_denoise_model(x_noisy_1d, denoise_model)

    # 2. 1D → 2D
    x_2d = reshape_to_2D(x_denoised, model_index)

    # 3. HPC 예측
    prediction = predict_with_HPC(x_2d, hpc_model)

    return x_denoised, x_2d, prediction


In [9]:
total_indices = np.load('../data/total_indices/total_indices_v4_full.npy', allow_pickle=True).item()

# AB_lists_dic: a dictionary file that contains nuclear spins (value: (A,B) pairs (Hz)) with a corresponding target period (key: A (Hz)).
AB_lists_dic = np.load('../data/AB_target_dic/AB_target_dic_v4_s0.npy', allow_pickle=True).item()
for i in range(1, 16):
    temp = np.load('../data/AB_target_dic/AB_target_dic_v4_s{}.npy'.format(i), allow_pickle=True).item()
    AB_lists_dic.update(temp)

In [10]:
N_SAMPLES_TRAIN = 8192*4
N_SAMPLES_VALID = 8192
data_size = 1024

N_PULSE_32 = 32
N_PULSE_256 = 256

time_data = np.arange(0, 60, 0.004)

X_valid = np.zeros((N_SAMPLES_VALID, data_size))
Y_valid = np.zeros((N_SAMPLES_VALID, data_size))
Y_valid_pure = np.zeros((N_SAMPLES_VALID, data_size))

# for i in range(int(N_SAMPLES_VALID/ABlists_valid_0.shape[0])):
#     for j in range(len(ABlists_valid_0)):
#         rand_idx = np.random.randint(11000)
#         time_data_temp = time_data[rand_idx:rand_idx+data_size]
#
#         X_valid[i*len(ABlists_valid_0)+j], \
#         Y_valid[i*len(ABlists_valid_0)+j], \
#         Y_valid_pure[i*len(ABlists_valid_0)+j], \
#         = Px_noise_data(time_data_temp, WL_VALUE, globals()['ABlists_valid_{}'.format(i)][j], N_PULSE_32, rand_idx, data_size, y_train_pure=True)

In [14]:
# 예제: 노이즈가 낀 실험 CPMG 시그널
x_noisy_example = X_valid[0].squeeze()

# 모델 인덱스 (TPk에 해당하는 A_index 기준)
A_index = 10050
model_index = get_model_index(total_indices, A_index, time_thres_idx=time_data.shape[0], image_width=10)

# 모델 불러오기
denoise_model = Denoise_Model().cuda()
denoise_model.load_state_dict(torch.load('../data/models/denoising_model.pt'))
denoise_model.eval()

cut_idx, image_width = model_index.shape[0], 10  # 예시
input_size = cut_idx * (2 * image_width + 1)
output_size = 3

hpc_model = HPC(input_size, output_size).cuda()
hpc_model.load_state_dict(torch.load('../data/models/hpc_model_504.pt'))
hpc_model.eval()



# 실행
x_denoised, x_2d, prediction = full_pipeline(x_noisy_example, model_index, denoise_model, hpc_model)

# 시각화
plt.figure(facecolor='w')
plt.subplot(1,3,1)
plt.plot(x_noisy_example)
plt.title('Noisy input')

plt.subplot(1,3,2)
plt.plot(x_denoised)
plt.title('Denoised')

plt.subplot(1,3,3)
plt.pcolor(x_2d)
plt.title('2D Image (HPC input)')

print("Prediction vector:", prediction)


RuntimeError: Error(s) in loading state_dict for HPC:
	size mismatch for linear1.weight: copying a param with shape torch.Size([2048, 1024]) from checkpoint, the shape in current model is torch.Size([2048, 1092]).
	size mismatch for linear4.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([3, 512]).
	size mismatch for linear4.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([3]).