In [1]:
# import modules
import torch
import numpy as np
import tifffile


from src.utils.utils_dataprocessing import get_all_files, split_and_save_tensor, normalize, resample, fourier_interpolate
from src.SN2N_2D.constants_2d import rawdataFolder, datasetsFolder, visualizationFolder, model, paramsFolder
from src.utils.utils_train_predict import try_all_gpus, loss_channels

In [2]:
raw_map_list = get_all_files(rawdataFolder)
print(raw_map_list[1])
raw_map = np.asarray(tifffile.imread(raw_map_list[1]))
normalized_map = normalize(raw_map, mode='2d')
tifffile.imwrite(f'{visualizationFolder}/normalized_maps/normalized_map.tif', normalized_map) 


/home/tyche/paddle_SN2N/data/data_2d/raw_data/10.tif


In [3]:
datasets = get_all_files(datasetsFolder)
print(datasets[1000])
chunk = np.load(datasets[1000])['arr_0']
print(chunk.shape)
tifffile.imwrite(f'{visualizationFolder}/maps/chunk.tif', chunk)

/home/tyche/paddle_SN2N/data/data_2d/datasets/102_0_64_320.npz
(128, 128)


In [4]:

kernel = torch.tensor([[[1, 0], [0, 1]],
                      [[0, 1], [1, 0]]]).float() / 2

# 使用vstack进行垂直堆叠
vertical_kernel = np.vstack((kernel[0], kernel[1]))

# 保存为TIF文件
tifffile.imwrite(f'{visualizationFolder}/kernel_vertical.tif', vertical_kernel)

In [5]:

kernel = torch.tensor([[[1, 0], [0, 1]],
                      [[0, 1], [1, 0]]]).float() / 2
out_channels, *spatial_dims = kernel.shape
kernel = kernel.view(out_channels, 1, *spatial_dims)
conv_layer = torch.nn.Conv2d(in_channels=1, out_channels=2, kernel_size=2, stride=2, padding=0, bias=False)
conv_layer.weight.data = kernel

chunk = torch.from_numpy(np.load(datasets[0])['arr_0'])
print(chunk.shape)
print(torch.mean(chunk))
down_chunk = conv_layer(chunk.view(1, *chunk.shape))
print(down_chunk.shape)
print(torch.mean(down_chunk))
vertical_chunk = np.vstack((down_chunk[0].detach().numpy(), np.zeros((1, 64), dtype=np.float32), down_chunk[1].detach().numpy()))
tifffile.imwrite(f'{visualizationFolder}/down_sampled.tif', vertical_chunk)
up_chunk = fourier_interpolate(down_chunk.view(2, 1, 64, 64)) * 16
print(up_chunk.shape)
print(torch.mean(up_chunk))
vertical_chunk1 = np.vstack((up_chunk[0][0].detach().numpy(), np.zeros((1, 128), dtype=np.float32), up_chunk[1][0].detach().numpy()))
tifffile.imwrite(f'{visualizationFolder}/up_sampled.tif', vertical_chunk1)

torch.Size([128, 128])
tensor(0.6428)
torch.Size([2, 64, 64])
tensor(0.6428, grad_fn=<MeanBackward0>)
torch.Size([2, 1, 128, 128])
tensor(0.6429, grad_fn=<MeanBackward0>)


In [6]:
# training


def init_weights(m):
    if type(m) == torch.nn.Linear or type(m) == torch.nn.Conv3d or type(m) == torch.nn.Conv2d:  
        torch.nn.init.xavier_uniform_(m.weight)


current_epoch = len(get_all_files(paramsFolder))
print(f"current_epoch:{current_epoch}")
if current_epoch != 0:
    state_dict = torch.load(f'{paramsFolder}/checkPoint_{current_epoch - 1}.pth')
    missing_keys, unexpected_keys = model.load_state_dict(state_dict)
    if missing_keys:
        print(f"missing_keys: {missing_keys}")
    if unexpected_keys:
        print(f"unused_keys: {unexpected_keys}")

    print(f'load {paramsFolder}/checkPoint_{current_epoch - 1}')
else:
    model.apply(init_weights)
    print(f'no params found, randomly init model')


devices = try_all_gpus()
model = model.to(devices[0])

total_params = sum(p.numel() for p in model.parameters())
print(f"num_parameters: {total_params}")

up_chunk = up_chunk.to(device=devices[0])
predicted = model(up_chunk)

print(loss_channels(up_chunk.view(1, 2, 128, 128), predicted.view(1, 2, 128, 128)))
up_chunk = model(up_chunk)

vertical_chunk2 = np.vstack((up_chunk[0][0].cpu().detach().numpy(), np.zeros((1, 128), dtype=np.float32), up_chunk[1][0].cpu().detach().numpy()))
tifffile.imwrite(f'{visualizationFolder}/predicted.tif', vertical_chunk2)

current_epoch:20
load /home/tyche/paddle_SN2N/data/data_2d/params/checkPoint_19
num_parameters: 401288
tensor(0.0008, device='cuda:0', grad_fn=<DivBackward0>)


In [7]:
# predicting
