In [1]:
from pathlib import Path

import submitit
import torch
from diffdrr.drr import DRR
from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d
from pytorch_transformers.optimization import WarmupCosineSchedule
from timm.utils.agc import adaptive_clip_grad as adaptive_clip_grad_
from tqdm import tqdm
import ipdb
# import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as T

import tifffile as tiff

import matplotlib.pyplot as plt
from diffdrr.drr import DRR
from tqdm import tqdm
import random


from pathlib import Path
from PIL import Image

import statistics
import gc
import cv2


In [2]:
import wandb
# wandb.login(key = api)
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbananabond[0m ([33mkneedeeppose[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
#| export
from pathlib import Path
from typing import Optional, Union

import h5py
import numpy as np
import torch
from beartype import beartype
import nibabel as nib
import pandas as pd
import tifffile as tiff

import os
import sys

# Function to find the root directory of the project
def find_project_root(current_path, marker_file):
    while current_path != os.path.dirname(current_path):
        if marker_file in os.listdir(current_path):
            return current_path
        current_path = os.path.dirname(current_path)
    return None

# Get the current working directory
cwd = os.getcwd()

marker_file = '.git'

# Find the root directory
project_root = find_project_root(cwd, marker_file)

sys.path.append(project_root)

from diffpose.calibration import RigidTransform, perspective_projection, convert
from diffpose.kneefit import Transforms, get_random_offset
from diffpose.metrics import DoubleGeodesic, GeodesicSE3
from diffpose.registration import PoseRegressor

device = 'cuda'

# `KneeFit`

In [4]:
#| export
@beartype
class KneeFitDataset(torch.utils.data.Dataset):
    """
    Get X-ray projections and poses from specimens in the `DeepFluoro` dataset.

    Given a specimen ID and projection index, returns the projection and the camera matrix for DiffDRR.
    """

    def __init__(
        self,
        id_number: int,  # Specimen number 
        filename: Optional[Union[str, Path]] = None,  # Path to Sipla dataset file
        preprocess: bool = True,  # Preprocess X-rays,
        volume_type: str = "femur"
    ):
        # Load the volume
        (
            self.projections,
            self.segnet_input,
            self.volume,
            self.spacing,
            self.lps2volume,
            self.intrinsic,
            self.extrinsic,
            self.focal_len,
            self.x0,
            self.y0,
            self.delx
        ) = load_kneefit_dataset(id_number, filename, volume_type)
        self.preprocess = preprocess
        self.volume_type = volume_type

        # Get the isocenter pose (AP viewing angle at volume isocenter)
        # TODO: probably have to change the viewing angle here to LAT viewing angle

        # isocenter_xyz = torch.tensor([100 , -175 , 100], dtype=torch.float32).unsqueeze(0)  ## Left Leg

        # isocenter_rot = torch.tensor([[0,0,0]], dtype=torch.float32) ## LEFT Leg
        isocenter_rot = torch.tensor([[-torch.pi / 2, -torch.pi, 0]])
        isocenter_xyz = torch.tensor(self.volume.shape) * self.spacing / 2
        isocenter_xyz = isocenter_xyz.unsqueeze(0)
        shift = torch.tensor([[-972/2, 0, 0]], dtype=torch.float32)
        # isocenter_xyz += shift


        self.isocenter_pose = RigidTransform(
            isocenter_rot, isocenter_xyz, "euler_angles", "XYZ"
        )

        # Miscellaneous transformation matrices for wrangling SE(3) poses
        self.flip_xz = RigidTransform(
            torch.tensor([[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]),
            torch.zeros(3),
        )
        self.translate = RigidTransform(
            torch.eye(3),
            torch.tensor([-self.focal_len / 2, 0, 0]),
        )
        self.flip_180 = RigidTransform(
            torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]),
            torch.zeros(3),
        )

    def __len__(self):
        return len(self.projections)

    def __iter__(self):
        return iter(self[idx] for idx in range(len(self)))

    def __getitem__(self, idx):
        """
        (1) Swap the x- and z-axes
        (2) Reverse the x-axis to make the matrix E(3) -> SE(3)
        (3) Move the camera to the origin
        (4) Rotate the detector plane by 180, if offset
        (5) Form the full SE(3) transformation matrix
        """
        # print("get item called")
        projection = self.projections.iloc[idx]
        img = self.process_image(projection["img_path"])

        # img_raw = tiff.imread(projection["img_path"])
        segnet_ip = self.segnet_input[idx]
        print("segnet_ip", segnet_ip)
        segnet_path = os.path.join("D:\kneefit_femur_synthetic", segnet_ip)
        print("segnet_path", segnet_path)
        segnet_img = self.process_image(segnet_path, png=True)
        
        world2volume = self.get_gt_pose(projection)
        # world2volume[:3, 3] = torch.tensor([120,90,250],dtype=torch.float32)
        world2volume = RigidTransform(world2volume[:3, :3], world2volume[:3, 3])
        pose = convert_kneefit_to_diffdrr(self, world2volume)
        # print(pose)

        # Handle rotations in the imaging dataset
        # if self._rot_180_for_up(idx):
        #     img = torch.rot90(img, k=2)
        #     pose = self.flip_180.compose(pose)

        # Optionally, preprocess the images
        img = img.unsqueeze(0).unsqueeze(0)
        if self.preprocess:
            img = preprocess(img)
            segnet_img = preprocess(segnet_img)



        return img, pose, segnet_img
    
    def get_gt_pose(self, projection):
        # # print(projection)
        ## tibia local machine
        # gt_pose = torch.tensor([[projection[f"{self.volume_type}_rxx"], projection[f"{self.volume_type}_rxy"], projection[f"{self.volume_type}_rxz"], -116.402],
        #                         [projection[f"{self.volume_type}_ryx"], projection[f"{self.volume_type}_ryy"], projection[f"{self.volume_type}_ryz"], -196.095],
        #                         [projection[f"{self.volume_type}_rzx"], projection[f"{self.volume_type}_rzy"], projection[f"{self.volume_type}_rzz"], -947.326],
        #                         [0, 0, 0, 1]], dtype=torch.float32)
        

        ## femur local machine 
        # gt_pose = torch.tensor([[projection[f"{self.volume_type}_rxx"], projection[f"{self.volume_type}_rxy"], projection[f"{self.volume_type}_rxz"], -217.279],
        #                         [projection[f"{self.volume_type}_ryx"], projection[f"{self.volume_type}_ryy"], projection[f"{self.volume_type}_ryz"], -34.6134],
        #                         [projection[f"{self.volume_type}_rzx"], projection[f"{self.volume_type}_rzy"], projection[f"{self.volume_type}_rzz"], -973.889],
        #                         [0, 0, 0, 1]], dtype=torch.float32)
        gt_pose = torch.tensor([[projection[f"{self.volume_type}_rxx"], projection[f"{self.volume_type}_rxy"], projection[f"{self.volume_type}_rxz"], projection[f"{self.volume_type}_tx"]],
                        [projection[f"{self.volume_type}_ryx"], projection[f"{self.volume_type}_ryy"], projection[f"{self.volume_type}_ryz"], projection[f"{self.volume_type}_ty"]],
                        [projection[f"{self.volume_type}_rzx"], projection[f"{self.volume_type}_rzy"], projection[f"{self.volume_type}_rzz"], projection[f"{self.volume_type}_tz"]],
                        [0, 0, 0, 1]], dtype=torch.float32)
        return gt_pose
    
    def process_image(self, img_path, png=False):

        if png:
            image = Image.open(img_path)
            xray_img = np.asarray(image)
        else:
            xray_img = tiff.imread(img_path)

        # print(type(xray_img))
        # print(xray_img.shape)
        xray_tensor = torch.tensor(xray_img, dtype=torch.float32, device='cuda')
        plt.imshow(xray_tensor.cpu().numpy())
        xray_tensor = xray_tensor.unsqueeze(0).unsqueeze(0)

        # print(xray_tensor.shape)
        return xray_tensor


In [5]:
from pathlib import Path

import submitit
import torch
from diffdrr.drr import DRR
from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d
from pytorch_transformers.optimization import WarmupCosineSchedule
from timm.utils.agc import adaptive_clip_grad as adaptive_clip_grad_
from tqdm import tqdm
import ipdb
# import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as T

import tifffile as tiff

import matplotlib.pyplot as plt
from diffdrr.drr import DRR
from tqdm import tqdm
import random


from pathlib import Path
from PIL import Image

import statistics
import gc
import cv2


In [6]:
from pathlib import Path

import submitit
import torch
from diffdrr.drr import DRR
from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d
from pytorch_transformers.optimization import WarmupCosineSchedule
from timm.utils.agc import adaptive_clip_grad as adaptive_clip_grad_
from tqdm import tqdm
import ipdb
# import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as T

import tifffile as tiff

import matplotlib.pyplot as plt
from diffdrr.drr import DRR
from tqdm import tqdm
import random


from pathlib import Path
from PIL import Image

import statistics
import gc
import cv2


In [7]:
#| export
def convert_kneefit_to_diffdrr(specimen, pose: RigidTransform):
    """Transform the camera coordinate system used in Kneefit to the convention used by DiffDRR."""
    extrinsicc = (
            specimen.flip_xz.inverse().compose(specimen.translate.inverse()).compose(pose)
        )
    return pose
    return (
        specimen.lps2volume.inverse().compose(pose).compose(specimen.flip_xz).compose(specimen.translate)
    )

def convert_diffdrr_to_deepfluoro(specimen, pose: RigidTransform):
    """Transform the camera coordinate system used in DiffDRR to the convention used by DeepFluoro."""
    return pose
    return (
        # specimen.lps2volume.inverse()
        specimen.compose(pose.inverse())
        # .compose(specimen.translate)
        .compose(specimen.flip_xz)
    )

In [8]:
from torch.nn.functional import pad

from diffpose.calibration import perspective_projection


class Evaluator:
    def __init__(self, specimen, idx):
        # Save matrices to device
        self.translate = specimen.translate
        self.flip_xz = specimen.flip_xz
        self.intrinsic = specimen.intrinsic
        self.intrinsic_inv = specimen.intrinsic.inverse()

        # Get gt fiducial locations
        self.specimen = specimen
        self.fiducials = specimen.fiducials
        gt_pose = specimen[idx][1]
        self.true_projected_fiducials = self.project(gt_pose)

    def project(self, pose):
        extrinsic = convert_diffdrr_to_deepfluoro(self.specimen, pose)
        x = perspective_projection(extrinsic, self.intrinsic, self.fiducials)
        x = -self.specimen.focal_len * torch.einsum(
            "ij, bnj -> bni",
            self.intrinsic_inv,
            pad(x, (0, 1), value=1),  # Convert to homogenous coordinates
        )
        extrinsic = (
            self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose)
        )
        return extrinsic.transform_points(x)

    def __call__(self, pose):
        pred_projected_fiducials = self.project(pose)
        registration_error = (
            (self.true_projected_fiducials - pred_projected_fiducials)
            .norm(dim=-1)
            .mean()
        )
        registration_error *= 0.194  # Pixel spacing is 0.194 mm / pixel isotropic
        return registration_error

In [9]:
#| export
from diffdrr.utils import parse_intrinsic_matrix, get_principal_point


def load_kneefit_dataset(id_number, filename, volume_type):
    # Open the H5 file for the dataset
    if filename is None:
        # root = Path(__file__).parent.parent.absolute()
        filename = Path.cwd().parent.parent.absolute() / "data/sipla_bone_local_preprocessed.csv"

    print(filename)
    f = pd.read_csv(filename)
    assert id_number in range(1, f["patient_id"].nunique())
    f = f[f["patient_id"] == id_number]
    (
        intrinsic,
        extrinsic,
        num_cols,
        num_rows,
        proj_col_spacing,
        proj_row_spacing,
    ) = parse_proj_params(f)

    # uncomment if focal length is in unit length
    focal_len = intrinsic[0,0]
    x0, y0 = intrinsic[0,2], intrinsic[1,2]

    # uncomment if focal length is in pixels
    # focal_len, x0, y0 = parse_intrinsic_matrix(
    #     intrinsic,
    #     num_rows,
    #     num_cols,
    #     proj_row_spacing,
    #     proj_col_spacing,
    # )


    # Try to load the particular specimen
    projections = f[f["patient_id"] == id_number]
    segnet_input = [img for img in os.listdir(r"D:\\kneefit_femur_synthetic")]
    segnet_input = segnet_input[projections.first_valid_index():projections.first_valid_index() + projections.shape[0]]

    print("Segnet_Input", segnet_input)
    print(len(segnet_input))

    assert projections.shape[0] == len(segnet_input)

    # Parse the volume
    volume, spacing, lps2volume = parse_volume(projections, volume_type)
    return (
        projections,
        segnet_input,
        volume,
        spacing,
        lps2volume,
        intrinsic,
        extrinsic,
        focal_len,
        x0,
        y0,
        proj_row_spacing
    )

def get_volume_data(volume):
    nii_data = nib.load(volume)
    # Access the affine transformation matrix
    # affine = nii_data.affine
    # Access the image data as a NumPy array
    data = nii_data.get_fdata()
    # Access the affine transformation matrix
    affine = nii_data.affine
    return data, affine


def parse_volume(specimen, volume_type):
    # Parse the volume
    #TODO: use pitch as spacing if using stl converted to voxel grid OR use code from nii library nibabel from data_analysis notebook
    # spacing = specimen["vol/spacing"][:].flatten()
    print(specimen["spacing"].unique()[0])
    spacing = np.array(eval(specimen["spacing"].unique()[0])).astype(np.float32)
    #TODO: check shape of 3d array
    # import ipdb; ipdb.set_trace()
    volume, affine = get_volume_data(specimen[f"{volume_type}_nii"].unique()[0])
    print("affine transformation matrix", affine)

    volume = volume.astype(np.float32)
    # volume = specimen["vol/pixels"][:].astype(np.float32)

    #TODO: is swapping the axis same as changing lateral to AP view?
    # volume = np.swapaxes(volume, 0, 2)[::-1].copy()

    # Parse the translation matrix from LPS coordinates to volume coordinates
    origin = torch.tensor(affine[:3, 3])
    # origin = torch.tensor([39, 39, -140], dtype=torch.float32)
    lps2volume = RigidTransform(torch.eye(3), origin)
    return volume, spacing, lps2volume


def parse_proj_params(f):
    # proj_params = f["proj-params"]
    # extrinsic parameters here do not matter, set to zeros
    try:
        fx = f["cal_focal_length"].unique()[0]
        px = f["cal_principalp_x"].unique()[0]
        py = f["cal_principalp_y"].unique()[0]
    except Exception as err:
        print("camera instrinsics must be the same for all frames of 1 patient")
        sys.exit()
    extrinsic = torch.eye(4, dtype=torch.float32)
    extrinsic[0, 3] = px
    extrinsic[1, 3] = py
    extrinsic[2, 3] = fx

    

    # TODO: understand why extrinsic goes through RigidTransform
    extrinsic = RigidTransform(extrinsic[..., :3, :3], extrinsic[:3, 3])
    intrinsic = torch.tensor([[fx, 0, px],
                              [0, fx, py],
                              [0, 0, 1]], dtype=torch.float32)
    num_cols = num_rows = 1000
    proj_col_spacing = proj_row_spacing = float(f["cal_mm_per_pxl"].unique()[0])
    return intrinsic, extrinsic, num_cols, num_rows, proj_col_spacing, proj_row_spacing


In [10]:
#| export
from torchvision.transforms.functional import center_crop, gaussian_blur


def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)):
    """
    Recover the line integral: $L[i,j] = \log I_0 - \log I_f[i,j]$

    (1) Remove edge due to collimator
    (2) Smooth the image to make less noisy
    (3) Subtract the log initial energy for each ray
    (4) Recover the line integral image
    (5) Rescale image to [0, 1]
    """
    img = center_crop(img, (1436, 1436))
    img = gaussian_blur(img, (5, 5), sigma=1.0)
    img = initial_energy.log() - img.log()
    img = (img - img.min()) / (img.max() - img.min())
    return img

