In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Define the 3D FFCM module
class FFCM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(FFCM, self).__init__()

        # Convolutional layers
        self.conv1 = nn.Conv3d(in_channels, out_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv2 = nn.Conv3d(out_channels // 2, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

        # Batch Normalization
        self.bn1 = nn.BatchNorm3d(out_channels // 2)
        self.bn2 = nn.BatchNorm3d(out_channels)

        # Activation
        self.act = nn.ReLU()

    def forward(self, x):
        # Apply first convolution
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)

        # Apply second convolution
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)

        return x

# Define the 3D Basic Block
class BasicBlock3D(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

# Define the 3D ResNet model with reduced complexity (layer4 removed)
class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=1):
        super(ResNet3D, self).__init__()
        self.in_channels = 64

        self.ffcm = FFCM(1, 64)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        # Adjust MaxPool3d stride to 2
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)

        # Adjust strides in layers to achieve desired output size
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)  # Changed stride from 1 to 2
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  # Keep stride as 2
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1)  # Changed stride from 2 to 1
        # Removed layer4

        # Adjusted the input channels to the final convolution layer
        self.conv = nn.Conv3d(256 * block.expansion, (7 + num_classes), kernel_size=1, stride=1, padding=0)

        # Activation functions
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

        # Adaptive pooling to get desired grid size
        self.adaptive_pool = nn.AdaptiveAvgPool3d((20, 20, 20))

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.ffcm(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        # Removed layer4 forward pass

        x = self.conv(x)  # Output shape: (batch_size, 7 + num_classes, D, H, W)

        # Apply adaptive pooling to get grid size 20x20x20
        x = self.adaptive_pool(x)

        # Apply activation functions to appropriate channels
        bbox_pred = self.tanh(x[:, :6, :, :, :])  # Bounding box coordinates between -1 and 1
        obj_pred = x[:, 6:7, :, :, :]  # Objectness score (raw output for BCEWithLogitsLoss)
        class_pred = x[:, 7:, :, :, :]  # Class scores (raw output for CrossEntropyLoss)

        # Concatenate predictions
        predictions = torch.cat([bbox_pred, obj_pred, class_pred], dim=1)

        return predictions

# The rest of your code remains the same...

# Define the custom 3D NMS function
def custom_3d_nms(boxes, scores, threshold):
    if boxes.size(0) == 0:
        return torch.empty(0, dtype=torch.int64, device=boxes.device)

    # Sort scores in descending order
    scores, idxs = scores.sort(descending=True)
    boxes = boxes[idxs]

    keep = []
    while boxes.size(0) > 0:
        keep.append(idxs[0].item())
        if boxes.size(0) == 1:
            break

        iou = compute_3d_iou(boxes[0:1], boxes[1:])
        boxes = boxes[1:][iou <= threshold]
        idxs = idxs[1:][iou <= threshold]

    return torch.tensor(keep, dtype=torch.long, device=boxes.device)

# Define the 3D IoU computation
def compute_3d_iou(box1, boxes):
    inter_xmin = torch.max(box1[:, 0], boxes[:, 0])
    inter_ymin = torch.max(box1[:, 1], boxes[:, 1])
    inter_zmin = torch.max(box1[:, 2], boxes[:, 2])
    inter_xmax = torch.min(box1[:, 3], boxes[:, 3])
    inter_ymax = torch.min(box1[:, 4], boxes[:, 4])
    inter_zmax = torch.min(box1[:, 5], boxes[:, 5])

    inter_dims = (inter_xmax - inter_xmin).clamp(min=0) * \
                 (inter_ymax - inter_ymin).clamp(min=0) * \
                 (inter_zmax - inter_zmin).clamp(min=0)

    box1_vol = (box1[:, 3] - box1[:, 0]) * \
               (box1[:, 4] - box1[:, 1]) * \
               (box1[:, 5] - box1[:, 2])

    boxes_vol = (boxes[:, 3] - boxes[:, 0]) * \
                (boxes[:, 4] - boxes[:, 1]) * \
                (boxes[:, 5] - boxes[:, 2])

    union = box1_vol + boxes_vol - inter_dims
    iou = inter_dims / union

    return iou

# Define the Detect class for inference
class Detect(nn.Module):
    def __init__(self, num_classes, conf_threshold=0.5, nms_threshold=0.4):
        super(Detect, self).__init__()
        self.num_classes = num_classes
        self.conf_threshold = conf_threshold
        self.nms_threshold = nms_threshold

    def forward(self, predictions):
        batch_size = predictions.size(0)
        D, H, W = predictions.size(2), predictions.size(3), predictions.size(4)

        # Reshape predictions
        predictions = predictions.permute(0, 2, 3, 4, 1).contiguous()  # (batch_size, D, H, W, 7 + num_classes)
        predictions = predictions.view(batch_size, -1, 7 + self.num_classes)  # (batch_size, num_voxels, 7 + num_classes)

        # Extract components
        pred_bboxes = predictions[:, :, :6]  # (batch_size, num_voxels, 6)
        pred_obj_conf = torch.sigmoid(predictions[:, :, 6])  # Apply sigmoid to objectness score
        pred_class_scores = torch.sigmoid(predictions[:, :, 7:])  # Apply sigmoid to class scores

        # Multiply objectness confidence with class probabilities
        pred_scores = pred_obj_conf.unsqueeze(-1) * pred_class_scores  # (batch_size, num_voxels, num_classes)

        detections = []
        for batch_idx in range(batch_size):
            # Filter out low confidence predictions
            mask = pred_obj_conf[batch_idx] > self.conf_threshold
            if mask.sum() == 0:
                detections.append(None)
                continue

            boxes = pred_bboxes[batch_idx][mask]
            scores = pred_scores[batch_idx][mask]
            obj_conf = pred_obj_conf[batch_idx][mask]

            # For each class, perform NMS
            boxes_list = []
            scores_list = []
            labels_list = []
            for cls in range(self.num_classes):
                cls_scores = scores[:, cls]
                cls_mask = cls_scores > self.conf_threshold
                if cls_mask.sum() == 0:
                    continue
                cls_boxes = boxes[cls_mask]
                cls_scores = cls_scores[cls_mask]

                # Perform NMS
                keep = custom_3d_nms(cls_boxes, cls_scores, self.nms_threshold)
                cls_boxes = cls_boxes[keep]
                cls_scores = cls_scores[keep]
                cls_labels = torch.full((len(keep),), cls, dtype=torch.int64, device=predictions.device)

                boxes_list.append(cls_boxes)
                scores_list.append(cls_scores)
                labels_list.append(cls_labels)

            if boxes_list:
                detections.append({
                    'boxes': torch.cat(boxes_list),
                    'scores': torch.cat(scores_list),
                    'labels': torch.cat(labels_list)
                })
            else:
                detections.append(None)

        return detections

# Define the YOLO loss function
class YoloLoss(nn.Module):
    def __init__(self, num_classes, obj_weight=0.5, noobj_weight=1.0, class_weight=0.5):
        super(YoloLoss, self).__init__()
        self.num_classes = num_classes
        self.obj_weight = obj_weight
        self.noobj_weight = noobj_weight
        self.class_weight = class_weight
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([obj_weight]))
        self.bce_loss_noobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([noobj_weight]))
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, predictions, targets):
        batch_size = predictions.size(0)
        D, H, W = predictions.size(2), predictions.size(3), predictions.size(4)

        # Reshape predictions and targets
        predictions = predictions.permute(0, 2, 3, 4, 1).contiguous()  # (batch_size, D, H, W, 7 + num_classes)
        predictions = predictions.view(batch_size, -1, 7 + self.num_classes)  # (batch_size, num_voxels, 7 + num_classes)
        targets = targets.permute(0, 2, 3, 4, 1).contiguous()
        targets = targets.view(batch_size, -1, 7 + self.num_classes)

        # Extract components
        pred_bboxes = predictions[:, :, :6]  # (batch_size, num_voxels, 6)
        pred_obj_conf = predictions[:, :, 6]  # (batch_size, num_voxels)
        pred_class = predictions[:, :, 7:]    # (batch_size, num_voxels, num_classes)

        true_bboxes = targets[:, :, :6]
        true_obj_conf = targets[:, :, 6]
        true_class = targets[:, :, 7:].argmax(dim=2)  # Assuming one-hot encoded classes

        # Mask for cells containing objects
        obj_mask = true_obj_conf == 1
        noobj_mask = true_obj_conf == 0

        # Bounding box loss (only for cells with objects)
        if obj_mask.sum() > 0:
            box_loss = self.mse_loss(pred_bboxes[obj_mask], true_bboxes[obj_mask])
            class_loss = self.ce_loss(pred_class[obj_mask], true_class[obj_mask])
            obj_loss = self.bce_loss(pred_obj_conf[obj_mask], true_obj_conf[obj_mask])
        else:
            box_loss = torch.tensor(0.0, device=predictions.device)
            class_loss = torch.tensor(0.0, device=predictions.device)
            obj_loss = torch.tensor(0.0, device=predictions.device)

        # Objectness loss for cells without objects
        if noobj_mask.sum() > 0:
            noobj_loss = self.bce_loss_noobj(pred_obj_conf[noobj_mask], true_obj_conf[noobj_mask])
        else:
            noobj_loss = torch.tensor(0.0, device=predictions.device)

        # Total loss
        total_loss = box_loss + obj_loss + noobj_loss + class_loss

        return total_loss

