In [None]:
!pip3 install pydicom -qU
!pip3 install wandb -qU
!pip3 install effdet -qU

import wandb
import os
import numpy as np
import torch
import random

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)  # set PYTHONHASHSEED env var at fixed value
    random.seed(seed)  #set fixed value for python built-in pseudo-random generator
    np.random.seed(seed) # for numpy pseudo-random generator
    torch.manual_seed(seed) # pytorch (both CPU and CUDA)

set_seed(2020)

In [None]:
wandb.login(key = "8a625877a46f5b9236fa4719743bf8c17928ead7")

[34m[1mwandb[0m: Currently logged in as: [33mioana-baciu4[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
sweep_config: dict = {
    "project": "licenta",
    "metric":
        {"name": "loss","goal": "minimize"}
    ,
    "method": "grid", # grid/random
     "parameters":
    #     {
    #     "learning_rate": {
    #         "values": [1e-4, 1e-5, 1e-6]
    #         },
    #     "number_of_epochs": {
    #         "values": [8,9,10]
    #         },
    #     },
    None
}
parameters: dict = {
    "learning_rate": {
        "values": [1e-3, 1e-4, 1e-5]
    },
    "number_of_epochs": {
        "values": [50,75,100]
    }
}
sweep_config["parameters"] = parameters
sweep_id = wandb.sweep(sweep_config)

Create sweep with ID: smih9zqu
Sweep URL: https://wandb.ai/ioana-baciu4/licenta/sweeps/smih9zqu


In [None]:
!cp drive/MyDrive/Licenta/duke_dbt_data.py .
!cp drive/MyDrive/Licenta/sort_split_data.py .
!mkdir drive/MyDrive/Licenta/detection_checkpoints

mkdir: cannot create directory ‘drive/MyDrive/Licenta/detection_checkpoints’: File exists


In [None]:
import os
from collections import Counter
from random import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.utils.data
from torchvision import transforms
from torchvision.transforms import v2
from duke_dbt_data import dcmread_image, draw_box
from sort_split_data import find_image_path
from IPython.display import Image, display
from skimage.io import imread
import torchvision.transforms.functional as fn
import albumentations as A


def get_train_transforms():
    return A.Compose([
        # A.OneOf([
        #     A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2,
        #                          val_shift_limit=0.2, p=0.9),
        #     A.RandomBrightnessContrast(brightness_limit=0.2,
        #                                contrast_limit=0.2, p=0.9),
        # ], p=0.9),
        # A.ToGray(p=0.01),
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.5),
    ],
        p=1.0,
        bbox_params=A.BboxParams(
            format='pascal_voc',
            min_area=0,
            min_visibility=0,
            label_fields=['labels'],
        )
    )


def get_nth_file_name(directory, n):
    return np.sort(np.array(os.listdir(directory)))[n]


def extract_indices(batch, index, size, class_indices):
    sampled_indices = np.random.choice(class_indices[index], size=size, replace=False)
    batch.extend(sampled_indices.tolist())
    indices = []
    for sample in sampled_indices:
        new_index = np.where(class_indices[index] == sample)
        indices.append(int(new_index[0][0]))
    indices = np.array(indices)
    class_indices[index] = np.delete(class_indices[index], indices)


def add_class(batch, index, amount, class_indices):
    extract_indices(batch, index, amount, class_indices)


class DetectionDataset(torch.utils.data.Dataset):

    def __init__(self, base_folder, split_name, number_of_slices, batch_size,
                 transfm=get_train_transforms(), threshold=None):
        super().__init__()
        self.transfm = transfm
        self.base_folder = base_folder
        self.batch_size = batch_size
        self.number_of_slices = number_of_slices
        self.threshold = threshold
        self.images_file = os.path.join(self.base_folder, "images/")
        self.label_file = os.path.join(self.base_folder, f"labels/BCS-DBT labels-{split_name}.csv")
        self.data_paths = []
        self.breastCancerData = pd.read_csv(os.path.join(self.base_folder, f"BCS-DBT file-paths-{split_name}.csv"))
        self.breastCancerBoxes = pd.read_csv(os.path.join(self.base_folder,
                                                          f"bounding_boxes/BCS-DBT boxes-{split_name}.csv"))
        self.breastCancerLabel = pd.read_csv(self.label_file)
        if threshold is None:
            self.threshold = self.breastCancerBoxes.shape[0]
        self.targets = []
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        self.load_data()

    def calculate_class_indices(self):
        class_indices = {}
        for class_label in range(2):
            class_indices[class_label] = np.where(np.array(self.targets) == class_label+1)[0]
        return class_indices

    def load_image(self, path, file_name):
        slices = np.zeros(shape=[3, 512, 512], dtype=np.float32)
        slices[0] = imread(fname=os.path.join(path, file_name), as_gray=True)
        slices[0] = slices[0] / 255.0  # Scale pixel values to [0, 1]
        slices = torch.tensor(slices)
        slices[1] = slices[0]
        slices[2] = slices[0]
        slices = fn.normalize(slices, mean=self.mean, std=self.std)
        return slices

    def load_data(self):
        minimum = 104
        for idx in range(0, self.threshold):
            view_series = self.breastCancerBoxes.iloc[idx]
            label = view_series["Class"]
            if label == 'benign':
                label = 1
            else:
                label = 2
            index_fisier = self.breastCancerData.index[
                self.breastCancerData["PatientID"] == view_series["PatientID"]].tolist()[0]
            image_path = find_image_path(self.breastCancerData.iloc[index_fisier])
            shape = get_nth_file_name(image_path, 0)
            slices = int(shape[1:shape.index(",")])
            if slices < minimum:
                minimum = slices
            for i in range(self.number_of_slices):
                self.targets.append(label)
                if image_path[-1] == "/":
                    self.data_paths.append(image_path + f"{i}")
                else:
                    self.data_paths.append(image_path + f"/{i}")

        if self.batch_size == 1:
            return

        class_indices = self.calculate_class_indices()
        batches = []
        malign = len(class_indices[1]) - len(class_indices[0])
        if malign > 0:
            sampled_indices = np.random.choice(class_indices[1], size=malign, replace=False)
            indices = []
            for sample in sampled_indices:
                index = np.where(class_indices[1] == sample)
                indices.append(int(index[0][0]))
            class_indices[1] = np.delete(class_indices[1], indices)
            self.data_paths = np.delete(self.data_paths, sampled_indices)
            self.targets = np.delete(self.targets, sampled_indices)
            class_indices = self.calculate_class_indices()
        benign = len(class_indices[0]) - len(class_indices[1])
        if benign >  0:
            sampled_indices = np.random.choice(class_indices[0], size=benign, replace=False)
            indices = []
            for sample in sampled_indices:
                index = np.where(class_indices[0] == sample)
                indices.append(int(index[0][0]))
            class_indices[0] = np.delete(class_indices[0], indices)
            self.data_paths = np.delete(self.data_paths, sampled_indices)
            self.targets = np.delete(self.targets, sampled_indices)
            class_indices = self.calculate_class_indices()
        batches_number = len(self) // self.batch_size
        if batches_number * self.batch_size != len(self):
            batches_number = batches_number + 1
        batch_number = 0
        while batch_number < batches_number:
            batch = []
            add_class(batch, 1, min(self.batch_size // 2, len(class_indices[1])), class_indices)
            add_class(batch, 0, min(self.batch_size // 2, len(class_indices[0])), class_indices)
            np.random.shuffle(batch)
            batches.extend(batch)
            batch_number = batch_number + 1
        self.data_paths = np.array(self.data_paths)
        batches = np.array(batches)
        self.data_paths = self.data_paths[batches]

    def __len__(self):
        return len(self.data_paths)

    def scale_coordinates(self, index):
        image_path = find_image_path(self.breastCancerData.iloc[index])
        shape = get_nth_file_name(image_path, 0)
        original_width = int(shape[shape.index(",") + 1:shape.rindex(",")])
        original_height = int(shape[shape.rindex(",") + 1:shape.rindex(")")])

        new_width = 512
        new_height = 512

        scale_width = new_width / original_width
        scale_height = new_height / original_height

        return scale_width, scale_height

    def load_bounding_box(self, idx, index_fisier):
        scaled_width, scaled_height = self.scale_coordinates(index_fisier)
        x, y, width, height = self.breastCancerBoxes.iloc[idx]["X"], \
            self.breastCancerBoxes.iloc[idx]["Y"], \
            self.breastCancerBoxes.iloc[idx]["Width"], \
            self.breastCancerBoxes.iloc[idx]["Height"]
        boxes = [x * scaled_height, y * scaled_width,
                 (x + width) * scaled_height, (y + height) * scaled_width]
        # boxes = np.array(boxes)
        # boxes = boxes[np.array((1, 0, 3, 2))]
        label = self.breastCancerBoxes.iloc[idx]["Class"]
        if label == 'benign':
            label = 1.
        else:
            label = 2.
        boxes = torch.tensor(boxes, dtype=torch.float32).unsqueeze(dim=0)
        label = torch.tensor(label).unsqueeze(dim=0)

        target = {'bbox': boxes, 'cls': label}

        return target

    def __getitem__(self, index):
        slice_path = self.data_paths[index]
        directory_path = slice_path[:slice_path.rindex("/") + 1]
        path_in_file = slice_path[len(self.base_folder) + 7:]
        path_in_file = path_in_file[:path_in_file.rindex("/") + 1] + "1-1.dcm"
        file_name_1 = path_in_file[:path_in_file.rindex("NA") - 1]
        path_in_file = file_name_1 + path_in_file[path_in_file.rindex("NA") + 2:]
        index_fisier = self.breastCancerData.index[
            self.breastCancerData["descriptive_path"] == path_in_file].tolist()[0]
        index_fisier_boxes = self.breastCancerBoxes.index[
            self.breastCancerBoxes["PatientID"] == self.breastCancerData.iloc[index_fisier]["PatientID"]].tolist()[0]
        index_slice = slice_path[self.data_paths[index].rindex("/") + 1]
        file_name = get_nth_file_name(directory_path, int(index_slice) + 1)
        image = self.load_image(directory_path, file_name)
        bounding_box = self.load_bounding_box(index_fisier_boxes, index_fisier)

        sample = self.transfm(**{
            'image': np.array(image.permute(1, 2, 0)),
            'bboxes': bounding_box['bbox'],
            'labels': bounding_box['cls']
        })
        target = {}
        image = torch.tensor(sample['image']).permute(2, 0, 1)
        target['bbox'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
        target['bbox'][:, [0, 1, 2, 3]] = target['bbox'][:, [1, 0, 3, 2]]
        target['bbox'] = target['bbox'].clone().detach()
        target['bbox'] = target['bbox'].to(torch.float32).cpu()
        target['cls'] = torch.stack(sample['labels']).cpu()  # <--- add this!
        # = get_valid_transforms()(**sample)
        return image.cpu(), target

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [None]:
import os
import numpy as np
import torch
import wandb
from tqdm import tqdm

class Fitter:

    def __init__(self, model, config, sweep_config=None, device="cpu", reload=None):
        self.epoch = 0
        self.config = config

        self.base_dir = f'./{config.folder}'
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)

        # self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10 ** 5

        self.model = model.to(device)
        self.device = device

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        self.sweep_config = sweep_config
        if sweep_config is not None:
            run = wandb.init(config=sweep_config)
            print(run)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.0001)
        self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params)

        if reload is not None:
            all_files = np.sort(np.array(os.listdir(self.base_dir)))
            path = all_files[0]
            self.load(os.path.join(self.base_dir,path))
            for i in range(1, len(all_files)):
                os.remove(os.path.join(self.base_dir, all_files[i]))


    def fit(self, train_loader, validation_loader):
        for e in range(self.config.n_epochs):
            if self.config.verbose:
                lr = self.optimizer.param_groups[0]['lr']

            summary_loss = self.train_one_epoch(train_loader)
            # if self.config.step_scheduler:
            #     self.scheduler.step(metrics=summary_loss.avg)

            if summary_loss.avg<self.best_summary_loss:
                path = f'{self.base_dir}/loss_{summary_loss.avg}_epoch_{self.epoch}.pth'
                if len(str(int(summary_loss.avg))) != len(str(int(self.best_summary_loss))):
                    all_files = os.listdir(self.base_dir)
                    for file in all_files:
                      os.remove(os.path.join(self.base_dir, file))
                self.save(path)
                self.best_summary_loss = summary_loss.avg

            summary_loss = self.validation(validation_loader)
            if self.config.validation_scheduler:
                self.scheduler.step()

            self.epoch += 1
            if e%10 == 0:
                all_files = np.sort(np.array(os.listdir(self.base_dir)))
                path = all_files[0]
                self.load(os.path.join(self.base_dir,path))
                for i in range(1, len(all_files)):
                    os.remove(os.path.join(self.base_dir, all_files[i]))

    def validation(self, val_loader):
        self.model.eval()
        summary_loss = AverageMeter()
        for data, targets in tqdm(val_loader):
            with torch.no_grad():
                images = torch.stack(data).cpu().to(self.device)
                batch_size = images.shape[0]

                transformed_dict = {'bbox': [], 'cls': [],
                                    'img_size': [],
                                    'img_scale': []}
                for element in targets:
                    for key, value in element.items():
                        transformed_dict[key].append(value)
                        transformed_dict['img_size'].append(torch.tensor((512, 512)).unsqueeze(dim=0))
                        transformed_dict['img_scale'].append(torch.tensor(1).unsqueeze(dim=0))

                transformed_dict['bbox'] = torch.cat(transformed_dict['bbox'], dim=0).cpu().to(self.device)
                transformed_dict['cls'] = torch.cat(transformed_dict['cls'], dim=0).cpu().to(self.device)
                transformed_dict['img_size'] = torch.cat(transformed_dict['img_size'], dim=0).cpu().to(self.device)
                transformed_dict['img_scale'] = torch.cat(transformed_dict['img_scale'], dim=0).cpu().to(self.device)

                output = self.model(images, transformed_dict)
                loss = output["loss"]
                if self.sweep_config is not None:
                    wandb.log({"loss_test": loss.item()})

                summary_loss.update(loss.detach().item(), batch_size)

        return summary_loss

    def train_one_epoch(self, train_loader):
        self.model.train()
        summary_loss = AverageMeter()
        for data, targets in tqdm(train_loader):
            images = torch.stack(data).cpu().to(self.device)
            # images = images.to(self.device).float()
            batch_size = images.shape[0]

            self.optimizer.zero_grad()

            transformed_dict = {'bbox': [], 'cls': []}

            for element in targets:
                for key, value in element.items():
                    transformed_dict[key].append(value)

            transformed_dict['bbox'] = torch.cat(transformed_dict['bbox'], dim=0).cpu().to(self.device)
            transformed_dict['cls'] = torch.cat(transformed_dict['cls'], dim=0).cpu().to(self.device)

            loss = self.model(images, transformed_dict)["loss"]

            loss.backward()
            self.optimizer.step()
            if self.config.step_scheduler:
                self.scheduler.step()
            summary_loss.update(loss.detach().item(), batch_size)

            if self.sweep_config is not None:
                wandb.log({"loss": loss.item()})


        return summary_loss

    def save(self, path):
        self.model.eval()
        torch.save({
            'model_state_dict': self.model.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_summary_loss': self.best_summary_loss,
            'epoch': self.epoch,
        }, path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_summary_loss = checkpoint['best_summary_loss']
        self.epoch = checkpoint['epoch'] + 1

bs = 16

train_dataset = DetectionDataset(base_folder="drive/MyDrive/Licenta/", split_name="train", number_of_slices=22,
                                 batch_size=bs)
test_dataset = DetectionDataset(base_folder="drive/MyDrive/Licenta/", split_name="test", number_of_slices=22,
                                batch_size=bs)


In [None]:
import torch


class TrainGlobalConfig:
    num_workers = 0
    batch_size = 16
    n_epochs = 50  # n_epochs = 40
    lr = 0.0002

    folder = 'drive/MyDrive/Licenta/detection_checkpoints'

    # -------------------
    verbose = True
    verbose_step = 1
    # -------------------

    # --------------------
    step_scheduler = True  # do scheduler.step after optimizer.step
    validation_scheduler = False  # do scheduler.step after validation stage loss

    #     SchedulerClass = torch.optim.lr_scheduler.OneCycleLR
    #     scheduler_params = dict(
    #         max_lr=0.001,
    #         epochs=n_epochs,
    #         steps_per_epoch=int(len(train_dataset) / batch_size),
    #         pct_start=0.1,
    #         anneal_strategy='cos',
    #         final_div_factor=10**5
    #     )

    SchedulerClass = torch.optim.lr_scheduler.CosineAnnealingLR
    scheduler_params = dict(
        T_max=int(len(train_dataset)/batch_size*60),
        eta_min=1e-6,
    )


In [None]:
import torch
import torch.utils.data
from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain
from effdet import create_model_from_config
from effdet.efficientdet import HeadNet
from torch import nn
import timm


def collate_fn(batch):
    return tuple(zip(*batch))


def get_net():

    config = get_efficientdet_config('tf_efficientdet_d0')

    config.image_size = (512, 512)
    config.norm_kwargs = dict(eps=.001, momentum=.01)

    net = create_model_from_config(config=config, bench_task='train', num_classes=2, pretrained=True, bench_labeler = True)

    # net = EfficientDet(config, pretrained_backbone=False)
    # checkpoint = torch.load('drive/MyDrive/Licenta/weights/efficientdet-d0_1.pth')
    # net.load_state_dict(checkpoint, strict=False)

    # net.reset_head(num_classes=config.num_classes)
    # net.class_net = HeadNet(config, num_outputs=config.num_classes)

    # return DetBenchTrain(net, config)

    return net


net = get_net()


def run_training(sweep_config=None, reload=None):
    # device = torch.device('cuda:0')
    # net.to(device)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=bs,
        num_workers=0,
        collate_fn=collate_fn,
    )

    val_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=bs,
        num_workers=0,
        collate_fn=collate_fn,
    )

    fitter = Fitter(model=net, config=TrainGlobalConfig, reload=reload, sweep_config=sweep_config)
    fitter.fit(train_loader, val_loader)


In [None]:
run_training(sweep_config)
# for i in range(4):
#     run_training(sweep_config, True)

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,███▇▆▆▆▆▆▅▅▅▅▅▄▄▄▃▃▄▄▄▄▄▃▃▃▃▂▃▂▁▂▁▁▂▁▁▂▁

0,1
loss,0.75532


<wandb.sdk.wandb_run.Run object at 0x791aa5f2d4e0>


100%|██████████| 88/88 [12:21<00:00,  8.42s/it]
100%|██████████| 14/14 [01:02<00:00,  4.45s/it]
100%|██████████| 88/88 [12:32<00:00,  8.55s/it]
100%|██████████| 14/14 [00:32<00:00,  2.30s/it]
100%|██████████| 88/88 [12:27<00:00,  8.49s/it]
100%|██████████| 14/14 [00:30<00:00,  2.19s/it]
100%|██████████| 88/88 [12:21<00:00,  8.43s/it]
100%|██████████| 14/14 [00:33<00:00,  2.40s/it]
100%|██████████| 88/88 [12:00<00:00,  8.19s/it]
100%|██████████| 14/14 [00:32<00:00,  2.32s/it]
100%|██████████| 88/88 [11:54<00:00,  8.12s/it]
100%|██████████| 14/14 [00:31<00:00,  2.22s/it]
100%|██████████| 88/88 [11:52<00:00,  8.10s/it]
100%|██████████| 14/14 [00:30<00:00,  2.21s/it]
100%|██████████| 88/88 [11:52<00:00,  8.10s/it]
100%|██████████| 14/14 [00:31<00:00,  2.24s/it]
100%|██████████| 88/88 [11:50<00:00,  8.07s/it]
100%|██████████| 14/14 [00:32<00:00,  2.30s/it]
100%|██████████| 88/88 [11:51<00:00,  8.08s/it]
100%|██████████| 14/14 [00:32<00:00,  2.31s/it]
  0%|          | 0/88 [00:01<?, ?it/s]


KeyboardInterrupt: 