In [2]:
# Standard library imports
import copy
import datetime
import gc
import glob
import heapq
import joblib
import math
import os
import random
import time
from collections import defaultdict
from itertools import chain

# Related third-party imports
import albumentations as A
import cv2
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from albumentations.pytorch import ToTensorV2
from colorama import Fore, Back, Style
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from PIL import Image
from skimage import io
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, f1_score, accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm, tqdm_notebook

from torch.utils.data.sampler import WeightedRandomSampler


# Local application/library specific imports
# (Your local imports here, if any)

# Set up for colored terminal text
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

# Disable warnings
import warnings
warnings.filterwarnings('ignore')

# Additional note: The following import seems to be from a very specific context or not typically used.
# You might want to review if it's necessary or correct:
# from joblib.externals.loky.backend.context import get_context


## 1. Datasets & Preprocessing

In [1]:
from sklearn.model_selection import StratifiedGroupKFold

def create_data(WSI_DIRS, TMA_DIR, CONFIG, cancer_th=None):
    df_orig = pd.read_csv("/kaggle/input/UBC-OCEAN/train.csv")
    df_orig = df_orig.rename(columns={"label":"subtype"})
    
    df_masks = create_img_dataframe_from_folder("/kaggle/input/ubc-ovarian-cancer-competition-supplemental-masks/")

    dfs = []
    cols = ['image_path', 'image_id', 'subtype', 'image_width', 'image_height', 'is_tma']
    for dir_ in WSI_DIRS:
        df = create_img_dataframe_from_directory(dir_)
        df["image_id"] = df["image_id"].astype(int)
        df = pd.merge(df, df_orig, on="image_id", how="left")
        df = df[df["is_tma"]!=True]
        dfs.append(df[cols])
    
    df = create_img_dataframe_from_directory(TMA_DIR)
    df["image_id"] = df["image_id"].astype(int)
    df = pd.merge(df, df_orig, on="image_id", how="left")
    df = df[df["is_tma"]==True]
    dfs.append(df[cols])
    
    df_train = pd.concat(dfs, axis=0, ignore_index=True)
    df_train = df_train.rename(columns={"subtype":"label"})
    print(df_train.shape, df_train.image_id.nunique())

    encoder = LabelEncoder()
    df_train['target_label'] = encoder.fit_transform(df_train['label'])
    with open("label_encoder.pkl", "wb") as fp:
        joblib.dump(encoder, fp)
    
    # use stratified K Fold for crossvalidation 
    sgkf = StratifiedGroupKFold(n_splits=CONFIG['n_fold'], shuffle=True, random_state=CONFIG["seed"])

    for fold, ( _, val_) in enumerate(sgkf.split(X=df_train, y=df_train.target_label, groups=df_train.image_id.values)):
        df_train.loc[val_ , "kfold"] = int(fold)
    display(df_train.head())

    # assert that images for which we have masks are not part of test set (avoid information leakage)
    df_train.loc[df_train["image_id"].isin(df_masks["image_id"]), "kfold"] = CONFIG["n_fold"] + 1
    display(df_train["kfold"].value_counts())
    # separate train and test dataset
    df_test = df_train[df_train["kfold"]==CONFIG["test_fold"]].reset_index(drop=True)
    df_train = df_train[df_train["kfold"]!=CONFIG["test_fold"]].reset_index(drop=True)
    print(f"Shape df_train: {df_train.shape}, Shape df_test: {df_test.shape} ")
    display(df_train.label.value_counts())
    display(df_test.label.value_counts())
    df_test.image_id.nunique()
    return df_train, df_test, encoder, df_orig, df_masks

In [3]:
class UBCDataset(Dataset):
    def __init__(self, df, transforms=None, apply_vertical_crop=True):
        self.df = df
        self.filenames = df.file_path.values
        self.labels =  df.target_label.values
        self.transforms = transforms
        self.apply_vertical_crop = apply_vertical_crop

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.apply_vertical_crop:
            img = crop_vertical(img)
                
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            "image": img,
            "label": torch.tensor(self.labels[idx], dtype=torch.long)
               }