# Example model and loss initialization
model = ResNet3D(BasicBlock3D, [2, 2, 2], num_classes=1)  # Adjusted layers parameter
loss_fn = YoloLoss(num_classes=1)  # Adjust num_classes here as well

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define input dimensions
D_in, H_in, W_in = 204, 204, 204  # Input dimensions as per your requirement
input_dims = (D_in, H_in, W_in)

# Verify the output dimensions
if __name__ == "__main__":
    # Create a dummy input tensor
    input_tensor = torch.randn(1, 1, D_in, H_in, W_in)
    output = model(input_tensor)
    print(f"Output shape: {output.shape}")  # Should be (1, 8, 20, 20, 20)



Output shape: torch.Size([1, 8, 20, 20, 20])


In [7]:
def convert_bboxes(bboxes):
    # Check if bboxes is a list of bounding boxes
    if isinstance(bboxes[0][0], list):
        # Convert each bounding box in the list
        return [convert_bboxes(bbox) for bbox in bboxes]
    else:
        # Convert a single bounding box
        x_min, x_max = bboxes[0]
        y_min, y_max = bboxes[1]
        z_min, z_max = bboxes[2]

        x_center = (x_min + x_max) / 2
        y_center = (y_min + y_max) / 2
        z_center = (z_min + z_max) / 2

        width = x_max - x_min
        height = y_max - y_min
        depth = z_max - z_min

        return [x_center, y_center, z_center, width, height, depth]
