In [1]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.10 (you have 1.4.8). Upgrade using: pip install --upgrade albumentations


In [2]:
from torch.utils.data import Dataset
import pandas as pd
from typing import List, Optional, Callable
import h5py
# from utils.helpers import get_transforms, seed_everything
# from utils.dataset import load_success_ids
from random import sample
from tqdm import tqdm
from PIL import Image
import sqlite3

In [3]:
def get_transforms(train=False):
    """
    Takes a list of images and applies the same augmentations to all of them.
    This is completely overengineered but it makes it easier to use in our pipeline
    as drop-in replacement for torchvision transforms.
    ## Example
    ``` python
    imgs = [Image.open(f”image{i}.png”) for i in range(1, 4)]
    t = get_albumentations_transforms(train=True)
    t_imgs = t(imgs) # List[torch.Tensor]
    ```
    For the single image case:
    ```python
    img = Image.open(f”image{0}.png”)
    # or img = np.load(some_bytes)
    t = get_albumentations_transforms(train=True)
    t_img = t(img) # torch.Tensor
    ```
    """
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    _data_transform = None
    def _get_transform(n: int = 3):
        if train:
            data_transforms = A.Compose(
                [
                    A.Resize(224, 224),
                    A.OneOf(
                        [
                            A.Rotate(limit=0, p=1),
                            A.Rotate(limit=90, p=1),
                            A.Rotate(limit=180, p=1),
                            A.Rotate(limit=270, p=1),
                        ],
                        p=0.5,
                    ),
                    A.Compose(
                        [
                            A.OneOf(
                                [
                                    A.ColorJitter(
                                        brightness=(0.9, 1),
                                        contrast=(0.9, 1),
                                        saturation=(0.9, 1),
                                        hue=(0, 0.1),
                                        p=1.0,
                                    ),
                                    A.Affine(
                                        scale=(0.5, 1.5),
                                        translate_percent=(0.0, 0.0),
                                        shear=(0.5, 1.5),
                                        p=1.0,
                                    ),
                                ],
                                p=0.5,
                            ),
                            A.GaussianBlur(
                                blur_limit=(1, 3), sigma_limit=(0.1, 3), p=1.0
                            ),
                        ]
                    ),
                    A.OneOf(
                        [
                            A.HorizontalFlip(p=0.5),
                            A.VerticalFlip(p=0.5),
                        ],
                        p=0.5,
                    ),
                    A.Normalize(mean=mean, std=std),
                    ToTensorV2(),
                ],
                additional_targets={f"image{i}": "image" for i in range(1, n)},
            )
        else:
            data_transforms = A.Compose(
                [
                    A.Resize(224, 224),
                    A.Normalize(mean=mean, std=std),
                    ToTensorV2(),
                ],
                additional_targets={f"image{i}": "image" for i in range(1, n)},

            )
        return data_transforms
    def transform_images(images: any):
        nonlocal _data_transform
        if not isinstance(images, list):
            n = 1
            images = [images]
        else:
            n = len(images)
        if _data_transform is None:
            # instantiate once
            _data_transform = _get_transform(n)
        # accepts both lists of np.Array and PIL.Image
        if isinstance(images[0], Image.Image):
            images = [np.array(img) for img in images]
        image_dict = {"image": images[0]}
        for i in range(1, n):
            image_dict[f"image{i}"] = images[i]
        transformed = _data_transform(**image_dict)
        transformed_images = [
            transformed[key] for key in transformed.keys() if "image" in key
        ]
        if len(transformed_images) == 1:
            return transformed_images[0]
        return transformed_images
    return transform_images

In [4]:
def load_success_ids(feat_folder: str):
    """
    Backwards-compatible loading of success IDs.
    We either load the available slide ids from the deprecated success.txt file
    or from the success.db sqlite database.
    If both files exist, we always prefer the database.
    """
    success_ids = set()
    success_txt = f"{feat_folder}/success.txt"
    success_db = f"{feat_folder}/success.db"
    if os.path.exists(success_txt):
        print("Warning: Loading success IDs from deprecated success.txt.")
        with open(success_txt, "r") as f:
            for line in f:
                success_ids.add(line.strip())
    if os.path.exists(success_db):
        print("Loading success IDs from database.")
        conn = sqlite3.connect(success_db)
        cursor = conn.cursor()
        cursor.execute("SELECT slide_id FROM success")
        success_ids = set([row[0] for row in cursor.fetchall()])
        conn.close()
    return success_ids

