# 1. Installation of required packages

In [None]:
! pip install transformers==4.32.0 torch==1.11.0 scikit-learn==1.1.3 pytorch_lightning==1.7.2 dotmap==1.3.30 ir-measures==0.3.1 torchmetrics==0.11.4 torchtext==0.12.0 redis==5.0.0

In [None]:
! git clone https://github.com/ProjectDossier/patient-trial-matching.git
%cd patient-trial-matching
! git checkout refactoring/repo
! git submodule update --init --recursive
! pip install -e clinical-trials
! pip install -e .
! python prepare_data.py
%cd ..

In [None]:
! cp -r patient-trial-matching/trec_cds .

# 2. Get required data

In [None]:
! wget https://owncloud.tuwien.ac.at/index.php/s/cyPIaxzXoo2Czqa/download -O trials.tar.gz   # smaller subset of data
# ! wget https://owncloud.tuwien.ac.at/index.php/s/NBsBSX3ch8RSUx1/download -O trials.tar.gz  # full dataset - requires at least 16GB of RAM

! tar -xf trials.tar.gz

In [None]:
import sys
sys.version

In [None]:
! wget https://owncloud.tuwien.ac.at/index.php/s/Y0T0f65EKAdUytx -O run2022

# 3. Download pretrained models and preprocessed topics

In [None]:
! wget https://owncloud.tuwien.ac.at/index.php/s/hwuQ7IvaBbFUNkd/download -O bertbase-trained.ckpt
# ! wget https://owncloud.tuwien.ac.at/index.php/s/58KQznKpFSXYCOX/download -O bluebert-trained.ckpt
# ! wget https://owncloud.tuwien.ac.at/index.php/s/gPZXdhS3j6ggrPr/download -O clinicalbert-trained.ckpt
# ! wget https://owncloud.tuwien.ac.at/index.php/s/Dh4aIXxw3mg7JpS/download -O biobert-trained.ckpt

In [None]:
! wget https://owncloud.tuwien.ac.at/index.php/s/NxmMOJcKI3etFQR/download -O topics2021.jsonl
! wget https://owncloud.tuwien.ac.at/index.php/s/5A5XlZtB6qSXCfo/download -O topics2022.jsonl

# 4. Import packages

In [None]:
import transformers

In [None]:
from transformers import AutoModel, AutoConfig, get_linear_schedule_with_warmup

In [None]:
from transformers import AdamW

In [None]:
import pytorch_lightning as pl
import yaml
from dotmap import DotMap
from pytorch_lightning.loggers import TensorBoardLogger

In [None]:
import ir_measures
import numpy as np
from ir_measures import *

In [None]:
from abc import ABC

In [None]:
import torch
from torch import nn, split

In [None]:
from trec_cds.neural.data.ClinicalTrialsDataModule import ClinicalTrialsDataModule

In [None]:
from trec_cds.neural.models.crossencoder import CrossEncoder

In [None]:
from trec_cds.neural.utils.evaluator import Evaluator

In [None]:
from trec_cds.neural.utils.loss import PairwiseHingeLoss

# 5. Load configuration

In [None]:
models = {
    "bertbase": "bert-base-uncased",
    "biobert": "seiya/oubiobert-base-uncased",
    "bluebert": "bionlp/bluebert_pubmed_uncased_L-12_H-768_A-12",
    "clinicalbert": "Tsubasaz/clinical-pubmed-bert-base-512"
}

In [None]:
MODEL_NAME = "bertbase"

In [None]:
MODEL_CHECKPOINT_PATH = f"{MODEL_NAME}-trained.ckpt"
MODEL_ALIAS = f"{MODEL_NAME}_crossencoder"
LOGGER_NAME = f"{MODEL_NAME}_crossencoder"

In [None]:
PATH_TO_RUN = "patient-trial-matching/data/submissions/bm25_postprocessed_2022"
PATH_TO_QRELS = "patient-trial-matching/data/external/qrels2022.txt"
PATH_TO_TRIALS = "essir_trials_subset.jsonl" # change path if using all trials
PATH_TO_PATIENTS = "topics2022.jsonl"

In [None]:
BATCH_SIZE = 16
N_SAMPLES = 50
GPUS = [ 0 ]
MODE = "predict_w_labels"
FIELDS = ['criteria']
QUERY_REPR = "description"
RELEVANT_LABELS = [ 1, 2 ]
VERSION = 2022

In [None]:
import os
os.makedirs("reports", exist_ok=True)

# 6. Create a data module

In [None]:
data_module = ClinicalTrialsDataModule(
    eval_batch_size=BATCH_SIZE,
    n_test_samples=N_SAMPLES,
    model_name=models[MODEL_NAME],
    mode=MODE,
    fields=FIELDS,
    query_repr=QUERY_REPR,
    relevant_labels=RELEVANT_LABELS,
    path_to_run=PATH_TO_RUN,
    path_to_qrels=PATH_TO_QRELS,
    path_to_trials_jsonl=PATH_TO_TRIALS,
    dataset_version=VERSION,
    path_to_patients=PATH_TO_PATIENTS,

)

# 7. Create Evaluator

In [None]:
evaluator = Evaluator(
    write_csv=True,
    mode="predict",
    output_path="reports/",
    run_id=models[MODEL_NAME],
    re_rank=True,
    path_to_base_run=PATH_TO_RUN,
    qrels_file=PATH_TO_QRELS,
)

# 8. Load Model from checkpoint

In [None]:
checkpoint = torch.load(MODEL_CHECKPOINT_PATH)

state_dict = checkpoint['state_dict']
if 'transformer.embeddings.position_ids' in state_dict:
    del state_dict['transformer.embeddings.position_ids']

# Update the state_dict and save a new checkpoint
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, 'updated_checkpoint.ckpt')
del checkpoint
del state_dict

In [None]:
model = CrossEncoder.load_from_checkpoint(
    checkpoint_path='updated_checkpoint.ckpt',
    model_name=models[MODEL_NAME],
    num_labels=2,
    mode="predict",
    evaluator=evaluator,
)

In [None]:
logger = TensorBoardLogger(
    save_dir=f"reports/{models[MODEL_NAME]}_pred_logs",
    name=models[MODEL_NAME],
)

In [None]:
gpus = [0]

# 9. Create PyTorch Lightning Trainer object

In [None]:
trainer = pl.Trainer(logger=logger, gpus=gpus)

# 10. Run the prediction

In [None]:
model_predictions = trainer.predict(model=model, dataloaders=data_module.predict_dataloader())

# 11. Inspect model predictions

In [None]:
model_predictions[0]

In [None]:
model_predictions[0]['prediction']

In [None]:
# code todo...

# 12. Evaluate predictions

In [None]:
# code todo

# 13. Compare with baseline model

In [None]:
# load and evaluate BM25 model

In [None]:
# plot for which queries neural improves over BM25