# Example usage
bounding_boxes_with_comments = {
    "#100140": [[[97, 106], [77, 93], [89, 105]]],
    "#100151": [[[105, 135], [83, 95], [78, 97]]],
    "#100152": [[[101, 112], [73, 86], [120, 134]]],
    "#100171": [[[51, 77], [65, 78], [98, 116]]],
    "#100205": [[[127, 148], [58, 82], [92, 121]]],
    "#100218": [[[109, 133], [56, 68], [105, 119]]],
    "#100231": [[[64, 82], [90, 103], [76, 90]]],
    "#100234": [[[71, 94], [80, 92], [81, 92]]],
    "#100242": [[[40, 61], [77, 98], [103, 125]]],
    "#100277": [[[113, 132], [77, 97], [123, 138]]],
    "#100305": [[[72, 90], [79, 93], [80, 96]]],
    "#100316": [[[176, 189], [76, 89], [105, 117]]],
    "#100324": [[[72, 91], [82, 104], [71, 99]]],
    "#100326": [[[94, 105], [76, 95], [101, 118]]],
    "#100394": [[[162, 173], [87, 99], [97, 109]]],
   # "#100248": [[[115, 129], [76, 94], [110, 124]]],
    "#100428": [[[115,126],[78,88],[110,123]]],
    "#100455": [[[46, 59], [75, 89], [98, 117]]],
    "#100460": [[[178, 194], [89, 101], [103, 119]]],
    "#100517": [[[142, 174], [77, 95], [92, 112]]],
    "#100536": [[[28, 37], [67, 94], [88, 103]]],
    "#100556": [[[159, 173], [83, 92], [94, 112]]],
    "#100601": [[[136, 158], [58, 71], [136, 149]]],
    "#100643": [[[9, 25], [72, 86], [95, 109]]],
    "#100649": [[[91, 104], [77, 91], [94, 111]]],
    "#100660": [[[67, 83], [77, 92], [81, 94]]],
    "#100673": [[[121, 148], [67, 86], [88, 118]]],
    "#100680": [[[68, 92], [70, 90], [78, 102]]],
    "#100700": [[[85, 112], [28, 56], [98, 123]]],
    "#100723": [[[70, 97], [81, 103], [66, 99]]],
    "#100739": [[[10, 24], [82, 97], [101, 115]]],
    "#100818": [[[144, 167], [73, 97], [106, 123]]],
    "#100873": [[[104, 117], [75, 95], [87, 103]]],
    "#100880": [[[112, 131], [58, 74], [94, 106]]],
    "#100915": [[[105, 119], [80, 96], [70, 86]]],
    "#100988": [[[30, 49], [87, 104], [84, 101]]],
    "#101026": [[[47, 60], [63, 81], [103, 123]]],
    "#101043": [[[157, 164], [81, 91], [104, 122]]],
    "#101054": [[[23, 34], [105, 117], [91, 104]]],
    "#101085": [[[65, 84], [88, 105], [85, 103]]],
    "#101118": [[[109, 131], [72, 84], [85, 102]]],
    "#101122": [[[122, 134], [59, 74], [98, 111]]],
    "#101140": [[[119, 131], [63, 77], [91, 104]]],
    "#101148": [[[102, 111], [48, 63], [95, 112]]],
    "#101157": [[[37, 56], [82, 91], [82, 96]]],
    "#101158": [[[161, 178], [72, 84], [105, 119]]],
    "#101194": [[[174, 194], [72, 85], [100, 116]]],
    "#101220": [[[124, 136], [61, 75], [99, 114]]],
    "#101234": [[[108, 133], [76, 96], [84, 98]]],
    "#101263": [[[113, 130], [80, 94], [80, 112]]],
    "#101272": [[[12, 27], [78, 86], [95, 105]]],
    "#101288": [[[35, 48], [86, 103], [90, 104]],[[142,154], [66,78], [100,113]]],
    "#101298": [[[149, 161], [89, 103], [109, 123]]],
    "#101302": [[[36, 47], [64, 79], [96, 113]]],
    "#101309": [[[91, 105], [72, 84], [107, 117]]],
    "#101331": [[[93, 102], [76, 89], [99, 114]]],
    "#101365": [[[75, 92], [74, 90], [114, 130]]],
    "#101393": [[[176, 189], [78, 104], [96, 116]]],
    "#101426": [[[12, 23], [65, 90], [88, 98]],[[106,115],[45,57],[104,115]]],  
    "#101493": [[[91, 115], [31, 51], [100, 120]]],
    "#101495": [[[146, 156], [65, 84], [109, 122]]],
    "#101508": [[[34, 47], [82, 101], [98, 111]]],
    "#101516": [[[7, 17], [91, 109], [124, 136]]],
    "#101562": [[[15, 32], [80, 97], [102, 116]]],
    "#101601": [[[164, 178], [84, 99], [92, 109]]],
    "#101675": [[[71, 90], [77, 91], [78, 95]]],
    "#101735": [[[33, 51], [68, 80],[116,126]]],
    "#101782": [[[115, 132], [76, 91], [84, 96]]],
    "#101842": [[[32, 52], [65, 88], [83, 103]]],
    "#101885": [[[94, 107], [73, 97], [107, 130]]],
    "#101892": [[[137, 153], [69, 81], [106, 119]]],
    "#101990": [[[101, 136], [75, 90], [72, 91]]],
    "#101991": [[[98, 114], [83, 98], [105, 121]]],
    "#101995": [[[114, 128], [64, 80], [95, 106]]],
    "#101998": [[[107, 133], [75, 92], [95, 111]]],
    "#102024": [[[66, 85], [58, 71], [93, 114]]],
    "#102084": [[[38, 58], [76, 101], [92, 104]]],
    "#102141": [[[96, 108], [79, 94], [110, 123]]],
    "#102158": [[[149, 166], [79, 95], [99, 117]]],
    "#102200": [[[101, 131], [82, 102], [72, 105]]],
    "#102201": [[[106, 129], [77, 93], [90, 102]]],
    "#102202": [[[95, 110], [71, 89], [103, 113]]],
    "#102228": [[[149, 163], [82, 91], [99, 112]]],
    "#102283": [[[169, 182], [76, 94], [90, 106]]],
    "#102333": [[[104, 124], [82, 102], [90, 110]]],
    "#102348": [[[71, 85], [74, 93], [90, 103]]],
    "#102391": [[[118, 139], [66, 83], [100, 116]]],
    "#102418": [[[117, 136], [66, 91], [86, 103]]],
    "#102427": [[[68, 89], [87, 100], [81, 98]]],
    "#102442": [[[67, 83], [67, 80], [92, 108]]],
    "#102455": [[[144, 160], [71, 83], [84, 101]]],
    "#102496": [[[160, 181], [74, 90], [100, 116]]],
    "#102564": [[[116, 139], [70, 91], [106, 121]]],
    "#102590": [[[150, 164], [81, 90], [93, 109]],[[114,133],[78,92],[79,88]]],  # and [[118, 136], [82, 93], [77, 91]]
    "#102595": [[[112, 127], [70, 91], [83, 93]],[[113, 131], [75, 86], [92, 105]]],  # and [[113, 131], [75, 86], [92, 105]]
    "#102601": [[[88, 104], [81, 96], [93, 104]]],
    "#102609": [[[20, 42], [80, 98], [94, 115]]],
    "#102638": [[[40, 60], [70, 86], [112, 130]]],
    "#102693": [[[209, 125], [82, 98], [75, 92]]],
    "#102720": [[[22, 36], [86, 100], [92, 103]]],
    "#102761": [[[68, 85], [80, 94], [80, 98]]],
    "#102764": [[[72, 91], [72, 92], [86, 97]]],
    "#102797": [[[73, 86], [86, 102], [84, 98]]],
    "#102835": [[[157, 178], [90, 104], [88, 101]]],
    "#102840": [[[111, 129], [82, 94], [78, 94]]],
    "#102864": [[[112, 132], [79, 100], [79, 93]],[[68, 88], [83, 97], [82, 95]],[[184, 201], [32, 48], [131, 143]]],  # and [[68, 88], [83, 97], [82, 95]], [[184, 201], [32, 48], [131, 143]]
    "#102886": [[[41, 49], [71, 94], [97, 107]]],
    "#102935": [[[143, 161], [71, 91], [104, 120]]],
    "#102959": [[[66, 85], [78, 93], [90, 99]]],
    "#102974": [[[110, 127], [76, 95], [71, 90]]]
}


