In [1]:
# Model performance before training 

In [2]:
# %pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# %pip install ensemble_boxes
# %pip install albumentations
# %pip install effdet
# %pip install natsort
#%pip install nibabel

### Import

In [3]:
import torch
import gc
import cv2
import numpy as np
import nibabel as nib
import random
import albumentations as A
from albumentations.pytorch import ToTensorV2
from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain, DetBenchPredict
from torch.utils.data import Dataset,DataLoader
from effdet.efficientdet import HeadNet
from albumentations import Compose, Resize, Normalize


  from .autonotebook import tqdm as notebook_tqdm


# Checking for GPU

In [4]:
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Number of CUDA devices:", torch.cuda.device_count())
    print("CUDA device name:", torch.cuda.get_device_name(0))

PyTorch version: 2.0.1+cu118
CUDA available: True
Number of CUDA devices: 1
CUDA device name: NVIDIA GeForce RTX 3060 Laptop GPU


In [5]:
## Print all the cases available in the dataset

In [6]:
import os

testing_label_relative = 'VALDO_Dataset\Task2'
current_directory = os.getcwd()

two_directories_up = os.path.abspath(os.path.join(current_directory, "../"))

# Combine the current directory with the relative path
testing_label_absolute = os.path.join(two_directories_up, testing_label_relative)

folders = [item for item in os.listdir(testing_label_absolute) if os.path.isdir(os.path.join(testing_label_absolute, item))]

cases = {"cohort1": [], "cohort2": [], "cohort3": []}
# Print the list of folders
for folder in folders:
    if "sub-1" in folder:
        cases["cohort1"].append(folder)
    elif "sub-2" in folder:
        cases["cohort2"].append(folder)
    else:
        cases["cohort3"].append(folder)

cases

{'cohort1': ['sub-101',
  'sub-102',
  'sub-103',
  'sub-104',
  'sub-105',
  'sub-106',
  'sub-107',
  'sub-108',
  'sub-109',
  'sub-110',
  'sub-111'],
 'cohort2': ['sub-201',
  'sub-202',
  'sub-203',
  'sub-204',
  'sub-205',
  'sub-206',
  'sub-207',
  'sub-208',
  'sub-209',
  'sub-210',
  'sub-211',
  'sub-212',
  'sub-213',
  'sub-214',
  'sub-215',
  'sub-216',
  'sub-217',
  'sub-218',
  'sub-219',
  'sub-220',
  'sub-221',
  'sub-222',
  'sub-223',
  'sub-224',
  'sub-225',
  'sub-226',
  'sub-227',
  'sub-228',
  'sub-229',
  'sub-230',
  'sub-231',
  'sub-232',
  'sub-233',
  'sub-234'],
 'cohort3': ['sub-301',
  'sub-302',
  'sub-303',
  'sub-304',
  'sub-305',
  'sub-306',
  'sub-307',
  'sub-308',
  'sub-309',
  'sub-310',
  'sub-311',
  'sub-312',
  'sub-313',
  'sub-314',
  'sub-315',
  'sub-316',
  'sub-317',
  'sub-318',
  'sub-319',
  'sub-320',
  'sub-321',
  'sub-322',
  'sub-323',
  'sub-324',
  'sub-325',
  'sub-326',
  'sub-327']}

# Divide the available cases 

In [7]:
cohort1_labels = []
cohort1_ids = []
for case in cases["cohort1"]:
    label = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort1_labels.append(label)
    cohort1_ids.append(id)
# print("Label:", cohort1_labels, cohort1_labels.__len__())
# print("Ids:", cohort1_ids, cohort1_ids.__len__())

cohort2_labels = []
cohort2_ids = []
for case in cases["cohort2"]:
    label = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort2_labels.append(label)
    cohort2_ids.append(id)
# print("Label:", cohort2_labels, cohort2_labels.__len__())
# print("Ids:", cohort2_ids, cohort2_ids.__len__())

cohort3_labels = []
cohort3_ids = []
for case in cases["cohort3"]:
    label = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort3_labels.append(label)
    cohort3_ids.append(id)
# print("Label:", cohort3_labels, cohort3_labels.__len__())
# print("Ids:", cohort3_ids, cohort3_ids.__len__())

all_labels = cohort1_labels + cohort2_labels + cohort3_labels
all_ids = cohort1_ids + cohort2_ids + cohort3_ids



print(all_labels[0])
print(all_ids[0])

c:\Users\nigel\Documents\VALDO_Dataset\Task2\sub-101\sub-101_space-T2S_CMB.nii.gz
c:\Users\nigel\Documents\VALDO_Dataset\Task2\sub-101\sub-101_space-T2S_desc-masked_T2S.nii.gz


In [8]:
# Collate for each batch

