In [None]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import logging
from dataset import DataloaderImg, GreyToRGB, NormalizeVolume
from torch import optim
import torchvision.transforms.v2 as tfms
from tqdm import tqdm

plt.ion()  # interactive mode

# install pytorch correctly
# https://discuss.pytorch.org/t/torch-cuda-is-not-available/74845/11

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
logging.basicConfig(
    level=logging.INFO,
)
# install cuda driver ubuntu
# https://ubuntu.com/server/docs/nvidia-drivers-installation

In [None]:
# https://pytorch.org/docs/stable/cuda.html
# torch.cuda.is_initialized()

if torch.cuda.is_available():
    current_device = torch.cuda.current_device()
    print(f"Current cuda device: {torch.cuda.get_device_name(current_device)}")

In [None]:
# script for segmentation
# https://github.com/mateuszbuda/brain-segmentation-pytorch

In [None]:
# help in loading images
# https://discuss.pytorch.org/t/how-to-load-all-the-nii-from-the-directory-without-augmentation-using-pytorch-dataloader/60938/3
root_dir = "/media/MeMoSLAP_Subjects/derivatives/automated_electrode_extraction"  #!NOTE: delete "/train" for all subjects

custom_transforms = [NormalizeVolume(), GreyToRGB()]
transforms = [tfms.RandomRotation(180)]
n_validation = 4
# full_dataset = Dataloder_img('C:/Users/Ali ktk/.spyder-py3/dataloader/data/train/1', 'C:/Users/Ali ktk/.spyder-py3/dataloader/data/train/1/ADNI_136_S_0300_MR_MPR__GradWarp__B1_Correction__N3__Scaled_Br_20080529142830882_S50401_I107759.nii' ,tfms.Compose([tfms.RandomRotation(180).tfms.ToTensor()]))
train_dataset = DataloaderImg(
    root_dir,
    subset="train",
    validation_cases=n_validation,
    custom_transforms=custom_transforms,
    transforms=transforms,
)
validation_dataset = DataloaderImg(
    root_dir,
    subset="validation",
    validation_cases=n_validation,
    custom_transforms=custom_transforms,
    transforms=transforms,
)

loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

In [None]:
def show_batch(img_batch, mask_batch, n_cols=3):
    """
    Plot all images and masks of a batch as subplots with masks overlayed in red and opacity
    """

    n_axs = n_img = len(img_batch)

    if n_img % n_cols != 0:
        n_axs += n_cols - (n_img % n_cols)

    _, axs = plt.subplots(
        int(n_axs / n_cols), n_cols, figsize=(15, 5 * int(n_axs / n_cols))
    )

    for img, mask, ax in zip(img_batch, mask_batch, axs.flatten()):
        ax.imshow(img.permute(1, 2, 0))
        ax.imshow(mask.squeeze(), alpha=0.3, cmap="Reds")

    return axs

In [None]:
train_img, train_mask = next(iter(loader))
show_batch(train_img, train_mask)

In [None]:
unet = torch.hub.load(
    "mateuszbuda/brain-segmentation-pytorch",
    "unet",
    in_channels=3,
    out_channels=1,
    init_features=32,
    pretrained=False,
)

In [None]:
import torch.nn as nn


class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
        self.smooth = 1.0

    def forward(self, y_pred, y_true):
        if y_pred.size() != y_true.size():
            raise ValueError("y_pred and y_true must have the same shape.")
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2.0 * intersection + self.smooth) / (
            y_pred.sum() + y_true.sum() + self.smooth
        )
        return 1.0 - dsc

In [None]:
from torchinfo import summary

summary(unet, input_size=(1, 3, 288, 288))

In [None]:
unet.to(device)

batch_size = 15  # todo: increase
epochs = 1_000  # todo: increase
vis_frequency = 10
vis_images = 4

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

loaders = {"train": train_loader, "valid": valid_loader}

dsc_loss = DiceLoss()
best_validation_dsc = 0.0

optimizer = optim.Adam(unet.parameters(), lr=1e-4)  # todo: whats lr?

loss_train = []
loss_valid = []

step = 0

for epoch in tqdm(range(epochs), total=epochs):
    # for phase in ["train", "valid"]:
    # ...

    phase = "train"

    if phase == "train":
        unet.train()
    else:
        unet.eval()

    validation_pred = []
    validation_true = []

    for i, data in enumerate(loaders[phase]):
        if phase == "train":
            step += 1

        x, y_true = data
        x, y_true = x.to(device, dtype=torch.float), y_true.to(device)

        optimizer.zero_grad()

        with torch.set_grad_enabled(phase == "train"):
            y_pred = unet(x)
            loss = dsc_loss(y_pred, y_true)

            if phase == "valid":
                loss_valid.append(loss.item())

                y_pred_np = y_pred.detach().cpu().numpy()
                validation_pred.extend(
                    [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                )
                y_true_np = y_true.detach().cpu().numpy()
                validation_true.extend(
                    [y_true_np[s] for s in range(y_true_np.shape[0])]
                )

                if (epochs % vis_frequency == 0) or (epoch == epochs - 1):
                    if i * batch_size < vis_images:
                        tag = f"image/{i}"
                        num_images = vis_images - i * batch_size
                        ...

            if phase == "train":
                loss_train.append(loss.item())
                loss.backward()
                optimizer.step()

        if phase == "train" and (step + 1) % 10 == 0:
            ...

    if phase == "valid":
        ...

print("Best validation mean DSC: {:.4f}".format(best_validation_dsc))

In [None]:
x.cpu().detach().size()

In [None]:
torch.save(unet, f"unet_epochs_{epochs}_batchsize_{batch_size}.pt")

In [None]:
show_batch(x.cpu().detach().type(torch.IntTensor), y_pred.cpu().detach())