In [None]:
# !git clone 'https://github.com/Lorenz92/SKADC1.git'
# % cd SKADC1
# !echo $PWD
# !pip install requirements.txt

In [None]:
import numpy as np

import src.dataset as dataset
import src.config as config 
from src.utils import *
import src.models as models
import src.losses as loss

path = config.TRAIN_PATCHES_FOLDER
%load_ext autoreload
%autoreload 2

np.random.seed(config.RANDOM_SEED)

In [None]:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
if 'google.colab' in str(get_ipython()):
  use_colab = True
  print('Running on CoLab')
else:
  use_colab = False
  print('Not running on CoLab')

In [None]:
source_dir = './data/training/patches'

if use_colab:
    # Read file from Colab Notebook
    from google.colab import drive
    drive.mount('/content/drive')
    config.MODEL_WEIGHTS = "/content/drive/My Drive/Colab Notebooks/SKADC1"
    config.IMAGE_PATH = "/content/drive/My Drive/Colab Notebooks/SKADC1/asset/560Mhz_1000h.fits"
    config.TRAIN_DATA_FOLDER = "/content/drive/My Drive/Colab Notebooks/SKADC1/asset"
    config.TRAIN_PATCHES_FOLDER = "/content/drive/My Drive/Colab Notebooks/SKADC1/patches"
    source_dir = "/content/drive/My Drive/Colab Notebooks/SKADC1/asset/patches"

In [None]:
# Choose the feature extraction model
backbone='baseline_44'
# backbone='baseline_16'
# backbone='vgg16'

if backbone=='baseline_16':
    config.patch_dim = 20
    config.resizePatch = True
    config.rpn_stride = 4
    config.num_rois = 16
    config.anchor_box_scales = [4, 8, 16, 24, 32, 64] # anchors in the original image size
    config.resizeFinalDim = 100
    input_shape_1 = config.resizeFinalDim
elif backbone=='baseline_44':
    config.patch_dim = 20
    config.resizePatch = True
    config.rpn_stride = 8
    config.num_rois = 16
    config.anchor_box_scales = [4, 8, 16, 24, 32, 64]
    config.resizeFinalDim = 100
    input_shape_1 = config.resizeFinalDim
else:
    config.patch_dim = 100
    config.resizePatch = True
    config.rpn_stride = 16
    config.num_rois = 16
    config.resizeFinalDim = 600
    input_shape_1=config.resizeFinalDim
    config.anchor_box_scales = [32,64,128]
    config.in_out_img_size_ratio = config.rpn_stride

config.anchor_num = len(config.anchor_box_ratios)*len(config.anchor_box_scales)
input_shape_2=(None, 4)
use_focal_loss = False

checkpoint = get_model_last_checkpoint(backbone)
print(f'Model last checkpoint: {checkpoint}')

file_path = f'{config.MODEL_WEIGHTS}/{backbone}'
print(f'Writing configuration on txt file: {config.MODEL_WEIGHTS}/config.txt')

if not os.path.exists(file_path):
        os.makedirs(file_path)
        
f = open(f'{file_path}/config.txt',"w+")
f.write(f'backbone = {backbone}\n config.patch_dim = {config.patch_dim}\n config.resizePatch = {config.resizePatch}\n config.rpn_stride = {config.rpn_stride}\n config.num_rois = {config.num_rois}\n config.anchor_box_scales = {config.anchor_box_scales}\n config.resizeFinalDim = {config.resizeFinalDim}\n input_shape_1 = {input_shape_1}')
f.close()

In [None]:
# Dataset parsing and loading
# use "subset" in config file to load a small portion of data for development/debugging purposes
ska_dataset = dataset.SKADataset(print_info=False, show_plot=True)

In [None]:
ska_dataset.cleaned_train_df[['width', 'height', 'area_orig', 'area_cropped']].describe()

In [None]:
ska_dataset.cleaned_train_df[['width', 'height', 'area_orig']].quantile([.1,.2,.3,.4,.5,.6,.7,.8,.9,.95,.98,.99,1.])

In [None]:
objects_to_ignore=[20167150, 27514971]

In [None]:
ska_dataset.generate_patches(limit=10000, plot_patches=False, objects_to_ignore=objects_to_ignore, source_dir=source_dir, rgb_norm=True)

In [None]:
# Use seed=5 for 20_100 set and seed=15 for 50_100
ska_dataset.split_train_val(random_state=5, val_portion=0.2, balanced=False, size=350)

# Training

### Get FRCNN model

In [None]:
rpn_model, detector_model, total_model = models.get_train_model(input_shape_1=input_shape_1, input_shape_2=input_shape_2, anchor_num=config.anchor_num, pooling_regions=config.pooling_regions, num_rois=config.num_rois, num_classes=len(ska_dataset.class_list)+1, backbone=backbone, use_expander=False)

