In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install git+https://github.com/Mahmoodlab/CONCH.git

In [2]:
import os
from pathlib import Path
import json

from conch.open_clip_custom import create_model_from_pretrained
from conch.downstream.zeroshot_path import zero_shot_classifier, run_zeroshot

import torch 
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# display all jupyter output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [8]:
!pwd

/kaggle


In [3]:
root = Path('../').resolve()
os.chdir(root)

In [4]:
model_cfg = 'conch_ViT-B-16'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
checkpoint_path = './input/pytorch-model/pytorch_model.bin'
force_image_size = 224
model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, device=device,
                                                 force_image_size=force_image_size)
_ = model.eval()

In [5]:
data_source = '/kaggle/input/train-tcga-coad-msi-mss/tcga_coad_msi_mss/test'
dataset = ImageFolder(data_source, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)
if hasattr(dataloader.dataset, 'class_to_idx'):
     idx_to_class = {v:k for k,v in dataloader.dataset.class_to_idx.items()}
else:
     raise ValueError('Dataset does not have label_map attribute')
print("num samples: ", len(dataloader.dataset))
print(idx_to_class)

num samples:  19233
{0: 'MSIMUT', 1: 'MSS'}


In [9]:
json_file = {
    "0": {
        "classnames": {
#             'MSIMUT': ["mucus",
#                 "mucin",
#                 "mucus pool",
#                 "mucin pool"],
#             'MSS': ["smooth muscle",
#                 "smooth muscle tissue",
#                 "muscle",
#                 "muscularis propria",
#                 "muscularis mucosa"]
#             'MSIMUT': ["microsatellite instable",
#                "msi-h",
#                "microsatellite instability high",
#                "msi high",
#                "microsatellite unstable",
#                "msi mutant",
#                "microsatellite instability",
#                "high msi",
#                "msi high cancer",
#                "msi-high tumor",
#                "msi-high"],
    
#             'MSS': ["microsatellite stable",
#                 "non msi-h",
#                 "microsatellite stability",
#                 "msi low",
#                 "microsatellite stable cancer",
#                 "mss tumor",
#                 "microsatellite stable tumor",
#                 "low msi",
#                 "stable msi",
#                 "msi-stable",
#                 "mss cancer"]
#         },
    'MSIMUT': [
        "msi-h",
        "msi high",
        "msi mutant",
        "high msi",
        "msi high cancer",
        "msi-high tumor",
        "colorectal adenocarcinoma",
        "endometrioid carcinoma",
        "gastric adenocarcinoma",
        "ovarian carcinoma",
        "small intestine adenocarcinoma",
        "cholangiocarcinoma",
        "hepatocellular carcinoma",
        "urothelial carcinoma",
        "glioblastoma",
        "medulloblastoma",
        "sebaceous gland carcinoma"
    ],
    
    'MSS': [
        "microsatellite stable",
        "non msi-h",
        "stable msi",
        "mss tumor",
        "microsatellite stable tumor",
        "cutaneous squamous cell carcinoma"
    ]
        },


       "templates": [
            "CLASSNAME.",
            "a photomicrograph showing CLASSNAME.",
            "a photomicrograph of CLASSNAME.",
            "an image of CLASSNAME.",
            "an image showing CLASSNAME.",
            "an example of CLASSNAME.",
            "CLASSNAME is shown.",
            "this is CLASSNAME.",
            "there is CLASSNAME.",
            "a histopathological image showing CLASSNAME.",
            "a histopathological image of CLASSNAME.",
            "a histopathological photograph of CLASSNAME.",
            "a histopathological photograph showing CLASSNAME.",
            "shows CLASSNAME.",
            "presence of CLASSNAME.",
            "CLASSNAME is present.",
            "an H&E stained image of CLASSNAME.",
            "an H&E stained image showing CLASSNAME.",
            "an H&E image showing CLASSNAME.",
            "an H&E image of CLASSNAME.",
            "CLASSNAME, H&E stain.",
            "CLASSNAME, H&E."
        ]
    }
}

In [10]:
prompt_file = json_file
prompts = prompt_file['0']
classnames = prompts['classnames']
templates = prompts['templates']
n_classes = len(classnames)
classnames_text = [classnames[str(idx_to_class[idx])] for idx in range(n_classes)]
for class_idx, classname in enumerate(classnames_text):
    print(f'{class_idx}: {classname}')

0: ['msi-h', 'msi high', 'msi mutant', 'high msi', 'msi high cancer', 'msi-high tumor', 'colorectal adenocarcinoma', 'endometrioid carcinoma', 'gastric adenocarcinoma', 'ovarian carcinoma', 'small intestine adenocarcinoma', 'cholangiocarcinoma', 'hepatocellular carcinoma', 'urothelial carcinoma', 'glioblastoma', 'medulloblastoma', 'sebaceous gland carcinoma']
1: ['microsatellite stable', 'non msi-h', 'stable msi', 'mss tumor', 'microsatellite stable tumor', 'cutaneous squamous cell carcinoma']


In [11]:
zeroshot_weights = zero_shot_classifier(model, classnames_text, templates, device=device)
print(zeroshot_weights.shape)

torch.Size([512, 2])


In [12]:
results, dump = run_zeroshot(model, zeroshot_weights, dataloader, device, 
                    dump_results=True, metrics=['bacc', 'weighted_f1'])

  self.pid = os.fork()
  self.pid = os.fork()
100%|██████████| 151/151 [03:59<00:00,  1.58s/it]


In [14]:
for k, v in results.items():
    print(f'{k}: {v:.8f}')

bacc: 0.50053113
weighted_f1: 0.36316994


# RUN MIZERO

In [16]:
import os
from pathlib import Path
import json

from conch.open_clip_custom import create_model_from_pretrained
from conch.downstream.zeroshot_path import zero_shot_classifier, run_mizero
from conch.downstream.wsi_datasets import WSIEmbeddingDataset

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pandas as pd 
import numpy as np

# display all jupyter output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [17]:
root = Path('../').resolve()
os.chdir(root)

In [19]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
checkpoint_path = '/kaggle/input/pytorch-model/pytorch_model.bin'
model, _ = create_model_from_pretrained(model_cfg='conch_ViT-B-16', checkpoint_path=checkpoint_path, device=device)
_ = model.eval()

In [26]:
# zeroshot_weights = zero_shot_classifier(model, classnames_text, templates, device=device)
# print(zeroshot_weights.shape)
# results, dump = run_mizero(model, zeroshot_weights, dataloader, device, 
#                     dump_results=True, metrics=['bacc', 'weighted_f1'])