In [11]:
def check_threshold(drr_values, threshold):
    negative_map = torch.tensor(drr_values < threshold).nonzero(as_tuple=True)[0]
    return negative_map

In [12]:
# offsets = get_random_offset(2, device)
# print(offsets[0][0])

In [13]:
def plot_grad_flow(named_parameters):
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in named_parameters:
        if p.grad is None:
            print('this layer grad is none', n)

        elif(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.abs().mean().detach().cpu())
            max_grads.append(p.abs().max().detach().cpu())
            # print("ok")

    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=max(max_grads)) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")

    return plt

In [14]:
def gen_random_offset(batch_size, device):

    # rotation along x is good
    # r1 = torch.distributions.Normal(0, 0.3*torch.pi).sample((batch_size,))
    # r2 = torch.distributions.Normal(0, 0.01*torch.pi).sample((batch_size,))
    # r3 = torch.distributions.Normal(0, 0.1*torch.pi).sample((batch_size,))
    
    # rot =  torch.stack([r1, r2, r3], dim=1).to(device)


    # rand_T = RigidTransform(rot, trans, "euler_angles", "XYZ")


    r1_se3 = torch.distributions.Normal(0, 2.1).sample((batch_size,))
    r2_se3 = torch.distributions.Normal(0, 1.5).sample((batch_size,))
    r3_se3 = torch.distributions.Normal(0, 0.5).sample((batch_size,))
    # r1_se3, r2_se3, r3_se3 = torch.tensor([2, -2, 2, -2]), torch.tensor([0.0, -0.0, 1, -1]), torch.tensor([0.1, -0.1,0.1, -0.1])


    
    t1 = torch.distributions.Normal(-50, 1).sample((batch_size,))
    t2 = torch.distributions.Normal(0, 50).sample((batch_size,))
    t3 = torch.distributions.Normal(0, 50).sample((batch_size,))

    rot_se3 =  torch.stack([r1_se3, r2_se3, r3_se3], dim=1).to(device)
    trans_se3 = torch.stack([t1, t2, t3], dim=1).to(device)
    rand_T_se3 =  convert(
        [rot_se3, trans_se3],
        "se3_log_map",
        "se3_exp_map",
    )

    # return rand_T
    return rand_T_se3


