In [None]:
# %pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# %pip install ensemble_boxes
# %pip install albumentations
# %pip install effdet
# %pip install natsort

## Imports

In [3]:
import sys
from ensemble_boxes import *
import torch
import numpy as np
import pandas as pd
from glob import glob
from torch.utils.data import Dataset,DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import gc
from matplotlib import pyplot as plt
import torch.nn as nn
import os
from datetime import datetime
import time
import random
import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
import natsort as ns
import re
from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain
from effdet.efficientdet import HeadNet

## Checking for GPU

In [3]:
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Number of CUDA devices:", torch.cuda.device_count())
    print("CUDA device name:", torch.cuda.get_device_name(0))

PyTorch version: 2.3.0+cu118
CUDA available: True
Number of CUDA devices: 1
CUDA device name: NVIDIA GeForce GTX 1650


## Setting the seed

In [None]:
SEED = 42 #any constant

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

## Change into the device name of your GPU

In [None]:
device_num = 'NVIDIA GeForce GTX 1650'

## Loading the pretrained EfficientDet

In [None]:
def load_net(checkpoint_path):
    config = get_efficientdet_config('tf_efficientdet_d3')
    net = EfficientDet(config, pretrained_backbone=False)
    
    config.num_classes = 1
    config.image_size=512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))

    checkpoint = torch.load(checkpoint_path, map_location=device_num)
    net.load_state_dict(checkpoint['model_state_dict'])

    del checkpoint
    gc.collect()

    net = DetBenchEval(net, config)
    net.eval()
    device = torch.device(device_num)
    return net.to(device)

## Loading the Axial marking from excel file

In [None]:
def get_axial_marking(label_path):
    lists_dir = glob(label_path+'*') #label file directorie list
    lists_dir.sort()

    lists_name = [f for f in os.listdir(label_path) if not f.startswith('.')]   #label file list. Neglect hidden files
    lists_name.sort()
    lists_name

    marking = pd.DataFrame(columns=['image_id', 'x', 'y', 'w', 'h'])

    for i in range(len(lists_dir)):
        xlsx = pd.read_excel(lists_dir[i], header = None)    
        temp = pd.DataFrame(columns=['slice', 'x', 'y', 'class'])
        temp2 = pd.DataFrame(columns=['image_id', 'x', 'y', 'w', 'h'])
        for k in range(xlsx.shape[0]):
            temp.loc[k] = list(xlsx.loc[k])
        temp = temp.drop_duplicates(['x','y'], keep = 'first')       #drop out repeated 'x','y' values(= drop out same cmb) -
        temp = temp.sort_values(by = 'slice',ignore_index=True)
        for k in range(temp.shape[0]):
            temp2.loc[k, 'image_id'] = lists_name[i].replace('.xlsx','')+ '_'+ str(temp.loc[k,'slice'])
            temp2.loc[k, 'x'] = temp.loc[k,'x']-44    #Convert coordinates 512X448 -> 360X360
            temp2.loc[k, 'y'] = temp.loc[k,'y']-76
            temp2.loc[k, 'w'] = 20
            temp2.loc[k, 'h'] = 20
        marking = pd.concat([marking, temp2], ignore_index=True)
    return marking

## Make the table for 'whole' test set images

In [None]:
def make_whole_marking_axial(label_path,IMAGE_ROOT_PATH, marking_test):
    lists_name = [f for f in os.listdir(label_path) if not f.startswith('.')]   #label file list. Neglect hidden files
    lists_name.sort()
    marking_test_all = pd.DataFrame(columns=['image_id', 'x', 'y', 'w', 'h'])

    for i in range(len(lists_name)):

        patient_name = lists_name[i].replace('.xlsx','')
        im_list = [path.split('/')[-1][:-4] for path in glob(f'{IMAGE_ROOT_PATH}/{patient_name}_*.png')]
        im_list = ns.natsorted(im_list)

        temp2 = pd.DataFrame(columns=['image_id', 'x', 'y', 'w', 'h'])
        temp2['image_id'] = im_list
        temp2['x'] = 1
        temp2['y'] = 1
        temp2['w'] = 1
        temp2['h'] = 1
        marking_test_all = pd.concat([marking_test_all, temp2], ignore_index=True)

    for i in range(len(marking_test)):     # fill the CMBs labels
        index_num = marking_test_all.index[marking_test_all['image_id']==marking_test.loc[i,'image_id']].tolist()
        if marking_test_all.loc[index_num[0],'x'] == 1:     #if it is first CMB on certain slice
            marking_test_all.loc[index_num[0]] = marking_test.loc[i]
        else:   #not first CMB on certain slice
            temp1 = marking_test_all[marking_test_all.index < index_num[0]]
            temp2 = marking_test_all[marking_test_all.index >= index_num[0]]
            marking_test_all = temp1.append(marking_test.loc[i],ignore_index=True).append(temp2, ignore_index=True)
    return marking_test_all

