# Setup

In [1]:
import itertools
from collections import defaultdict

import torch
import numpy as np
from monai.metrics import compute_hausdorff_distance
import torch.nn.functional as F
import matplotlib.pyplot as plt
import einops as E
from torch.utils.data import DataLoader

import utils.dataset as example_data
from utils.visualization import visualize_tensors
from utils.utils import seed_everything
from models.original_universeg import universeg
from models.original_universeg import SimpleUniverSeg
import math
import itertools

from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt

import einops as E
import pathlib
import subprocess
from dataclasses import dataclass
from utils.const import DATA_FOLDER
from typing import Literal, Tuple

import numpy as np
import nibabel as nib
import PIL
import torch
from torch.utils.data import Dataset
import monai.losses


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = universeg(pretrained=False).to(device)

In [3]:
def visualize_tensors(tensors, col_wrap=8, col_names=None, title=None):
    M = len(tensors)
    N = len(next(iter(tensors.values())))

    cols = col_wrap
    rows = math.ceil(N/cols) * M

    d = 2.5
    fig, axes = plt.subplots(rows, cols, figsize=(d*cols, d*rows))
    if rows == 1:
      axes = axes.reshape(1, cols)

    for g, (grp, tensors) in enumerate(tensors.items()):
        for k, tensor in enumerate(tensors):
            col = k % cols
            row = g + M*(k//cols)
            x = tensor.detach().cpu().numpy().squeeze()
            ax = axes[row,col]
            if len(x.shape) == 2:
                ax.imshow(x,vmin=0, vmax=1, cmap='gray')
            else:
                ax.imshow(E.rearrange(x,'C H W -> H W C'))
            if col == 0:
                ax.set_ylabel(grp, fontsize=16)
            if col_names is not None and row == 0:
                ax.set_title(col_names[col])

    for i in range(rows):
        for j in range(cols):
            ax = axes[i,j]
            ax.grid(False)
            ax.set_xticks([])
            ax.set_yticks([])

    if title:
        plt.suptitle(title, fontsize=20)

    plt.tight_layout()

In [69]:
def dice_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred = y_pred.long()
    y_true = y_true.long()
    score = 2*(y_pred*y_true).sum() / (y_pred.sum() + y_true.sum())
    return score.item()

# Data Set

In [4]:
# supportloader = DataLoader(d_support, batch_size=32, shuffle=False, num_workers=1)
# testloader = DataLoader(d_test, batch_size=32, shuffle=False, num_workers=1)
# devloader = DataLoader(d_dev, batch_size=32, shuffle=False, num_workers=1)


In [73]:
lr = 10e-4
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
dice_loss = monai.losses.DiceLoss()
max_epoch = 100
support_set_sizes = [1, 2, 4, 8, 16, 32, 64]
val_interval = 1
support_set_size = 32
BATCH_SIZE = 2

In [75]:
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
for epoch in range(max_epoch):
    model.train()
    epoch_loss = 0
    step = 0
    task_idx = np.random.randint(20)
    d_support = example_data.My_OASISDataset("support", task_idx)
    supportloader = DataLoader(d_support, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
    for batch_data, batch_indices in zip(supportloader, supportloader.batch_sampler):
        step += 1
        image, labels = batch_data[0].to(device), batch_data[1].to(device)
        support_set = [element for idx, element in enumerate(d_support) if idx not in batch_indices]

        support_images, support_labels = zip(*itertools.islice(support_set, support_set_size*BATCH_SIZE))
        support_images = torch.stack(support_images).view(BATCH_SIZE, support_set_size, 1, 128, 128).to(device)
        support_labels = torch.stack(support_labels).view(BATCH_SIZE, support_set_size, 1, 128, 128).to(device)
        optimizer.zero_grad()

        logits = model(image, support_images, support_labels)
        pred = torch.sigmoid(logits)
        loss = dice_loss(pred, labels)
        loss.backward()
        epoch_loss += loss.item()        
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    if (epoch + 1) % val_interval == 0:
        model.eval()
        d_dev = example_data.My_OASISDataset("dev", task_idx)
        devloader = DataLoader(d_dev, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
        score = []
        with torch.no_grad():
            y_pred = torch.tensor([], dtype = torch.float32, device = device)
            y = torch.tensor([], dtype = torch.long, device = device)
            for dev_data, dev_indices in zip(devloader, devloader.batch_sampler):
                dev_image, dev_label = (dev_data[0].to(device), dev_data[1].to(device))
                dev_support = [element for idx, element in enumerate(d_dev) if idx not in dev_indices]

                dev_support_images, dev_support_labels = zip(*itertools.islice(dev_support, support_set_size*BATCH_SIZE))
                dev_support_images = torch.stack(dev_support_images).view(BATCH_SIZE, support_set_size, 1, 128, 128).to(device)
                dev_support_labels = torch.stack(dev_support_labels).view(BATCH_SIZE, support_set_size, 1, 128, 128).to(device)

                dev_logits = model(dev_image, dev_support_images, dev_support_labels)
                dev_soft_pred = torch.sigmoid(dev_logits)
                dev_hard_pred = dev_soft_pred.round().clip(0,1)
                score.append(dice_score(dev_hard_pred, dev_label))
                # y_pred = torch.cat([y_pred, dev_hard_pred], dim=0)
                # y = torch.cat([y, dev_label], dim = 0)
            result = sum(score) / len(score)
            # if result > best_metric:
            #     best_metric = result
            #     best_metric_epoch = epoch + 1
            #     torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
            #     print("saved new best metric model")
            print(f"current epoch: {epoch + 1}" f" current dice score: {result:.4f}")




epoch 1 average loss: 0.9976
current epoch: 1 current dice score: 0.0004
epoch 2 average loss: 0.9974
current epoch: 2 current dice score: 0.0014
epoch 3 average loss: 0.9872
current epoch: 3 current dice score: 0.0051
epoch 4 average loss: 0.9967
current epoch: 4 current dice score: 0.0011
epoch 5 average loss: 0.9817
current epoch: 5 current dice score: 0.0142
epoch 6 average loss: 0.9859
current epoch: 6 current dice score: 0.0056
epoch 7 average loss: 0.9816
current epoch: 7 current dice score: 0.0033
epoch 8 average loss: 0.9993
current epoch: 8 current dice score: 0.0000
epoch 9 average loss: 0.9816
current epoch: 9 current dice score: 0.0033
epoch 10 average loss: 0.9816
current epoch: 10 current dice score: 0.0033
epoch 11 average loss: 0.9976
current epoch: 11 current dice score: 0.0004
epoch 12 average loss: 0.9816
current epoch: 12 current dice score: 0.0033
epoch 13 average loss: 0.9934
current epoch: 13 current dice score: 0.0035
epoch 14 average loss: 0.9985
current epoch