# HistoMIL Multiple Instance Learning Notebook

This Jupyter notebook demonstrates how to train a model using multiple instance learning (MIL) on histopathology whole-slide images using HistoMIL. The notebook is divided into three main sections: parameter definition, data preparation, and model definition and training.

## Getting Started

Before proceeding with this notebook, please make sure that you have followed the setup instructions provided in the project's README file. This includes creating a conda environment and installing the required dependencies.

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
#--------------------------> base env setting
# avoid pandas warning
import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
# avoid multiprocessing problem
import torch
import torch.nn as nn
torch.multiprocessing.set_sharing_strategy('file_system')
#--------------------------> logging setup
import logging
logging.basicConfig(
    level=logging.INFO,
    format='|%(asctime)s.%(msecs)03d| [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d|%H:%M:%S',
    handlers=[
        logging.StreamHandler()
    ]
)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


# Change path to use HistoMIL since it's not a library that is pip installed

In [4]:
import os
os.getcwd()
os.chdir('/Users/awxlong/Desktop/my-studies/hpc_exps/') # path to parent dir of HistoMIL

In [5]:
from HistoMIL.MODEL.Image.MIL.TransMIL.paras import TransMILParas
from HistoMIL.MODEL.Image.MIL.DSMIL.paras import DSMILParas
from HistoMIL.EXP.paras.env import EnvParas

import pickle
import wandb
from dotenv import load_dotenv


  from .autonotebook import tqdm as notebook_tqdm


## Model Definition

The second section of the notebook covers model definition for MIL. This includes defining the MIL model architecture using the parameters defined in the first section.

In [6]:
#--------------------------> model setting

# for transmil
model_para_transmil = TransMILParas()
model_para_transmil.feature_size=512
model_para_transmil.n_classes=2
model_para_transmil.norm_layer=nn.LayerNorm
# for dsmil
model_para_dsmil = DSMILParas()
model_para_dsmil.feature_dim = 224 # feature embedding size of feature extractor, in this case prov-gigapath's
model_para_dsmil.p_class = 2
model_para_dsmil.b_class = 2
model_para_dsmil.dropout_r = 0.5

model_name = "TransMIL"  # or "TransMIL" or "ABMIL"

model_para_settings = {"TransMIL":model_para_transmil,
                       "DSMIL":model_para_dsmil} 

In [7]:
model_para_transmil.encoder_name # if you already ran preprocessing and stored the feature vectors, ENSURE this is set as 'pre-calculated'

'pre-calculated'

## Parameter Definition

The first section of the notebook defines the parameters used in the MIL training process. This includes the model architecture, loss function, optimizer, and learning rate scheduler. You can modify these parameters to customize the training process for your specific needs.

In [13]:
gene2k_env = EnvParas()
precomputed = True

#--------------------------> task setting
task_name = "g0_arrest" # Coincides with column name of target label

#--------------------------> parameters
# logging information
gene2k_env.exp_name = f"{model_name}_{task_name}"
gene2k_env.project = "g0_arrest" 
gene2k_env.entity = "cell-x"    # make sure it's initialized to an existing wandb entity

#----------------> cohort
gene2k_env.cohort_para.localcohort_name = "CRC" # name of patient cohort 
gene2k_env.cohort_para.task_name = task_name
gene2k_env.cohort_para.cohort_file = f'local_cohort_{gene2k_env.cohort_para.localcohort_name}.csv' # e.g. local_cohort_CRC.csv, this is created automatically, and contains folder, filename, slide_nb, tissue_nb, etc. 
gene2k_env.cohort_para.task_file = f'{gene2k_env.cohort_para.localcohort_name}_{gene2k_env.cohort_para.task_name}.csv' # e.g. CRC_g0_arrest.csv, which has PatientID matched with g0_arrest labels. This is SUPPLIED by the user and assumed to be stored in the EXP/Data/ directory
gene2k_env.cohort_para.pid_name = "PatientID"
gene2k_env.cohort_para.targets = ['g0_arrest']  # e.g. "g0_arrest"  # the column name of interest; supply as a list
gene2k_env.cohort_para.targets_idx = 0
gene2k_env.cohort_para.label_dict = {'negative':0,'positive':1}  # SINGLE quotations for the keys, converts strings objects to binary values
#debug_env.cohort_para.update_localcohort = True
#----------------> pre-processing
#----------------> dataset
gene2k_env.dataset_para.dataset_name = f"CRC_{task_name}"
gene2k_env.dataset_para.concepts = ["slide","patch","feature"] # default ['slide', 'tissue', 'patch', 'feature'] in this ORDER
gene2k_env.dataset_para.split_ratio = [0.8,0.2]                # dataset split ratio which must sum to one, and training ratio is greater than testing
#----------------> model
gene2k_env.trainer_para.model_name = model_name
gene2k_env.trainer_para.model_para = model_para_settings[model_name]
#----------------> trainer or analyzer
if precomputed:
    gene2k_env.trainer_para.use_pre_calculated = True ### FOR LOADING COMPUTED FEATURES
else:
    gene2k_env.trainer_para.backbone_name = "resnet18"
    gene2k_env.trainer_para.additional_pl_paras.update({"accumulate_grad_batches":8})
    gene2k_env.trainer_para.label_format = "int"#"one_hot" 


#k_fold = None
#--------------------------> init machine and person

