In [None]:
import torch
from V_NAS import Network, get_device
from monai.metrics import DiceMetric, compute_meandice
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.data import (
    DataLoader,
    Dataset,
    load_decathlon_datalist,
    
)
import matplotlib.pyplot as plt
from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
)

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)

device = get_device(1)

In [None]:
model = Network((64, 64, 64), 3)
model.load_state_dict(torch.load("model/one_shot_with_test.ptb"))
model.to(device)

In [None]:
Val_datalist = load_decathlon_datalist("data/dataset.json", True, "val")
val_ds = Dataset(
    data=Val_datalist, transform=val_transforms
)
val_loader = DataLoader(
    val_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True
)

In [None]:
post_label = AsDiscrete(to_onehot=3)
post_pred = AsDiscrete(argmax=True, to_onehot=3)



def validation(model, val_loader):
    score = []
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(val_loader):
            val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device))
            val_outputs = sliding_window_inference(val_inputs, (64, 64, 64), 4, model)
            # print(val_outputs.shape, val_labels.shape)
            # break
            
            val_labels_converted = post_label(val_labels[0]).unsqueeze(0)
            val_output_converted = post_pred(val_outputs[0]).unsqueeze(0)
            score.append(compute_meandice(val_output_converted, val_labels_converted))
            print(score)


    return score


score_list = validation(model, val_loader)
print(score_list)

In [None]:
x = torch.ones(1, 3)
for i in range(len(score_list)):
    x = torch.cat((x, score_list[i].cpu()))

x = x[1:, :]


torch.save(x, "./score_one_shot.pt")

In [None]:
torch.std(x, dim=0), torch.mean(x, dim=0)

In [None]:
torch.max(x, dim=0), torch.min(x, dim=0), torch.median(x, dim=0)

tensor([[0.9981, 0.6119, 0.6945],
        [0.9993, 0.6715, 0.6865],
        [0.9992, 0.6079, 0.3120],
        [0.9992, 0.6791, 0.7500],
        [0.9988, 0.7256, 0.4642],
        [0.9986, 0.4109, 0.0000],
        [0.9990, 0.7824, 0.0000],
        [0.9997, 0.8194, 0.0000],
        [0.9988, 0.5011, 0.0231],
        [0.9990, 0.6497, 0.3705],
        [0.9996, 0.5929, 0.2356],
        [0.9995, 0.7610, 0.0000],
        [0.9993, 0.6204, 0.1426],
        [0.9984, 0.5909, 0.4971],
        [0.9991, 0.6210, 0.1695],
        [0.9992, 0.4618, 0.2081],
        [0.9989, 0.7259, 0.4295],
        [0.9988, 0.5617, 0.6697],
        [0.9987, 0.7401, 0.1002],
        [0.9993, 0.6343, 0.5715],
        [0.9994, 0.4161, 0.7172],
        [0.9990, 0.5279, 0.1993],
        [0.9991, 0.6963, 0.0358],
        [0.9995, 0.6834, 0.0407],
        [0.9993, 0.6133, 0.2447],
        [0.9987, 0.6333, 0.0485],
        [0.9992, 0.6690, 0.6087],
        [0.9992, 0.7514, 0.0000]])