# HistoMIL Multiple Instance Learning Notebook

This Jupyter notebook demonstrates how to train a model using multiple instance learning (MIL) on histopathology whole-slide images using HistoMIL. The notebook is divided into three main sections: parameter definition, data preparation, and model definition and training.

## Getting Started

Before proceeding with this notebook, please make sure that you have followed the setup instructions provided in the project's README file. This includes creating a conda environment and installing the required dependencies.

In [None]:
#--------------------------> base env setting
# avoid pandas warning
import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
# avoid multiprocessing problem
import torch
import torch.nn as nn
torch.multiprocessing.set_sharing_strategy('file_system')
#--------------------------> logging setup
import logging
logging.basicConfig(
    level=logging.INFO,
    format='|%(asctime)s.%(msecs)03d| [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d|%H:%M:%S',
    handlers=[
        logging.StreamHandler()
    ]
)

# Change path to use HistoMIL since it's not a library that is pip installed

In [None]:
import os
os.getcwd()
os.chdir('path/to/parent-dir of HistoMIL')

In [None]:
from HistoMIL.MODEL.Image.MIL.TransMIL.paras import TransMILParas
from HistoMIL.MODEL.Image.MIL.DSMIL.paras import DSMILParas
from HistoMIL.EXP.paras.env import EnvParas

import pickle
import wandb
from dotenv import load_dotenv


## Model Definition

The second section of the notebook covers model definition for MIL. This includes defining the MIL model architecture using the parameters defined in the first section.

In [None]:
#--------------------------> model setting

# for transmil
model_para_transmil = TransMILParas()
model_para_transmil.feature_size=512
model_para_transmil.n_classes=2
model_para_transmil.norm_layer=nn.LayerNorm
# for dsmil
model_para_dsmil = DSMILParas()
model_para_dsmil.feature_dim = 512 #resnet18
model_para_dsmil.p_class = 2
model_para_dsmil.b_class = 2
model_para_dsmil.dropout_r = 0.5

model_name = "TransMIL"  # or "TransMIL" or "ABMIL"

model_para_settings = {"TransMIL":model_para_transmil,
                       "DSMIL":model_para_dsmil} 

In [None]:
model_para_transmil.encoder_name # if you already ran preprocessing and stored the feature vectors, ENSURE this is set as 'pre-calculated'

## Parameter Definition

The first section of the notebook defines the parameters used in the MIL training process. This includes the model architecture, loss function, optimizer, and learning rate scheduler. You can modify these parameters to customize the training process for your specific needs.

In [None]:
gene2k_env = EnvParas()


#--------------------------> task setting
task_name = "example_mil"

#--------------------------> parameters
# logging information
gene2k_env.exp_name = f"{model_name}_{task_name}"
gene2k_env.project = "g0_arrest" 
gene2k_env.entity = "cell-x"    # make sure it's initialized to an existing wandb entity

#----------------> cohort
gene2k_env.cohort_para.localcohort_name = "BRCA" 
gene2k_env.cohort_para.task_name = task_name
gene2k_env.cohort_para.cohort_file = f'local_cohort_{gene2k_env.cohort_para.localcohort_name}.csv' # e.g. local_cohort_BRCA.csv, this is created automatically, and contains folder, filename, slide_nb, tissue_nb, etc. 
gene2k_env.cohort_para.task_file = f'{gene2k_env.cohort_para.localcohort_name}_{gene2k_env.cohort_para.task_name}.csv' # e.g. BRCA_g0_arrest.csv, which has PatientID matched with g0_arrest labels. This is SUPPLIED by the user and assumed to be stored in the EXP/Data/ directory
gene2k_env.cohort_para.pid_name = "Patient_ID"
gene2k_env.cohort_para.targets = f'name of target_label column'  # e.g. "g0_arrest"  # the column name of interest
gene2k_env.cohort_para.targets_idx = 0
gene2k_env.cohort_para.label_dict = "{'HRD':0,'HRP':1}"  # SINGLE quotations for the keys, converts strings objects to binary values
#debug_env.cohort_para.update_localcohort = True
#----------------> pre-processing
#----------------> dataset
gene2k_env.dataset_para.dataset_name = f"BRCA_{task_name}"
gene2k_env.dataset_para.concepts = ["slide","patch","feature"] # default ['slide', 'tissue', 'patch', 'feature'] in this ORDER
gene2k_env.dataset_para.split_ratio = [0.8,0.2]                # dataset split ratio which must sum to one, and training ratio is greater than testing
#----------------> model
gene2k_env.trainer_para.model_name = model_name
gene2k_env.trainer_para.model_para = model_para_settings[model_name]
#----------------> trainer or analyzer
gene2k_env.trainer_para.backbone_name = "resnet18"
gene2k_env.trainer_para.additional_pl_paras.update({"accumulate_grad_batches":8})
gene2k_env.trainer_para.label_format = "int"#"one_hot" 
gene2k_env.trainer_para.use_pre_calculated = True ### FOR LOADING COMPUTED FEATURES

#k_fold = None
#--------------------------> init machine and person



In [None]:
machine_cohort_loc = "Path/to/BRCA_machine_config.pkl"
with open(machine_cohort_loc, "rb") as f:   # Unpickling
    [data_locs,exp_locs,machine,user] = pickle.load(f)
gene2k_env.data_locs = data_locs
gene2k_env.exp_locs = exp_locs

## Initialize wandb (once is enough)

In [None]:
# api_dir = 'path/to API.env/'                    # We assume you store your API keys in a .env file
# load_dotenv(dotenv_path=f"{api_dir}API.env")
# user.wandb_api_key = os.getenv("WANDB_API_KEY") # We assume your wandb API key is named as WANDB_API_KEY in the API.env file                             # should have the api key if machine_config.ipynb was run without problems
user.wandb_api_key                                # should have the API key if the machine_config.ipynb notebook was run without issues

wandb.init(project=gene2k_env.project, 
           entity=gene2k_env.entity,
           api_key=user.wandb_api_key)

## Model initialisation and Training

The third and final section of the notebook covers model definition and training. This includes defining the MIL model using the parameters defined in the first section, and training the model using the dataloaders created in the second section.

After training is complete, the notebook will also demonstrate how to evaluate the trained model on a validation set and make predictions on new whole-slide images.

In [None]:
logging.info("setup experiment")
from HistoMIL.EXP.workspace.experiment import Experiment
exp = Experiment(env_paras=gene2k_env)
exp.setup_machine(machine=machine,user=user)
logging.info("setup data")
exp.init_cohort()
logging.info("setup trainer..")
exp.setup_experiment(main_data_source="slide",
                    need_train=True)

exp.exp_worker.train()