In [None]:
import os
import math
import numpy as np
import pandas as pd
import argparse
from tqdm import tqdm

from skfda import FDataGrid
from skfda.representation.basis import BSpline
from skfda.preprocessing.smoothing import BasisSmoother
from skfda.preprocessing.registration import landmark_registration
from skfda.ml.classification import KNeighborsClassifier, NearestCentroid

from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, roc_curve, confusion_matrix


agg_columns = ['patient_id', 'slice_id', 'img_type']

n_basis=18
order=4

prc_rm=0.05
n_points =111

basis = BSpline(domain_range=(0, 1), n_basis=n_basis, order=order)
smoother = BasisSmoother(basis=basis, return_basis=True, method='svd')


def cut_ends(bsplined, order=0, prc_rm_start=prc_rm, prc_rm_end=prc_rm, n_points=n_points):
	bsplined_grid = bsplined.derivative(order=order).to_grid(np.linspace(0, 1, n_points))
	return FDataGrid(
		data_matrix=bsplined_grid.data_matrix[
			..., int(n_points * prc_rm_start): int(n_points * (1 - prc_rm_end)), 0
		],
		grid_points=bsplined_grid.grid_points[0][
			int(n_points * prc_rm_start): int(n_points * (1 - prc_rm_end))
		]
	)


def get_landmark_registration(bsplined, order=0):
	bsplined_grid = cut_ends(bsplined, order)
	landmark_indexes = cut_ends(bsplined, order, prc_rm_end=0.5).data_matrix.argmax(axis=1)
	grid_points = bsplined_grid.grid_points[0]
	landmarks = [grid_points[index] for index in np.concatenate(landmark_indexes)]
	return landmark_registration(bsplined_grid, landmarks)


def specificity(y_true, y_pred, zero_division=0):
	tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
	if tn+fp == 0 and zero_division:
		return zero_division
	return tn / (tn+fp)


def to_fd(segmentation_id, prefix='fixed', is_registered=False):
	file = f'./segmentations/{prefix}-{segmentation_id}.csv'
	dataset = (
	pd.read_csv(
		file,
		dtype={
			'img_type': int,
			'patient_id': int,
			'cycle_id': int,
			'slice_id': int,
			'label': bool,
			'mask_int_mean': float,
			'segment': int,
		},
	)
	.drop_duplicates()
	.sort_values(agg_columns + ['cycle_id'])
	)
	dataset = dataset.merge(dataset.query('label').patient_id.drop_duplicates())
	dataset = dataset[dataset.patient_id.apply(lambda x: x in [2, 15, 28, 32, 35, 39, 40, 41, 45, 50, 52, 64, 66])]
	ts = (
		dataset[['patient_id', 'cycle_id']].drop_duplicates()
			.groupby('patient_id').cycle_id.count()
			.apply(lambda x: np.linspace(0, 1, int(x)))
			.reset_index()
	)

	dataset = dataset.groupby(agg_columns + ['label']).mask_int_mean.apply(list).reset_index()
	bsplined = dataset.groupby('patient_id').mask_int_mean.apply(list).reset_index().merge(ts)
	bsplined = bsplined.apply(
		lambda x: smoother.fit_transform(
			FDataGrid(data_matrix=x['mask_int_mean'], grid_points=x['cycle_id'])
		),
		axis='columns',
	)
	dataset = dataset[['patient_id', 'label']].groupby('patient_id').label.apply(list).reset_index()
	if is_registered:
		dataset['fd_smooth'] = [get_landmark_registration(fd_smooth, 1) for fd_smooth in bsplined]
	else:
		dataset['fd_smooth'] = [cut_ends(fd_smooth, 1) for fd_smooth in bsplined]
	return dataset


def train_model(patient_data, segmentation_id):
	labels = patient_data['label']
	model = NearestCentroid()
	model = model.fit(patient_data['fd_smooth'], labels)
	pred = model.predict(patient_data['fd_smooth'])
	return {
		'patient_id': patient_data['patient_id'],
		'segmentation_id': segmentation_id,
		'precision': precision_score(labels, pred, zero_division=0),
		'recall': recall_score(labels, pred, zero_division=0),
		'f1': f1_score(labels, pred, zero_division=0),
		'balanced_accuracy': balanced_accuracy_score(labels, pred),
		'specificity': specificity(labels, pred, zero_division=0),
	}



## Training

In [None]:
dataset = np.concatenate(
	[
		to_fd(segmentation_id, 'proportionate', True).apply(lambda x: train_model(x, segmentation_id), axis='columns')
		for segmentation_id in tqdm(range(0, 25))
	]
)
# pd.DataFrame(list(dataset)).to_csv(f'./show/functional_data/nearestCentroid/proportionate_unregistered.csv', index=False)
pd.DataFrame(list(dataset)).to_csv(f'./show/functional_data/nearestCentroid/proportionate_registered.csv', index=False)
# pd.DataFrame(list(dataset)).to_csv(f'./show/functional_data/nearestCentroid/fixed_unregistered.csv', index=False)
# pd.DataFrame(list(dataset)).to_csv(f'./show/functional_data/nearestCentroid/fixed_registered.csv', index=False)