In [26]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from typing import *
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from tqdm import tqdm
import h5py
import random
from PIL import Image

def get_transforms(train=False):
    """
    Takes a list of images and applies the same augmentations to all of them.
    This is completely overengineered but it makes it easier to use in our pipeline
    as drop-in replacement for torchvision transforms.
    ## Example
    ``` python
    imgs = [Image.open(f”image{i}.png”) for i in range(1, 4)]
    t = get_albumentations_transforms(train=True)
    t_imgs = t(imgs) # List[torch.Tensor]
    ```
    For the single image case:
    ```python
    img = Image.open(f”image{0}.png”)
    # or img = np.load(some_bytes)
    t = get_albumentations_transforms(train=True)
    t_img = t(img) # torch.Tensor
    ```
    """
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    _data_transform = None
    def _get_transform(n: int = 3):
        data_transforms = A.Compose(
            [
                A.Resize(224, 224),
                A.Normalize(mean=mean, std=std),
                ToTensorV2(),
            ],
            additional_targets={f"image{i}": "image" for i in range(1, n)},

        )
        return data_transforms
    def transform_images(images: any):
        nonlocal _data_transform
        if not isinstance(images, list):
            n = 1
            images = [images]
        else:
            n = len(images)
        if _data_transform is None:
            # instantiate once
            _data_transform = _get_transform(n)
        # accepts both lists of np.Array and PIL.Image
        if isinstance(images[0], Image.Image):
            images = [np.array(img) for img in images]
        image_dict = {"image": images[0]}
        for i in range(1, n):
            image_dict[f"image{i}"] = images[i]
        transformed = _data_transform(**image_dict)
        transformed_images = [
            transformed[key] for key in transformed.keys() if "image" in key
        ]
        if len(transformed_images) == 1:
            return transformed_images[0]
        return transformed_images
    return transform_images


    
class IdilDataSet(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and tokenizes prompts.
    """
    def __init__(
        self,
        csv_path: str,
        folder: str,
        magnification: int,
        transform: Optional[Callable] = None,
        n_patches: int = 250,
        random_selection=False,
        limit: Optional[int] = None,
        wsi_type: str = "frozen"
    ):
        super().__init__()
        self.csv = pd.read_csv(csv_path)
        self.csv = self.csv[self.csv["wsi_type"] == wsi_type]

        # Filter out unwanted tumor types
        self.csv = self.csv[self.csv["Tumor Type"] != "Oligoastrocytoma"]
        
        # Replace grade values
        self.csv["Neoplasm Histologic Grade"] = self.csv["Neoplasm Histologic Grade"].replace({"G2": "low grade glioma", "G3": "high grade glioma"})
        
        # Replace IDH status values
        self.csv["Subtype"] = self.csv["Subtype"].replace({
            "LGG_IDHmut-non-codel": "IDH mutation",
            "LGG_IDHmut-codel": "IDH mutation",
            "LGG_IDHwt": "wild-type IDH"
        })

        self.folder = folder
        self.magnification = magnification
        self.transform = transform
        self.n_patches = n_patches
        self.random_selection = random_selection
        self.slide_ids = self.csv["uuid"].unique()
        success_ids = self.load_success_ids(self.folder)
        self.slide_ids = [x for x in self.slide_ids if x in success_ids]
        if limit:
            self.slide_ids = self.slide_ids[:limit]
        self.labels = []
        self.patches = []
        self.load_patches()
        self.compute_weights()
        
    def load_success_ids(self, feat_folder: str):
        success_ids = set()
        success_txt = f"{feat_folder}/success.txt"
        success_db = f"{feat_folder}/success.db"
        if os.path.exists(success_txt):
            print("Warning: Loading success IDs from deprecated success.txt.")
            with open(success_txt, "r") as f:
                for line in f:
                    success_ids.add(line.strip())
        if os.path.exists(success_db):
            print("Loading success IDs from database.")
            conn = sqlite3.connect(success_db)
            cursor = conn.cursor()
            cursor.execute("SELECT slide_id FROM success")
            success_ids = set([row[0] for row in cursor.fetchall()])
            conn.close()
        return success_ids

    def load_patches(self):
        """
        Load n_patches into memory.
        """
        for slide_id in tqdm(self.slide_ids, desc="Prefetch patches"):
            file = f"{self.folder}/{slide_id}_features.h5"
            try:
                with h5py.File(file, "r") as h5f:
                    n_patches = min(self.n_patches, len(h5f[str(self.magnification)]))
                    if self.random_selection:
                        indices = random.sample(range(n_patches), n_patches)
                    else:
                        indices = list(range(n_patches))
                    imgs = [
                        Image.fromarray(h5f[str(self.magnification)][i]) for i in indices
                    ]
                    self.patches.append((imgs, slide_id))
                    self.labels.append(self.get_label(slide_id))
            except Exception as e:
                pass

    def __len__(self):
        return len(self.patches)

    def get_label(self, slide_id):
        return self.csv.loc[self.csv["uuid"] == slide_id, "label"].values[0]

    def get_metadata(self, slide_id):
        return self.csv.loc[self.csv["uuid"] == slide_id]

    def compute_weights(self):
        """
        Compute weights for WeightedRandomSampler.
        """
        class_counts = {}
        for label in self.labels:
            if label in class_counts:
                class_counts[label] += 1
            else:
                class_counts[label] = 1
        class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
        self.weights = [class_weights[label] for label in self.labels]

    def __getitem__(self, idx):
        imgs, slide_id = self.patches[idx]
        imgs = torch.stack([self.transform(img) for img in imgs])
        label = self.get_label(slide_id)
        metadata = self.get_metadata(slide_id)
        age = metadata["Diagnosis Age"].values[0]
        race = metadata["Race Category"].values[0]
        sex = metadata["Sex"].values[0]
        grade = metadata["Neoplasm Histologic Grade"].values[0]
        IDHstatus = metadata["Subtype"].values[0]
        tumortype = metadata["Tumor Type"].values[0]
        prompt = f"a frozen brain histopathology slide of a {race.lower()}, {sex.lower()}, age {age}, has {IDHstatus.lower()}, {grade.lower()} of {tumortype.lower()}"
        return slide_id, imgs, label, prompt


In [53]:

csv_path = "/n/data2/hms/dbmi/kyu/lab/che099/data/idil_tcga_lgg_merge_idh.csv"
folder = "/n/data2/hms/dbmi/kyu/lab/che099/data/tcga_lgg/frozen_patches_20x"
magnification = 20
transform = get_transforms()
n_patches = 2
wsi_type = "frozen"
ds = IdilDataSet(
    csv_path,
    folder=folder,
    magnification=magnification,   
    transform=get_transforms(),
    n_patches=250,
    random_selection=True,
    limit=None,
    wsi_type="frozen"
)
    



Prefetch patches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 536/536 [00:36<00:00, 14.66it/s]


In [54]:
dl = torch.utils.data.DataLoader(ds, num_workers=16)

In [57]:
%timeit
batch = next(iter(dl))

In [58]:
torch.cuda.is_available()

  return torch._C._cuda_getDeviceCount() > 0


False

In [29]:
len(ds)

2

In [31]:
slide_id, imgs, label, prompt = ds[0]

In [36]:
import matplotlib.pyplot as plt

In [49]:
def denorm(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    for t,m,s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

In [52]:
from ipywidgets import interact

@interact(i=(0, len(imgs)-1))
def show(i):
    plt.axis("off")
    plt.imshow(
        denorm(imgs[i]).numpy().transpose(1,2,0)
    )

interactive(children=(IntSlider(value=124, description='i', max=249), Output()), _dom_classes=('widget-interac…