In [None]:
import torch
import torch.nn as nn
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
)
import matplotlib.pyplot as plt

# my implementation
from V_NAS import Network, get_device



from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
)

device = get_device(2)


train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ]
)


val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)


# load the 
Train_datalist = load_decathlon_datalist("./data/dataset.json", True, "training")
Val_datalist = load_decathlon_datalist("./data/dataset.json", True, "val")



train_ds = CacheDataset(
    data=Train_datalist,
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=8,
)
train_loader = DataLoader(
    train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True
)
val_ds = CacheDataset(
    data=Val_datalist, transform=val_transforms, cache_num=1, cache_rate=1.0, num_workers=4
)
val_loader = DataLoader(
    val_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True
)
print("num of train_ds {}, num of val_ds {}".format(len(train_ds), len(val_ds)))


In [None]:
batch = next(iter(train_loader))

x, y = batch["image"], batch["label"]

print(x.shape, y.shape)


In [None]:

plt.imshow(y[0, 0, :, :, 65])
plt.colorbar()

In [None]:
loss_function = DiceCELoss(
    include_background=True, 
    to_onehot_y=True, 
    ce_weight=torch.tensor([0., 0.2, 0.8]).to(device),
    lambda_ce=0.8,
    lambda_dice=0.2
)



model = Network().to(device)
for step, batch in enumerate(train_loader):
    x, y = batch["image"].to(device), batch["label"].to(device)

    pred = model(x)
    loss = loss_function(pred, y)
    print(loss.item())


In [None]:
import torch
from torch.nn import CrossEntropyLoss


lf = CrossEntropyLoss(torch.tensor([0.9, 0.1, 0.5]))

y = torch.tensor([[
    [0, 1],
    [2, 2]
]])

pred = torch.tensor([[
    [[0.005, 1],
    [2, 0]],

    [[1, 5],
    [1, 1]],

    [[1, 1],
    [20, 20]],
]]).to(torch.float32)


lf(pred, y)