def log_memory_usage(epoch, prefix=""):
    """Logs GPU memory usage to Wandb."""
    memory_allocated = torch.cuda.memory_allocated()
    memory_reserved = torch.cuda.memory_reserved()
    
    wandb.log({
        f"{prefix}memory_allocated": memory_allocated,
        f"{prefix}memory_reserved": memory_reserved,
        "step": epoch
    })

def normalize_ncc(ncc_loss):
    return (ncc_loss - ncc_loss.min()) / (ncc_loss.max() - ncc_loss.min() + 1e-8)

def normalize_geodesic(geodesic_loss):
    return (geodesic_loss - geodesic_loss.min()) / (geodesic_loss.max() - geodesic_loss.min() + 1e-8)

# DRR for knee pose

In [15]:
def  synthetic_train(
    run_id,
    id_number,
    model,
    optimizer,
    scheduler,
    drr,
    transforms,
    isocenter_pose,
    device,
    batch_size,
    n_epochs,
    n_batches_per_epoch,
    model_params,
    volume_type,
):
    # initialize losses
    print("ISO = ", isocenter_pose.get_translation(), isocenter_pose.get_rotation("euler_angles", "XYZ"))
    metric = MultiscaleNormalizedCrossCorrelation2d(patch_sizes=[16, 64], patch_weights=[0.75, 0.25], eps=1e-4)
    # geodesic = GeodesicSE3(geo_weights=torch.tensor([1/2.1, 1/1.5, 1/0.5], dtype=torch.float))
    # double = DoubleGeodesic(drr.detector.sdr, geo_weights=torch.tensor([1/2.1, 1/1.5, 1/0.5], dtype=torch.float))
    double = DoubleGeodesic(drr.detector.sdr)
    geodesic = GeodesicSE3()

    # contrast_distribution = torch.distributions.Uniform(1.0, 10.0)
    # pixel_threshold = 4000
    transito = RigidTransform(torch.eye(3, dtype=torch.float32, device=device), torch.tensor([972/2, 0,0], dtype=torch.float32, device=device))
    best_loss = torch.inf
    # pixel_check = False
    model.train()
    #  visualization part
    # generate random R and t for CT volume
    print("Debug", batch_size, device)

    # ctr_break_ctr = 0

    for epoch in range(n_epochs + 1):
        # print("CTR Break triggered = ", ctr_break_ctr)
        losses_ep = []
        geodesic_ep = []
        ncc_ep = []
        double_geo_ep = []
        geo_xyz_ep = []
        geo_rot_ep = []


        for _ in (itr := tqdm(range(n_batches_per_epoch), leave=False)):
            # bone_attenuation is chosen randomly from a uniform distribution
            contrast = 2.0
            # generate random pose using normal distribution
            # drr_pixels = torch.zeros(batch_size)
            # ctr = 0
            # negative_map = check_threshold(drr_pixels, pixel_threshold)
            # while (not all(pixel > pixel_threshold for pixel in drr_pixels)):
                    # negative_map = check_threshold(drr_pixels, pixel_threshold)
                    # print("Negative map = ", negative_map)
            new_offsets = gen_random_offset(batch_size, device)
            poses = transito.compose(new_offsets.compose(isocenter_pose))

            # offsets[negative_map] = new_offsets[:len(negative_map)]
            # generate artificial xray image (DRR) given a random pose 
            # pose = isocenter_pose
            imgs = drr(None, None, None, pose=poses, bone_attenuation_multiplier=contrast)
            imgs = transforms(imgs)
            imgs = imgs.float()

                # if not pixel_check:
                #     break

                # drr_pixels = sum(sum(imgs.squeeze().cpu().numpy()>0))
                # ctr += 1
                

                # if ctr >= 2:
                #     ctr_break_ctr+=1
                #     print("exceeded tries, ctr break = ", ctr_break_ctr)
                #     break
            # apply predicted pose and generate predicted DRR
            pred_offset = model(imgs)
            pred_pose = transito.compose(pred_offset.compose(isocenter_pose))
            pred_img = drr(None, None, None, pose=pred_pose)
            pred_img = transforms(pred_img)
            pred_img = pred_img.float()

            # compute image reconstruction and numerical losses
            ncc = metric(pred_img, imgs)
            log_geodesic = geodesic(pred_pose, poses)
            geodesic_rot, geodesic_xyz, double_geodesic = double(pred_pose, poses)
            geodesic_weight = 1e-3

            # Normalize NCC values to [0, 1]
            # normalized_ncc = (ncc + 1) / 2

            # print('normalized_ncc', normalized_ncc)
            # print(double_geodesic.max())

            # Compute weights as the inverse of the normalized loss
            # weight_ncc = 1.0 / (normalized_ncc.mean())  # Add epsilon to avoid division by zero
            # weight_geodesic = 1.0 / (geodesic_weight*(double_geodesic.max() + log_geodesic.max()))  # Add epsilon to avoid division by zero

            # print('before normalization')
            # print("weight_ncc", weight_ncc)
            # print("weight_geodesic", weight_geodesic)

            # Normalize weights to ensure they sum to 1
            # total_weight = weight_ncc + weight_geodesic
            # weight_ncc /= total_weight
            # weight_geodesic /= total_weight

            # print('after normalization')
            # print("weight_ncc", weight_ncc)
            # print("weight_geodesic", weight_geodesic)

            # Calculate the final weighted loss
            loss = ( geodesic_weight * (log_geodesic + double_geodesic)) + (1 - ncc)
            # loss = 1 - ncc + geodesic_weight*double_geodesic
            # loss = (1 - ncc) + geodesic_weight * (0.6*geodesic_rot + 0.4*geodesic_xyz)

            
            if loss.isnan().any():
                print("Aaaaaaand we've crashed...")
                # print(ncc)
                print(geodesic_rot)
                print(geodesic_xyz)
                print(geodesic_ep)
                print(double_geodesic)
                print(poses.get_matrix())
                print(pred_pose.get_matrix())
                torch.save(
                    {
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "height": drr.detector.height,
                        "epoch": epoch,
                        "batch_size": batch_size,
                        "n_epochs": n_epochs,
                        "n_batches_per_epoch": n_batches_per_epoch,
                        "pose": poses.get_matrix().cpu(),
                        "pred_pose": pred_pose.get_matrix().cpu(),
                        "img": imgs.cpu(),
                        # "pred_img": pred_img.cpu()
                        **model_params,
                    },
                    f"checkpoints/{run_id}_specimen_{id_number:02d}_{volume_type}_crashed.ckpt",
                )
                raise RuntimeError("NaN loss")

            optimizer.zero_grad()
            loss.mean().backward()
            adaptive_clip_grad_(model.parameters())
            optimizer.step()
            scheduler.step()

            ncc_ep.append(ncc.mean().item())
            geodesic_ep.append(log_geodesic.mean().item())
            double_geo_ep.append(double_geodesic.mean().item())
            geo_xyz_ep.append(geodesic_xyz.mean().item())
            geo_rot_ep.append(geodesic_rot.mean().item())
            losses_ep.append(loss.mean().item())

            # Update progress bar
            itr.set_description(f"Epoch [{epoch}/{n_epochs}]")
            itr.set_postfix(
                geodesic_rot=geodesic_rot.mean().item(),
                geodesic_xyz=geodesic_xyz.mean().item(),
                geodesic_dou=double_geodesic.mean().item(),
                loss=loss.mean().item(),
                ncc=ncc.mean().item(),
            )

    
            # prev_pose = poses
            # prev_pred_pose = pred_pose

            log_memory_usage(int(epoch*batch_size + _))
        
        
        # WANDB
        # print(model)
        # plt = plot_grad_flow(model.named_parameters())
        # plt.show()
        pred_img_log = wandb.Image(pred_img[0].detach().cpu().numpy())
        ip_img_log = wandb.Image(imgs[0].detach().cpu().numpy())
        wandb.log({"pred_img": wandb.Image(pred_img_log),
                    "ip_img" : wandb.Image(ip_img_log),
                    # "gradient_chart": plt,
                    })
        wandb.log({"epoch": epoch,
                   "ncc": statistics.mean(ncc_ep),
                   "log_geodesic": statistics.mean(geodesic_ep),
                   "final_loss":statistics.mean(losses_ep),
                   "geodesic_rot":statistics.mean(geo_rot_ep),
                   "geodesic_xyz":statistics.mean(geo_xyz_ep),
                   "double_geodesic":statistics.mean(double_geo_ep)})
        


        losses = torch.tensor(losses_ep)
        tqdm.write(f"Epoch {epoch + 1:04d} | Loss {losses.mean().item():.4f}")
        # if losses.mean() < best_loss and not losses.isnan().any():
        #     best_loss = losses.mean().item()
        #     torch.save(
        #         {
        #             "model_state_dict": model.state_dict(),
        #             "optimizer_state_dict": optimizer.state_dict(),
        #             "height": drr.detector.height,
        #             "epoch": epoch,
        #             "loss": losses.mean().item(),
        #             "batch_size": batch_size,
        #             "n_epochs": n_epochs,
        #             "n_batches_per_epoch": n_batches_per_epoch,
        #             **model_params,
        #         },
        #         f"checkpoints/{run_id}_specimen_{id_number:02d}_{volume_type}_best.ckpt",
        #     )

        if epoch % 50 == 0:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "height": drr.detector.height,
                    "epoch": epoch,
                    "loss": losses.mean().item(),
                    "batch_size": batch_size,
                    "n_epochs": n_epochs,
                    "n_batches_per_epoch": n_batches_per_epoch,
                    **model_params,
                },
                f"checkpoints/{run_id}_specimen_{id_number:02d}_{volume_type}_epoch{epoch:03d}.ckpt",
            )

        torch.cuda.empty_cache()
        gc.collect()