In [9]:
def collate_fn(batch):
    slices = []
    targets = []
    img_paths = []
    
    for item in batch:
        item_slices, item_targets, item_img_path = item
        slices.extend(item_slices)
        targets.extend(item_targets)
        img_paths.append(item_img_path)

    slices = [torch.stack(tuple(slice_set)) for slice_set in slices]
    
    return slices, targets, img_paths
    
def euclid_dist(t1, t2):
    return np.sqrt(((t1-t2)**2).sum(axis = 1))

In [10]:
# Custom Dataset for VALDO

In [11]:
class VALDODataset(Dataset):
    def __init__(self, img_paths, ann_paths, transform=None):
        self.img_paths = img_paths
        self.ann_paths = ann_paths
        self.transform = transform

        assert len(self.img_paths) == len(self.ann_paths), "Mismatch between number of images and annotations"

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        try:
            img_path = self.img_paths[idx]
            ann_path = self.ann_paths[idx]

            # Load 3D image
            img = nib.load(img_path).get_fdata()
            img = (img / np.max(img) * 255).astype(np.uint8)
            
            # Load 3D annotation
            ann = nib.load(ann_path).get_fdata()
            ann = (ann > 0).astype(np.uint8)  # Ensure mask is binary

            slices = []
            targets = []

            for i in range(img.shape[2]):
                img_slice = img[:, :, i]
                ann_slice = ann[:, :, i]

                img_slice = cv2.merge([img_slice] * 3)  # Convert single-channel to three-channel
                boxes = self.extract_bounding_boxes(ann_slice)

                if boxes:
                    augmented = self.transform(image=img_slice, bboxes=boxes, labels=[1]*len(boxes))
                    img_slice = augmented['image']
                    boxes = augmented['bboxes']
                    labels = augmented['labels']
                else:
                    augmented = self.transform(image=img_slice, bboxes=[], labels=[])
                    img_slice = augmented['image']
                    boxes = augmented['bboxes']
                    labels = augmented['labels']

                target = {
                    'boxes': torch.tensor(boxes, dtype=torch.float32),
                    'labels': torch.tensor(labels, dtype=torch.int64)
                }

                slices.append(img_slice)
                targets.append(target)

            return slices, targets, img_path
        
        except Exception as e:
            print(f"Error processing index {idx}: {e}")
            raise

    def extract_bounding_boxes(self, mask):
        # Extract bounding boxes from mask
        boxes = []
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for cnt in contours:
            x, y, w, h = cv2.boundingRect(cnt)
            # boxes.append([x, y, x + w, y + h])
            boxes.append([x, y, x + 20, y + 20])
        return boxes

# Transform

In [12]:
transform = Compose(
        [
            A.Resize(height=256, width=256, p=1.0),
            ToTensorV2(p=1.0),
        ], 
        p=1.0, 
        bbox_params=A.BboxParams(
            format='pascal_voc',
            min_area=0, 
            min_visibility=0,
            label_fields=['labels']
        )
)

