In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
print("starting Training")
# Define the 3D Basic Block
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Define the 3D Basic Block
class FFCM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(FFCM, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv3d(in_channels, out_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv2 = nn.Conv3d(out_channels // 2, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        
        # Batch Normalization
        self.bn1 = nn.BatchNorm3d(out_channels // 2)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        # Activation
        self.act = nn.ReLU()

    def forward(self, x):
        # Apply first convolution
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        
        # Apply second convolution
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        
        return x
class BasicBlock3D(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)  # Use 'out' instead of 'x'
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

# Define the 3D ResNet model
class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=1):
        super(ResNet3D, self).__init__()
        self.in_channels = 64
        
        self.ffcm = FFCM(1, 64)
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.conv = nn.Conv3d(512 * block.expansion, (7 + num_classes), kernel_size=1, stride=1, padding=2)
        self.detect = Detect(num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.ffcm(x)
      #  x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.conv(x)
        detections, predictions = self.detect(x)
        return detections, predictions

# Define the Detect and Loss classes
class Detect(nn.Module):
    def __init__(self, num_classes, conf_threshold=0.5, nms_threshold=0.4):
        super(Detect, self).__init__()
        self.num_classes = num_classes
        self.conf_threshold = conf_threshold
        self.nms_threshold = nms_threshold

    def forward(self, predictions):
        batch_size, _, d, h, w = predictions.shape
        predictions = predictions.view(batch_size, -1, self.num_classes + 7, d, h, w)
        
        # Calculate the coordinates and dimensions
        x_min = predictions[:, :, 0, :, :, :]
        y_min = predictions[:, :, 1, :, :, :]
        z_min = predictions[:, :, 2, :, :, :]
        x_max = predictions[:, :, 3, :, :, :]
        y_max = predictions[:, :, 4, :, :, :]
        z_max = predictions[:, :, 5, :, :, :]
        obj_conf = torch.sigmoid(predictions[:, :, 6, :, :, :])
        class_conf, class_pred = torch.max(predictions[:, :, 7:, :, :, :], dim=2)

        # Filter out low confidence objectness predictions
        mask = obj_conf > self.conf_threshold
        detections = []
        for batch_idx in range(batch_size):
            boxes = []
            scores = []
            for class_idx in range(self.num_classes):
                class_mask = mask[batch_idx, :, :, :, :] & (class_pred[batch_idx, :, :, :, :] == class_idx)
                if class_mask.sum() == 0:
                    continue
                x_min_class = x_min[batch_idx, :, :, :][class_mask]
                y_min_class = y_min[batch_idx, :, :, :][class_mask]
                z_min_class = z_min[batch_idx, :, :, :][class_mask]
                x_max_class = x_max[batch_idx, :, :, :][class_mask]
                y_max_class = y_max[batch_idx, :, :, :][class_mask]
                z_max_class = z_max[batch_idx, :, :, :][class_mask]
                obj_conf_class = obj_conf[batch_idx, :, :, :][class_mask]
                class_conf_class = class_conf[batch_idx, :, :, :][class_mask]
                scores_class = obj_conf_class * class_conf_class

                boxes_class = torch.stack((x_min_class, y_min_class, z_min_class, x_max_class, y_max_class, z_max_class), dim=-1)
                keep = custom_3d_nms(boxes_class, scores_class, self.nms_threshold)

                if len(keep) > 0:
                    boxes.append(boxes_class[keep])
                    scores.append(scores_class[keep])

            if len(boxes) > 0:
                boxes = torch.cat(boxes, dim=0)
                scores = torch.cat(scores, dim=0)
                detections.append((boxes, scores))
            else:
                detections.append((torch.zeros((0, 6)), torch.zeros((0,))))

        return detections, predictions.view(batch_size, -1, d, h, w)

def custom_3d_nms(boxes, scores, threshold):
    if boxes.size(0) == 0:
        return torch.empty(0, dtype=torch.int64, device=boxes.device)

    # Sort scores in descending order
    _, idxs = scores.sort(descending=True)
    boxes = boxes[idxs]
    scores = scores[idxs]

    keep = []
    max_iterations = 1000  # Set a maximum number of iterations to prevent infinite loops
    iteration = 0

    while boxes.size(0) > 0:
        pick = boxes.new_tensor([0], dtype=torch.long)
        keep.append(idxs[pick].item())

        if boxes.size(0) == 1:
            break

        iou = compute_3d_iou(boxes[pick].unsqueeze(0), boxes)
        mask = (iou <= threshold).squeeze(0)

        if mask.sum() == 0:
            break  # No boxes left after NMS

        boxes = boxes[mask]
        scores = scores[mask]
        idxs = idxs[mask]

        iteration += 1
   #     print(f"Remaining boxes: {boxes.size(0)}, Iteration: {iteration}")  # Debugging print statement

        if iteration >= max_iterations:
       #     print("Reached maximum iterations in NMS, breaking out to prevent infinite loop.")
            break

    return torch.tensor(keep, dtype=torch.int64, device=boxes.device)

def compute_3d_iou(box1, boxes):
    if box1.dim() == 1:
        box1 = box1.unsqueeze(0)
    elif box1.dim() == 3:
        box1 = box1.squeeze(0)

    inter_xmin = torch.max(box1[:, 0].unsqueeze(1), boxes[:, 0])
    inter_ymin = torch.max(box1[:, 1].unsqueeze(1), boxes[:, 1])
    inter_zmin = torch.max(box1[:, 2].unsqueeze(1), boxes[:, 2])
    inter_xmax = torch.min(box1[:, 3].unsqueeze(1), boxes[:, 3])
    inter_ymax = torch.min(box1[:, 4].unsqueeze(1), boxes[:, 4])
    inter_zmax = torch.min(box1[:, 5].unsqueeze(1), boxes[:, 5])

    inter_volume = torch.clamp(inter_xmax - inter_xmin, min=0) * torch.clamp(inter_ymax - inter_ymin, min=0) * torch.clamp(inter_zmax - inter_zmin, min=0)
    box1_volume = (box1[:, 3] - box1[:, 0]) * (box1[:, 4] - box1[:, 1]) * (box1[:, 5] - box1[:, 2])
    boxes_volume = (boxes[:, 3] - boxes[:, 0]) * (boxes[:, 4] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 2])

    union_volume = box1_volume.unsqueeze(1) + boxes_volume - inter_volume
    iou = inter_volume / union_volume

    return iou

class YoloLoss(nn.Module):
    def __init__(self, num_classes, obj_weight=1.0, noobj_weight=1.0, class_weight=1.0):
        super(YoloLoss, self).__init__()
        self.num_classes = num_classes
        self.bce_loss = nn.BCELoss()
        self.mse_loss = nn.MSELoss()
        self.obj_weight = obj_weight
        self.noobj_weight = noobj_weight
        self.class_weight = class_weight

    def forward(self, predictions, targets):
        # Slicing predictions to get individual components
        pred_bboxes = predictions[:, :6, ...]  # Assuming first 6 channels are bbox predictions
        pred_obj_conf = predictions[:, 6, ...]  # Assuming 7th channel is objectness confidence
        pred_class = predictions[:, 7:7+self.num_classes, ...]  # Following channels are class predictions

        # Slicing targets to get individual components
        true_bboxes = targets[:, :6, ...]
        true_obj_conf = targets[:, 6, ...]
        true_class = targets[:, 7:7+self.num_classes, ...]  # Assuming one-hot encoded class targets

        # Debug prints for shapes
        print("pred_class shape:", pred_class.shape)
        print("true_class shape:", true_class.shape)

        # Calculate MSE loss for bounding box predictions
        box_loss = self.mse_loss(pred_bboxes, true_bboxes)

        # Calculate BCE loss for objectness confidence
        obj_loss = self.bce_loss(pred_obj_conf, true_obj_conf)

        # Calculate BCE loss for class predictions
        class_loss = self.bce_loss(pred_class, true_class)

        # Total loss
        total_loss = box_loss + self.obj_weight * obj_loss + self.noobj_weight * (1 - obj_loss) + self.class_weight * class_loss
        return total_loss
# Example model and loss initialization
# Example model and loss initialization
model = ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=1)  # Adjust num_classes based on your actual setup
loss_fn = YoloLoss(num_classes=101)  # Adjust num_classes here as well

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of epochs
num_epochs = 2

# Example training data loader (replace with your actual data loader)
# This is a placeholder; you should use your actual DataLoader providing (input_data, targets) pairs
# train_loader = DataLoader(dataset, batch_size=1, shuffle=True)

# For demonstration, using dummy data
input_data = torch.randn((1, 1, 204, 204, 204))  # Example input with 1 channel
targets = torch.zeros((1, 8, 17, 17, 17))  # Adjust target size to match predictions
targets[0, :6, 5, 5, 5] = 0.5  # Example bounding box target
targets[0, 6, 5, 5,5] = 1  # Object confidence target
targets[0, 7:7+1, 5, 5, 5] = 1  # Class target 

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode

    # Example loop over data (replace with your actual data loader loop)
    for batch_idx in range(1):  # replace with 'for input_data, targets in train_loader:'
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        detections, predictions = model(input_data)

        # Apply sigmoid to the predictions
        predictions_sigmoid = torch.sigmoid(predictions)

        # Calculate loss using the custom loss function
        loss = loss_fn(predictions_sigmoid, targets)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print loss for this batch
        print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {loss.item()}")

print("Training completed.")

starting Training
pred_class shape: torch.Size([1, 1, 17, 17, 17])
true_class shape: torch.Size([1, 1, 17, 17, 17])
Epoch [1/2], Batch [1], Loss: 2.0201175212860107


In [3]:
def convert_bboxes(bboxes):
    # Check if bboxes is a list of bounding boxes
    if isinstance(bboxes[0][0], list):
        # Convert each bounding box in the list
        return [convert_bboxes(bbox) for bbox in bboxes]
    else:
        # Convert a single bounding box
        x_min, x_max = bboxes[0]
        y_min, y_max = bboxes[1]
        z_min, z_max = bboxes[2]

        x_center = (x_min + x_max) / 2
        y_center = (y_min + y_max) / 2
        z_center = (z_min + z_max) / 2

        width = x_max - x_min
        height = y_max - y_min
        depth = z_max - z_min

        return [x_center, y_center, z_center, width, height, depth]
# Example usage
bounding_boxes_with_comments = {
    "#100140": [[[97, 106], [77, 93], [89, 105]]],
    "#100151": [[[105, 135], [83, 95], [78, 97]]],
    "#100152": [[[101, 112], [73, 86], [120, 134]]],
    "#100171": [[[51, 77], [65, 78], [98, 116]]],
    "#100205": [[[127, 148], [58, 82], [92, 121]]],
    "#100218": [[[109, 133], [56, 68], [105, 119]]],
    "#100231": [[[64, 82], [90, 103], [76, 90]]],
    "#100234": [[[71, 94], [80, 92], [81, 92]]],
    "#100242": [[[40, 61], [77, 98], [103, 125]]],
    "#100277": [[[113, 132], [77, 97], [123, 138]]],
    "#100305": [[[72, 90], [79, 93], [80, 96]]],
    "#100316": [[[176, 189], [76, 89], [105, 117]]],
    "#100324": [[[72, 91], [82, 104], [71, 99]]],
    "#100326": [[[94, 105], [76, 95], [101, 118]]],
    "#100394": [[[162, 173], [87, 99], [97, 109]]],
   # "#100248": [[[115, 129], [76, 94], [110, 124]]],
    "#100428": [[[115,126],[78,88],[110,123]]],
    "#100455": [[[46, 59], [75, 89], [98, 117]]],
    "#100460": [[[178, 194], [89, 101], [103, 119]]],
    "#100517": [[[142, 174], [77, 95], [92, 112]]],
    "#100536": [[[28, 37], [67, 94], [88, 103]]],
    "#100556": [[[159, 173], [83, 92], [94, 112]]],
    "#100601": [[[136, 158], [58, 71], [136, 149]]],
    "#100643": [[[9, 25], [72, 86], [95, 109]]],
    "#100649": [[[91, 104], [77, 91], [94, 111]]],
    "#100660": [[[67, 83], [77, 92], [81, 94]]],
    "#100673": [[[121, 148], [67, 86], [88, 118]]],
    "#100680": [[[68, 92], [70, 90], [78, 102]]],
    "#100700": [[[85, 112], [28, 56], [98, 123]]],
    "#100723": [[[70, 97], [81, 103], [66, 99]]],
    "#100739": [[[10, 24], [82, 97], [101, 115]]],
    "#100818": [[[144, 167], [73, 97], [106, 123]]],
    "#100873": [[[104, 117], [75, 95], [87, 103]]],
    "#100880": [[[112, 131], [58, 74], [94, 106]]],
    "#100915": [[[105, 119], [80, 96], [70, 86]]],
    "#100988": [[[30, 49], [87, 104], [84, 101]]],
    "#101026": [[[47, 60], [63, 81], [103, 123]]],
    "#101043": [[[157, 164], [81, 91], [104, 122]]],
    "#101054": [[[23, 34], [105, 117], [91, 104]]],
    "#101085": [[[65, 84], [88, 105], [85, 103]]],
    "#101118": [[[109, 131], [72, 84], [85, 102]]],
    "#101122": [[[122, 134], [59, 74], [98, 111]]],
    "#101140": [[[119, 131], [63, 77], [91, 104]]],
    "#101148": [[[102, 111], [48, 63], [95, 112]]],
    "#101157": [[[37, 56], [82, 91], [82, 96]]],
    "#101158": [[[161, 178], [72, 84], [105, 119]]],
    "#101194": [[[174, 194], [72, 85], [100, 116]]],
    "#101220": [[[124, 136], [61, 75], [99, 114]]],
    "#101234": [[[108, 133], [76, 96], [84, 98]]],
    "#101263": [[[113, 130], [80, 94], [80, 112]]],
    "#101272": [[[12, 27], [78, 86], [95, 105]]],
 #   "#101288": [[[35, 48], [86, 103], [90, 104]],[[142,154], [66,78], [100,113]]],
    "#101298": [[[149, 161], [89, 103], [109, 123]]],
    "#101302": [[[36, 47], [64, 79], [96, 113]]],
    "#101309": [[[91, 105], [72, 84], [107, 117]]],
    "#101331": [[[93, 102], [76, 89], [99, 114]]],
    "#101365": [[[75, 92], [74, 90], [114, 130]]],
    "#101393": [[[176, 189], [78, 104], [96, 116]]],
 #   "#101426": [[[12, 23], [65, 90], [88, 98]],[[106,115],[45,57],[104,115]]],  
    "#101493": [[[91, 115], [31, 51], [100, 120]]],
    "#101495": [[[146, 156], [65, 84], [109, 122]]],
    "#101508": [[[34, 47], [82, 101], [98, 111]]],
    "#101516": [[[7, 17], [91, 109], [124, 136]]],
    "#101562": [[[15, 32], [80, 97], [102, 116]]],
    "#101601": [[[164, 178], [84, 99], [92, 109]]],
    "#101675": [[[71, 90], [77, 91], [78, 95]]],
    "#101735": [[[33, 51], [68, 80],[116,126]]],
    "#101782": [[[115, 132], [76, 91], [84, 96]]],
    "#101842": [[[32, 52], [65, 88], [83, 103]]],
    "#101885": [[[94, 107], [73, 97], [107, 130]]],
    "#101892": [[[137, 153], [69, 81], [106, 119]]],
    "#101990": [[[101, 136], [75, 90], [72, 91]]],
    "#101991": [[[98, 114], [83, 98], [105, 121]]],
    "#101995": [[[114, 128], [64, 80], [95, 106]]],
    "#101998": [[[107, 133], [75, 92], [95, 111]]],
    "#102024": [[[66, 85], [58, 71], [93, 114]]],
    "#102084": [[[38, 58], [76, 101], [92, 104]]],
    "#102141": [[[96, 108], [79, 94], [110, 123]]],
    "#102158": [[[149, 166], [79, 95], [99, 117]]],
    "#102200": [[[101, 131], [82, 102], [72, 105]]],
    "#102201": [[[106, 129], [77, 93], [90, 102]]],
    "#102202": [[[95, 110], [71, 89], [103, 113]]],
    "#102228": [[[149, 163], [82, 91], [99, 112]]],
    "#102283": [[[169, 182], [76, 94], [90, 106]]],
    "#102333": [[[104, 124], [82, 102], [90, 110]]],
    "#102348": [[[71, 85], [74, 93], [90, 103]]],
    "#102391": [[[118, 139], [66, 83], [100, 116]]],
    "#102418": [[[117, 136], [66, 91], [86, 103]]],
    "#102427": [[[68, 89], [87, 100], [81, 98]]],
    "#102442": [[[67, 83], [67, 80], [92, 108]]],
    "#102455": [[[144, 160], [71, 83], [84, 101]]],
    "#102496": [[[160, 181], [74, 90], [100, 116]]],
    "#102564": [[[116, 139], [70, 91], [106, 121]]],
#    "#102590": [[[150, 164], [81, 90], [93, 109]],[[114,133],[78,92],[79,88]]],  # and [[118, 136], [82, 93], [77, 91]]
#    "#102595": [[[112, 127], [70, 91], [83, 93]],[[113, 131], [75, 86], [92, 105]]],  # and [[113, 131], [75, 86], [92, 105]]
    "#102601": [[[88, 104], [81, 96], [93, 104]]],
    "#102609": [[[20, 42], [80, 98], [94, 115]]],
    "#102638": [[[40, 60], [70, 86], [112, 130]]],
    "#102693": [[[209, 125], [82, 98], [75, 92]]],
    "#102720": [[[22, 36], [86, 100], [92, 103]]],
    "#102761": [[[68, 85], [80, 94], [80, 98]]],
    "#102764": [[[72, 91], [72, 92], [86, 97]]],
    "#102797": [[[73, 86], [86, 102], [84, 98]]],
    "#102835": [[[157, 178], [90, 104], [88, 101]]],
    "#102840": [[[111, 129], [82, 94], [78, 94]]],
  #  "#102864": [[[112, 132], [79, 100], [79, 93]],[[68, 88], [83, 97], [82, 95]],[[184, 201], [32, 48], [131, 143]]],  # and [[68, 88], [83, 97], [82, 95]], [[184, 201], [32, 48], [131, 143]]
    "#102886": [[[41, 49], [71, 94], [97, 107]]],
    "#102935": [[[143, 161], [71, 91], [104, 120]]],
    "#102959": [[[66, 85], [78, 93], [90, 99]]],
    "#102974": [[[110, 127], [76, 95], [71, 90]]]
}


converted_bboxes_with_comments = {}

for key, bboxes in bounding_boxes_with_comments.items():
    # Convert each list of bounding boxes or single bounding box
    converted_bboxes_with_comments[key] = convert_bboxes(bboxes)

print(converted_bboxes_with_comments)





{'#100140': [[101.5, 85.0, 97.0, 9, 16, 16]], '#100151': [[120.0, 89.0, 87.5, 30, 12, 19]], '#100152': [[106.5, 79.5, 127.0, 11, 13, 14]], '#100171': [[64.0, 71.5, 107.0, 26, 13, 18]], '#100205': [[137.5, 70.0, 106.5, 21, 24, 29]], '#100218': [[121.0, 62.0, 112.0, 24, 12, 14]], '#100231': [[73.0, 96.5, 83.0, 18, 13, 14]], '#100234': [[82.5, 86.0, 86.5, 23, 12, 11]], '#100242': [[50.5, 87.5, 114.0, 21, 21, 22]], '#100277': [[122.5, 87.0, 130.5, 19, 20, 15]], '#100305': [[81.0, 86.0, 88.0, 18, 14, 16]], '#100316': [[182.5, 82.5, 111.0, 13, 13, 12]], '#100324': [[81.5, 93.0, 85.0, 19, 22, 28]], '#100326': [[99.5, 85.5, 109.5, 11, 19, 17]], '#100394': [[167.5, 93.0, 103.0, 11, 12, 12]], '#100428': [[120.5, 83.0, 116.5, 11, 10, 13]], '#100455': [[52.5, 82.0, 107.5, 13, 14, 19]], '#100460': [[186.0, 95.0, 111.0, 16, 12, 16]], '#100517': [[158.0, 86.0, 102.0, 32, 18, 20]], '#100536': [[32.5, 80.5, 95.5, 9, 27, 15]], '#100556': [[166.0, 87.5, 103.0, 14, 9, 18]], '#100601': [[147.0, 64.5, 142.5

In [4]:
nifti_directory = '/Users/lucabernecker/Desktop/N128_local/aneu_det'
import os
import nibabel as nib
import numpy as np
# List all files in the folder
file_ids = [key.strip("#") for key in converted_bboxes_with_comments.keys()]

# Initialize a list to store the loaded nifti images as numpy arrays
nifti_images_list = []

# Desired output shape
desired_shape = (204, 204, 204)

# Function to pad the images
def pad_image(image, target_shape):
    pad_width = [(0, max(0, t - s)) for s, t in zip(image.shape, target_shape)]
    return np.pad(image, pad_width, mode='constant', constant_values=0)

# Load and preprocess the nifti images
for file_id in file_ids:
    file_name = f"cube_{file_id}.nii.gz"
    file_path = os.path.join(nifti_directory, file_name)
    
    # Load the nifti image
    nifti_image = nib.load(file_path)
    
    # Convert the nifti image to a numpy array
    image_data = nifti_image.get_fdata()
    
    # Pad the image to the desired shape
    padded_image = pad_image(image_data, desired_shape)
    
    # Append to the list
    nifti_images_list.append(padded_image)

# Stack all numpy arrays into a single numpy array
x_train = np.stack(nifti_images_list, axis=0)

###############

# Convert the list of images to a NumPy array
x_train = torch.tensor(x_train, dtype=torch.float32)
x_train = torch.unsqueeze(x_train, 1)
print("X_Train",np.shape(x_train))
num_classes = 1

grid_size = 17  # Grid size of 17x17x17
targets = torch.zeros((104, 8, grid_size, grid_size, grid_size))  # Correctly sized target tensor

# Sample data (example, use actual data for real case)

data = converted_bboxes_with_comments
# Calculate grid size
cell_size = 204 / grid_size

for sample_idx in range(targets.shape[0]):
    for key, boxes in data.items():
        for box in boxes:
            x, y, z, w, h, d = box
            
            # Convert center coordinates to grid cell index
            grid_x = int(x // cell_size)
            grid_y = int(y // cell_size)
            grid_z = int(z // cell_size)
            
            # Check if indices are within bounds
            if grid_x >= grid_size or grid_y >= grid_size or grid_z >= grid_size:
                continue
            
            # Normalize coordinates relative to the grid cell
            x_offset = (x % cell_size) / cell_size
            y_offset = (y % cell_size) / cell_size
            z_offset = (z % cell_size) / cell_size
            
            # Normalize width, height, and depth relative to the input size
            w_norm = w / 204
            h_norm = h / 204
            d_norm = d / 204
            
            # Fill in the targets
            try:
                targets[sample_idx, 0:3, grid_x, grid_y, grid_z] = torch.tensor([x_offset, y_offset, z_offset])
                targets[sample_idx, 3:6, grid_x, grid_y, grid_z] = torch.tensor([w_norm, h_norm, d_norm])
                targets[sample_idx, 6, grid_x, grid_y, grid_z] = 1  # Object confidence
                targets[sample_idx, 7, grid_x, grid_y, grid_z] = 0  # Class target, assuming a single class
            except IndexError as e:
                print(f"IndexError for box with center at ({x}, {y}, {z}) in sample {sample_idx}: {e}")
                continue
num_epochs = 10
batch_size = 1
num_samples = x_train.shape[0]
num_batches = int(np.ceil(num_samples / batch_size))
        
print(targets.shape,"TARGETS")
def loss_fn(predictions, targets):
    return torch.nn.functional.mse_loss(predictions, targets)
# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, num_samples)
        input_data = x_train[start_idx:end_idx]
        target_data = targets[start_idx:end_idx]
        input_data = torch.tensor(input_data, dtype=torch.float32)
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        detections, predictions = model(input_data)

        # Apply sigmoid to the predictions
        predictions_sigmoid = torch.sigmoid(predictions)

        # Calculate loss using the custom loss function
        loss = loss_fn(predictions_sigmoid, target_data)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print loss for this batch
        print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{num_batches}], Loss: {loss.item()}")

print("Training completed.")

X_Train torch.Size([104, 1, 204, 204, 204])
torch.Size([104, 8, 17, 17, 17]) TARGETS


  input_data = torch.tensor(input_data, dtype=torch.float32)


Epoch [1/10], Batch [1/104], Loss: 0.21572944521903992
Epoch [1/10], Batch [2/104], Loss: 0.1879815012216568
Epoch [1/10], Batch [3/104], Loss: 0.16486185789108276
Epoch [1/10], Batch [4/104], Loss: 0.1558379828929901
Epoch [1/10], Batch [5/104], Loss: 0.15065714716911316
Epoch [1/10], Batch [6/104], Loss: 0.14667336642742157
Epoch [1/10], Batch [7/104], Loss: 0.14377611875534058
Epoch [1/10], Batch [8/104], Loss: 0.14211443066596985
Epoch [1/10], Batch [9/104], Loss: 0.1415337473154068
Epoch [1/10], Batch [10/104], Loss: 0.14120949804782867
Epoch [1/10], Batch [11/104], Loss: 0.1408495306968689
Epoch [1/10], Batch [12/104], Loss: 0.14059139788150787
Epoch [1/10], Batch [13/104], Loss: 0.14039114117622375
Epoch [1/10], Batch [14/104], Loss: 0.1402612179517746
Epoch [1/10], Batch [15/104], Loss: 0.14007072150707245


KeyboardInterrupt: 