In [31]:
def get_deeplab_model(device):
    # Initialize DeepLabV3 with a ResNet backbone
    # segmentation_model = torch.hub.load('pytorch/vision:v0.7.0', 'deeplabv3_resnet50', pretrained=True)
    model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
    # model = torchvision.models.segmentation.deeplabv3_resnet50(weights='COCO_WITH_VOC_LABELS_V1', num_classes =3)
    model.backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.classifier[4] = torch.nn.Conv2d(in_channels=256, out_channels=3, kernel_size=1, stride=1)
    
    model = model.to(device)  # Move model to GPU

    return model
    

In [46]:
m = get_deeplab_model(device)
test = torch.rand(2, 1,256,256, dtype=torch.float32, device=device)
op = m(test)
op["out"]


Using cache found in C:\Users\Public Admin/.cache\torch\hub\pytorch_vision_v0.10.0


tensor([[[ 0.1659,  0.1659,  0.1659,  ...,  0.1224,  0.1224,  0.1224],
         [ 0.1659,  0.1659,  0.1659,  ...,  0.1224,  0.1224,  0.1224],
         [ 0.1659,  0.1659,  0.1659,  ...,  0.1224,  0.1224,  0.1224],
         ...,
         [-0.0268, -0.0268, -0.0268,  ..., -0.1881, -0.1881, -0.1881],
         [-0.0268, -0.0268, -0.0268,  ..., -0.1881, -0.1881, -0.1881],
         [-0.0268, -0.0268, -0.0268,  ..., -0.1881, -0.1881, -0.1881]],

        [[ 0.9364,  0.9364,  0.9364,  ...,  0.5916,  0.5916,  0.5916],
         [ 0.9364,  0.9364,  0.9364,  ...,  0.5916,  0.5916,  0.5916],
         [ 0.9364,  0.9364,  0.9364,  ...,  0.5916,  0.5916,  0.5916],
         ...,
         [ 0.3930,  0.3930,  0.3930,  ...,  0.6624,  0.6624,  0.6624],
         [ 0.3930,  0.3930,  0.3930,  ...,  0.6624,  0.6624,  0.6624],
         [ 0.3930,  0.3930,  0.3930,  ...,  0.6624,  0.6624,  0.6624]],

        [[-0.2432, -0.2432, -0.2432,  ..., -0.2748, -0.2748, -0.2748],
         [-0.2432, -0.2432, -0.2432,  ..., -0

In [16]:
import torchvision.transforms as transforms

def finetune_train(
        run_id,
    id_number,
    model,
    optimizer,
    scheduler,
    drr,
    transforms,
    specimen,
    isocenter_pose,
    device,
    batch_size,
    n_epochs,
    n_batches_per_epoch,
    model_params,
    volume_type,
):
    


    print("model loaded")

    # Define the loss function and optimizer
    criterion = torch.nn.BCEWithLogitsLoss()  # Binary Cross-Entropy with logits (useful for binary classification)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model.train()
    #  visualization part

    # TODO - Get num XRAYS in here
    num_xrays = len(specimen.projections)
    print(num_xrays)
    ids = [i for i in range(num_xrays)]

    print("number of xrays is {num_xrays} for patient {id_number}")

    split_index = int(len(ids) * 0.8)

    # Split the IDs into training and testing sets
    train_ids = ids[:split_index]
    test_ids = ids[split_index:]
    
    random.shuffle(ids)
    true_xray_tensor_batch, iso_gt_pose_batch, output_xray_batch = process_specimens(specimen, isocenter_pose, train_ids, batch_size)

    print("inputs and targets loaded")
    # print(len(true_xray_tensor_batch), len(true_xray_tensor_batch))
    
    # visualize(drr=drr, pose=pose,device=device)
    for epoch in range(n_epochs + 1):
        running_loss = 0.0
        # print('epoch entered')
        for itr, (true_xray, seg_xray) in enumerate(tqdm(zip(true_xray_tensor_batch, output_xray_batch), total=len(true_xray_tensor_batch), leave=False)):

            true_xray = true_xray.to(device)
            seg_xray = seg_xray.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(true_xray)['out']  # DeepLabV3 returns a dictionary with 'out' as the segmentation map

            # Compute the loss
            loss = criterion(outputs, seg_xray)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            itr.set_description(f"Epoch [{epoch}/{n_epochs}]")
            itr.set_postfix(
                loss=loss.item()
            )
        tqdm.write(f"Epoch {epoch + 1:04d} | Loss {running_loss/len(true_xray_tensor_batch):.4f}")
        if epoch % 10 == 0:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "height": drr.detector.height,
                    "epoch": epoch,
                    "loss": running_loss/len(true_xray_tensor_batch),
                    # "batch_size": batch_size,
                    # "n_epochs": n_epochs,
                    # "n_batches_per_epoch": n_batches_per_epoch,
                    **model_params,
                },
                f"checkpoints/segnet_{volume_type}_epoch{epoch:03d}.ckpt",
            )



In [17]:
def apply_xavier_weights(model):
    for m in model.modules():
        if isinstance(m, torch.nn.Linear):
            # Apply Xavier Uniform initialization to Linear layers
            torch.nn.init.xavier_uniform_(m.weight)
            # print(m.weight)
            if m.bias is not None:
                m.bias.data.fill_(0.01)  # Initialize bias with a small constant value

In [18]:
#| eval: false
def load(id_number, height, device, volume_type="femur"):
    if volume_type not in ["femur", "tibia"]:
        raise Exception("Specify volume type between femur and tibia")
    sipla_path = Path.cwd().parent.parent.absolute() / "data/sipla_bone_local_preprocessed.csv"
    specimen = KneeFitDataset(id_number, filename=sipla_path, preprocess=False, volume_type=volume_type)
    isocenter_pose = specimen.isocenter_pose.to(device)

    # Take the height of xray sample and divide by image dimensions required by ResNet
    subsample = 1000 / height
    # TODO: delx to be retreived from combined csv file
    # delx is different for different samples, below hard-coded for SUB2
    delx = specimen.delx * subsample
    drr = DRR(
        specimen.volume,
        specimen.spacing,
        float(specimen.focal_len),
        height,
        delx,
        x0=specimen.x0,
        y0=specimen.y0,
        reverse_x_axis=False,
        bone_attenuation_multiplier=2.0,
        patch_size=64
    ).to(device)
    transforms = Transforms(height)

    return specimen, isocenter_pose, transforms, drr

In [19]:
def process_specimens(specimen, isocenter_pose, ids, batch_size):
    # specimen, isocenter_pose, transforms, drr = load(5, height=256, device=device, volume_type="tibia")
    curr = 0
    while curr <= len(ids):
        batch_ids = ids[curr:curr+batch_size]
        curr = curr+batch_size

        true_xray_batch = []
        iso_gt_pose_batch = []
        output_xray_batch = []
        for id in batch_ids:
            true_xray_tensor, pose, segnet_input = specimen[id]
            # print(true_xray_tensor.shape)

            projection = specimen.projections.iloc[id]

            base_name = Path(projection["img_path"]).stem

            im = Image.open(projection["img_path"])
            im = im.resize((256, 256))
            if not os.path.exists(project_root+rf"\data\example\{base_name}-256_resize.tif"):
                im.save(project_root+rf"\data\example\{base_name}-256_resize.tif")
            
            segnet_input = os.path.join("D:\kneefit_femur_synthetic",specimen.segnet_input[id])
            base_name = Path(segnet_input).stem
            seg_im = Image.open(segnet_input)
            seg_im = seg_im.resize((256, 256))
            if not os.path.exists(project_root+rf"\data\segnet_resize\{base_name}-256_resize.tif"):
                seg_im.save(project_root+rf"\data\segnet_resize\{base_name}-256_resize.tif")
            seg_im = np.asarray(seg_im)
            iso_gt_pose = isocenter_pose.compose(pose.to(device="cuda"))
            binary_mask = ( seg_im> 0)
            seg_im = seg_im * binary_mask
            seg_tensor = torch.tensor(seg_im, dtype=torch.float32, device = device)
            # print(iso_gt_pose)

            true_xray_batch.append(true_xray_tensor)
            iso_gt_pose_batch.append(iso_gt_pose)
            output_xray_batch.append(seg_tensor)

            # print(len(true_xray_batch), len(iso_gt_pose_batch))



    return true_xray_batch, iso_gt_pose_batch, output_xray_batch

In [20]:
def visualize( drr, pose , device):
    
    pred_xray = drr(None, None, None, pose=pose.to(device))
    print(pose.get_translation())
    print(pose.get_rotation())
    xray = pred_xray[0,:,:,:]
    plt.figure(constrained_layout=False)
    plt.subplot(121)
    plt.title("DRR")
    plt.imshow(xray.squeeze().detach().cpu().numpy(), cmap="gray")
    plt.show()

In [21]:
def vis_drr(drr, pose, device):
    pred_xray = drr(None, None, None, pose=pose.to(device))

    #| label: bone_attenuation_multiplier=1.0
#| code-fold: true
#| eval: false
    plt.figure(constrained_layout=True)
    plt.subplot(121)
    plt.title("DRR: " + str(pose.get_rotation("euler_angles", "XYZ")))
    plt.imshow(pred_xray.squeeze().cpu().numpy(), cmap="gray")

    plt.show()

In [22]:
def main(
    run_id,
    id_number,
    device,
    height=256,
    volume_type="femur",
    restart=None,
    skip_training = False,
    skip_finetuning = False,
    model_name="resnet18",
    parameterization="se3_log_map",
    convention=None,
    lr=1e-3,
    batch_size=3,
    n_epochs=5,
    n_batches_per_epoch=10,
    pretrained = True
):
    id_number = int(id_number)

    # load patient CT volume
    specimen, isocenter_pose, transforms, drr = load(id_number, height, device, volume_type=volume_type)

    # setup the model
    model_params = {
        "model_name": model_name,
        "parameterization": parameterization,
        "convention": convention,
        "norm_layer": "groupnorm",
        "pretrained": pretrained
    }
    model = PoseRegressor(**model_params)
    # apply_xavier_weights(model)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    if restart is not None:
        print("Loading from checkpoint = " + str(restart))
        ckpt = torch.load(restart)
        model.load_state_dict(ckpt["model_state_dict"])
        model = model.to(device)
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])

    scheduler = WarmupCosineSchedule(
        optimizer,
        5 * n_batches_per_epoch,
        n_epochs * n_batches_per_epoch - 5 * n_batches_per_epoch,
    )
    # scheduler = None

    Path("checkpoints").mkdir(exist_ok=True)
    if not skip_training:
        synthetic_train(run_id=run_id,
                        id_number=id_number,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        drr=drr,
                        transforms=transforms,
                        isocenter_pose=isocenter_pose,
                        device=device,
                        batch_size=batch_size,
                        n_epochs=n_epochs,
                        n_batches_per_epoch=n_batches_per_epoch,
                        model_params=model_params,
                        volume_type=volume_type)

    if not skip_finetuning:
        finetune_train(
            run_id,
            id_number,
            model,
            optimizer,
            scheduler,
            drr,
            transforms,
            specimen,
            isocenter_pose,
            device,
            batch_size,
            n_epochs,
            n_batches_per_epoch,
            model_params,
            volume_type,
        )
        
        

    return model, specimen, drr, isocenter_pose