## Resizes the image and specifies the parameters for the bounding boxes

In [None]:
def get_valid_transforms_axial():
    return A.Compose(
        [
            A.Resize(height=512, width=512, p=1.0),
            ToTensorV2(p=1.0),
        ], 
        p=1.0, 
        bbox_params=A.BboxParams(
            format='pascal_voc',
            min_area=0, 
            min_visibility=0,
            label_fields=['labels']
        )
    )

In [None]:
class DatasetRetriever_cmbs:

    def __init__(self, marking, image_ids, image_root_path, transforms=None, test=False):
        super().__init__()

        self.image_ids = image_ids
        self.marking = marking
        self.transforms = transforms
        self.test = test
        self.image_root_path  = image_root_path

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
    
        image, boxes = self.load_image_and_boxes(index)
        
        # there is only one class
        labels = torch.ones((boxes.shape[0],), dtype=torch.int64)

        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = torch.tensor([index])


        if self.transforms:
            for i in range(10):
                sample = self.transforms(**{
                    'image': image,
                    'bboxes': target['boxes'],
                    'labels': labels
                })
                if len(sample['bboxes']) > 0:       
                    image = sample['image']
                    target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
                    target['boxes'][:,[0,1,2,3]] = target['boxes'][:,[1,0,3,2]]  #yxyx: be warning
                    break
        else:
            image = torch.tensor(image)
            target['boxes'] = torch.tensor(boxes)

        return image, target, image_id

    def __len__(self) -> int:
        return self.image_ids.shape[0]

    def load_image_and_boxes(self, index):
        image_id = self.image_ids[index]
        image = cv2.imread(f'{self.image_root_path}/{image_id}.png', cv2.IMREAD_UNCHANGED)    #get 16bit images
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)    #Convert BGR -> RGB
        image/=65535.0
        records = self.marking[self.marking['image_id'] == image_id]
        boxes = records[['x', 'y', 'w', 'h']].values 
        boxes[:, 0] = boxes[:, 0] - boxes[:, 2]/2         #transforms to left top corner&right bottom corner
        boxes[:, 1] = boxes[:, 1] - boxes[:, 3]/2
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        return image, boxes

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))
    
def euclid_dist(t1, t2):
    return np.sqrt(((t1-t2)**2).sum(axis = 1))


## Replaces occurrences of a substring from the right side of the original string with a new substring, up to a specified count.

In [None]:
def replaceRight(original, old, new, count_right):
    repeat=0
    text = original
    old_len = len(old)
    
    count_find = original.count(old)
    if count_right > count_find: 
        repeat = count_find
    else :
        repeat = count_right

    while(repeat):
      find_index = text.rfind(old)
      text = text[:find_index] + new + text[find_index+old_len:]

      repeat -= 1
      
    return text

# Training

## File paths for training and validation labels

In [None]:
train_label_path = '/data/labels/train/'
val_label_path = '/data/labels/validation/'

## Getting the markings of the labels

In [None]:
marking_train = get_axial_marking(train_label_path)
marking_val = get_axial_marking(val_label_path)

## Setting up the dataset that will be used for the training

In [None]:
train_dataset_aug = DatasetRetriever_cmbs(
    image_ids=np.array(marking_train['image_id']),  #array with image_ids
    marking=marking_train, 
    transforms=get_train_transforms(),
    test=False,
)

validation_dataset = DatasetRetriever_cmbs(
    image_ids=np.array(marking_val['image_id']),
    marking=marking_val,
    transforms=get_valid_transforms(),
    test=True,
)

## Model Configuration

In [None]:
class TrainGlobalConfig:
    num_workers = 20
    batch_size = 1
    n_epochs = 10
    lr = 0.0001

    folder = 'Model_Save(Axial)_D7'

    # -------------------
    verbose = True
    verbose_step = 1
    # -------------------

    # --------------------
    step_scheduler = False # do scheduler.step after optimizer.step
    epoch_scheduler = False
    validation_scheduler = True # do scheduler.step after validation stage loss -> For scheduler 'ReduceLROnPlateau'
    
