In [None]:
from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
!pip install -q uncertainty-calibration open_clip_torch

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for uncertainty-calibration (setup.py) ... [?25l[?25hdone


In [None]:
!pip install -q git+https://github.com/openai/CLIP.git

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for clip (setup.py) ... [?25l[?25hdone


In [None]:
!pip install -q datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m388.9/388.9 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import random
import os
from dataclasses import dataclass
import glob
import pathlib
from tqdm import tqdm
from PIL import Image

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from torch.nn import functional as F
from torchvision import transforms
import torchvision
from torchvision.datasets import ImageFolder
from transformers import AutoProcessor, CLIPModel, AutoTokenizer
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, Lambda
import calibration as cal
import open_clip
import clip
from datasets import load_dataset

import matplotlib.pyplot as plt
from matplotlib import style
plt.style.use('seaborn-v0_8')

In [None]:
data_root = "drive/MyDrive/CV2_project/data/"

MHIST_DIR = data_root+'mhist/'
PCAM_DIR = data_root+'pcam/'
LCLUNG_DIR = data_root+'lung_colon_image_set/lung_image_sets/'
LCCOLON_DIR = data_root+'lung_colon_image_set/colon_image_sets/'
BACH_DIR = data_root+'ICIAR2018_BACH_Challenge/Photos/'
NCK_DIR = data_root+'NCT-CRC-HE-100K/'
SICAPv2_DIR = data_root+'SICAPv2/'
SKIN_DIR = data_root+'SkinCancer_Files/data/'
OSTEO_DIR = data_root+'Osteosarcoma_Tumor_Assessment/Separated/'
RENAL_DIR = data_root+'tissue_classification/'
DATABIOX_DIR = data_root+'databiox/'
TUMOR_DIR = data_root+'SkinTumor/'


In [None]:
all_dataset_class_labels = {
    "pcam": [
        "lymph node",
        "lymph node containing metastatic tumor tissue"
    ],
    "nck": ["adipose",
            "debris",
            "lymphocytes",
            "mucus",
            "smooth muscle",
            "normal colon mucosa",
            "cancer-associated stroma",
            "colorectal adenocarcinoma epithelium"
    ],
    "lc25000_lung": ["benign lung",
                     "lung adenocarcinoma",
                     "lung squamous cell carcinoma"
    ],
    "lc25000_colon": ["colon adenocarcinoma",
                      "benign colonic tissue"
    ],
    "mhist": ["hyperplastic polyp",
              "sessile serrated adenoma"
    ],
    "sicap": ["benign glands",
        "atrophic dense glands",
        "cribriform ill-formed fused papillary patterns",
        "isolated nest cells without lumen roseting patterns"
    ],
    "idc_grade": ["well differentiated bloom richardson grade one",
                "moderately differentiated bloom richardson grade two",
                "poorly differentiated grade three"
    ],
    "databiox": ["well differentiated bloom richardson grade one",
                "moderately differentiated bloom richardson grade two",
                "poorly differentiated grade three"
    ],
    "osteo": ["non-tumor",
        "non-viable necrotic osteosarcoma tumor",
        "viable osteosarcoma tumor"
    ],
    "bach": ["breast non-malignant benign tissue",
            "breast malignant in-situ carcinoma",
            "breast malignant invasive carcinoma",
            "breast normal breast tissue"],
    "renal_cell": ["red blood cells",
                     "renal cancer",
                     "normal renal tissue",
                     "torn adipose necrotic tissue",
                     "muscle fibrous stroma blood vessels"
    ],
    "skin": ["necrosis",
        "skeletal muscle",
        "eccrine sweat glands",
        "vessels",
        "elastosis",
        "chondral tissue",
        "hair follicle",
        "epidermis",
        "nerves",
        "subcutis",
        "dermis",
        "sebaceous glands",
        "squamous-cell carcinoma",
        "melanoma in-situ",
        "basal-cell carcinoma",
        "naevus"
    ],
    "skin_tumor": [
        "squamous-cell carcinoma",
        "melanoma in-situ",
        "basal-cell carcinoma",
        "naevus"
    ]
}

In [None]:
class MhistDataset(torch.utils.data.Dataset):
    def __init__(self, root, csv_file, image_dir, transform=None, train=True):
        csv_file = os.path.join(root, csv_file)
        image_dir = os.path.join(root, image_dir)

        self.data = pd.read_csv(csv_file)
        # if train:
        #     self.data = self.data[self.data['Partition'] == 'train']
        # else:
        #     self.data = self.data[self.data['Partition'] != 'train']
        self.image_paths = self.data['Image Name'].values
        self.labels = self.data['Majority Vote Label'].values
        self.image_dir = image_dir
        self.transform = transform
        self.train = train
        self.cat_to_num_map = {'HP': 0, 'SSA': 1}
        self.classes = ["hyperplastic polyp", "sessile serrated adenoma"]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.image_paths[index])
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = self.cat_to_num_map[self.labels[index]]

        return image, label


class SicapDataset(torch.utils.data.Dataset):
    def __init__(self, root, image_dir, transform=None, train=True):

        image_dir = os.path.join(root, image_dir)

        if train:
            csv_file = os.path.join(root, "partition/Test", "Train.xlsx")
            self.data = pd.read_excel(csv_file)
        else:
            csv_file = os.path.join(root, "partition/Test", "Test.xlsx")
            self.data = pd.read_excel(csv_file)

        # drop all columns except image_name and the label columns
        label_columns = ['NC', 'G3', 'G4', 'G5']  # , 'G4C']
        self.data = self.data[['image_name'] + label_columns]

        # get the index of the maximum label value for each row
        self.data['labels'] = self.data[label_columns].idxmax(axis=1)

        # replace the label column values with categorical values
        self.cat_to_num_map = label_map = {'NC': 0, 'G3': 1, 'G4': 2, 'G5': 3}  # , 'G4C': 4}
        self.data['labels'] = self.data['labels'].map(label_map)

        self.image_paths = self.data['image_name'].values
        self.labels = self.data['labels'].values
        self.image_dir = image_dir
        self.transform = transform
        self.train = train
        self.classes = ["non-cancerous well-differentiated glands",
                        "gleason grade 3 with atrophic well differentiated and dense glandular regions",
                        "gleason grade 4 with cribriform, ill-formed, large-fused and papillary glandular patterns",
                        "gleason grade 5 with nests of cells without lumen formation, isolated cells and pseudo-roseting patterns",
                        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.image_paths[index])
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = self.labels[index]

        return image, label


class ArchCsvDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, transforms, img_key='image_path', caption_key='caption', sep=","):
        df = pd.read_csv(csv_file, sep=sep)
        self.images = df[img_key].tolist()
        self.captions = df[caption_key].tolist()
        self.transforms = transforms
        self.ids = list(sorted(df['ids'].tolist()))

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        id_ = self.ids[idx]
        images = self.transforms(Image.open(str(self.images[id_])))
        texts = [str(self.captions[id_])]
        return images, texts


class OsteoDataset(torch.utils.data.Dataset):
    def __init__(self, root, csv_file, image_dir, transform=None):
        csv_file = os.path.join(root, csv_file)
        image_dir = os.path.join(root, image_dir)

        self.data = pd.read_csv(csv_file)
        self.data = self.data[self.data['classification'] != "viable: non-viable"] #53 samples removed from 1144
        self.image_paths = self.data['image.name'].values
        self.labels = self.data['classification'].values
        self.image_dir = image_dir
        self.transform = transform
        self.cat_to_num_map = {'Non-Tumor': 0, 'Non-Viable-Tumor': 1, 'Viable': 2}
        self.classes = ["non-tumor", "necrotic tumor", "viable tumor"]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.image_paths[index])
        image_path = image_path.replace(' - ', '-')
        image_path = glob.glob(f"{image_path.replace(' ', '-')}*")[0]
        image = Image.open(image_path.replace(' ', '-')).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = self.cat_to_num_map[self.labels[index]]

        return image, label


class SkinDataset(torch.utils.data.Dataset):
    def __init__(self, root, csv_file, transform=None, train=True, val=False,
                 tumor=False):
        csv_file = os.path.join(root, csv_file)
        self.data = pd.read_csv(csv_file)

        if train:
            self.data = self.data[self.data['set'] == 'Train']
        else:
            if val:
                self.data = self.data[self.data['set'] == "Validation"]
            else:
                self.data = self.data[self.data['set'] == 'Test']

        if tumor:
            self.data = self.data[self.data['malignicy'] == 'tumor']
        self.tumor = tumor

        self.image_paths = self.data['file'].values
        self.labels = self.data['class'].values

        self.transform = transform
        self.train = train

        self.cat_to_num_map = {'nontumor_skin_necrosis_necrosis': 0,
                               'nontumor_skin_muscle_skeletal': 1,
                               'nontumor_skin_sweatglands_sweatglands': 2,
                               'nontumor_skin_vessel_vessel': 3,
                               'nontumor_skin_elastosis_elastosis': 4,
                               'nontumor_skin_chondraltissue_chondraltissue': 5,
                               'nontumor_skin_hairfollicle_hairfollicle': 6,
                               'nontumor_skin_epidermis_epidermis': 7,
                               'nontumor_skin_nerves_nerves': 8,
                               'nontumor_skin_subcutis_subcutis': 9,
                               'nontumor_skin_dermis_dermis': 10,
                               'nontumor_skin_sebaceousglands_sebaceousglands': 11,
                               'tumor_skin_epithelial_sqcc': 12,
                               'tumor_skin_melanoma_melanoma': 13,
                               'tumor_skin_epithelial_bcc': 14,
                               'tumor_skin_naevus_naevus': 15
                               }

        self.tumor_map = {'tumor_skin_epithelial_sqcc': 0,
                          'tumor_skin_melanoma_melanoma': 1,
                          'tumor_skin_epithelial_bcc': 2,
                          'tumor_skin_naevus_naevus': 3
                          }

        self.classes = list(self.cat_to_num_map)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        if not self.tumor:
            label = self.cat_to_num_map[self.labels[index]]
        else:
            label = self.tumor_map[self.labels[index]]

        return image, label

In [None]:
def get_dataset(preprocess, args):

    def iloader(path):
        image = Image.open(path)
        return image

    if args.dataset == 'pcam':

        args.data_dir = pathlib.Path(PCAM_DIR)
        dataset = torchvision.datasets.PCAM(
            root=args.data_dir,
            transform=preprocess,
            download=True,
            split="test",
        )


    elif args.dataset == 'lc25000_colon':

        args.data_dir = pathlib.Path(LCCOLON_DIR)
        dataset = torchvision.datasets.DatasetFolder(
            root=args.data_dir,
            loader = iloader,
            transform=preprocess,
            extensions = 'jpeg'
        )


    elif args.dataset == 'lc25000_lung':

        # args.data_dir = pathlib.Path(LCLUNG_DIR)
        # dataset = torchvision.datasets.DatasetFolder(
        #     root=args.data_dir,
        #     loader = iloader,
        #     transform=preprocess,
        #     extensions = 'jpeg'
        # )

        dataset = load_dataset("1aurent/LC25000")
        dataset = dataset["train"].filter(lambda example: example["organ"] == 0)
        dataset = dataset.with_format("torch")
        def ds_transforms(examples):
            examples["image"] = [preprocess(image) for image in examples["image"]]
            return examples
        dataset = dataset.with_transform(ds_transforms)

    elif args.dataset == 'bach':

        args.data_dir = pathlib.Path(BACH_DIR)
        dataset = torchvision.datasets.ImageFolder(
            root=args.data_dir,
            transform=preprocess,
        )

    elif args.dataset == 'nck':

        args.data_dir = pathlib.Path(NCK_DIR)
        dataset = torchvision.datasets.ImageFolder(
            root=args.data_dir,
            transform=preprocess,
        )

    elif args.dataset == 'sicap':

        args.data_dir = pathlib.Path(SICAPv2_DIR)
        dataset = SicapDataset(
            root=args.data_dir,
            image_dir="images",
            transform=preprocess,
            train=False
        )

    elif args.dataset == 'skin':

        args.data_dir = pathlib.Path(SKIN_DIR)
        dataset = SkinDataset(
            root=args.data_dir,
            csv_file="tiles-v2.csv",
            transform=preprocess,
            train=False,
            tumor=False
        )

    elif args.dataset == 'skin_tumor':

        args.data_dir = pathlib.Path(TUMOR_DIR)
        dataset = SkinDataset(
            root=args.data_dir,
            csv_file="tiles-v3.csv",
            transform=preprocess,
            train=False,
            tumor=True
        )

    elif args.dataset == 'mhist':

        args.data_dir = pathlib.Path(MHIST_DIR)
        dataset = MhistDataset(
            root=args.data_dir,
            csv_file="annotations.csv",
            image_dir="images",
            transform=preprocess
        )

    elif args.dataset == 'osteo':

      args.data_dir = pathlib.Path(OSTEO_DIR)
      dataset = torchvision.datasets.ImageFolder(
          root=args.data_dir,
          transform=preprocess,
      )
    elif args.dataset == 'renal_cell':

      args.data_dir = pathlib.Path(RENAL_DIR)
      dataset = torchvision.datasets.ImageFolder(
          root=args.data_dir,
          transform=preprocess,
      )
    elif args.dataset == 'databiox':

      args.data_dir = pathlib.Path(DATABIOX_DIR)
      dataset = torchvision.datasets.ImageFolder(
          root=args.data_dir,
          transform=preprocess,
    )


    else:
        raise ValueError

    return dataset


def get_label_texts(labels_to_classname, args):

    label_to_texts = dict()

    if args.descriptors is not None:
        if args.descriptors == "sentence":

            save_path = "drive/MyDrive/CV2_project/code/med_vlm_cal/descriptors/sentence/{}.csv".format(args.dataset)
            df = pd.read_csv(save_path)
            for label in all_dataset_class_labels[args.dataset]:
                row = df[df["label"] == label].iloc[0]
                response = eval(row["response"])
                texts = []
                for c in response["choices"]:
                    clean = c["message"]["content"].strip().strip('\"')
                    texts.append(clean)
                label_to_texts[label] = texts

        elif args.descriptors == "feature":

            save_path = "drive/MyDrive/CV2_project/code/med_vlm_cal/descriptors/feature/{}.csv".format(dataset)
            df = pd.read_csv(save_path)
            for label in all_dataset_class_labels[args.dataset]:
                row = df[df["label"] == label].iloc[0]
                response = eval(row["response"])
                c = response["choices"][0]["message"]["content"].strip()
                texts = ["histopathology image of "+label+" with "+t.split(". ")[1].strip(".") for t in c.split("\n")]
                label_to_texts[label] = texts

        else:
            raise ValueError

    else:
        templates = ["a histopathology slide showing {c}",
                "histopathology image of {c}",
                "pathology tissue showing {c}",
                "presence of {c} tissue on image"]
        for classname in labels_to_classname:
            texts = [template.format(c=classname) for template in templates]
            label_to_texts[classname] = texts

    print(label_to_texts)

    return label_to_texts


def compute_label_encodings(model, tokenizer, label_to_texts, args):

    zeroshot_weights = []
    for classname, texts in label_to_texts.items():
        print(classname, texts)
        if args.model in ["plip", "pubmed"]:
            input = tokenizer(texts, padding=True, return_tensors="pt").to("cuda")
            class_embedding = model.get_text_features(**input)
        else:
            input = tokenizer(texts).to(args.device)  # tokenize
            class_embedding = model.encode_text(input)

        class_embedding = F.normalize(class_embedding, dim=-1).mean(dim=0)
        class_embedding /= class_embedding.norm()
        zeroshot_weights.append(class_embedding)

    label_encodings = torch.stack(zeroshot_weights, dim=1).T.to(args.device)

    return label_encodings


clip_transform = Compose([
            Lambda(lambda img: img.convert("RGB")),
            Resize(size=224, interpolation=InterpolationMode.BICUBIC),              # Resize the shortest side to 256 pixels
            CenterCrop(224),          # Crop a square in the center with sides of length 224 pixels
            ToTensor(),               # Convert the image to a PyTorch tensor
            Normalize(mean=[0.48145466, 0.4578275, 0.40821073],  # Normalize using the mean and std dev
                      std=[0.26862954, 0.26130258, 0.27577711]),
        ])


def get_model_and_tokenizer(args):

    if args.model == "quilt":
        model_path = "hf-hub:wisdomik/QuiltNet-B-32"
        model, _, preprocess = open_clip.create_model_and_transforms(model_path, device=args.device)
        tokenizer = open_clip.get_tokenizer(model_path)

    elif args.model == "quilt16":
        model_path = "hf-hub:wisdomik/QuiltNet-B-16"
        model, _, preprocess = open_clip.create_model_and_transforms(model_path, device=args.device)
        tokenizer = open_clip.get_tokenizer(model_path)

    elif args.model == "quiltbert":
        model_path = "hf-hub:wisdomik/QuiltNet-B-16-PMB"
        model, _, preprocess = open_clip.create_model_and_transforms(model_path, device=args.device)
        tokenizer = open_clip.get_tokenizer(model_path)

    elif args.model == "biomed":
        model_path = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
        model, preprocess = open_clip.create_model_from_pretrained(model_path, device=args.device)
        tokenizer = open_clip.get_tokenizer(model_path)

    elif args.model == "plip":
        model_path = "vinid/plip"
        model = CLIPModel.from_pretrained(model_path).to("cuda")
        tokenizer = AutoProcessor.from_pretrained(model_path).tokenizer
        preprocess = clip_transform

    elif args.model == "pubmed":
        model_path = "flaviagiammarino/pubmed-clip-vit-base-patch32"
        model = CLIPModel.from_pretrained(model_path).to("cuda")
        tokenizer = AutoProcessor.from_pretrained(model_path).tokenizer
        preprocess = clip_transform

    elif args.model == "vitb32":
        model_path = "ViT-B-32"
        model, _, preprocess = open_clip.create_model_and_transforms(model_path, device=args.device)
        tokenizer = open_clip.get_tokenizer(model_path)

    else:
        raise ValueError

    return model, preprocess, tokenizer

In [None]:
def extract_features(args):
    torch.manual_seed(args.seed)

    print("Loading model...")

    device = torch.device(args.device)

    model, preprocess, tokenizer = get_model_and_tokenizer(args)

    model.eval()

    dataset = get_dataset(preprocess, args)
    print("n data:", len(dataset))

    save_path = "drive/MyDrive/CV2_project/code/med_vlm_cal/output/{}/{}/".format(args.model, args.dataset)
    if args.descriptors is not None:
        save_path += "{}_desc_".format(args.descriptors)
    print("save to:", save_path)

    torch.manual_seed(args.seed)
    dataloader = DataLoader(
        dataset, args.batch_size,
        num_workers=2, pin_memory=True,
        shuffle=True
    )

    class_label_text = all_dataset_class_labels[args.dataset]
    label_to_texts = get_label_texts(class_label_text, args)

    with torch.no_grad():

        label_encodings = compute_label_encodings(model, tokenizer, label_to_texts, args)

        all_features = []
        all_logits = []
        all_labels = []

        batch_idx = 0
        for batch in tqdm(dataloader):

            if args.dataset in ["lc25000_lung"]:
                images = batch["image"]
                labels = batch["label"]
            else:
                images, labels = batch

            images = images.to(device)
            labels = labels.to(device)

            if args.model in ["plip", "pubmed"]:
                image_encodings = model.get_image_features(images)
                image_encodings = F.normalize(image_encodings)
            else:
                image_encodings = model.encode_image(images, normalize=True)
            image_labels_similarity = 100*image_encodings @ label_encodings.T

            all_features.append(image_encodings.detach().cpu())
            all_logits.append(image_labels_similarity.detach().cpu())
            all_labels.extend(labels.detach().cpu().tolist())

            batch_idx += 1

            if len(all_labels) >= args.max_items:
                break

    print("Done loop")

    all_features = torch.vstack(all_features).numpy()
    all_logits = torch.vstack(all_logits).numpy()
    all_labels = np.array(all_labels)

    print(all_features.shape, all_logits.shape, all_labels.shape)

    acc = np.sum(np.argmax(all_logits, -1) == all_labels)/all_features.shape[0]
    print("acc", acc)

    os.makedirs(save_path, exist_ok=True)

    np.save(save_path+"features.npy", all_features)
    np.save(save_path+"logits.npy", all_logits)
    np.save(save_path+"labels.npy", all_labels)

In [None]:
@dataclass
class Args:
    dataset: str
    model: str="quilt"
    seed: int=0
    device: str="cuda"
    batch_size: int=32
    max_items: int=100000
    descriptors: str=None

In [None]:
for dataset in [
    "lc25000_lung",
    "lc25000_colon",
    "mhist",
    "pcam",
    "bach",
    "nck",
    "osteo",
    "renal_cell",
    "skin",
    "skin_tumor",
    "sicap",
    "databiox"
]:
  for model in [
      "plip",
      "biomed",
      "quilt",
  ]:
    args = Args(dataset = dataset, model=model)
    extract_features(args)