In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from src.classification import ClassificationDataSet
from src.classification import SimpleDataset
from src.classification import training_loop
from src.classification import test
from src.classification import get_all_probabilities
from src.classification import optimise_thresholds

from src.settings import settings
from sklearn.model_selection import train_test_split
from typing import Optional


# Example Condition Classification Pipeline

###### This can be used for the tagging of conditions to clincal trials where those trials lack annotation (clinicaltrials.gov only adds this section to completed trials) or used to expand the annotation

In [None]:
# Build up the dataset and dataloaders
data_set = ClassificationDataSet(
    embedding_path=f'{settings.data_dir}\\core_data_embedding.jsonl',
    annotation_path=f'{settings.data_dir}\\core_data_annotation.jsonl'
)
train_features, test_features, train_labels, test_labels = train_test_split(
    data_set.features, data_set.labels, test_size=0.2, random_state=42
)
train_features, val_features, train_labels, val_labels = train_test_split(
    train_features, train_labels, test_size=0.25, random_state=42  # 0.25 x 0.8 = 0.2 of original data
)
train_dataset = SimpleDataset(train_features, train_labels)
test_dataset = SimpleDataset(test_features, test_labels)
val_dataset = SimpleDataset(val_features, val_labels)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
validation_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
simple_model = MultiLabelModel(data_set.features.shape[1], data_set.num_classes)
criterion = nn.BCELoss()  # Binary Cross-Entropy for multi-label classification
optimizer = optim.Adam(simple_model.parameters(), lr=0.001)

In [None]:
training_loop(model=simple_model,
                train_dataloader=train_dataloader,
                validation_dataloader=validation_dataloader,
                num_epochs=10,
                patience=3,
                optimizer=optimizer,
                criterion=criterion)

In [None]:
simple_model.load_state_dict(torch.load(f'{settings.data_dir}/best_model.pth'))
simple_model.eval()
test(simple_model, test_dataloader)

##### We can see the quality of the output by looking at at some examples from the test dataset


In [None]:
subset_data_features,subset_data_labels = test_dataset[:4]  # Get a single sample for evaluation

threshold = 0.1
with torch.no_grad():
    for sample_data, sample_labels in zip(subset_data_features,subset_data_labels):
        predictions = simple_model(sample_data.unsqueeze(0))  # Add batch dimension
        predicted_indices = (predictions >= threshold).squeeze().nonzero(as_tuple=True)[
            0]  # Get indices of predictions above the threshold
        predicted_values = predictions.squeeze()[predicted_indices]
        predicted_labels = [data_set.index_to_label[idx.item()] for idx in
                            predicted_indices]  # Map indices to label strings

        ground_truth_indices = sample_labels.nonzero(as_tuple=True)[0]  # Get indices of true labels
        ground_truth_labels = [data_set.index_to_label[idx.item()] for idx in
                                ground_truth_indices]  # Map indices to label strings
        print(f"Prediction Values {[round(x,4) for x in predicted_values.tolist()]}")
        print(f"Predictions ids: {predicted_labels}")
        print(f"Predictions terms: {[data_set.mesh_name_dict[x] for x in predicted_labels]}")

        print(f"Ground Truth ids: {ground_truth_labels}")
        print(f"Ground Truth terms: {[data_set.mesh_name_dict[x] for x in ground_truth_labels]}")

##### We can also optimise the thresholds on a per class basis either on all the labels or a subset using ghostml


In [None]:
conditions = ['Sclerosis', 'Multiple Sclerosis']
mesh_ids = [data_set.name_mesh_dict[x] for x in conditions]
condition_indexes = [data_set.label_to_index[x] for x in mesh_ids]

train_probabilities = get_all_probabilities(model=simple_model,
                                            dataloader=train_dataloader,
                                            )
thresholds = [float(round(x, 2)) for x in np.arange(0.05, 0.55, 0.05)]
min_positive_count = 3
optimal_thresholds = optimise_thresholds(train_labels=train_labels,
                                            train_probabilities=train_probabilities,
                                            thresholds=thresholds,
                                            subset=condition_indexes,
                                            min_positive_count=3
                                            )
for index, value in optimal_thresholds.items():
    print(f'{data_set.mesh_name_dict[data_set.index_to_label[index]]}: {value}')