# 3.Application module

In this notebook, we demonstrate how to train immunotherapy response models using HiST application module based on spatial gene profiles predicted by HiST prediction module.

xx refers to the directory created by the function based on the date

Note: This part is bulit based on the HCC immunotherapy response HE data from OMIX of NGDC with accession ID: OMIX009369

## Import

In [None]:
import os
import sys
import glob
sys.path.append('../src/')
import pickle
import numpy as np
import pandas as pd
import PIL.Image as Image
from datetime import date
from util.predict import *
from util.cluster import *
from util.survival_plot import *
from util.seed import seed_torch
from util.patch import tile_HE, reconstruct
from ApplicationModule.solver import ICB_Solver
from FeatureExtraction.feature import extract_features, load_features


seed = 42
seed_torch(seed)
Image.MAX_IMAGE_PIXELS = 100000000000

In [None]:
gene_list = list(pd.read_csv('../resource/HCC_SVG448_list.txt',header=None).iloc[:,0])

## Preprocess metadata

In [8]:
metadata = pd.read_excel('../resource/HCC_ICB_metadata.xlsx')
metadata

Unnamed: 0,patient_id,HE_path,response
0,patient1,patient1.tif,NR
1,patient2-2,patient2-2.tif,NR
2,patient2-1,patient2-1.tif,NR
3,patient3,patient3.tif,NR
4,patient4,patient4.tif,NR
...,...,...,...
126,patient117,patient117.tif,R
127,patient118,patient118.tif,NR
128,patient119,patient119.tif,R
129,patient120,patient120.jpg,NR


In [None]:
sample_list = metadata['patient_id'].tolist()

## Tile

In [None]:
for index, row in metadata.iterrows():
    patient_id = row['patient_id']
    HE_path = row['HE_path']
    tile_HE(
        sample_id = patient_id,
        HE_path = HE_path,
        out_path = '../HCC_ICB/tiles/',
        target_size = 224
    )

## Feature extraction

In [None]:
ICB_sample_features = extract_features(
                            tile_path = '../output/HCC_ICB/tiles/',
                            img_ids = sample_list,
                            model_weight_path = '../resource/ctranspath.pth',
                            save = True,
                            seed = seed,
                            file = '../output/features/ICB_sample_features.pkl')

In [None]:
# # if you have already run wsi_clean_up
# ICB_clean_features = extract_features(
#                             tile_path = '../output/HCC_ICB/clean_tiles_75/',
#                             img_ids = sample_list,
#                             model_weight_path = '../resource/ctranspath.pth',
#                             save = True,
#                             seed = seed,
#                             file = '../output/features/ICB_sample_features.pkl')

## Get prediction for all samples

In [None]:
# use HCC based model
predict_gene_matrix_list = GetPredictGMList(
    sample_list=sample_list,
    gene_list=gene_list,
    all_sample_features=ICB_sample_features,
    model_path='../output/model/tumor/checkpoint_all/xx/200_model.pth',
    seed=seed
)

In [None]:
# use HCC based model
predict_mask_matrix_list = GetPredictTMList(
    sample_list=sample_list,
    all_sample_features=ICB_sample_features,
    model_path='../output/model/gene/checkpoint_all/xx/200_model.pth',
    seed=seed
)

In [None]:
# save results
os.makedirs('../output/predict_matrix/',exist_ok=True)
with open('../output/predict_matrix/predict_gene_matrix_list.pkl','wb') as f:
    pickle.dump(predict_gene_matrix_list,f)
with open('../output/predict_matrix/predict_mask_matrix_list.pkl','wb') as f:
    pickle.dump(predict_mask_matrix_list,f)

## Train

In [None]:
labels = pd.get_dummies(metadata['response'],dtype = int)
labels_tensor = torch.tensor(np.array(labels), dtype=torch.float32)

ICBsolver = ICB_Solver(
    seed = 22,
    num_classes=2,
    drop_path_rate=0.4,
    depths=[2, 2, 8, 2],
    dims=[16, 24, 32, 40],
    epochs = 200,
    lr = 1e-4,
    kfold_seed = 123,
    verbose = True
)

ICBsolver.train_kfold(
        labels_tensor=labels_tensor,
        gene_matrix_list=predict_gene_matrix_list,
        mask_matrix_list=predict_mask_matrix_list,
        he_features = ICB_sample_features,
        method = 'gene',
        kfold_splits = 5,
        batch_size = 20,
        out_dir = '../output/model/HCC_ICB/checkpoint_5fold/'
)