# CenterNet


In [None]:
import albumentations as A
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from ROI.centernet import CenterNet, init_weights
from ROI.dataset import RoiDataset
from ROI.train import fit
from ROI.helpers import *
from ROI.losses import CenterLoss

## Prepare dataset for training

- Pad high resolution images to square shapes
- Resize padded images to 512x512
- Generate dataframe with bounding boxes for optic discs
- Split dataset into training and validation sets


In [None]:
ORIGA_DIR = '../data/ORIGA'
DATA_DIR = ORIGA_DIR
CSV_FILE = ORIGA_DIR + '/origa.csv'

In [None]:
generate_padded_dataset(
    src_images_dir=ORIGA_DIR + '/Images',
    src_masks_dir=ORIGA_DIR + '/Masks',
    dst_images_dir=DATA_DIR + '/Images_Padded',
    dst_masks_dir=DATA_DIR + '/Masks_Padded',
)

In [None]:
generate_resized_dataset(
    src_images_dir=DATA_DIR + '/Images_Padded',
    src_masks_dir=DATA_DIR + '/Masks_Padded',
    dst_images_dir=DATA_DIR + '/Images_512x512',
    dst_masks_dir=DATA_DIR + '/Masks_512x512',
    size=512,
)

In [None]:
generate_bbox_csv(
    images_dir=DATA_DIR + '/Images_512x512',
    masks_dir=DATA_DIR + '/Masks_512x512',
    csv_file=CSV_FILE,
    margin=0,
)

In [None]:
df = pd.read_csv(CSV_FILE)
image_ids = df['image_id'].unique()
train_ids, val_ids = train_test_split(image_ids, test_size=0.2, random_state=411)

print('Training size:', len(train_ids))
print('Validation size:', len(val_ids))

df.head()

## Train model

In [None]:
# Hyper-parameters
MODEL_PATH = r'C:\Users\ASUS\PycharmProjects\DP-GlaucomaSegmentation\notebooks\models\centernet_resnet18_margin8.pth'
INPUT_SIZE = 512
IN_SCALE = 512 // INPUT_SIZE
MODEL_SCALE = 4
BATCH_SIZE = 2
MODEL_NAME = 'resnet18'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_transform = A.Compose([
    A.Resize(INPUT_SIZE, INPUT_SIZE, interpolation=cv.INTER_AREA),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.RandomGamma(p=0.5),
    A.GaussianBlur(p=0.5, blur_limit=(5, 15)),
    A.Normalize(mean=(0.9400, 0.6225, 0.3316), std=(0.1557, 0.1727, 0.1556)),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='coco', label_fields=['labels']))

val_transform = A.Compose([
    A.Resize(INPUT_SIZE, INPUT_SIZE, interpolation=cv.INTER_AREA),
    A.Normalize(mean=(0.9400, 0.6225, 0.3316), std=(0.1557, 0.1727, 0.1556)),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='coco', label_fields=['labels']))

res = []
dataset = RoiDataset(train_ids, df, INPUT_SIZE, IN_SCALE, MODEL_SCALE, train_transform)
for i in range(4):
    img, heatmap, regression, bboxes, *_ = dataset[0]
    img = (img - img.min()) / (img.max() - img.min())
    img = img.permute(1, 2, 0).numpy()
    img = (img * 255).astype(np.uint8)
    for bbox in bboxes:
        x, y, w, h = bbox
        x, y, w, h = int(x), int(y), int(w), int(h)
        img = cv.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 2)
    res.append(img)
_, ax = plt.subplots(2, 2, figsize=(8, 8))
ax = ax.flatten()
for i, img in enumerate(res):
    ax[i].imshow(img)
plt.tight_layout()
plt.show()

In [None]:
train_dataset = RoiDataset(train_ids, df, INPUT_SIZE, IN_SCALE, MODEL_SCALE, train_transform)
val_dataset = RoiDataset(val_ids, df, INPUT_SIZE, IN_SCALE, MODEL_SCALE, val_transform)
total_dataset = RoiDataset(image_ids, df, INPUT_SIZE, IN_SCALE, MODEL_SCALE, val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
total_loader = DataLoader(total_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
model = CenterNet(n_classes=1, base='resnet18', custom=True)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = CenterLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)

if MODEL_PATH:
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    print('Model loaded from', MODEL_PATH)
else:
    init_weights(model)
model = model.to(DEVICE)

In [None]:
hist, best_weights = fit(model, optimizer, criterion, DEVICE, train_loader, val_loader,
                         epochs=100, scheduler=scheduler, early_stopping_patience=10)
print([k for k in hist.keys()])

In [None]:
avg_loss = np.mean(hist['train_loss']) + np.mean(hist['val_loss'])

plt.figure(figsize=(8, 8))
plt.plot(hist['train_loss'], label='train_loss')
plt.plot(hist['val_loss'], label='val_loss')
plt.ylim(0, avg_loss * 2)
plt.legend()
plt.show()

In [None]:
torch.save(model.state_dict(), 'model.pth')
torch.save(best_weights, 'best_model.pth')

## Make predictions

In [None]:
disc_df = detect_objects(
    model, total_loader, DEVICE, INPUT_SIZE, MODEL_SCALE,
    margin=16, out_file=DATA_DIR + '/centernet.csv',
)
disc_df

In [None]:
generate_cropped_dataset(
    disc_df,
    src_images_dir=DATA_DIR + '/Images_Padded',
    src_masks_dir=DATA_DIR + '/Masks_Padded',
    dst_images_dir=DATA_DIR + '/Images_CenterNet_Cropped',
    dst_masks_dir=DATA_DIR + '/Masks_CenterNet_Cropped',
    size=512,
    margin=0,
)