In [125]:

from random import randint, randrange
import torchvision

device = 'cuda'
# best_model_path = r"C:\Users\Public Admin\Desktop\Gitlab\kneedeeppose\notebooks\api\checkpoints\specimen_01_best.ckpt"

run_id = randint(100, 9999) 

lr=1e-3
vol_type = "femur"
n_batches_per_epoch = 10
specimen_id = 10
batch_size = 8
ckpt_train = r"checkpoints\\4405_specimen_10_femur_epoch500.ckpt"
epochs = 5
run = wandb.init(
    # Set the project where this run will be logged
    project="KneeDeepPose",
    # Track hyperparameters and run metadata
    config={
        "vol_type" : vol_type,
        "specimen" : specimen_id,
        "batch_size" : batch_size,
        "ckpt":ckpt_train,
        "n_batches_per_epoch":n_batches_per_epoch,
        "learning_rate": lr,
        "optimizer" : "Adam",
        "PoseRegressor":"default",
        "description": "segnet "
    },
    name = "segnet_" + str(vol_type)+"_" +str(specimen_id) + "_" + str(run_id)
)


main(run_id=run_id,
     id_number=specimen_id,
     device=device, height=256,  
     volume_type=vol_type, 
     restart=ckpt_train, 
     skip_training=True, 
     skip_finetuning=False,
     model_name="resnet18",
     parameterization="se3_log_map", 
     convention=None, 
     lr=lr, 
     batch_size=batch_size, 
     n_epochs=epochs,  
     n_batches_per_epoch=n_batches_per_epoch,
     pretrained=False)


