In [6]:
import torch
import json
from copy import deepcopy
from pathlib import Path
from timeit import default_timer as timer
from tqdm import tqdm
import mrcfile
from models.map_splitter import reconstruct_maps, create_cube_list
from models.unet import UNetRes
from utils.utils import load_data
import numpy as np
import skimage

In [2]:
with open("configs/inference.json", "r") as f:
    conf = json.load(f)

device = (
        torch.device("cuda")
        if torch.cuda.is_available()
        else "cpu"
    )
with open(conf["checkpoint"]["model_config"], "r") as f:
    model_conf = json.load(f)
    
conf = {**conf, **model_conf}
dataloader = load_data(conf, training=False)

In [22]:
model = UNetRes(n_blocks=conf["model"]["n_blocks"], act_mode=conf["model"]["act_mode"])
checkpoint = torch.load(conf["checkpoint"]["trained_weights"], map_location="cpu")
model.load_state_dict(checkpoint)
model = model.to(device)

In [3]:
input_map = mrcfile.open("example_data/emd_23099/resampled_map.mrc", mode="r")

In [31]:
def map_resample(input_map):
    vol_x, vol_y, vol_z = (
        float(input_map.voxel_size.x),
        float(input_map.voxel_size.y),
        float(input_map.voxel_size.z),
    )
    voxel_size = [vol_x, vol_y, vol_z]
    meta_data = deepcopy(input_map.header)
    input_map = deepcopy(input_map.data)
    scale_factor = [vol / 1.0 for vol in voxel_size]
    output_shape = [
        round(dim * scale) for dim, scale in zip(input_map.shape, scale_factor)
    ]

    input_map = skimage.transform.resize(
        input_map,
        output_shape,
        order=3,
        mode="reflect",
        cval=0,
        clip=True,
        preserve_range=False,
        anti_aliasing=True,
        anti_aliasing_sigma=None,
    )

    input_map = (input_map - input_map.min()) / (input_map.max() - input_map.min())
    input_cube_list = np.array(create_cube_list(input_map))

    return torch.tensor(input_cube_list, dtype=torch.float), meta_data

In [32]:
input_data, meta_data = map_resample(input_map)

In [37]:
batch_size = 16
torch.backends.cudnn.benchmark = True
model.eval()
with torch.no_grad():
    y_pred = torch.tensor(())
    for indx in tqdm(range(0, input_data.shape[0], batch_size)):
        x_partial = input_data[indx : indx + batch_size].unsqueeze(dim=1).to(device)
        y_pred_partial = model(x_partial)
        y_pred = torch.cat(
            (y_pred, y_pred_partial.squeeze(dim=1).detach().cpu()),
            dim=0,
        )
    original_shape = (
        int(meta_data.cella.x),
        int(meta_data.cella.y),
        int(meta_data.cella.z),
    )
    y_pred_recon = reconstruct_maps(
        y_pred.numpy(),
        original_shape,
    )
    # if conf["test_data"]["save_output"]:
    #     with mrcfile.new(conf["output_path"] + "/pred_{}.mrc".format(id[0])) as mrc:
    #         mrc.set_data(y_pred_recon)
    #         mrc.header.cella.x = meta_data.cella.x
    #         mrc.header.cella.y = meta_data.cella.y
    #         mrc.header.cella.z = meta_data.cella.z
    #         mrc.header.nxstart = meta_data.nxstart
    #         mrc.header.nystart = meta_data.nystart
    #         mrc.header.nzstart = meta_data.nzstart

100%|████████████████████████████████████████████████| 8/8 [00:10<00:00,  1.28s/it]
