In [None]:
import scipy.io
import numpy as np
from os import listdir
from os.path import isfile, join
from monai.utils import first, set_determinism
from monai.transforms import (
    Activations,
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
    LoadImage,
    RandAffined,
    PadListDataCollate,
    RandSpatialCropd,
    SpatialPadd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    RandRotate90d,
    NormalizeIntensityd
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import monai
import SimpleITK as sitk

In [None]:
cuda = torch.cuda.is_available()

device = torch.device("cuda" if cuda else "cpu")
num_workers = 4 if cuda else 0

print('You are using gpu if true, cpu if false:', cuda)

In [None]:
data_dir = '' #enter location of 2P datasets

val_images = sorted(
    glob.glob(os.path.join(data_dir, "*nii*")))

val_files = [
    {"image": image_name}
    for image_name in zip(val_images)
]

In [None]:
val_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        EnsureTyped(keys=["image"]),
    ]
)

In [None]:
val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1,num_workers=0)

In [None]:
roi = [256,256,224]
max_epochs = 1500
val_interval = 20
batchsz = 2

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay = 1e-6)

In [None]:
temp = torch.load('') #load .pt file 
model.load_state_dict(temp['model_state_dict'])
savepath = data_dir+"mask_"

In [None]:
with torch.no_grad():
    for i, val_data in enumerate(val_loader):
        model.eval()
        roi_size = (256, 256, 192)
        sw_batch_size = 1
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model
        )
        im_final = torch.argmax(
            val_outputs, dim=1).detach().cpu().numpy()
        im_final = im_final[0,:,:,:].astype('int16')
        im_final = im_final.transpose(2,0,1)
        im_itk = sitk.GetImageFromArray(im_final)
        imagename = val_images[i].replace(data_dir,'')
        savepath_new = savepath + imagename 
        sitk.WriteImage(im_itk, savepath)
        