In [1]:
%load_ext autoreload
%autoreload 2

In [40]:
import math
from itertools import combinations
from pathlib import Path
from typing import List, Tuple

import airlab as al
import cv2
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
from einops import rearrange
from ipywidgets import fixed, interact
from monai import transforms as tfm
from monai.data import DataLoader, Dataset
from monai.data.meta_tensor import MetaTensor
from monai.networks.nets import DenseNet
from torch import nn
from tqdm import tqdm
from wbpetct.data import FDG_PET_CT_Dataset, get_train_valid_loaders

from meddist.data import get_dataloaders
from meddist.dist import (
    euclidean_dist,
    get_bbox_centers,
    get_cropped_bboxes,
    pairwise_comparisons,
)
from meddist.model import ContrastiveDistanceDenseNet

In [3]:
np.set_printoptions(precision=2, suppress=True)
torch.set_printoptions(precision=2, sci_mode=False)

In [4]:
def plot_means(volume, batch=0, show=False):

    if len(volume.shape) == 4:
        volume = volume[batch]

    fig, axs = plt.subplots(1, 3)

    for i in range(3):
        axs[i].imshow(np.mean(volume, axis=i), cmap="gray")
        axs[i].set_axis_off()

    if show:
        plt.show()
        return

    return fig, axs

In [5]:
raw_data_path = "/sc-scratch/sc-scratch-gbm-radiomics/tcia/manifest-1654187277763/nifti/FDG-PET-CT-Lesions"

data_path = Path(
    "/sc-scratch/sc-scratch-gbm-radiomics/tcia/manifest-1654187277763/nifti/FDG-PET-CT-Lesions-data2"
)
registered_dir = data_path / "registered"
processed_dir = data_path / "processed"

In [6]:
train_loader, valid_loader = get_dataloaders(registered_dir)

In [34]:
batch = next(iter(train_loader))
model = DenseNet(spatial_dims=3, in_channels=1, out_channels=512).to("cuda")

In [77]:
def train_distance_model(model, optimizer, loss_fn, dataloader, num_epochs):

    for epoch in range(num_epochs):
        for iteration, batch in enumerate(dataloader):

            bboxes = get_cropped_bboxes(batch["image"], "RandSpatialCropSamples")
            centers = get_bbox_centers(bboxes)

            gt_dist_mat = torch.cdist(
                torch.tensor(centers), torch.tensor(centers), p=2.0
            ).float()

            # Forward pass
            image = batch["image"].to("cuda")
            embeddings = model(image).cpu()

            pred_dist_mat = torch.cdist(embeddings, embeddings, p=2)

            loss = loss_fn(pred_dist_mat, gt_dist_mat)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Log
            print(f"Epoch/Iteration {epoch:03}/{iteration:03} Loss, {loss.item():.3f}")

In [78]:
# Define the model and optimizer
model = DenseNet(spatial_dims=3, in_channels=1, out_channels=512).to("cuda")
optimizer = torch.optim.Adam(model.parameters())

# Define the loss function
loss_fn = nn.MSELoss()

# Define the DataLoader
registered_dir = data_path / "registered"
train_loader, valid_loader = get_dataloaders(registered_dir)

In [79]:
# Train the model

losses = []

train_distance_model(model, optimizer, loss_fn, train_loader, num_epochs=10)

Epoch/Iteration 000/000 Loss, 35352.285
Epoch/Iteration 000/001 Loss, 6343.036
Epoch/Iteration 000/002 Loss, 32113.518
Epoch/Iteration 000/003 Loss, 16899.178
Epoch/Iteration 000/004 Loss, 24254.955
Epoch/Iteration 000/005 Loss, 13105.867
Epoch/Iteration 000/006 Loss, 2248.387
Epoch/Iteration 000/007 Loss, 17576.955
Epoch/Iteration 000/008 Loss, 1191.772
Epoch/Iteration 000/009 Loss, 13198.685
Epoch/Iteration 000/010 Loss, 800.718
Epoch/Iteration 000/011 Loss, 9826.324
Epoch/Iteration 000/012 Loss, 12160.454
Epoch/Iteration 000/013 Loss, 3000.873
Epoch/Iteration 000/014 Loss, 3254.417
Epoch/Iteration 000/015 Loss, 5987.673
Epoch/Iteration 000/016 Loss, 6317.909
Epoch/Iteration 000/017 Loss, 5574.994
Epoch/Iteration 000/018 Loss, 8782.736
Epoch/Iteration 000/019 Loss, 10295.872
Epoch/Iteration 000/020 Loss, 2144.821
Epoch/Iteration 000/021 Loss, 2438.458
Epoch/Iteration 000/022 Loss, 1574.512
Epoch/Iteration 000/023 Loss, 2387.370
Epoch/Iteration 000/024 Loss, 2304.122
Epoch/Iteration 0


KeyboardInterrupt