def crop_vertical(image):
    """
    Function crops images if multiple slices contained and separated by black vertical background.
    """
    vertical_sum = np.sum(image, axis=(0, 2))

    # Identify the positions where the sum is zero
    zero_positions = np.where(vertical_sum == 0)[0]

    if len(zero_positions)==0:
        cropped_images = [image]
    else:
        # If the image does not start with a black area, add index 0
        if zero_positions[0] != 0:
            zero_positions = np.insert(zero_positions, 0, 0)

        # If the image does not end with a black area, add the image width
        if zero_positions[-1] != image.shape[1] - 1:
            zero_positions = np.append(zero_positions, image.shape[1] - 1)

        start_idx = zero_positions[0]
        cropped_images = []

        for idx in range(1, len(zero_positions)):
            end_idx = zero_positions[idx]
            if end_idx - start_idx > 1:  # If the width of the cropped section is greater than 1
                cropped = image[:, start_idx:end_idx]
                # only include samples which are of min size
                if cropped.shape[1]>200:  
                    cropped_images.append(cropped)
                    # cv2.imwrite(f"{save_prefix}_{idx}.jpg", cropped)
            start_idx = end_idx

    final_crops = []
    # remove black bars above/below the crops 
    for cropped in cropped_images:
        horizontal_sum = np.sum(cropped, axis=(1, 2))
        zero_positions = np.where(horizontal_sum == 0)[0]
        img_ = np.delete(cropped, zero_positions, axis=0)
        final_crops.append(img_)
    if len(final_crops)==0:
        return image
    return final_crops[0]


def custom_center_crop_or_resize(image, crop_size):
    # If both dimensions of the image are greater than or equal to the desired size, apply CenterCrop
    if image.shape[0] >= crop_size[0] and image.shape[1] >= crop_size[1]:
        return A.CenterCrop(crop_size[0], crop_size[1])(image=image)["image"]
    # Else, just resize the image to the desired size
    else:
        return A.Resize(crop_size[0], crop_size[1])(image=image)["image"]

In [4]:
def _color_means(img_path):
    img = np.array(Image.open(img_path))
    mask = np.sum(img[..., :3], axis=2) == 0
    img[mask, :] = 255
    if np.max(img) > 1.5:
        img = img / 255.0
    clr_mean = {i: np.mean(img[..., i]) for i in range(3)}
    clr_std = {i: np.std(img[..., i]) for i in range(3)}
    return clr_mean, clr_std

"""
ls_images = glob.glob(os.path.join(TRAIN_DIR, "*", "*.png"))
clr_mean_std = Parallel(n_jobs=os.cpu_count())(delayed(_color_means)(fn) for fn in tqdm(ls_images[:9000]))

img_color_mean = pd.DataFrame([c[0] for c in clr_mean_std]).describe()
display(img_color_mean.T)
img_color_std = pd.DataFrame([c[1] for c in clr_mean_std]).describe()
display(img_color_std.T)

img_color_mean = list(img_color_mean.T["mean"])
img_color_std = list(img_color_std.T["mean"])
print(f"{img_color_mean=}\n{img_color_std=}")
"""

## histogram matching 
#from skimage.exposure import match_histograms
#ref_img = np.array(Image.open("/kaggle/input/tiles-of-cancer-2048px-scale-0-25/10077/000067_16-3.png"))
#bef_img = np.array(Image.open("/kaggle/input/tiles-of-cancer-2048px-scale-0-25/12522/000028_6-2.png"))
#start = time.time()
#aft_img = match_histograms(bef_img, ref_img, channel_axis=-1)
#print(time.time()-start)


"""        
A.Normalize(
    mean=[0.485, 0.456, 0.406], 
    std=[0.229, 0.224, 0.225], 
    max_pixel_value=255.0, 
    p=1.0
),
"""


'        \nA.Normalize(\n    mean=[0.485, 0.456, 0.406], \n    std=[0.229, 0.224, 0.225], \n    max_pixel_value=255.0, \n    p=1.0\n),\n'

In [5]:
img_color_mean=[0.8661704276539922, 0.7663107094675368, 0.8574260897185548]
img_color_std=[0.08670629753900036, 0.11646580094195522, 0.07164169171856792]