dataset = VALDODataset(img_paths=all_ids, ann_paths=all_labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

# Setting the Seed

In [13]:
SEED = 42 #any constant

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

# Change into the device name of your GPU

In [14]:
device_num = torch.cuda.get_device_name(0)

# Clear Exisisting GPU Memory

In [15]:
torch.cuda.empty_cache()

In [16]:

# Loading the model
test_device = torch.device('cuda')
print(test_device)

config = get_efficientdet_config('tf_efficientdet_d7')
config.update({'num_classes': 1})
config.update({'image_size': (256, 256)})  # Adjust image size if needed
config.update({"url": "https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d7-f05bf714.pth"})
print(config)

net = EfficientDet(config, pretrained_backbone=True)
net.class_net = HeadNet(config, num_outputs=config.num_classes)

gc.collect()

EffDet = DetBenchPredict(net)
net.eval()
device = torch.device(test_device)
EffDet = EffDet.to(device)

pre_test_dataset = VALDODataset(
    img_paths=all_ids, ann_paths=all_labels, transform=transform
)

pre_test_dataloader_axial = DataLoader(
    pre_test_dataset,
    batch_size=1,  # Reduce batch size if needed
    drop_last=False,
    num_workers=0,
    collate_fn=collate_fn,
    shuffle=False
)

prediction_list = []
for j, (images_axial, targets_axial, image_ids_axial) in enumerate(pre_test_dataloader_axial):
    images_axial = torch.stack(images_axial).to(test_device).float()
    preds = []
    with torch.no_grad():
        det = EffDet(images_axial)
        for i in range(images_axial.shape[0]):
            boxes = det[i].detach().cpu().numpy()[:, :4]
            scores = det[i].detach().cpu().numpy()[:, 4]
            indexes = np.where(scores > 0.1)[0]
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
            preds.append({
                'boxes': boxes,
                'scores': scores,
            })
        prediction_list.append({"predictions": preds, "id": image_ids_axial})
        print(f'Batch {j} prediction done')

    # Clear cache after each batch
    torch.cuda.empty_cache()

print(torch.cuda.memory_summary(device=None, abbreviated=False))

print(prediction_list)

cuda
{'name': 'tf_efficientdet_d7', 'backbone_name': 'tf_efficientnet_b6', 'backbone_args': {'drop_path_rate': 0.2}, 'backbone_indices': None, 'image_size': [256, 256], 'num_classes': 1, 'min_level': 3, 'max_level': 7, 'num_levels': 5, 'num_scales': 3, 'aspect_ratios': [[1.0, 1.0], [1.4, 0.7], [0.7, 1.4]], 'anchor_scale': 5.0, 'pad_type': 'same', 'act_type': 'swish', 'norm_layer': None, 'norm_kwargs': {'eps': 0.001, 'momentum': 0.01}, 'box_class_repeats': 5, 'fpn_cell_repeats': 8, 'fpn_channels': 384, 'separable_conv': True, 'apply_resample_bn': True, 'conv_bn_relu_pattern': False, 'downsample_type': 'max', 'upsample_type': 'nearest', 'redundant_bias': True, 'head_bn_level_first': False, 'head_act_type': None, 'fpn_name': 'bifpn_sum', 'fpn_config': None, 'fpn_drop_path_rate': 0.0, 'alpha': 0.25, 'gamma': 1.5, 'label_smoothing': 0.0, 'legacy_focal': False, 'jit_loss': False, 'delta': 0.1, 'box_loss_weight': 50.0, 'soft_nms': False, 'max_detection_points': 5000, 'max_det_per_image': 100,

Unexpected keys (bn2.bias, bn2.num_batches_tracked, bn2.running_mean, bn2.running_var, bn2.weight, classifier.bias, classifier.weight, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


Batch 0 prediction done
Batch 1 prediction done
Batch 2 prediction done
Batch 3 prediction done
Batch 4 prediction done
Batch 5 prediction done
Batch 6 prediction done
Batch 7 prediction done
Batch 8 prediction done
Batch 9 prediction done
Batch 10 prediction done
Batch 11 prediction done
Batch 12 prediction done
Batch 13 prediction done
Batch 14 prediction done
Batch 15 prediction done
Batch 16 prediction done
Batch 17 prediction done
Batch 18 prediction done
Batch 19 prediction done
Batch 20 prediction done
Batch 21 prediction done
Batch 22 prediction done
Batch 23 prediction done
Batch 24 prediction done
Batch 25 prediction done
Batch 26 prediction done
Batch 27 prediction done
Batch 28 prediction done
Batch 29 prediction done
Batch 30 prediction done
Batch 31 prediction done
Batch 32 prediction done
Batch 33 prediction done
Batch 34 prediction done
Batch 35 prediction done
Batch 36 prediction done
Batch 37 prediction done
Batch 38 prediction done
Batch 39 prediction done
Batch 40 p

# Put all the predicted boxes in a list

In [None]:
predicted_boxes = []
for i in range(len(prediction_list[0]['predictions'])):
    # print(i, prediction_list[0]['predictions'][i])
    predicted_boxes.append(prediction_list[0]['predictions'][i]['boxes'])
predicted_boxes

# Plot each slice with the predicted boxes as green and true bloxes as blue 

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy

# Assuming you have defined `targets` and `slices` elsewhere in your code
slices, targets, id = dataset[0]
# Calculate the number of subplots needed based on the length of your data
num_slices = len(slices)
num_cols = 5
num_rows = (num_slices + num_cols - 1) // num_cols  # Round up to the nearest integer

# Create the subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 3))
print(targets[8]['boxes'])
# Iterate over slices and targets
for idx, (slice_base, target) in enumerate(zip(slices, targets)):
    row = idx // num_cols
    col = idx % num_cols
    ax = axes[row, col]

    # Generate heatmap
    heatmap_data = torch.mean(slice_base.float(), dim=0)
    heatmap_data_np = heatmap_data.numpy()
    sns.heatmap(heatmap_data_np, ax=ax)

    # Generate bounding box
    print(idx)
    boxes = predicted_boxes[idx]
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        ax.add_patch(plt.Rectangle((x_min-8.5, y_min-8.5), x_max - x_min, y_max - y_min, 
                                    linewidth=2, edgecolor='g', facecolor='none'))
    
    boxes = target['boxes']
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        ax.add_patch(plt.Rectangle((x_min-8.5, y_min-8.5), x_max - x_min, y_max - y_min, 
                                    linewidth=2, edgecolor='b', facecolor='none'))
    

plt.tight_layout()
plt.show()