torch.cuda.empty_cache()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

c:\Users\Public Admin\Desktop\Gitlab\kneedeeppose\data\sipla_bone_local_preprocessed.csv
Segnet_Input ['kneefit_femur_18846_syn.png', 'kneefit_femur_18847_syn.png', 'kneefit_femur_18848_syn.png', 'kneefit_femur_18849_syn.png', 'kneefit_femur_1884_syn.png', 'kneefit_femur_18850_syn.png', 'kneefit_femur_18851_syn.png', 'kneefit_femur_18852_syn.png', 'kneefit_femur_18853_syn.png', 'kneefit_femur_18854_syn.png', 'kneefit_femur_18855_syn.png', 'kneefit_femur_18856_syn.png', 'kneefit_femur_18857_syn.png', 'kneefit_femur_18858_syn.png', 'kneefit_femur_18859_syn.png', 'kneefit_femur_1885_syn.png', 'kneefit_femur_18860_syn.png', 'kneefit_femur_18861_syn.png', 'kneefit_femur_18862_syn.png', 'kneefit_femur_18863_syn.png', 'kneefit_femur_18864_syn.png', 'kneefit_femur_18865_syn.png', 'kneefit_femur_18866_syn.png', 'kneefit_femur_18867_syn.png', 'kneefit_femur_18868_syn.png', 'kneefit_femur_18869_syn.png', 'kneefit_femur_1886_syn.png', 'kneefit_femur_18870_syn.png', 'kneefit_femur_18871_syn.png', '

Using cache found in C:\Users\Public Admin/.cache\torch\hub\pytorch_vision_v0.10.0


model loaded
1173
number of xrays is {num_xrays} for patient {id_number}
segnet_ip kneefit_femur_18846_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femur_18846_syn.png
segnet_ip kneefit_femur_18847_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femur_18847_syn.png
segnet_ip kneefit_femur_18848_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femur_18848_syn.png
segnet_ip kneefit_femur_18849_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femur_18849_syn.png
segnet_ip kneefit_femur_1884_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femur_1884_syn.png
segnet_ip kneefit_femur_18850_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femur_18850_syn.png
segnet_ip kneefit_femur_18851_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femur_18851_syn.png
segnet_ip kneefit_femur_18852_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femur_18852_syn.png
segnet_ip kneefit_femur_18853_syn.png
segnet_path D:\kneefit_femur_synthetic\kneefit_femu

                                     

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 1, 1, 1, 1000, 1000]

Error in callback <function _draw_all_if_interactive at 0x0000020FAC302980> (for post_execute), with arguments args (),kwargs {}:


KeyboardInterrupt: 