In [6]:
class CancerTilesDataset(Dataset):
    @staticmethod
    def get_img_dir(data_row):
        # based on if is_tma or not we select the respective image path
        if data_row.is_tma == True:
            return glob.glob(os.path.join("/kaggle/input/ubc-tma-tiles-512-05scale/UBC_TMA_tiles_1024p_scale05", str(data_row.image_id), "*.png"))
        else:
            return glob.glob(os.path.join("/kaggle/input/tiles-of-cancer-2048px-scale-0-25", str(data_row.image_id), "*.png")) 

    def __init__(
        self,
        df_data,
        path_img_dir: str =  '',
        transforms = None,
        mode: str = 'train',
        labels_lut = None,
        white_thr: int = 225,
        thr_max_bg: float = 0.2,
        train_val_split: float = 0.90,
        n_tiles: int = 1,
        tma_weight: float = 1.0,
    ):
        assert os.path.isdir(path_img_dir)
        self.path_img_dir = path_img_dir
        self.transforms = transforms
        self.mode = mode
        self.white_thr = white_thr
        self.thr_max_bg = thr_max_bg
        self.train_val_split = train_val_split
        self.n_tiles = n_tiles
        self.tma_weight = tma_weight

        self.data = df_data
        self.labels_unique = sorted(self.data["label"].unique())
        self.labels_lut = labels_lut or {lb: i for i, lb in enumerate(self.labels_unique)}

        self.data.is_tma = self.data.is_tma.astype(bool)
        self.data = self.data.sample(frac=1, random_state=42).reset_index(drop=True)

        # split dataset
        assert 0.0 <= self.train_val_split <= 1.0
        frac = int(self.train_val_split * len(self.data))
        self.data = self.data[:frac] if mode in ["train", "test"] else self.data[frac:]
        self.img_dirs = [CancerTilesDataset.get_img_dir(row) for i, row in self.data.iterrows()] 
        self.img_dirs = self.img_dirs * self.n_tiles
        self.img_paths = []
        #print(f"missing: {sum([not os.path.isfile(os.path.join(self.path_img_dir, im))
        #                       for im in self.img_names])}")
        # self.labels = list(self.data['label'])
        self.labels =  np.array(self.data.target_label.values.tolist() * self.n_tiles)
        
        # set sample weights 
        self.sample_weights = [self.tma_weight if is_tma == True else 1 for is_tma in self.data["is_tma"]] 
        self.sample_weights =  np.array(self.sample_weights * self.n_tiles)
        
    def __getitem__(self, idx: int) -> tuple:
        nth_iteration = idx//len(self.data)
        if self.mode=="train":
            random.seed()
        else:
            random.seed(CONFIG["seed"]+nth_iteration)
        random.shuffle(self.img_dirs[idx])
        for img_path in self.img_dirs[idx]:
            assert os.path.isfile(img_path), f"missing: {img_path}"
            tile = cv2.imread(img_path)
            tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB)
        
            # tile = np.array(Image.open(img_path))[..., :3]
            black_bg = np.sum(tile, axis=2) == 0
            tile[black_bg, :] = 255
            mask_bg = np.mean(tile, axis=2) > self.white_thr
            if np.sum(mask_bg) < (np.prod(mask_bg.shape) * self.thr_max_bg):
                self.img_paths.append(img_path)
                print(f"Idx: {idx}, Path: {img_path}, len img_pths: {len(self.img_paths)}, nunique img_paths: {len(set(self.img_paths))}")
                break

        # augmentation
        if self.transforms:
            tile = self.transforms(image=tile)["image"]
        #print(f"img dim: {img.shape}")
        return {
            "image": tile,
            "label": torch.tensor(self.labels[idx], dtype=torch.long),
               }
    def __len__(self) -> int:
        return len(self.img_dirs)
    
    def get_sample_weights(self):
        return torch.from_numpy(self.sample_weights).double()

In [7]:
def delete_tiles(directory_path):
    if os.path.isdir(directory_path):
        for filename in os.listdir(directory_path):
            if os.path.isfile(os.path.join(directory_path, filename)):
                os.remove(os.path.join(directory_path, filename))

def extract_image_tiles(
    p_img, img_id, tmp_dir, size: int = 2048, scale: float = 0.5,
    drop_thr: float = 0.8, white_thr: int = 245, max_samples: int = 50
) -> list:
    delete_tiles(tmp_dir)  # empty directory from previous images
    im = pyvips.Image.new_from_file(p_img)
    w = h = size
    # https://stackoverflow.com/a/47581978/4521646
    idxs = [(y, y + h, x, x + w) for y in range(0, im.height, h) for x in range(0, im.width, w)]
    # random subsample
    max_samples = max_samples if isinstance(max_samples, int) else int(len(idxs) * max_samples)
    random.seed(42)
    random.shuffle(idxs)
    images = []
    i = 0
    for y, y_, x, x_ in (idxs):
        i += 1
        img_path = f"{tmp_dir}/{str(i)}.png"
        # https://libvips.github.io/pyvips/vimage.html#pyvips.Image.crop
        tile = im.crop(x, y, min(w, im.width - x), min(h, im.height - y)).numpy()[..., :3]
        if tile.shape[:2] != (h, w):
            tile_ = tile
            tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
            tile = np.zeros(tile_size, dtype=tile.dtype)
            tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
        black_bg = np.sum(tile, axis=2) == 0
        tile[black_bg, :] = 255
        mask_bg = np.mean(tile, axis=2) > white_thr
        if np.sum(mask_bg) >= (np.prod(mask_bg.shape) * drop_thr):
            #print(f"skip almost empty tile: {k:06}_{int(x_ / w)}-{int(y_ / h)}")
            continue
        # print(tile.shape, tile.dtype, tile.min(), tile.max())
        new_size = int(size * scale), int(size * scale)
        tile = Image.fromarray(tile).resize(new_size, Image.LANCZOS)
        tile.save(img_path)
        images.append(img_path)
        # need to set counter check as some empty tiles could be skipped earlier
        if len(images) >= max_samples:
            break
    return images