converted_bboxes_with_comments = {}

for key, bboxes in bounding_boxes_with_comments.items():
    # Convert each list of bounding boxes or single bounding box
    converted_bboxes_with_comments[key] = convert_bboxes(bboxes)

print(converted_bboxes_with_comments)





{'#100140': [[101.5, 85.0, 97.0, 9, 16, 16]], '#100151': [[120.0, 89.0, 87.5, 30, 12, 19]], '#100152': [[106.5, 79.5, 127.0, 11, 13, 14]], '#100171': [[64.0, 71.5, 107.0, 26, 13, 18]], '#100205': [[137.5, 70.0, 106.5, 21, 24, 29]], '#100218': [[121.0, 62.0, 112.0, 24, 12, 14]], '#100231': [[73.0, 96.5, 83.0, 18, 13, 14]], '#100234': [[82.5, 86.0, 86.5, 23, 12, 11]], '#100242': [[50.5, 87.5, 114.0, 21, 21, 22]], '#100277': [[122.5, 87.0, 130.5, 19, 20, 15]], '#100305': [[81.0, 86.0, 88.0, 18, 14, 16]], '#100316': [[182.5, 82.5, 111.0, 13, 13, 12]], '#100324': [[81.5, 93.0, 85.0, 19, 22, 28]], '#100326': [[99.5, 85.5, 109.5, 11, 19, 17]], '#100394': [[167.5, 93.0, 103.0, 11, 12, 12]], '#100428': [[120.5, 83.0, 116.5, 11, 10, 13]], '#100455': [[52.5, 82.0, 107.5, 13, 14, 19]], '#100460': [[186.0, 95.0, 111.0, 16, 12, 16]], '#100517': [[158.0, 86.0, 102.0, 32, 18, 20]], '#100536': [[32.5, 80.5, 95.5, 9, 27, 15]], '#100556': [[166.0, 87.5, 103.0, 14, 9, 18]], '#100601': [[147.0, 64.5, 142.5

In [8]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Assuming the ResNet3D model and YoloLoss class are already defined as per the previous code.

# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Directory containing the NIfTI files
nifti_directory = '/Users/lucabernecker/Desktop/N128_local/aneu_det'

# List all file IDs from bounding box keys
file_ids = [key.strip("#") for key in converted_bboxes_with_comments.keys()]

# Initialize a list to store the loaded NIfTI images as numpy arrays
nifti_images_list = []

# Desired output shape (consistent with the input dimensions expected by the model)
desired_shape = (204, 204, 204)

# Function to pad the images
def pad_image(image, target_shape):
    pad_width = [(0, max(0, t - s)) for s, t in zip(image.shape, target_shape)]
    return np.pad(image, pad_width, mode='constant', constant_values=0)

# Load and preprocess the NIfTI images
for file_id in file_ids:
    file_name = f"cube_{file_id}.nii.gz"
    file_path = os.path.join(nifti_directory, file_name)
    
    # Load the NIfTI image
    nifti_image = nib.load(file_path)
    
    # Convert the NIfTI image to a numpy array
    image_data = nifti_image.get_fdata()
    
    # Pad the image to the desired shape
    padded_image = pad_image(image_data, desired_shape)
    
    # Append to the list
    nifti_images_list.append(padded_image)

# Stack all numpy arrays into a single numpy array
x_train = np.stack(nifti_images_list, axis=0)

# Convert the list of images to a Torch tensor and add a channel dimension
x_train = torch.tensor(x_train, dtype=torch.float32)
x_train = torch.unsqueeze(x_train, 1)  # Shape: (batch_size, 1, 204, 204, 204)
print("X_Train Shape:", x_train.shape)

# Number of classes (assuming 1 class)
num_classes = 1

# Grid size for 3D detection (e.g., 13x13x13 grid)
grid_size = 13

# Initialize the target tensor
targets = torch.zeros((len(file_ids), 7 + num_classes, grid_size, grid_size, grid_size))  # Shape: (batch_size, 8, 13, 13, 13)

# Prepare the targets
for idx, (file_id, boxes) in enumerate(converted_bboxes_with_comments.items()):
    for box in boxes:
        # Extract center coordinates (x, y, z) and dimensions (width, height, depth)
        center_x, center_y, center_z, width, height, depth = box
        
        # Normalize coordinates to be between 0 and 1
        norm_center_x = center_x / desired_shape[0]
        norm_center_y = center_y / desired_shape[1]
        norm_center_z = center_z / desired_shape[2]
        
        # Compute which grid cell the center falls into
        grid_x = int(norm_center_x * grid_size)
        grid_y = int(norm_center_y * grid_size)
        grid_z = int(norm_center_z * grid_size)
        
        # Ensure grid indices are within bounds
        grid_x = min(grid_size - 1, max(0, grid_x))
        grid_y = min(grid_size - 1, max(0, grid_y))
        grid_z = min(grid_size - 1, max(0, grid_z))
        
        # Encode bounding box coordinates relative to the grid cell
        # Compute the offset within the cell
        cell_x = norm_center_x * grid_size - grid_x
        cell_y = norm_center_y * grid_size - grid_y
        cell_z = norm_center_z * grid_size - grid_z
        
        # Normalize the bounding box dimensions to be relative to the entire image size
        norm_width = width / desired_shape[0]
        norm_height = height / desired_shape[1]
        norm_depth = depth / desired_shape[2]
        
        # Add bounding box info to the target tensor (grid cell)
        targets[idx, 0:6, grid_x, grid_y, grid_z] = torch.tensor(
            [cell_x, cell_y, cell_z, norm_width, norm_height, norm_depth]
        )
        # Set objectness score to 1
        targets[idx, 6, grid_x, grid_y, grid_z] = 1
        # Set class label (one-hot encoding)
        class_label = 0  # Assuming only one class
        targets[idx, 7 + class_label, grid_x, grid_y, grid_z] = 1

print("Targets Shape:", targets.shape)

# Define the model
model = ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=num_classes)
model.to(device)

# Define the loss function
loss_fn = YoloLoss(num_classes=num_classes)
loss_fn.to(device)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Move data to device
x_train = x_train.to(device)
targets = targets.to(device)

# Training parameters
num_epochs = 50
batch_size = 1
num_samples = x_train.shape[0]
num_batches = int(np.ceil(num_samples / batch_size))

print("Targets Shape:", targets.shape)

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, num_samples)
        input_data = x_train[start_idx:end_idx]
        target_data = targets[start_idx:end_idx]
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        predictions = model(input_data)
        
        # Print shapes for debugging
        print(f"Batch {batch_idx+1}")
        print("Input data shape:", input_data.shape)
        print("Predictions shape:", predictions.shape)
        print("Target data shape:", target_data.shape)
        
        # Compute loss
        loss = loss_fn(predictions, target_data)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Print loss for this batch
        print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{num_batches}], Loss: {loss.item()}")
        
