In [1]:
import torch.nn as nn
from torch.utils.data import Dataset

# Yolo 3D

In [2]:
class BasicBlock3D(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

# Define the 3D ResNet model
class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=1, num_anchors=1):
        super(ResNet3D, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
   #     print("anchors",num_anchors,"classes",num_classes)
        self.conv = nn.Conv3d(512 * block.expansion, num_anchors * (7 + num_classes), kernel_size=1,stride=1,padding=2)
        self.detect = Detect(num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.conv(x)
  #      print(x.shape,"final before detect")
        detections, predictions = self.detect(x)
        return detections, predictions

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchio as tio
import torchvision.transforms as transforms
import torch.utils.data as data
# Define a 3D BasicBlock for ResNet


# Create a Random 3D Dataset

import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image

import torch
from torch.utils.data import Dataset
import numpy as np



# Example usage


# Iterate through the dataset
 #   print("Sample image shape (tensor):", img.shape)
 #   print("Sample target (x, y, z, w, h, d):", target)
#print("box1",np.shape(box1))
#print("boxes",np.shape(boxes))

class Detect(nn.Module):
    def __init__(self, num_classes, conf_threshold=0.5, nms_threshold=0.4):
        super(Detect, self).__init__()
        self.num_classes = num_classes
        self.conf_threshold = conf_threshold
        self.nms_threshold = nms_threshold

    def forward(self, predictions):
    #    print("predicitions in detect",predictions.shape)
        batch_size = predictions.size(0)
        num_channels = predictions.size(1)
        grid_size = predictions.size(2)
        
        # Calculate expected number of channels
        num_anchors = 1
        num_params_per_anchor = 7 + self.num_classes
        expected_channels = batch_size * num_anchors * num_params_per_anchor * grid_size * grid_size * grid_size
        
        # Reshape predictions correctly
        predictions = predictions.view(batch_size, num_anchors, num_params_per_anchor, grid_size, grid_size, grid_size)

        bbox_pred = predictions[:, :, :6, :, :, :]  # Adjust for 3D: 6 parameters for bbox
        obj_confidence = torch.sigmoid(predictions[:, :, 6, :, :, :])
        class_scores = torch.sigmoid(predictions[:, :, 7:, :, :, :])

        grid_x = torch.arange(grid_size, dtype=torch.float, device=predictions.device).repeat(grid_size, grid_size, 1)
        grid_y = torch.arange(grid_size, dtype=torch.float, device=predictions.device).repeat(grid_size, grid_size, 1).transpose(1, 2)
        grid_z = torch.arange(grid_size, dtype=torch.float, device=predictions.device).repeat(grid_size, grid_size, 1).transpose(0, 1)

        scaled_anchors = torch.zeros_like(bbox_pred)

        anchor_w = [12.0] * num_anchors
        anchor_h = [12.0] * num_anchors
        anchor_d = [12.0] * num_anchors

        for i in range(num_anchors):
            scaled_anchors[:, i, 0, :, :, :] = (torch.sigmoid(bbox_pred[:, i, 0, :, :, :]) + grid_x) / grid_size
            scaled_anchors[:, i, 1, :, :, :] = (torch.sigmoid(bbox_pred[:, i, 1, :, :, :]) + grid_y) / grid_size
            scaled_anchors[:, i, 2, :, :, :] = (torch.sigmoid(bbox_pred[:, i, 2, :, :, :]) + grid_z) / grid_size
            scaled_anchors[:, i, 3, :, :, :] = torch.exp(bbox_pred[:, i, 3, :, :, :]) * anchor_w[i] / grid_size
            scaled_anchors[:, i, 4, :, :, :] = torch.exp(bbox_pred[:, i, 4, :, :, :]) * anchor_h[i] / grid_size
            scaled_anchors[:, i, 5, :, :, :] = torch.exp(bbox_pred[:, i, 5, :, :, :]) * anchor_d[i] / grid_size

        x_min = scaled_anchors[:, :, 0, :, :, :] - scaled_anchors[:, :, 3, :, :, :] / 2
        y_min = scaled_anchors[:, :, 1, :, :, :] - scaled_anchors[:, :, 4, :, :, :] / 2
        z_min = scaled_anchors[:, :, 2, :, :, :] - scaled_anchors[:, :, 5, :, :, :] / 2
        x_max = scaled_anchors[:, :, 0, :, :, :] + scaled_anchors[:, :, 3, :, :, :] / 2
        y_max = scaled_anchors[:, :, 1, :, :, :] + scaled_anchors[:, :, 4, :, :, :] / 2
        z_max = scaled_anchors[:, :, 2, :, :, :] + scaled_anchors[:, :, 5, :, :, :] / 2

        detections = []
        for batch_idx in range(batch_size):
            batch_detections = []
            for class_idx in range(self.num_classes):
                class_scores_batch = class_scores[batch_idx, :, class_idx, :, :, :]
                conf_scores = obj_confidence[batch_idx, :, :, :, :]
                scores = conf_scores * class_scores_batch
                mask = scores >= self.conf_threshold
                scores = scores[mask]
                if scores.size(0) == 0:
                    continue
                x_min_class = x_min[batch_idx, :, :, :, :][mask]
                y_min_class = y_min[batch_idx, :, :, :, :][mask]
                z_min_class = z_min[batch_idx, :, :, :, :][mask]
                x_max_class = x_max[batch_idx, :, :, :, :][mask]
                y_max_class = y_max[batch_idx, :, :, :, :][mask]
                z_max_class = z_max[batch_idx, :, :, :, :][mask]
                boxes = torch.stack((x_min_class, y_min_class, z_min_class, x_max_class, y_max_class, z_max_class), dim=-1)

                keep = custom_3d_nms(boxes, scores, self.nms_threshold)

                boxes = boxes.view(-1, 6)[keep]
                scores = scores[keep]

                batch_detections.extend(
                    [{"boxes": boxes, "scores": scores, "labels": torch.tensor([class_idx]*len(boxes), device=predictions.device)}]
                )
            detections.append(batch_detections)

        return detections, predictions
def custom_3d_nms(boxes, scores, threshold):
    if boxes.size(0) == 0:
        return torch.empty(0, dtype=torch.int64, device=boxes.device)

    # Sort scores in descending order
    _, idxs = scores.sort(descending=True)
    boxes = boxes[idxs]
    scores = scores[idxs]

    keep = []
    while boxes.size(0) > 0:
        pick = boxes.new_tensor([0], dtype=torch.long)
        keep.append(idxs[pick].item())

        iou = compute_3d_iou(boxes[pick].unsqueeze(0), boxes)
        mask = (iou <= threshold).squeeze(0)

        boxes = boxes[mask]
        scores = scores[mask]
        idxs = idxs[mask]

    return torch.tensor(keep, dtype=torch.int64, device=boxes.device)


def compute_3d_iou(box1, boxes):
    if box1.dim() == 1:
        box1 = box1.unsqueeze(0)
    elif box1.dim() == 3:
        box1 = box1.squeeze(0)

    inter_xmin = torch.max(box1[:, 0].unsqueeze(1), boxes[:, 0])
    inter_ymin = torch.max(box1[:, 1].unsqueeze(1), boxes[:, 1])
    inter_zmin = torch.max(box1[:, 2].unsqueeze(1), boxes[:, 2])
    inter_xmax = torch.min(box1[:, 3].unsqueeze(1), boxes[:, 3])
    inter_ymax = torch.min(box1[:, 4].unsqueeze(1), boxes[:, 4])
    inter_zmax = torch.min(box1[:, 5].unsqueeze(1), boxes[:, 5])

    inter_volume = torch.clamp(inter_xmax - inter_xmin, min=0) * \
                   torch.clamp(inter_ymax - inter_ymin, min=0) * \
                   torch.clamp(inter_zmax - inter_zmin, min=0)

    vol1 = (box1[:, 3] - box1[:, 0]) * (box1[:, 4] - box1[:, 1]) * (box1[:, 5] - box1[:, 2])
    vol2 = (boxes[:, 3] - boxes[:, 0]) * (boxes[:, 4] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 2])

    union_volume = vol1.unsqueeze(1) + vol2 - inter_volume

    iou = inter_volume / union_volume.clamp(min=1e-6)

    return iou
# YOLO Loss Function
import torch
import torch.nn as nn


class YoloLoss(nn.Module):
    def __init__(self, num_classes):
        super(YoloLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCELoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.num_classes = num_classes

    def forward(self, predictions, targets):
   #     print("predictions",predictions.shape)
        batch_size = targets.size(0)
        num_anchors = predictions.size(1)
        grid_size_x = predictions.size(3)
        grid_size_y = predictions.size(4)
        grid_size_z = predictions.size(5)
#        print("predictions",predictions.shape)
        # Reshape predictions correctly
        predictions = predictions.view(batch_size, num_anchors, 7 + self.num_classes, grid_size_x, grid_size_y, grid_size_z)
 #       print("predictions",predictions.shape)
        # Ensure targets have the correct shape
        targets = targets.view(batch_size, num_anchors, 7 + self.num_classes, grid_size_x, grid_size_y, grid_size_z)
   #     print("targets",targets.shape)

        # Ensure predictions and targets have valid values
        pred_boxes = predictions[:, :, :6, ...]  # Extract predicted boxes (x, y, z, w, h, d)
        true_boxes = targets[:, :, :6, ...]      # Extract true boxes (x, y, z, w, h, d)

        box_loss = self.mse_loss(pred_boxes, true_boxes)
        pred_obj = torch.sigmoid(predictions[:, :, 6, ...])  # Predicted objectness score
        true_obj = targets[:, :, 6, ...]  # True objectness score
        obj_loss = self.bce_loss(pred_obj, true_obj)
        pred_cls = predictions[:, :, 7:, ...]                # Predicted class probabilities
        true_cls = targets[:, :, 7:, ...]                    # True class labels
  #      print(true_cls.shape,"true_cls")
   #     print(pred_cls.shape,"pred_cls")
        class_loss = self.ce_loss(pred_cls.reshape(-1, self.num_classes), true_cls.reshape(-1, self.num_classes).argmax(dim=-1))
        total_loss = box_loss + obj_loss + class_loss
        return total_loss

# Example usage




In [6]:
class Yolo3DCustomDataset(Dataset):
    def __init__(self, num_samples, num_classes = 1, img_size=146, grid_size=7, num_anchors=1):
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.grid_size = grid_size
        self.num_anchors = num_anchors
        self.img_size = img_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random 3D image data (numpy array: depth x height x width x channels)
        img = np.random.randint(0, 256, (self.img_size, self.img_size, self.img_size, 1), dtype=np.uint8)
        img = torch.tensor(img, dtype=torch.float32).permute(3, 0, 1, 2) / 255.0  # Normalize to [0, 1]
        # Generate random bounding boxes and class labels
        targets = torch.zeros((self.num_anchors, 7 + self.num_classes, self.grid_size, self.grid_size, self.grid_size))
        for anchor in range(self.num_anchors):
            x = np.random.uniform(0, self.img_size)
            y = np.random.uniform(0, self.img_size)
            z = np.random.uniform(0, self.img_size)
            w = np.random.uniform(10, 50)  # Random width
            h = np.random.uniform(10, 50)  # Random height
            d = np.random.uniform(10, 50)  # Random depth
            obj_score = 1  # Assuming there is an object
            class_label = np.random.randint(0, self.num_classes)
            class_one_hot = np.zeros(self.num_classes)
            class_one_hot[class_label] = 1
            # Determine which grid cell the object center falls into
            grid_x = int(x // (self.img_size / self.grid_size))
            grid_y = int(y // (self.img_size / self.grid_size))
            grid_z = int(z // (self.img_size / self.grid_size))
            # Normalize the coordinates and dimensions
            x = (x % (self.img_size / self.grid_size)) / (self.img_size / self.grid_size)
            y = (y % (self.img_size / self.grid_size)) / (self.img_size / self.grid_size)
            z = (z % (self.img_size / self.grid_size)) / (self.img_size / self.grid_size)
            w /= self.img_size
            h /= self.img_size
            d /= self.img_size
            # Fill in the target tensor
            targets[anchor, :6, grid_x, grid_y, grid_z] = torch.tensor([x, y, z, w, h, d])
            targets[anchor, 6, grid_x, grid_y, grid_z] = obj_score
            targets[anchor, 7:, grid_x, grid_y, grid_z] = torch.tensor(class_one_hot)
        return img, targets
num_classes = 1
dataset = Yolo3DCustomDataset(num_samples=4, num_classes=num_classes, img_size=146, grid_size=9, num_anchors=1)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
# Iterate through the dataset
for imgs, targets in dataloader:
    print("Images:", imgs.shape)       # Should be [batch_size, 1, img_size, img_size, img_size]
    print("Targets:", targets.shape)   # Should be [batch_size, num_anchors, 7 + num_classes, grid_size, grid_size, grid_size]
    break
num_classes = 1
model = ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=num_classes)
# Instantiate the loss function
criterion = YoloLoss(num_classes=num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training Loop
for epoch in range(2):
    model.train()
    total_batches = len(dataloader)
    for batch_idx, (imgs, targets) in enumerate(dataloader):
        print("IMAGE SHAPE",imgs.shape)
        print("Image tensor dtype:", imgs.dtype)
        print("TARGET",targets.shape)
        optimizer.zero_grad()
        _, predictions = model(imgs)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()
        
        # Print progress within epoch
        if batch_idx % 10 == 0:  # Adjust frequency of printing as needed
            print(f"Epoch [{epoch + 1}/{5}], Batch [{batch_idx}/{total_batches}], Loss: {loss.item():.4f}")

    print(f"Epoch [{epoch + 1}/{5}], Loss: {loss.item():.4f}")

Images: torch.Size([2, 1, 146, 146, 146])
Targets: torch.Size([2, 1, 8, 9, 9, 9])
IMAGE SHAPE torch.Size([2, 1, 146, 146, 146])
Image tensor dtype: torch.float32
TARGET torch.Size([2, 1, 8, 9, 9, 9])
Epoch [1/5], Batch [0/2], Loss: 0.7369
IMAGE SHAPE torch.Size([2, 1, 146, 146, 146])
Image tensor dtype: torch.float32
TARGET torch.Size([2, 1, 8, 9, 9, 9])
Epoch [1/5], Loss: 3.5355
IMAGE SHAPE torch.Size([2, 1, 146, 146, 146])
Image tensor dtype: torch.float32
TARGET torch.Size([2, 1, 8, 9, 9, 9])
Epoch [2/5], Batch [0/2], Loss: 1.3154
IMAGE SHAPE torch.Size([2, 1, 146, 146, 146])
Image tensor dtype: torch.float32
TARGET torch.Size([2, 1, 8, 9, 9, 9])
Epoch [2/5], Loss: 0.8432
IMAGE SHAPE torch.Size([2, 1, 146, 146, 146])
Image tensor dtype: torch.float32
TARGET torch.Size([2, 1, 8, 9, 9, 9])
Epoch [3/5], Batch [0/2], Loss: 0.6268
IMAGE SHAPE torch.Size([2, 1, 146, 146, 146])
Image tensor dtype: torch.float32
TARGET torch.Size([2, 1, 8, 9, 9, 9])
Epoch [3/5], Loss: 0.6007


In [7]:
def create_yolo_target_tensor_3d(bboxes_dict, img_width, img_height, img_depth, grid_size, num_classes):
    """
    Transforms 3D bounding boxes and class labels into a YOLO target tensor for non-cubic volumes,
    using the first bounding box as the anchor.

    Parameters:
    - bboxes_dict: Dictionary of bounding boxes, where key is the ID and value is a list of bounding boxes.
    - img_width: Width of the input image.
    - img_height: Height of the input image.
    - img_depth: Depth of the input image.
    - grid_size: Number of grid cells along each dimension (e.g., 13 or 19).
    - num_classes: Total number of classes.

    Returns:
    - target_tensor: YOLO target tensor of shape (num_images, num_anchors, 7 + num_classes, grid_size, grid_size, grid_size).
    """
    
    num_images = len(bboxes_dict)
    num_anchors = 1  # Single anchor for simplicity
    
    # Initialize the target tensor
    target_tensor = torch.zeros((num_images, num_anchors, 7 + num_classes, grid_size, grid_size, grid_size))
    
    # Compute the cell sizes
    cell_width = img_width / grid_size
    cell_height = img_height / grid_size
    cell_depth = img_depth / grid_size
    
    for img_idx, (img_id, bboxes) in enumerate(bboxes_dict.items()):
        aggregated_bboxes = []
        aggregated_classes = []
        
        # Flatten the list of lists to extract all bounding boxes
        for bbox_list in bboxes:
            for bbox in bbox_list:
                if isinstance(bbox, list) and len(bbox) == 6:
                    aggregated_bboxes.append(bbox)
                    # Extract the class from the image ID, assuming the class label is in the format '#class_label'
                    class_label = int(img_id.split('#')[1]) % num_classes
                    aggregated_classes.append(class_label)
        
        # Use the first bounding box as the anchor
        if aggregated_bboxes:
            anchors = [aggregated_bboxes[0][3:6]]
            print(f"Image {img_id} Anchor: {anchors}")  # Print the single anchor
        
        for bbox, class_label in zip(aggregated_bboxes, aggregated_classes):
            x_center, y_center, z_center, width, height, depth = bbox
            
            # Calculate the grid cell coordinates
            grid_x = int(x_center // cell_width)
            grid_y = int(y_center // cell_height)
            grid_z = int(z_center // cell_depth)
            
            # Calculate the relative coordinates within the grid cell
            x_center_rel = (x_center % cell_width) / cell_width
            y_center_rel = (y_center % cell_height) / cell_height
            z_center_rel = (z_center % cell_depth) / cell_depth
            width_rel = width / img_width
            height_rel = height / img_height
            depth_rel = depth / img_depth
            
            # Only use the anchor for the specific grid cell (9, 9, 9)
            if (grid_x, grid_y, grid_z) == (9, 9, 9):
                for anchor_idx, (anchor_width, anchor_height, anchor_depth) in enumerate(anchors):
                    # Normalize the anchor dimensions
                    anchor_width_rel = anchor_width / img_width
                    anchor_height_rel = anchor_height / img_height
                    anchor_depth_rel = anchor_depth / img_depth
                    
                    # Assign normalized bounding box values to the target tensor
                    target_tensor[img_idx, anchor_idx, 0, grid_x, grid_y, grid_z] = x_center_rel
                    target_tensor[img_idx, anchor_idx, 1, grid_x, grid_y, grid_z] = y_center_rel
                    target_tensor[img_idx, anchor_idx, 2, grid_x, grid_y, grid_z] = z_center_rel
                    target_tensor[img_idx, anchor_idx, 3, grid_x, grid_y, grid_z] = width_rel
                    target_tensor[img_idx, anchor_idx, 4, grid_x, grid_y, grid_z] = height_rel
                    target_tensor[img_idx, anchor_idx, 5, grid_x, grid_y, grid_z] = depth_rel
                    target_tensor[img_idx, anchor_idx, 6, grid_x, grid_y, grid_z] = 1  # Object confidence score
                    
                    one_hot_class = torch.zeros(num_classes)
                    one_hot_class[class_label] = 1
                    target_tensor[img_idx, anchor_idx, 7:, grid_x, grid_y, grid_z] = one_hot_class
    
    return target_tensor
img_width = 204
img_height = 146
img_depth = 156
import tensorflow as tf
grid_size = 9
num_classes = 1
y_train = create_yolo_target_tensor_3d(converted_bboxes_with_comments, img_width, img_height, img_depth, grid_size, num_classes)

import os
import nibabel as nib
import numpy as np


#################

# Define the folder containing the NIfTI images
nifti_directory = '/Users/lucabernecker/Desktop/N128_local/aneu_det'

# List all files in the folder
file_ids = [key.strip("#") for key in converted_bboxes_with_comments.keys()]

# Initialize a list to store the loaded nifti images as numpy arrays
nifti_images_list = []

# Load the nifti images in the specified order
for file_id in file_ids:
    file_name = f"cube_{file_id}.nii.gz"
    file_path = os.path.join(nifti_directory, file_name)
    
    # Load the nifti image
    nifti_image = nib.load(file_path)
    
    # Convert the nifti image to a numpy array and append to the list
    nifti_images_list.append(nifti_image.get_fdata())

# Stack all numpy arrays into a single numpy array
x_train = np.stack(nifti_images_list, axis=0)

###############

# Convert the list of images to a NumPy array
print("X_train initial",np.shape(x_train))
x_train = torch.tensor(x_train)
x_train = torch.unsqueeze(x_train, 1)
print("X_Train",np.shape(x_train))
num_classes = 1
model = ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes = num_classes)
#dataset = Yolo3DCustomDataset(num_samples=4, num_classes=num_classes,grid_size =8)
#dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
class CustomDataset(Dataset):
    def __init__(self, x_train, y_train, new_shape=(146, 146, 146)):
        """
        Custom dataset for YOLO-style object detection in 3D volumes.

        Parameters:
        - x_train (Tensor): Input images tensor of shape (num_images, channels, depth, height, width).
        - y_train (Tensor): YOLO target tensor of shape (num_images, num_anchors, 7 + num_classes, grid_size, grid_size, grid_size).
        - new_shape (tuple): Desired shape for resizing images (depth, height, width).
        """
        self.x_train = x_train
        self.y_train = y_train
        self.new_shape = new_shape

        assert len(self.x_train) == len(self.y_train), \
            "Number of images in x_train must match number of target tensors in y_train."

    def __len__(self):
        return len(self.x_train)

    def __getitem__(self, idx):
        x = self.x_train[idx]
        y = self.y_train[idx]

        # Reshape the image
        x_resized = self.reshape_image(x, self.new_shape)
        
        # Rescale the bounding boxes in the targets
        y_rescaled = self.rescale_bounding_boxes(y, x.shape[1:], self.new_shape)

        return x_resized, y_rescaled

    def reshape_image(self, image, new_shape):
        """
        Reshape image tensor to new dimensions using trilinear interpolation.

        Parameters:
        - image (Tensor): Input image tensor of shape (channels, depth, height, width).
        - new_shape (tuple): Desired shape for resizing (depth, height, width).

        Returns:
        - resized_image (Tensor): Resized image tensor of shape (channels, new_depth, new_height, new_width).
        """
        resized_image = torch.nn.functional.interpolate(image.unsqueeze(0), size=new_shape, mode='trilinear', align_corners=False).squeeze(0)
        return resized_image

    def rescale_bounding_boxes(self, targets, original_shape, new_shape):
        """
        Rescale bounding boxes in the targets according to the new image dimensions.

        Parameters:
        - targets (Tensor): YOLO target tensor of shape (num_anchors, 7 + num_classes, grid_size, grid_size, grid_size).
        - original_shape (tuple): Original shape of the image (depth, height, width).
        - new_shape (tuple): New shape of the image (depth, height, width).

        Returns:
        - rescaled_targets (Tensor): Target tensor with rescaled bounding boxes.
        """
        # Calculate scaling factors for depth, height, and width
        scale_z = new_shape[0] / original_shape[0]
        scale_y = new_shape[1] / original_shape[1]
        scale_x = new_shape[2] / original_shape[2]

        # Rescale bounding boxes
        targets_rescaled = targets.clone()
        for anchor in targets_rescaled:
            for bbox in anchor:
                # Adjust bounding box center coordinates (x, y, z) and dimensions (w, h, d)
                bbox[0] *= scale_x
                bbox[1] *= scale_y
                bbox[2] *= scale_z
                bbox[3] *= scale_x
                bbox[4] *= scale_y
                bbox[5] *= scale_z

        return targets_rescaled
    
dataset = CustomDataset(x_train,y_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
for imgs, targets in dataloader:
    print("Images:", imgs.shape)       # Should be [batch_size, 1, img_size, img_size, img_size]
    print("Targets:", targets.shape)   # Should be [batch_size, num_anchors, 7 + num_classes, grid_size, grid_size, grid_size]
    break
# Instantiate the loss function
criterion = YoloLoss(num_classes=num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.005)

# Training Loop
for epoch in range(500):
    model.train()
    total_batches = len(dataloader)
    for batch_idx, (imgs, targets) in enumerate(dataloader):
      #  imgs = imgs.double()
        imgs = imgs.float()
      #  print("Image tensor dtype:", imgs.dtype)
      #  print("IMGS",imgs.shape)
      #  print("TARGETS",targets.shape)
        optimizer.zero_grad()
        _, predictions = model(imgs)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()
        
        # Print progress within epoch
        if batch_idx % 10 == 0:  # Adjust frequency of printing as needed
            print(f"Epoch [{epoch + 1}/{5}], Batch [{batch_idx}/{total_batches}], Loss: {loss.item():.4f}")

    print(f"Epoch [{epoch + 1}/{500}], Loss: {loss.item():.4f}")

KeyboardInterrupt: 

In [None]:
def evaluate_model(model, image, device='cpu'):
    model.eval()
    with torch.no_grad():
        image = image.to(device)
        _, prediction = model(image)
        return prediction

file_name = '/Users/lucabernecker/Desktop/N128_local/healthy_train/cube_100100.nii.gz'
# List all files in the folder
file_ids = [key.strip("#") for key in converted_bboxes_with_comments.keys()]
# Initialize a list to store the loaded nifti images as numpy arrays
nifti_image = nib.load(file_name).get_fdata()
image_tensor = torch.tensor(nifti_image)

# Ensure the tensor is of type float32
image_tensor = image_tensor.float()
    # Convert the nifti image to a numpy array and append to the list
image_tensor = image_tensor.unsqueeze(0)

print(image_tensor.shape,"image tensor")
resized_image = torch.nn.functional.interpolate(image_tensor.unsqueeze(0), size=(146,146,146), mode='trilinear', align_corners=False).squeeze(0)
# Evaluate the model on the new image
resized_image = resized_image.unsqueeze(0)
print(resized_image.shape,"resized")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
#new_image = new_image.to(device)

prediction = evaluate_model(model, resized_image)
print("Prediction on new image:", prediction)

torch.Size([1, 204, 146, 156]) image tensor
torch.Size([1, 1, 146, 146, 146]) resized


In [2]:
def convert_bboxes(bboxes):
    # Check if bboxes is a list of bounding boxes
    if isinstance(bboxes[0][0], list):
        # Convert each bounding box in the list
        return [convert_bboxes(bbox) for bbox in bboxes]
    else:
        # Convert a single bounding box
        x_min, x_max = bboxes[0]
        y_min, y_max = bboxes[1]
        z_min, z_max = bboxes[2]

        x_center = (x_min + x_max) / 2
        y_center = (y_min + y_max) / 2
        z_center = (z_min + z_max) / 2

        width = x_max - x_min
        height = y_max - y_min
        depth = z_max - z_min

        return [x_center, y_center, z_center, width, height, depth]
# Example usage
bounding_boxes_with_comments = {
    "#100140": [[[97, 106], [77, 93], [89, 105]]],
    "#100151": [[[105, 135], [83, 95], [78, 97]]],
    "#100152": [[[101, 112], [73, 86], [120, 134]]],
    "#100171": [[[51, 77], [65, 78], [98, 116]]],
    "#100205": [[[127, 148], [58, 82], [92, 121]]],
    "#100218": [[[109, 133], [56, 68], [105, 119]]],
    "#100231": [[[64, 82], [90, 103], [76, 90]]],
    "#100234": [[[71, 94], [80, 92], [81, 92]]],
    "#100242": [[[40, 61], [77, 98], [103, 125]]],
    "#100277": [[[113, 132], [77, 97], [123, 138]]],
    "#100305": [[[72, 90], [79, 93], [80, 96]]],
    "#100316": [[[176, 189], [76, 89], [105, 117]]],
    "#100324": [[[72, 91], [82, 104], [71, 99]]],
    "#100326": [[[94, 105], [76, 95], [101, 118]]],
    "#100394": [[[162, 173], [87, 99], [97, 109]]],
   # "#100248": [[[115, 129], [76, 94], [110, 124]]],
    "#100428": [[[115,126],[78,88],[110,123]]],
    "#100455": [[[46, 59], [75, 89], [98, 117]]],
    "#100460": [[[178, 194], [89, 101], [103, 119]]],
    "#100517": [[[142, 174], [77, 95], [92, 112]]],
    "#100536": [[[28, 37], [67, 94], [88, 103]]],
    "#100556": [[[159, 173], [83, 92], [94, 112]]],
    "#100601": [[[136, 158], [58, 71], [136, 149]]],
    "#100643": [[[9, 25], [72, 86], [95, 109]]],
    "#100649": [[[91, 104], [77, 91], [94, 111]]],
    "#100660": [[[67, 83], [77, 92], [81, 94]]],
    "#100673": [[[121, 148], [67, 86], [88, 118]]],
    "#100680": [[[68, 92], [70, 90], [78, 102]]],
    "#100700": [[[85, 112], [28, 56], [98, 123]]],
    "#100723": [[[70, 97], [81, 103], [66, 99]]],
    "#100739": [[[10, 24], [82, 97], [101, 115]]],
    "#100818": [[[144, 167], [73, 97], [106, 123]]],
    "#100873": [[[104, 117], [75, 95], [87, 103]]],
    "#100880": [[[112, 131], [58, 74], [94, 106]]],
    "#100915": [[[105, 119], [80, 96], [70, 86]]],
    "#100988": [[[30, 49], [87, 104], [84, 101]]],
    "#101026": [[[47, 60], [63, 81], [103, 123]]],
    "#101043": [[[157, 164], [81, 91], [104, 122]]],
    "#101054": [[[23, 34], [105, 117], [91, 104]]],
    "#101085": [[[65, 84], [88, 105], [85, 103]]],
    "#101118": [[[109, 131], [72, 84], [85, 102]]],
    "#101122": [[[122, 134], [59, 74], [98, 111]]],
    "#101140": [[[119, 131], [63, 77], [91, 104]]],
    "#101148": [[[102, 111], [48, 63], [95, 112]]],
    "#101157": [[[37, 56], [82, 91], [82, 96]]],
    "#101158": [[[161, 178], [72, 84], [105, 119]]],
    "#101194": [[[174, 194], [72, 85], [100, 116]]],
    "#101220": [[[124, 136], [61, 75], [99, 114]]],
    "#101234": [[[108, 133], [76, 96], [84, 98]]],
    "#101263": [[[113, 130], [80, 94], [80, 112]]],
    "#101272": [[[12, 27], [78, 86], [95, 105]]],
 #   "#101288": [[[35, 48], [86, 103], [90, 104]],[[142,154], [66,78], [100,113]]],
    "#101298": [[[149, 161], [89, 103], [109, 123]]],
    "#101302": [[[36, 47], [64, 79], [96, 113]]],
    "#101309": [[[91, 105], [72, 84], [107, 117]]],
    "#101331": [[[93, 102], [76, 89], [99, 114]]],
    "#101365": [[[75, 92], [74, 90], [114, 130]]],
    "#101393": [[[176, 189], [78, 104], [96, 116]]],
 #   "#101426": [[[12, 23], [65, 90], [88, 98]],[[106,115],[45,57],[104,115]]],  
    "#101493": [[[91, 115], [31, 51], [100, 120]]],
    "#101495": [[[146, 156], [65, 84], [109, 122]]],
    "#101508": [[[34, 47], [82, 101], [98, 111]]],
    "#101516": [[[7, 17], [91, 109], [124, 136]]],
    "#101562": [[[15, 32], [80, 97], [102, 116]]],
    "#101601": [[[164, 178], [84, 99], [92, 109]]],
    "#101675": [[[71, 90], [77, 91], [78, 95]]],
    "#101735": [[[33, 51], [68, 80],[116,126]]],
    "#101782": [[[115, 132], [76, 91], [84, 96]]],
    "#101842": [[[32, 52], [65, 88], [83, 103]]],
    "#101885": [[[94, 107], [73, 97], [107, 130]]],
    "#101892": [[[137, 153], [69, 81], [106, 119]]],
    "#101990": [[[101, 136], [75, 90], [72, 91]]],
    "#101991": [[[98, 114], [83, 98], [105, 121]]],
    "#101995": [[[114, 128], [64, 80], [95, 106]]],
    "#101998": [[[107, 133], [75, 92], [95, 111]]],
    "#102024": [[[66, 85], [58, 71], [93, 114]]],
    "#102084": [[[38, 58], [76, 101], [92, 104]]],
    "#102141": [[[96, 108], [79, 94], [110, 123]]],
    "#102158": [[[149, 166], [79, 95], [99, 117]]],
    "#102200": [[[101, 131], [82, 102], [72, 105]]],
    "#102201": [[[106, 129], [77, 93], [90, 102]]],
    "#102202": [[[95, 110], [71, 89], [103, 113]]],
    "#102228": [[[149, 163], [82, 91], [99, 112]]],
    "#102283": [[[169, 182], [76, 94], [90, 106]]],
    "#102333": [[[104, 124], [82, 102], [90, 110]]],
    "#102348": [[[71, 85], [74, 93], [90, 103]]],
    "#102391": [[[118, 139], [66, 83], [100, 116]]],
    "#102418": [[[117, 136], [66, 91], [86, 103]]],
    "#102427": [[[68, 89], [87, 100], [81, 98]]],
    "#102442": [[[67, 83], [67, 80], [92, 108]]],
    "#102455": [[[144, 160], [71, 83], [84, 101]]],
    "#102496": [[[160, 181], [74, 90], [100, 116]]],
    "#102564": [[[116, 139], [70, 91], [106, 121]]],
#    "#102590": [[[150, 164], [81, 90], [93, 109]],[[114,133],[78,92],[79,88]]],  # and [[118, 136], [82, 93], [77, 91]]
#    "#102595": [[[112, 127], [70, 91], [83, 93]],[[113, 131], [75, 86], [92, 105]]],  # and [[113, 131], [75, 86], [92, 105]]
    "#102601": [[[88, 104], [81, 96], [93, 104]]],
    "#102609": [[[20, 42], [80, 98], [94, 115]]],
    "#102638": [[[40, 60], [70, 86], [112, 130]]],
    "#102693": [[[209, 125], [82, 98], [75, 92]]],
    "#102720": [[[22, 36], [86, 100], [92, 103]]],
    "#102761": [[[68, 85], [80, 94], [80, 98]]],
    "#102764": [[[72, 91], [72, 92], [86, 97]]],
    "#102797": [[[73, 86], [86, 102], [84, 98]]],
    "#102835": [[[157, 178], [90, 104], [88, 101]]],
    "#102840": [[[111, 129], [82, 94], [78, 94]]],
  #  "#102864": [[[112, 132], [79, 100], [79, 93]],[[68, 88], [83, 97], [82, 95]],[[184, 201], [32, 48], [131, 143]]],  # and [[68, 88], [83, 97], [82, 95]], [[184, 201], [32, 48], [131, 143]]
    "#102886": [[[41, 49], [71, 94], [97, 107]]],
    "#102935": [[[143, 161], [71, 91], [104, 120]]],
    "#102959": [[[66, 85], [78, 93], [90, 99]]],
    "#102974": [[[110, 127], [76, 95], [71, 90]]]
}


converted_bboxes_with_comments = {}

for key, bboxes in bounding_boxes_with_comments.items():
    # Convert each list of bounding boxes or single bounding box
    converted_bboxes_with_comments[key] = convert_bboxes(bboxes)

print(converted_bboxes_with_comments)





{'#100140': [[101.5, 85.0, 97.0, 9, 16, 16]], '#100151': [[120.0, 89.0, 87.5, 30, 12, 19]], '#100152': [[106.5, 79.5, 127.0, 11, 13, 14]], '#100171': [[64.0, 71.5, 107.0, 26, 13, 18]], '#100205': [[137.5, 70.0, 106.5, 21, 24, 29]], '#100218': [[121.0, 62.0, 112.0, 24, 12, 14]], '#100231': [[73.0, 96.5, 83.0, 18, 13, 14]], '#100234': [[82.5, 86.0, 86.5, 23, 12, 11]], '#100242': [[50.5, 87.5, 114.0, 21, 21, 22]], '#100277': [[122.5, 87.0, 130.5, 19, 20, 15]], '#100305': [[81.0, 86.0, 88.0, 18, 14, 16]], '#100316': [[182.5, 82.5, 111.0, 13, 13, 12]], '#100324': [[81.5, 93.0, 85.0, 19, 22, 28]], '#100326': [[99.5, 85.5, 109.5, 11, 19, 17]], '#100394': [[167.5, 93.0, 103.0, 11, 12, 12]], '#100428': [[120.5, 83.0, 116.5, 11, 10, 13]], '#100455': [[52.5, 82.0, 107.5, 13, 14, 19]], '#100460': [[186.0, 95.0, 111.0, 16, 12, 16]], '#100517': [[158.0, 86.0, 102.0, 32, 18, 20]], '#100536': [[32.5, 80.5, 95.5, 9, 27, 15]], '#100556': [[166.0, 87.5, 103.0, 14, 9, 18]], '#100601': [[147.0, 64.5, 142.5

# Anchor Free Object Detection

In [5]:
import torch
import torch.nn as nn

class Detect(nn.Module):
    def __init__(self, num_classes, conf_threshold=0.5, nms_threshold=0.4):
        super(Detect, self).__init__()
        self.num_classes = num_classes
        self.conf_threshold = conf_threshold
        self.nms_threshold = nms_threshold

    def forward(self, predictions):
        batch_size = predictions.size(0)
        grid_size = predictions.size(2)
        
        # Adjust number of channels for anchor-free
        num_params = 6 + 1 + self.num_classes  # 6 bbox, 1 obj confidence, num_classes
        predictions = predictions.view(batch_size, num_params, grid_size, grid_size, grid_size)

        # Predictions
        bbox_pred = predictions[:, :6, :, :, :]  # bbox predictions (x, y, z, w, h, d)
        obj_confidence = torch.sigmoid(predictions[:, 6, :, :, :])  # object confidence
        class_scores = torch.sigmoid(predictions[:, 7:, :, :, :])  # class scores

        # Adjust for center heatmap
        grid_x = torch.arange(grid_size, dtype=torch.float, device=predictions.device).repeat(grid_size, grid_size, 1)
        grid_y = torch.arange(grid_size, dtype=torch.float, device=predictions.device).repeat(grid_size, grid_size, 1).transpose(1, 2)
        grid_z = torch.arange(grid_size, dtype=torch.float, device=predictions.device).repeat(grid_size, grid_size, 1).transpose(0, 1)

        scaled_anchors = torch.zeros_like(bbox_pred)

        scaled_anchors[:, 0, :, :, :] = (torch.sigmoid(bbox_pred[:, 0, :, :, :]) + grid_x) / grid_size
        scaled_anchors[:, 1, :, :, :] = (torch.sigmoid(bbox_pred[:, 1, :, :, :]) + grid_y) / grid_size
        scaled_anchors[:, 2, :, :, :] = (torch.sigmoid(bbox_pred[:, 2, :, :, :]) + grid_z) / grid_size
        scaled_anchors[:, 3, :, :, :] = torch.exp(bbox_pred[:, 3, :, :, :]) / grid_size
        scaled_anchors[:, 4, :, :, :] = torch.exp(bbox_pred[:, 4, :, :, :]) / grid_size
        scaled_anchors[:, 5, :, :, :] = torch.exp(bbox_pred[:, 5, :, :, :]) / grid_size

        x_min = scaled_anchors[:, 0, :, :, :] - scaled_anchors[:, 3, :, :, :] / 2
        y_min = scaled_anchors[:, 1, :, :, :] - scaled_anchors[:, 4, :, :, :] / 2
        z_min = scaled_anchors[:, 2, :, :, :] - scaled_anchors[:, 5, :, :, :] / 2
        x_max = scaled_anchors[:, 0, :, :, :] + scaled_anchors[:, 3, :, :, :] / 2
        y_max = scaled_anchors[:, 1, :, :, :] + scaled_anchors[:, 4, :, :, :] / 2
        z_max = scaled_anchors[:, 2, :, :, :] + scaled_anchors[:, 5, :, :, :] / 2

        detections = []
        for batch_idx in range(batch_size):
            batch_detections = []
            for class_idx in range(self.num_classes):
                class_scores_batch = class_scores[batch_idx, class_idx, :, :, :]
                conf_scores = obj_confidence[batch_idx, :, :, :]
                scores = conf_scores * class_scores_batch
                mask = scores >= self.conf_threshold
                scores = scores[mask]
                if scores.size(0) == 0:
                    continue
                x_min_class = x_min[batch_idx, :, :, :][mask]
                y_min_class = y_min[batch_idx, :, :, :][mask]
                z_min_class = z_min[batch_idx, :, :, :][mask]
                x_max_class = x_max[batch_idx, :, :, :][mask]
                y_max_class = y_max[batch_idx, :, :, :][mask]
                z_max_class = z_max[batch_idx, :, :, :][mask]
                boxes = torch.stack((x_min_class, y_min_class, z_min_class, x_max_class, y_max_class, z_max_class), dim=-1)

                keep = custom_3d_nms(boxes, scores, self.nms_threshold)

                boxes = boxes.view(-1, 6)[keep]
                scores = scores[keep]

                batch_detections.extend(
                    [{"boxes": boxes, "scores": scores, "labels": torch.tensor([class_idx]*len(boxes), device=predictions.device)}]
                )
            detections.append(batch_detections)

        return detections, predictions

def custom_3d_nms(boxes, scores, threshold):
    if boxes.size(0) == 0:
        return torch.empty(0, dtype=torch.int64, device=boxes.device)

    _, idxs = scores.sort(descending=True)
    boxes = boxes[idxs]
    scores = scores[idxs]

    keep = []
    while boxes.size(0) > 0:
        pick = boxes.new_tensor([0], dtype=torch.long)
        keep.append(idxs[pick].item())

        iou = compute_3d_iou(boxes[pick].unsqueeze(0), boxes)
        mask = (iou <= threshold).squeeze(0)

        boxes = boxes[mask]
        scores = scores[mask]
        idxs = idxs[mask]

    return torch.tensor(keep, dtype=torch.int64, device=boxes.device)

def compute_3d_iou(box1, boxes):
    if box1.dim() == 1:
        box1 = box1.unsqueeze(0)
    elif box1.dim() == 3:
        box1 = box1.squeeze(0)

    inter_xmin = torch.max(box1[:, 0].unsqueeze(1), boxes[:, 0])
    inter_ymin = torch.max(box1[:, 1].unsqueeze(1), boxes[:, 1])
    inter_zmin = torch.max(box1[:, 2].unsqueeze(1), boxes[:, 2])
    inter_xmax = torch.min(box1[:, 3].unsqueeze(1), boxes[:, 3])
    inter_ymax = torch.min(box1[:, 4].unsqueeze(1), boxes[:, 4])
    inter_zmax = torch.min(box1[:, 5].unsqueeze(1), boxes[:, 5])

    inter_volume = torch.clamp(inter_xmax - inter_xmin, min=0) * \
                   torch.clamp(inter_ymax - inter_ymin, min=0) * \
                   torch.clamp(inter_zmax - inter_zmin, min=0)

    vol1 = (box1[:, 3] - box1[:, 0]) * (box1[:, 4] - box1[:, 1]) * (box1[:, 5] - box1[:, 2])
    vol2 = (boxes[:, 3] - boxes[:, 0]) * (boxes[:, 4] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 2])

    union_volume = vol1.unsqueeze(1) + vol2 - inter_volume

    iou = inter_volume / union_volume.clamp(min=1e-6)

    return iou

class YoloLoss(nn.Module):
    def __init__(self, num_classes):
        super(YoloLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCELoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.num_classes = num_classes

    def forward(self, predictions, targets):
        batch_size = targets.size(0)
        num_params = 6 + 1 + self.num_classes  # 6 bbox, 1 obj confidence, num_classes
        grid_size_x = predictions.size(2)
        grid_size_y = predictions.size(3)
        grid_size_z = predictions.size(4)
        
        predictions = predictions.view(batch_size, num_params, grid_size_x, grid_size_y, grid_size_z)
        targets = targets.view(batch_size, num_params, grid_size_x, grid_size_y, grid_size_z)

        pred_boxes = predictions[:, :6, ...]
        true_boxes = targets[:, :6, ...]

        box_loss = self.mse_loss(pred_boxes, true_boxes)
        pred_obj = torch.sigmoid(predictions[:, 6, ...])
        true_obj = torch.sigmoid(targets[:, 6, ...])  # Ensure target values are between 0 and 1
        obj_loss = self.bce_loss(pred_obj, true_obj)
        pred_cls = predictions[:, 7:, ...]
        true_cls = targets[:, 7:, ...]

        class_loss = self.ce_loss(pred_cls.permute(0, 2, 3, 4, 1).reshape(-1, self.num_classes), true_cls.argmax(dim=1).reshape(-1))

        total_loss = box_loss + obj_loss + class_loss
        return total_loss

# Example usage
model = Detect(num_classes=80)
loss_fn = YoloLoss(num_classes=80)

predictions = torch.randn((1, 7 + 80, 13, 13, 13))  # Example predictions

# Example targets with values within the expected range
targets = torch.zeros((1, 7 + 80, 13, 13, 13))
targets[0, :6, 5, 5, 5] = 0.5  # Example bounding box target
targets[0, 6, 5, 5, 5] = 1  # Object confidence target
targets[0, 7, 5, 5, 5] = 1  # Class target

detections, _ = model(predictions)
loss = loss_fn(predictions, targets)

print("Loss:", loss.item())



Loss: 6.715826511383057


hi


NameError: name 'converted_bboxes_with_comments' is not defined