class TilesInferenceDataset(Dataset):

    def __init__(
        self,
        img_path: str,
        img_id: str = None,
        tmp_dir: str = None,
        size: int = 2048,
        scale: float = 0.25,
        white_thr: int = 225,
        thr_max_bg: float = 0.6,
        max_samples: int = 30,
        transforms = None,
        is_submission: bool = True,
    ):
        self.max_samples = max_samples
        self.white_thr = white_thr
        self.thr_max_bg = thr_max_bg
        self.is_submission = is_submission
        
        self.transforms = transforms
        if self.is_submission:
            # print(img_path)
            assert os.path.isfile(img_path)
            self.imgs = extract_image_tiles(
                img_path, img_id, tmp_dir, size=size, scale=scale,
                drop_thr=self.thr_max_bg, max_samples=max_samples)
        else:  # test
            all_imgs = glob.glob(os.path.join(img_path, img_id, "*.png"))
            # Filter images based on background threshold
            self.imgs = []
            for img_path in all_imgs:
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                black_bg = np.sum(img, axis=2) == 0
                img[black_bg, :] = 255
                mask_bg = np.mean(img, axis=2) > self.white_thr
                if np.sum(mask_bg) <= (np.prod(mask_bg.shape) * self.thr_max_bg):
                    self.imgs.append(img_path)  # Include this image
            self.imgs = self.imgs[:self.max_samples]
            # print(self.imgs)

    def __getitem__(self, idx: int) -> tuple:
        img = cv2.imread(self.imgs[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # filter background
        mask = np.sum(img, axis=2) == 0
        img[mask, :] = 255
        if np.max(img) < 1.5:
            img = np.clip(img * 255, 0, 255).astype(np.uint8)
        if self.transforms:
            img = self.transforms(image=img)["image"]
        return img

    def __len__(self) -> int:
        return len(self.imgs)


In [8]:
def convert_dict_to_tensor(dict_):
    """Converts the values of a dict into a PyTorch tensor."""

    # Create a new PyTorch tensor
    tensor = torch.empty(len(dict_))

    # Iterate over the dict and for each key-value pair, convert the value to a PyTorch tensor and add it to the new tensor
    for i, (key, value) in enumerate(dict_.items()):
        tensor[i] = value

    # Return the new tensor
    return tensor

def get_class_weights(df_train):
    label_counts = df_train.target_label.value_counts().sort_index().to_dict()
    ratios_dict = {}
    for key,val in label_counts.items():
        ratios_dict[key] = val / df_train.shape[0]
    ratios_dict
    weights = {}
    sum_weights = 0
    for key, val in ratios_dict.items():
        weights[key] = 1 / val
        sum_weights +=  1 / val
    for key, val in weights.items():
        weights[key] = val / sum_weights
    weight_tensor = convert_dict_to_tensor(weights)
    return weight_tensor

def get_dataloaders(df, TRAIN_DIR, CONFIG, data_transforms, n_tiles=1, train_val_split=0.9, apply_sampler=True, tma_weight=1, sample_fac=1):
    # df_train = df[df["kfold"]!=fold].reset_index(drop=True)
    train_dataset = CancerTilesDataset(df, TRAIN_DIR, transforms=data_transforms["train"], mode="train", n_tiles=n_tiles, train_val_split=train_val_split, tma_weight=tma_weight)
    if apply_sampler:
        samples_weights = train_dataset.get_sample_weights()
        train_sampler = WeightedRandomSampler(samples_weights, len(samples_weights)*sample_fac)
    else:
        train_sampler = None
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], num_workers=2, sampler=train_sampler, shuffle=False, pin_memory=True)
    
    valid_dataset = CancerTilesDataset(df, TRAIN_DIR, transforms=data_transforms["valid"], mode="valid", n_tiles=n_tiles, train_val_split=train_val_split, tma_weight=tma_weight)
    if apply_sampler:
        samples_weights = valid_dataset.get_sample_weights()
        valid_sampler = WeightedRandomSampler(samples_weights, len(samples_weights)*sample_fac)
    else:
        valid_sampler=None
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], num_workers=2, sampler=valid_sampler, shuffle=False, pin_memory=True)
    print(f"Len Train Dataset: {len(train_dataset)}, Len Validation Dataset: {len(valid_dataset)}" )
    return train_loader, valid_loader, df