#     SchedulerClass = torch.optim.lr_scheduler.OneCycleLR
#     scheduler_params = dict(
#         max_lr=0.001,
#         epochs=n_epochs,
#         steps_per_epoch=2*int(len(train_dataset_aug) / batch_size),
#         pct_start=0.31,
#         anneal_strategy='cos', 
#         final_div_factor=10**4
#     )

#     SchedulerClass = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
#     scheduler_params = dict(
#         T_0=5,        # Number of iterations for the first restart.
#         T_mult=2,    
#         eta_min=0.00004,
#         last_epoch=-1, 
#         verbose=False
#     )

#     SchedulerClass = torch.optim.lr_scheduler.ExponentialLR
#     scheduler_params = dict(
#         gamma = 0.7
#     )

    SchedulerClass = torch.optim.lr_scheduler.ReduceLROnPlateau
    scheduler_params = dict(
        mode='min',
        factor=0.1,
        patience=1,
        verbose=False, 
        threshold=0.0001,
        threshold_mode='abs',
        cooldown=0, 
        min_lr=0,
        eps=1e-08
    )

## Training process of the model

In [None]:
def run_training():

    net = get_net()
    device = torch.device(device_num)
    net.to(device)

    train_loader = torch.utils.data.DataLoader(
        train_dataset_aug,
        batch_size=TrainGlobalConfig.batch_size,     
        sampler=RandomSampler(train_dataset_aug),
        pin_memory=False,
        drop_last=False,   #drop last one for having same batch size
        num_workers=TrainGlobalConfig.num_workers,
        collate_fn=collate_fn,
    )
    val_loader = torch.utils.data.DataLoader(
        validation_dataset, 
        batch_size=TrainGlobalConfig.batch_size,
        num_workers=TrainGlobalConfig.num_workers,
        shuffle=False,
        sampler=SequentialSampler(validation_dataset),
        pin_memory=False,
        collate_fn=collate_fn,
    )

    fitter = Fitter(model=net, device=device, config=TrainGlobalConfig)
    best_val_loss, summary_loss_over_itr_train, summary_loss_over_itr_val = fitter.fit(train_loader, val_loader)
    
    return best_val_loss, summary_loss_over_itr_train, summary_loss_over_itr_val

## This will return the efficientdet model

### To get the weights:

1. Go to: https://github.com/rwightman/efficientdet-pytorch/releases
2. Look for Weights in the bottom
3. Find "efficientdet_d7-f05bf714.pth"
4. Download

In [5]:
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d7')
    net = EfficientDet(config, pretrained_backbone=False)
    checkpoint = torch.load('iefficientdet_d7-f05bf714.pth')
    net.load_state_dict(checkpoint)

    config.num_classes = 1
    config.image_size = 512  #D0

    net.class_net = HeadNet(config, num_outputs=config.num_classes) #Use default batchnorm
    
    return DetBenchTrain(net, config)

# best_val_loss, summary_loss_over_itr_train, summary_loss_over_itr_val = run_training()

# Testing

## File Paths

In [None]:
test_label_path = 'data/labels/test/'
IMAGE_ROOT_PATH_AXIAL = 'data/images/axial/'

In [None]:
def make_marking_cd_gt_axial(marking_test_axial):
    marking_cd_gt_axial = pd.DataFrame(columns=['patient_id', 's', 'x', 'y'])
    for i in range(len(marking_test_axial)):
        image_id = marking_test_axial.loc[i]['image_id']
        numbers_axial = re.findall("\d+", image_id)
        slice_num_axial = int(numbers_axial[-1])
        patient_id = replaceRight(image_id, '_'+str(slice_num_axial), '', 1)

        x = 512*marking_test_axial.loc[i]['x']/360
        y = 512*marking_test_axial.loc[i]['y']/360    

        temp = pd.DataFrame(columns=['patient_id', 's', 'x', 'y'])
        temp.loc[0] = [patient_id, slice_num_axial, x, y]
        marking_cd_gt_axial = pd.concat([marking_cd_gt_axial, temp], ignore_index=True)

    return marking_cd_gt_axial

num_cmbs=len(marking_test_all_axial[marking_test_all_axial['y']!=1])