In [None]:
import os
from glob import glob
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader
#from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

import nibabel as nib
import numpy as np
from monai.data import Dataset
from monai.networks import nets
from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImaged,
    Activations,
    AsDiscrete,
    Compose, 
    ScaleIntensityd,
    ToTensord,
    SaveImage,
    Orientationd,
    FillHoles,
    EnsureChannelFirstd,
)


In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

print(device)

### Load the input data and check

In [None]:

work_root ='...input data path...'
valid_brain = glob(f'{work_root}/*.nii.gz')

save_dir='...save path...'
os.makedirs(save_dir, exist_ok=True)

#for check
a=nib.load(valid_brain[0])
affine = a.affine 
axcodes = nib.aff2axcodes(affine)
print(axcodes)
original_orientation = axcodes
a=a.get_fdata()
zs = int(a.shape[2]/2)
plt.imshow(a[:,:,zs],cmap='gray')

valid_dicts=[
    {"image": image_name} for image_name in valid_brain
]

pre_transforms = Compose(
    [
        LoadImaged(keys=('image')),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=('image'), axcodes=('PRS')),
        ScaleIntensityd(keys='image'),
        ToTensord(keys=('image')),
    ]
)

check_ds = Dataset(data=valid_dicts,transform=pre_transforms)
check_loader = DataLoader(check_ds, batch_size=2, num_workers=0, pin_memory=torch.cuda.is_available())
img=check_ds[0]['image']
print(img.shape)
plt.figure(dpi=128)
plt.imshow(img[0,:,:,zs],cmap='gray')
plt.show()

## Run inference

In [None]:
dataset = Dataset(data=valid_dicts, transform=pre_transforms)
dataloader = DataLoader(dataset, batch_size=2, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nets.BasicUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    features=(16,32,64,128,256,32),
    ).to(device)

post_trans = Compose(
    [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]
)

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5), FillHoles(connectivity=4)])

saver = SaveImage(output_dir=save_dir, output_ext=".nii.gz", output_postfix="DL")

load_path = f'dir/QSMmask_net_parameters.pth'
model.load_state_dict(torch.load(load_path))
model.eval()

with torch.no_grad():
    for d in tqdm(dataloader):
        images = d["image"].to(device)
        [bs,c, xs,ys,zs] = images.size()
        print(images.size())
        # define sliding window size and batch size for windows inference
        pred_outputs = sliding_window_inference(inputs=images, roi_size=(int(0.7*xs),int(0.7*ys),int(0.5*zs)), sw_batch_size=2, predictor=model,overlap=0.5)          
        pred_outputs = post_trans(pred_outputs[0])
        print(pred_outputs.size())
        
input_data=nib.load(valid_brain[0])
pred_outputs_cpu = (np.array(pred_outputs[0].cpu()))
plt.imshow(pred_outputs_cpu[:,:,int(zs/2)])
                        

save_output_data = nib.Nifti1Image(pred_outputs_cpu, input_data.affine, input_data.header)
nib.save(save_output_data,f'{save_dir}/QSMmask-net_mask.nii.gz')