print("Training completed.")


X_Train Shape: torch.Size([109, 1, 204, 204, 204])
Targets Shape: torch.Size([109, 8, 13, 13, 13])
Targets Shape: torch.Size([109, 8, 13, 13, 13])
Batch 1
Input data shape: torch.Size([1, 1, 204, 204, 204])
Predictions shape: torch.Size([1, 8, 13, 13, 13])
Target data shape: torch.Size([1, 8, 13, 13, 13])


KeyboardInterrupt: 

### Control

In [15]:
nifti_directory = '/Users/lucabernecker/Desktop/N128_local/aneu_det'
import os
import nibabel as nib
import numpy as np
import torch

# List all file IDs directly from converted_bboxes_with_comments keys
file_ids = list(converted_bboxes_with_comments.keys())  # Keep the '#' in place

# Initialize a list to store the loaded nifti images as numpy arrays
nifti_images_list = []

# Desired output shape
desired_shape = (204, 204, 204)

# Function to pad the images
def pad_image(image, target_shape):
    pad_width = [(0, max(0, t - s)) for s, t in zip(image.shape, target_shape)]
    return np.pad(image, pad_width, mode='constant', constant_values=0)

# Load and preprocess the nifti images
for file_id in file_ids:
    file_name = f"cube_{file_id.strip('#')}.nii.gz"  # Strip the '#' here when loading the file
    file_path = os.path.join(nifti_directory, file_name)
    
    # Load the nifti image
    nifti_image = nib.load(file_path)
    
    # Convert the nifti image to a numpy array
    image_data = nifti_image.get_fdata()
    
    # Pad the image to the desired shape
    padded_image = pad_image(image_data, desired_shape)
    
    # Append to the list
    nifti_images_list.append(padded_image)