rpn_model.summary()
detector_model.summary()
total_model.summary()

### Load weights

In [None]:
checkpoint= 'loss_0_frcnn_baseline_16.h5'
models.load_weigths(rpn_model, detector_model, backbone, resume_train=True, checkpoint=checkpoint)

if use_focal_loss:
    models.compile_models(rpn_model, detector_model, total_model, rpn_losses=[loss.rpn_loss_cls, loss.rpn_loss_regr], detector_losses=[loss.categorical_focal_loss(config.alpha, config.gamma), loss.detector_loss_regr], class_list=ska_dataset.class_list)
else:
    models.compile_models(rpn_model, detector_model, total_model, rpn_losses=[loss.rpn_loss_cls, loss.rpn_loss_regr], detector_losses=[loss.detector_loss_cls, loss.detector_loss_regr], class_list=ska_dataset.class_list)


In [None]:
# If you want to specifically check backbone weights you need to slice weights tensors like this:
# total_model.weights[24:25][0][0][0][0]

In [None]:
# Check that all of the pretrained weights have been loaded.
import numpy as np
for i, j in zip(total_model.weights, rpn_model.weights): 
    assert np.allclose(i,j), 'Weights don\'t match!'

### Train

In [None]:
# Generating validation model for validation step at epoch's end
rpn_model_eval, detector_model_eval, total_model_eval = models.get_eval_model(input_shape_1=input_shape_1, input_shape_2=input_shape_2, input_shape_fmap=None, anchor_num=config.anchor_num, pooling_regions=config.pooling_regions, num_rois=config.num_rois, num_classes=len(ska_dataset.class_list)+1, backbone=backbone, use_expander=False)

rpn_model_eval.summary()
detector_model_eval.summary()
total_model_eval.summary()

In [None]:
from src.train import *

train_frcnn(rpn_model, detector_model, total_model, ska_dataset.train_patch_list, rpn_model_eval, detector_model_eval, total_model_eval, ska_dataset.val_patch_list, ska_dataset.class_list, num_epochs=30, patches_folder_path=config.TRAIN_PATCHES_FOLDER, backbone=backbone, resume_train=True)

# Validation

In [None]:
rpn_model_eval, detector_model_eval, total_model_eval = models.get_eval_model(input_shape_1=input_shape_1, input_shape_2=input_shape_2, input_shape_fmap=None, anchor_num=config.anchor_num, pooling_regions=config.pooling_regions, num_rois=config.num_rois, num_classes=len(ska_dataset.class_list)+1, backbone=backbone, use_expander=False)

rpn_model_eval.summary()
detector_model_eval.summary()
total_model_eval.summary()

In the following cell please select weights to be used to perform model evaluation

In [None]:
cp = 'map_63_frcnn_baseline_44.h5'

In [None]:
# Models used for mAP eval
models.load_weigths(rpn_model_eval, detector_model_eval, backbone, checkpoint=cp)
models.compile_models(rpn_model_eval, detector_model_eval, total_model_eval, rpn_losses=[loss.rpn_loss_cls, loss.rpn_loss_regr], detector_losses=[loss.detector_loss_cls, loss.detector_loss_regr], class_list=ska_dataset.class_list)

In [None]:
# Evaluation step carried out on the entire validation set

preds, mAP, mPrecision, mRecall = evaluate_model(rpn_model_eval, detector_model_eval, backbone, ska_dataset.val_patch_list, ska_dataset.class_list, map_threshold=.5, acceptance_treshold=.5, save_eval_results=False)

In [None]:
preds, mAP, mPrecision, mRecall = evaluate_model(rpn_model_eval, detector_model_eval, backbone, ska_dataset.train_patch_list, ska_dataset.class_list, map_threshold=.5, acceptance_treshold=.5, save_eval_results=False)

In [None]:
# Qualitative evaluation by printing image, ground truth and predicted bounding boxes

patch = '92_17296_16729_20'

print_img(config.TRAIN_PATCHES_FOLDER, patch, config.EVAL_RESULTS, show_data=False)

# Plotting

### Loss plot

In [None]:
loss_history = np.load(f"./model/{backbone}/loss_history.npy")
print(loss_history.shape)
plot_loss(loss_history[:])

### Evaluation metrics plot

In [None]:
scores_history = np.load(f"./model/{backbone}/scores_history.npy")
print(scores_history.shape)
plot_scores(scores_history[:])


In [None]:
lsma_0 = moving_average(loss_history[100:,2], 200)

plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.plot(np.arange(0, len(lsma_0)), lsma_0, 'r')
plt.title('rpn cls')
