# Validation notebook for ResNet binary classifier


In [None]:
from torch.utils.data import DataLoader
import torch
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn

from torchvision import transforms

from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score, accuracy_score

import sys
import pandas as pd
import numpy as np
import os
import random
from matplotlib import pyplot as plt
# import seaborn as sns
sys.path.append('..')
from data_utils.dataset import BoneSlicesDatasetPrev
from training.validation_metrics import get_true_and_predicted_labels, get_predicted_labels

<br><br><br>
---
## Parameters

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else 'cpu'
MODEL_PATH = 'training/resnet_18_6_all/saved_models/Iteration_1/model_20240520_093559_14'
BATCH_SIZE = 64
NUM_WORKERS = 4
#VALIDATION_EXAMPLES_FILE = 'validation_examples.csv'
VALIDATION_EXAMPLES_FILE = 'test.csv'
#TRAINING_EXAMPLES_FILE = 'training_examples.csv'

In [None]:
DEVICE

In [None]:
os.chdir("..")
os.getcwd()

<br><br><br>
---
## Model

In [None]:
resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
# Changing last classificator layer from 1000 classes to 2
resnet.fc = nn.Linear(512, 2)

# Changing 3 channels into 1 (monochromatic image)
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# Load model
resnet.load_state_dict(torch.load(MODEL_PATH))

# Evaluation state (not computing gradients) and sending to DEVICE
resnet.eval()
resnet.to(DEVICE)


<br><br><br>
---
## Validation dataset and dataloader

In [None]:
validation_examples = list(pd.read_csv(VALIDATION_EXAMPLES_FILE)['Image Name'])
#training_examples = list(pd.read_csv(TRAINING_EXAMPLES_FILE)['Image Name'])

#train_ds = BoneSlicesDatasetPrev(json_config_filepath = 'data_utils/config_binary_z.json', transform=transforms)
valid_ds = BoneSlicesDatasetPrev(json_config_filepath = 'data_utils/config_binary_test.json')


#train_ds.subset_by_image_name(training_examples)
valid_ds.subset_by_image_name(validation_examples)
print(f"Size of the validation dataset: {len(valid_ds.metadata['Image Name'])}")
print(f"Size of the validation dataset: {len(valid_ds)}")
#print(f"Size of the training dataset: {len(train_ds.metadata['Image Name'])}")

#train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE)

<br><br><br>

# Results (binary classification)

In [None]:
import time
start = time.time()
#true_labels, predicted_labels = get_true_and_predicted_labels(resnet, val_dl, DEVICE)
predicted_labels = get_predicted_labels(resnet, val_dl, DEVICE)
end = time.time()
print(f"Execution time: {end - start}")

In [None]:
#true_labels

In [None]:
predicted_labels

In [None]:
n_wrongly_classifed = (predicted_labels != true_labels).sum().item()
print(f"Number of wrongly classified slices: {n_wrongly_classifed}")

In [None]:
indices = (predicted_labels != true_labels).nonzero()

fig, ax = plt.subplots(nrows = n_wrongly_classifed//5+1, ncols = 5, figsize = (40,40))
for i, idx in enumerate(indices):
    idx = idx.item()
    data, label = valid_ds[idx]
    ax[i//5, i%5].imshow(data.permute(2,1,0), cmap='gray')
    image_name = valid_ds.metadata.iloc[idx]['Image Name']
    slice_index = valid_ds.metadata.iloc[idx]['Slice Index']
    growth_plate_index = valid_ds.metadata.iloc[idx]['Growth Plate Index']
    ax[i//5, i%5].title.set_text(f'image: {image_name}, slice: {slice_index}, growth plate {growth_plate_index}')
    

In [None]:
report = classification_report(true_labels, predicted_labels, digits=4)
print(report)

<br><br><br>

# Competition score

In [None]:
result = valid_ds.metadata
result['predicted_labels'] = predicted_labels
#result['true_labels'] = true_labels
result

In [None]:
# Julka
predicted = []
for img in result.groupby(['Image Name']):
    predicted.append(np.asarray(img[1]['predicted_labels']))
    #growth_plate_index = img[1]['Growth Plate Index']

In [None]:
import cv2
pred = predicted[8]
pred_new = cv2.dilate(pred, np.ones((5, 1), np.uint8))

In [None]:
pred

In [None]:
from scipy.stats import norm
def _calculate_score(pred_slice_num, gt_slice_num):
    """Returns the survival function a single-sided normal distribution with stddev=3."""
    diff = abs(pred_slice_num - gt_slice_num)
    return 2 * norm.sf(diff, 0, 3)

In [None]:
import cv2

In [None]:
#scores = []
#prediction_pair = []
#print('Prediction | True')
all_pred = []
all_pred_filtr = []
for img in result.groupby(['Image Name']):
    predicted = np.asarray(img[1]['predicted_labels'])
    #true = img[1]['true_labels']
    #growth_plate_index = img[1]['Growth Plate Index']
    # applying morphological closing 
    predicted_filter = cv2.morphologyEx(predicted, cv2.MORPH_CLOSE, np.ones((5,1)))
    predicted_filter_index = (predicted_filter==0).argmax(axis=0)
    predicted_index = (predicted==0).argmax(axis=0)
    #true_index = growth_plate_index.iloc[0]
    # print(predicted, end = '\n\n\n')
    all_pred.append(predicted_index)
    all_pred_filtr.append(predicted_filter_index)
    # print(np.array(predicted==0, dtype = np.float16))
    # print(np.array(predicted==0, dtype = np.float16).argmax(axis=0))
    # print('pred164', predicted[174])
    # print("Predicted:", predicted_index)
    # print("True:", true_index)
   #break
    
    # scores.append(_calculate_score(predicted_index, true_index))
    # prediction_pair.append((predicted_index, true_index))
    # print(f"{predicted_index}  |  {true_index}")

# scores = np.array(scores)
# scores

In [None]:
all_pred

In [None]:
all_pred_filtr

In [None]:
7*24

In [None]:
prediction_pair

In [None]:
scores.mean()

In [None]:
prediction_pair

In [None]:
from sklearn.metrics import mean_absolute_error
y_hat = np.array((list(zip(*prediction_pair))[0]))
y = np.array((list(zip(*prediction_pair))[1]))
print(f"Mean absolute error: {mean_absolute_error(y_hat, y)}")

plt.hist(y_hat - y, bins = 15)
plt.xlabel("predicted_index - true_index")
plt.ylabel("Count")
plt.title("Distribution of the results")
plt.show()