In [5]:
def seed_everything(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [6]:
import torch
import numpy as np
import random

In [7]:
SEED = 42
seed_everything(SEED)

In [22]:
class IdilDataSet(Dataset):
    """
    Only for single dataset classes!
    """
    def __init__(
        self,
        csv_path: str,
        folder: str,
        magnification: int,
        transform: Optional[Callable] = get_transforms(),
        n_patches: int = 250,
        random_selection=False,
        limit: Optional[int] = None,
        wsi_type: str = "frozen"
    ):
        super().__init__()
        self.csv = pd.read_csv(csv_path)
        self.csv = self.csv[self.csv["wsi_type"] == wsi_type]

        # Filter out unwanted tumor types
        self.csv = self.csv[self.csv["Tumor Type"] != "Oligoastrocytoma"]
        
        # Replace grade values
        self.csv["Neoplasm Histologic Grade"] = self.csv["Neoplasm Histologic Grade"].replace({"G2": "low grade glioma", "G3": "high grade glioma"})
        
        # Replace IDH status values
        self.csv["Subtype"] = self.csv["Subtype"].replace({
            "LGG_IDHmut-non-codel": "IDH mutation",
            "LGG_IDHmut-codel": "IDH mutation",
            "LGG_IDHwt": "wild-type IDH"
        })

        
        self.folder = folder
        self.magnification = magnification
        self.transform = transform
        self.n_patches = n_patches
        self.random_selection = random_selection
        self.slide_ids = self.csv["uuid"].unique()
        success_ids = load_success_ids(self.folder)
        self.slide_ids = [x for x in self.slide_ids if x in success_ids]
        if limit:
            self.slide_ids = self.slide_ids[:limit]
        self.labels = []
        self.patches = []
        self.load_patches()
        self.compute_weights()
        
    def load_patches(self):
        """
        Load n_patches into memory.
        """
        for slide_id in tqdm(self.slide_ids, desc="Prefetch patches"):
            # TODO: adjust `_features.h5` once we renamed it on the storage server
            file = f"{self.folder}/{slide_id}_features.h5"
            try:
                with h5py.File(file, "r") as h5f:
                    n_patches = min(self.n_patches, len(h5f[str(self.magnification)]))
                    # select random indices
                    if self.random_selection:
                        indices = sample(range(n_patches), n_patches)
                    else:
                        indices = list(range(n_patches))
                    imgs = [
                        Image.fromarray(h5f[str(self.magnification)][i]) for i in indices
                    ]
                    self.patches.append((imgs, slide_id))
                    self.labels.append(self.get_label(slide_id))
            except Exception as e:
                pass
    def __len__(self):
        return len(self.patches)
    def get_label(self, slide_id):
        return self.csv.loc[self.csv["uuid"] == slide_id, "label"].values[0]
    def get_metadata(self, slide_id):
        return self.csv.loc[self.csv["uuid"] == slide_id]
    def compute_weights(self):
        """
        Compute weights for WeightedRandomSampler.
        """
        class_counts = {}
        for label in self.labels:
            if label in class_counts:
                class_counts[label] += 1
            else:
                class_counts[label] = 1
        class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
        self.weights = [class_weights[label] for label in self.labels]
    def __getitem__(self, idx):
        imgs, slide_id = self.patches[idx]
        imgs = [self.transform(img) for img in imgs]
        label = self.get_label(slide_id)
        metadata = self.get_metadata(slide_id)
        age = metadata["Diagnosis Age"].values[0]
        race = metadata["Race Category"].values[0]
        sex = metadata["Sex"].values[0]
        grade = metadata["Neoplasm Histologic Grade"].values[0]
        IDHstatus = metadata["Subtype"].values[0]
        tumortype = metadata["Tumor Type"].values[0]
        prompt = f"a frozen brain histopathology slide of a {race.lower()}, {sex.lower()}, age {age}, has {IDHstatus.lower()}, {grade.lower()} of {tumortype.lower()}"
        return slide_id, imgs, label, prompt
    

In [23]:
folder = "/n/data2/hms/dbmi/kyu/lab/che099/data/tcga_lgg/frozen_patches_20x"
csv = "/n/data2/hms/dbmi/kyu/lab/che099/data/idil_tcga_lgg_merge_idh.csv"

assert os.path.exists(folder)
assert os.path.exists(csv)
dataset = IdilDataSet(csv, folder=folder, magnification=20, random_selection=True, limit=None, wsi_type="frozen")



Prefetch patches: 100%|███████████████████████| 536/536 [01:06<00:00,  8.04it/s]


In [24]:
#slide_id, imgs, label, age, race, sex, grade, IDHstatus, tumortype = dataset[0]
#slide_id, len(imgs), age, race, sex, grade, IDHstatus, tumortype

slide_id, imgs, label, prompt = dataset[0]
slide_id, len(imgs), label, prompt

('021D115A-2E27-4E28-9723-0AE434E869CC',
 250,
 1,
 'a frozen brain histopathology slide of a white, male, age 40.0, has idh mutation, high grade glioma of oligodendroglioma')

In [18]:
load_success_ids(folder)



{'00620C3D-01C2-4487-B556-44697751DCCE',
 '00e68225-1fd3-48e2-92f0-9ad0c3b8302c',
 '021AF7FB-B3EA-4722-86FB-47E3D11F82AB',
 '021D115A-2E27-4E28-9723-0AE434E869CC',
 '02876b00-a460-4a00-af72-f227e199b73f',
 '0346F513-0EC9-45BE-9B9E-B3E84D5C7F66',
 '03F3FB8A-2F69-4FF1-AD60-F1DAFD4DE88F',
 '04314e91-e88b-4bee-b824-59ec10b7ae22',
 '044a24cf-c213-4381-a4da-9b2519e0c923',
 '0450182b-587e-4cb5-afec-5029859ef0b8',
 '047aaa50-428a-410f-a09c-1aeed00676ea',
 '04fc974d-27d6-4a50-8363-7958c3f59144',
 '051c4545-4fa8-4741-859c-0c5faad7dfe4',
 '0546D5AC-B409-442F-866B-09A293AAC0E5',
 '054903B8-8FC1-4302-BAFD-032FE13606B6',
 '058D1B01-62DB-417D-B758-359E80C801BE',
 '05C13AC5-F5DC-4815-BA94-64C239508D57',
 '05F11061-E005-4602-834F-D71A77E163ED',
 '0609DB24-3F98-4B1F-B675-DC0B94C08BCF',
 '062e4970-750f-441d-a06c-6bbfa2b8c9cc',
 '06A061DF-AC98-4751-8EF5-034F2D8C840D',
 '06a0f502-568c-4930-925b-0dd04eed26a7',
 '07D06BDE-21EF-4EE8-86FF-1CAE6981B357',
 '088E4193-6D27-44F6-8549-83935DE0D184',
 '08b3ffb6-38e5-