# Stack all numpy arrays into a single numpy array
x_train = np.stack(nifti_images_list, axis=0)

# Convert the list of images to a NumPy array and then to a Torch tensor
x_train = torch.tensor(x_train, dtype=torch.float32)
x_train = torch.unsqueeze(x_train, 1)  # Add a channel dimension
print("X_Train Shape:", np.shape(x_train))

# Function to "paint" the bounding box
def paint_bounding_box(image, bbox, value=1):
    """Paint the region inside the bounding box in the image with the given intensity value."""
    center_x, center_y, center_z, width, height, depth = bbox

    # Calculate the min and max coordinates from the center and dimensions
    x_min = max(0, int(center_x - width / 2))
    x_max = min(image.shape[0] - 1, int(center_x + width / 2))
    
    y_min = max(0, int(center_y - height / 2))
    y_max = min(image.shape[1] - 1, int(center_y + height / 2))
    
    z_min = max(0, int(center_z - depth / 2))
    z_max = min(image.shape[2] - 1, int(center_z + depth / 2))

    # Paint the bounding box region in the image with the maximum intensity
    image[x_min:x_max + 1, y_min:y_max + 1, z_min:z_max + 1] = value
    
    return image

# Assuming the first file ID corresponds to the first image in x_train
first_file_id = file_ids[0]  # Use file_id with '#'

# Check if the file_id exists in converted_bboxes_with_comments
if first_file_id in converted_bboxes_with_comments:
    bounding_box = converted_bboxes_with_comments[first_file_id][0]  # Use the first bounding box for this example

    # Use the bounding box from the targets for the first image
    first_image = x_train[0].squeeze().numpy()

    # Paint the bounding box with a higher intensity (representing "red" in grayscale)
    painted_image = paint_bounding_box(first_image, bounding_box, value=255)

    # Create a new NIfTI image with the painted bounding box
    painted_nifti = nib.Nifti1Image(painted_image, affine=np.eye(4))

    # Save the modified image as a NIfTI file
    output_path = os.path.join(nifti_directory, "painted_first_image.nii.gz")
    nib.save(painted_nifti, output_path)

    print(f"Annotated image with bounding box saved to: {output_path}")
else:
    print(f"File ID {first_file_id} not found in bounding box data.")


X_Train Shape: torch.Size([109, 1, 204, 204, 204])
Annotated image with bounding box saved to: /Users/lucabernecker/Desktop/N128_local/aneu_det/painted_first_image.nii.gz