gene2k_env.trainer_para.backbone_name = "prov-gigapath"

In [14]:
machine_cohort_loc = "/Users/awxlong/Desktop/my-studies/hpc_exps/User/CRC_machine_config.pkl"
with open(machine_cohort_loc, "rb") as f:   # Unpickling
    [data_locs,exp_locs,machine,user] = pickle.load(f)
gene2k_env.data_locs = data_locs
gene2k_env.exp_locs = exp_locs

## Initialize wandb (once is enough)

In [15]:
# # api_dir = 'path/to API.env/'                    # We assume you store your API keys in a .env file
# # load_dotenv(dotenv_path=f"{api_dir}API.env")
# # user.wandb_api_key = os.getenv("WANDB_API_KEY") # We assume your wandb API key is named as WANDB_API_KEY in the API.env file                             # should have the api key if machine_config.ipynb was run without problems
# user.wandb_api_key                                # should have the API key if the machine_config.ipynb notebook was run without issues

# wandb.setup(settings=wandb.Settings(
#     _disable_stats=True,
#     disable_git=True,
#     api_key=user.wandb_api_key  
# ))

# wandb.init(project=gene2k_env.project, 
#            entity=gene2k_env.entity)

## Model initialisation and Training

The third and final section of the notebook covers model definition and training. This includes defining the MIL model using the parameters defined in the first section, and training the model using the dataloaders created in the second section.

After training is complete, the notebook will also demonstrate how to evaluate the trained model on a validation set and make predictions on new whole-slide images.

In [18]:
logging.info("setup experiment")
from HistoMIL.EXP.workspace.experiment import Experiment
exp = Experiment(env_paras=gene2k_env)
exp.setup_machine(machine=machine,user=user)
logging.info("setup data")
exp.init_cohort()
logging.info("setup trainer..")
exp.setup_experiment(main_data_source="slide",
                    need_train=True)

exp.exp_worker.train()

|2024-06-21|16:42:34.465| [INFO] setup experiment
|2024-06-21|16:42:34.466| [INFO] Exp:: Start Environment TransMIL_g0_arrest
|2024-06-21|16:42:34.466| [INFO] Exp:: Set up machine
|2024-06-21|16:42:34.466| [INFO] setup data


|2024-06-21|16:42:34.467| [INFO] Exp:: Initialise slide-based data cohort
|2024-06-21|16:42:34.467| [INFO] Cohort::Set up local cohort for slides at /Users/awxlong/Desktop/my-studies/temp_data/CRC/TCGA-CRC/
|2024-06-21|16:42:34.469| [INFO] Cohort::Set up task cohort for file local_cohort_CRC.csv
|2024-06-21|16:42:34.470| [INFO] Cohort::Build task cohort use local_cohort_CRC.csv
|2024-06-21|16:42:34.474| [INFO] Cohort::Done and task cohort saved as /Users/awxlong/Desktop/my-studies/hpc_exps/Data/Task_g0_arrest.csv
|2024-06-21|16:42:34.474| [INFO] setup trainer..
|2024-06-21|16:42:34.475| [INFO]  Cohort::Show Task stat with {'negative': 0, 'positive': 1}:
|2024-06-21|16:42:34.475| [INFO]  Cohort::Category: positive include 2 slides,
|2024-06-21|16:42:34.475| [INFO]                include 31036  patch, 
|2024-06-21|16:42:34.475| [INFO]  Cohort::Category: negative include 2 slides,
|2024-06-21|16:42:34.476| [INFO]                include 32986  patch, 
|2024-06-21|16:42:34.476| [INFO] Cohor

|2024-06-21|16:42:38.833| [INFO] Trainer:: Best model will be saved at /Users/awxlong/Desktop/my-studies/hpc_exps/SavedModels/ as TransMIL_g0_arrest_{epoch:02d}-{auroc:.2f}
|2024-06-21|16:42:38.877| [INFO] GPU available: True (mps), used: True
|2024-06-21|16:42:38.878| [INFO] TPU available: False, using: 0 TPU cores
|2024-06-21|16:42:38.879| [INFO] IPU available: False, using: 0 IPUs
|2024-06-21|16:42:38.880| [INFO] HPU available: False, using: 0 HPUs
|2024-06-21|16:42:38.882| [INFO] Trainer:: Start training....


> [0;32m/Users/awxlong/Desktop/my-studies/hpc_exps/HistoMIL/EXP/trainer/base.py[0m(97)[0;36mtrain[0;34m()[0m
[0;32m     95 [0;31m        [0mvalloader[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdata_pack[0m[0;34m[[0m[0;34m"testloader"[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     96 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 97 [0;31m        self.trainer.fit(model=self.pl_model, 
[0m[0;32m     98 [0;31m                [0mtrain_dataloaders[0m[0;34m=[0m[0mtrainloader[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     99 [0;31m                val_dataloaders=valloader)
[0m
[tensor([[[1.4635e-02, 0.0000e+00, 1.3146e-01,  ..., 2.2182e-01,
          3.0558e-02, 6.4946e-01],
         [1.6679e-01, 4.2147e-03, 1.2958e-02,  ..., 5.9249e-01,
          5.1140e-04, 5.9399e-01],
         [3.8960e-01, 0.0000e+00, 3.0880e-02,  ..., 7.9045e-01,
          2.6434e-02, 1.0822e+00],
   