In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import math
from itertools import combinations
from pathlib import Path
from typing import List, Tuple

import airlab as al
import cv2
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
from einops import rearrange
from ipywidgets import fixed, interact
from monai import transforms as tfm
from monai.data import DataLoader, Dataset
from monai.data.meta_tensor import MetaTensor
from monai.networks.nets import DenseNet
from torch import nn
from tqdm import tqdm
from wbpetct.data import FDG_PET_CT_Dataset, get_train_valid_loaders
import wandb
from meddist.data import get_dataloaders, get_transformations
from meddist.dist import (
    euclidean_dist,
    get_bbox_centers,
    get_cropped_bboxes,
    pairwise_comparisons,
)

In [24]:
class NN(torch.nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        
        self.linear1 = torch.nn.Linear(10, 5)
        self.linear2 = torch.nn.Linear(5, 1)
        
    
    def forward(inp):
        return self.linear2(self.linear1(inp))

In [31]:
nn = NN()
adam = torch.optim.Adam(nn.parameters())

In [33]:
dir(adam)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_hook_for_profile',
 '_zero_grad_profile_name',
 'add_param_group',
 'defaults',
 'load_state_dict',
 'param_groups',
 'state',
 'state_dict',
 'step',
 'zero_grad']

In [3]:
np.set_printoptions(precision=2, suppress=True)
torch.set_printoptions(precision=2, sci_mode=False)

In [4]:
def plot_means(volume, batch=0, show=False):

    if len(volume.shape) == 4:
        volume = volume[batch]

    fig, axs = plt.subplots(1, 3)

    for i in range(3):
        axs[i].imshow(np.mean(volume, axis=i), cmap="gray")
        axs[i].set_axis_off()

    if show:
        plt.show()
        return

    return fig, axs

In [5]:
raw_data_path = "/sc-scratch/sc-scratch-gbm-radiomics/tcia/manifest-1654187277763/nifti/FDG-PET-CT-Lesions"

data_path = Path(
    "/sc-scratch/sc-scratch-gbm-radiomics/tcia/manifest-1654187277763/nifti/FDG-PET-CT-Lesions-data2"
)
registered_dir = data_path / "registered"
processed_dir = data_path / "processed"

In [9]:
train_loader, valid_loader = get_dataloaders(registered_dir, 5, add_intensity_augmentation=True)

In [10]:
batch = next(iter(train_loader))
# model = DenseNet(spatial_dims=3, in_channels=1, out_channels=512).to("cuda")

In [11]:
for i, batch in enumerate(tqdm(train_loader)):
    if i == 20:
        break

  3%|▎         | 20/585 [00:48<22:58,  2.44s/it]


In [None]:
import math

def max_crop_distance(image_size, crop_size):
    height, width = image_size
    crop_height, crop_width = crop_size, crop_size
    diagonal_distance = math.sqrt((height - crop_height)**2 + (width - crop_width)**2)
    return diagonal_distance / 2

In [None]:
import math

def max_crop_distance_3d(image_volume, crop_volume):
    depth, height, width = image_volume
    crop_depth, crop_height, crop_width = crop_volume, crop_volume, crop_volume
    diagonal_distance = math.sqrt((depth - crop_depth)**2 + (height - crop_height)**2 + (width - crop_width)**2)
    return diagonal_distance / 2


In [None]:
max_crop_distance_3d((400, 400, 319), 128)

In [16]:
lr = 0.001
for i in range(50):
    lr *= 0.9
    print(lr)
    

0.0009000000000000001
0.0008100000000000001
0.000729
0.0006561000000000001
0.00059049
0.000531441
0.0004782969
0.00043046721
0.000387420489
0.0003486784401
0.00031381059609000004
0.00028242953648100003
0.00025418658283290005
0.00022876792454961005
0.00020589113209464906
0.00018530201888518417
0.00016677181699666576
0.0001500946352969992
0.0001350851717672993
0.00012157665459056936
0.00010941898913151243
9.847709021836118e-05
8.862938119652506e-05
7.976644307687256e-05
7.17897987691853e-05
6.461081889226677e-05
5.81497370030401e-05
5.233476330273609e-05
4.7101286972462485e-05
4.239115827521624e-05
3.8152042447694614e-05
3.433683820292515e-05
3.090315438263264e-05
2.7812838944369376e-05
2.503155504993244e-05
2.2528399544939195e-05
2.0275559590445276e-05
1.8248003631400748e-05
1.6423203268260675e-05
1.4780882941434607e-05
1.3302794647291146e-05
1.1972515182562031e-05
1.0775263664305828e-05
9.697737297875246e-06
8.727963568087722e-06
7.85516721127895e-06
7.069650490151056e-06
6.36268544113

In [None]:
def train_distance_model(model, optimizer, loss_fn, dataloader, num_epochs):

    for epoch in range(num_epochs):
        for iteration, batch in enumerate(dataloader):

            bboxes = get_cropped_bboxes(batch["image"], "RandSpatialCropSamples")
            centers = get_bbox_centers(bboxes)

            gt_dist_mat = torch.cdist(
                torch.tensor(centers), torch.tensor(centers), p=2.0
            ).float()

            # Forward pass
            image = batch["image"].to("cuda")
            embeddings = model(image).cpu()

            pred_dist_mat = torch.cdist(embeddings, embeddings, p=2)

            loss = loss_fn(pred_dist_mat, gt_dist_mat)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Log
            print(f"Epoch/Iteration {epoch:03}/{iteration:03} Loss, {loss.item():.3f}")

In [None]:
# Define the model and optimizer
model = DenseNet(spatial_dims=3, in_channels=1, out_channels=512).to("cuda")
optimizer = torch.optim.Adam(model.parameters())

# Define the loss function
loss_fn = nn.MSELoss()

# Define the DataLoader
registered_dir = data_path / "registered"
train_loader, valid_loader = get_dataloaders(registered_dir)

In [None]:
# Train the model

losses = []

train_distance_model(model, optimizer, loss_fn, train_loader, num_epochs=10)