# Stain normalization notebook
This notebook demonstrates the use of deep learning models to train content encoder and attribute encoder for the stain normalization.
Changes compared to the normal code
<ol>
    <li>Add LeakyReLU to enforce non-negative constraints to the content encoder.</li>
</ol>

Reference
1. [Yibo's color normalization exp](https://docs.google.com/presentation/d/1MttGX3S6nWdlUOzHGhyHuAIWgttYA2EzEsehvELFBks/edit#slide=id.p)

In [None]:
# Set the environment to run the workflow
import os
import sys
### setting up env variables before running the dynamic workflows
os.environ["MLENV_NAME"] = "tan.nguyen/overlay_test" # your_mle_username/your_mle_envname
os.environ["MLPLATFORM_BRANCH"] = "pdac_fibrosis"
os.environ["JUPYTERHUB_USER"] = "tan.nguyen"
%reload_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
from typing import Optional
from tempfile import TemporaryDirectory
import matplotlib.pyplot as plt
import itertools

from dataclasses import dataclass
from IPython.core.debugger import set_trace
import lpips
import matplotlib.pyplot as plt

from itertools import product

from pathai.dynamic import set_jabba
from pathai.dynamic import set_local
from pathai.dynamic import Jabba
from pathai.dynamic import jmap
from pathai.dynamic import AssetRef
from pathai.io.samples_flexible_io import SamplesWriter
from pathai.io.samples_flexible_io import SamplesReader
import pathai.api.slides.slides as slide_api

In [None]:
INPUT_IMAGE_SIZE_PIXELS = 512
NUM_STAIN_VECTORS = 8

## 1. Prepare the training slides

In [None]:
from pathai.api.monocle.slides import get_slide_info

In [None]:
@dataclass
class Samples:
    core_name: str
    row_idx: int
    col_idx: int
    slide_id: Optional[int] = None

In [None]:
import numpy as np
import cv2
from pathai.dynamic import SlideReference
from pathai.dynamic import task

def _he_tissue_mask(slide_id: int, downsampling_factor) -> np.ndarray:
    # Compute the tissue mask from the input slide id.
    CLOSING_KEREL_SIZE_PIXELS = 6
    print(f"Computing the tissue mask for slide {slide_id}")
    slide_reference = SlideReference(int(slide_id))
    with slide_reference.read_object() as slide:
        original_mpp = slide.mpp
        im = slide.view_at_mpp(slide.mpp * downsampling_factor, mpp_tolerance=0.1)
        # Due to histogram adjustment, the background value may not be the same, need to use Otsu for this.
        im = cv2.cvtColor(im[:, :, :], cv2.COLOR_BGR2GRAY )
        laplacian = cv2.Laplacian(im,cv2.CV_32F)
        sobelx = cv2.Sobel(im,cv2.CV_32F,1,0,ksize=5)
        sobely = cv2.Sobel(im,cv2.CV_32F,0,1,ksize=5)
        grad_sqr = np.sqrt(sobelx ** 2 + sobely**2).astype(np.uint8)
        threshold_val, tissue_mask = cv2.threshold(grad_sqr, 0, 255, cv2.THRESH_OTSU)
        kernel = np.ones((CLOSING_KEREL_SIZE_PIXELS, CLOSING_KEREL_SIZE_PIXELS), np.uint8)
        return cv2.morphologyEx(tissue_mask, cv2.MORPH_CLOSE, kernel)

def compute_usable_tissue_fraction(
    center_coords: Tuple[int, int], patch_size_pixels: int, tissue_mask: np.ndarray
) -> float:
    row_idx, col_idx = center_coords
    nrows, ncols = tissue_mask.shape
    half_patch_size_pixels = patch_size_pixels // 2
    if not half_patch_size_pixels <= row_idx <= nrows - half_patch_size_pixels:
        raise ValueError(
            f"row_idx {row_idx} must be between [{half_patch_size_pixels}, {nrows - half_patch_size_pixels}]"
        )
    if not half_patch_size_pixels <= col_idx <= ncols - half_patch_size_pixels:
        raise ValueError(
            f"row_idx {col_idx} must be between [{half_patch_size_pixels}, {ncols - half_patch_size_pixels}]"
        )

    return (
        float(
            np.sum(
                tissue_mask[
                    row_idx - half_patch_size_pixels : row_idx + half_patch_size_pixels,
                    col_idx - half_patch_size_pixels : col_idx + half_patch_size_pixels,
                ]
            )
        )
        / float(patch_size_pixels ** 2)
        / 255
    )


@task(node_family="high-mem", slots=4)
def generate_wsi_sampling_points(
    slide_id: int,
    num_candidates_per_core: int = 100000, 
    patch_size_pixels: int = INPUT_IMAGE_SIZE_PIXELS, 
    minimum_tissue_fraction_to_be_selected: float = 0.8,
) -> List[Samples]:
    """Generates a list of Sample for each slide.
    
    Args:
        slide_id: The name of the core to generate Samples.
        num_candidates_per_core (optional): The number of candidates per core to consider for sampling coordinates. Defaults to 100000.
        patch_size (optional): The size of each sampling patch. Defaults to INPUT_IMAGE_SIZE_PIXELS.
        minimum_tissue_fraction_to_be_selected (optional): The number that specifies the minimum fraction of tissue for a candidate to be selected as a sampling
            point. Defaults to be 0.1.
            
    Returns:
        A list of Samples that contains the coordinates of the center point of each sampling patch.
    """
    downsampling_factor = 8
    print("Computing the tissue mask")
    tissue_mask = _he_tissue_mask(slide_id, downsampling_factor)
    ds_nrows, ds_ncols = tissue_mask.shape
    ds_patch_size_pixels = patch_size_pixels // downsampling_factor
    ds_half_patch_size_pixels = ds_patch_size_pixels // 2
    
    grid_spacing = int(np.sqrt((ds_nrows - ds_patch_size_pixels) * (ds_ncols - ds_patch_size_pixels) / num_candidates_per_core))

    print("Generating sampling candidate the tissue mask")
    sampling_candidates = product(range(ds_half_patch_size_pixels, ds_nrows - ds_half_patch_size_pixels, grid_spacing), 
                                  range(ds_half_patch_size_pixels, ds_ncols - ds_half_patch_size_pixels, grid_spacing))
    
    useful_tissue_fractions = map(lambda x : compute_usable_tissue_fraction(x, patch_size_pixels=ds_patch_size_pixels, tissue_mask=tissue_mask),
                                  sampling_candidates)
    
    return[Samples(core_name=None, 
                   row_idx=int(x[0]) * downsampling_factor, 
                   col_idx=int(x[1]) * downsampling_factor,
                   slide_id = slide_id,
                  ) 
           for x, y in zip(sampling_candidates, useful_tissue_fractions) if y > minimum_tissue_fraction_to_be_selected]

In [None]:
datasets_by_name: Dict[str, List[int]] ={
    'daphne': ([568263, 568260, 568257, 568254, 568251, 568248, 568245, 568242, 568239, 568236, 568233, 568230, 568227, 
                568224, 568218, 568215, 568212, 568209, 568206, 568203, 568200, 568197, 568194, 568191, 568188, 568185, 
                568182, 568179, 568176, 568170, 568167, 567988, 568164, 568161, 568158, 568155, 568152, 568149, 568146, 
                568143, 568140, 568137, 568134, 568131, 568128, 568125, 568122, 568119, 568116, 568113, 568110, 568107, 
                568101, 568098, 568095, 568092, 568089, 568086, 568083, 568080, 568077, 568074, 568068, 568065, 568062, 
                568059, 568056, 568053, 568050, 568047, 568044, 568041, 568038, 568035, 568032, 568029, 568026, 568023, 
                568020], [568017, 568012, 568009, 568006, 568003, 568000, 567997, 567994, 567991]),
    'p2_at2': ([336517, 336518, 338299, 338318, 338321, 338324, 338344, 340007, 340011, 340020, 340021, 340026, 340035, 
                340051, 340055, 340057, 340064, 340125, 340130, 340138, 340141, 340148, 340152, 340154, 340163, 342625,
                342632, 342635, 342641, 342653, 342662, 342666, 342743, 342745, 342750, 342752, 342760, 342764, 342769, 
                342782, 342784, 342795, 342856, 342857, 342865, 342868, 342869, 342873, 342879, 342906, 342918, 342924,
                342927, 342931, 342944, 342952, 342959, 343013, 343017, 343020, 343024, 343029, 343043, 343052, 343059,
                343062, 343064, 343079, 343080, 343088, 343094, 343129, 344180, 344188, 344194, 344198, 344206, 344210,
                344602, 344603, 344611, 344616, 344621, 344625, 344630, 344639, 344643, 344649, 344667, 344675, 344679, 
                344686, 344692, 344707, 344710, 344718, 344719, 344726, 344770, 344773, 344789, 344791, 344802, 344808,
                344811, 344816, 344821, 344827, 344842, 344847, 344851, 344854, 344986, 345091, 345094, 345098, 345100,
                345106, 345111, 345117, 345123, 345126, 345129, 345143, 345145, 345151, 345156, 345160, 345163, 345169,
                345749, 345755, 345761, 345763, 345778, 345782, 345804, 345815, 345829, 345836, 345840, 345872, 345877,
                345878, 345883, 345887, 345893, 345900, 345910, 345914, 345925, 345930, 345931, 345940, 346946, 346955,
                346961, 346964, 346967, 346970, 346976, 346983, 346986, 346990, 346993, 347000, 347004, 347012, 347019,
                347025, 347032, 347040, 347043, 347046, 347066, 347079, 347089, 347099, 347113, 347115, 347125, 347128,
                347143, 347145, 347149, 347156, 347162, 347166, 347172, 347177, 347189, 347196, 347203, 347219, 347225,
                347443, 347446, 347450, 347455, 349471, 349474, 349479, 349482, 349489, 349498, 349506, 349512, 349514,
                349517, 349525, 349528, 349545, 349547, 349550, 349566, 349570, 349573, 349580, 349603, 349613, 349619,
                349621, 349630, 349648, 349649, 349659, 349671, 349692, 349697, 349701, 349703, 349708, 349720, 349726,
                351671, 351675, 465745, 465746, 465747, 465748, 465749, 465750, 465751, 465752, 465753, 465766, 465767,
                465768, 465769, 465770, 465772, 465773, 465774, 465775, 465776, 465777, 465778, 465779, 465780, 465781,
                465782, 465783, 465784, 465785, 465786, 465787, 465788, 465789, 465790, 465791, 465792, 465793, 465794,
                465795, 465796, 465797, 465798, 465799, 465800, 465801, 465802, 465803, 465804, 465805, 465806, 465807,
                465808, 465809, 465832, 465833, 465834, 465835, 465836, 465837],
                [336527,338333,340069,340073,340076,340081,340086,340090,342676,342682,342685,342697,342700,342703,
                342819,342832,342886,342889,342968,342971,342973,343098,343105,343107,344250,344256,344257,344262,
                344264,344269,344282,344286,344289,344486,344494,344867,344869,344876,344898,344902,344913,344917,
                344919,344928,344936,344940,344941,344950,345954,345976,345983,345988,346009,346012,346016,346020,
                346026,346032,346037,346043,347227,347238,347242,347251,347254,347269,347274,347278,347286,347293,
                347296,347307,347314,347317,351681,351684,351690,351700,465754,465755,465756,465757,465758,465759,
                465760,465761,465762,465763,465764,465810,465811,465812,465813,465814,465815,465816,465817]),
    'p2_gt450': ([330608,330609,330722,330725,330727,330734,330801,331136,331144,331162,331164,331170,331179,
                331195,331199,331201,331208,331270,331275,331283,331286,331293,331297,331299,331308,336010,336017,
                336020,336026,336038,336047,336051,336128,336130,336135,336145,336149,336154,336167,336169,336180,
                336260,336262,336265,336276,336282,336284,336289,336292,336301,336304,336307,336313,336326,336334,
                336341,336396,336400,336403,336407,336412,336426,336436,336446,336448,336463,336464,336472,336478,
                336513,336634,336642,336648,336652,336660,336664,338433,338434,338442,338447,338452,338456,338462,
                338471,338475,338481,338499,338507,338511,338518,338524,338550,338551,338558,340169,340172,340188,
                340190,340201,340207,340210,340215,340220,340226,340241,340250,340253,340485,340488,340492,340494,
                340500,340505,340511,340517,340520,340537,340545,340550,340554,340557,340563,340584,340595,340601,
                340607,340609,340624,340628,340651,340662,340676,340683,340687,340693,340698,340699,340704,340708,
                340714,340721,340731,340735,340746,340752,340753,340762,341007,341017,341023,341026,341029,341032,
                341038,341045,341048,341052,341055,341062,341066,341074,341081,341087,341094,341102,341105,341108,
                341129,341141,341151,341161,341175,341177,341187,341190,341205,341207,341211,341218,341224,341228,
                341234,341239,341251,341258,341265,341281,341287,341504,341507,341511,341516,341520,341523,341528,
                341531,341538,341548,341556,341562,341564,341567,341575,341578,341595,341597,341600,341616,341620,
                341623,341630,341653,341664,341670,341682,341701,341702,341712,341724,341745,341750,341754,341756,
                341761,341773,341794,341798,343708,343713,343719,343728,343729,343737,343741,343748,343750,343754,
                343756,343760,343761,343765,343773,343774,343777,343778,343794,343805,343811,343819,343822,343825,
                343831,343834,343838,343841,343846,343850,343855,343858,343862,343866,343871,343876,343889,343890,
                343897,343901,343904,343911,343914,343919,344064,344067,344071,344073,344081,344084,344393,344397,
                344399,344407,344409,344423,344425,344434], 
                [330743,331213,331217,331220,331225,331231,331235,336061,336067,336070,336082,336085,336088,336204,
                336217,336227,336233,336350,336353,336355,336482,336489,336491,336685,336704,336710,336711,336716,
                336718,336723,336736,336740,336743,336755,340266,340268,340275,340297,340302,340313,340317,340319,
                340328,340336,340340,340341,340351,340776,340798,340805,340810,340831,340834,340838,340842,340848,
                340854,340859,340865,341289,341300,341304,341313,341316,341331,341336,341340,341348,341355,341358,
                341369,341376,341379,341804,341807,341813,341823,343926,343931,343933,343940,343971,343976,343979,
                343985,344438,344442,344448,344450,344457,344461,344463,344467,344473,344474,344483]),
    'p2_dp200': ([464595,464596,464597,464598,464599,464600,464601,464602,464603,464604,464605,464606,464607,464608,
                464609,464610,464612,464613,464614,464615,464616,464617,464618,464619,464620,464621,464622,464623,
                464624,464625,464626,464627,464628,464629,464630,464631,464632,464633,464634,464635,464636,464637,
                464638,464639,464640,464641,464642,464643,464644,464645,465511,465512,465513,465514,465515,465516,
                465517,465518,465530,465531,465532,465533,465534,465535,465536,465537,465538,465539,473544,473545,
                473546,473547,473548,473549,473550,473569,473570,473571,473572,473573,473574,473575,473576,473577,
                473578,473579,473580,473581,473582,473583,473585,473586,473587,473588,473589,473590,473591,474012,
                474020,474021,474022,474023,474025,474026,474035,474036,474037,474038,474039,474040,474041,474042,
                474043,474044,474045,474046,474047,474048,474049,474050,474051,474052,474053,474054,474067,474068,
                474069,474070,474071,474072,474073,474074,474075,474076,474077,474078,474079,474080,474081,474082,
                474083,474160,474162,474163,474178,474179,474180,474181,474182,474183,474184,474185,474186,474187,
                474188,474189,474190,474191,474192,474193,474194,474422,474423,474424,474425,474426,474439,474440,
                474441,474442,474443,474444,474445,474446,474447,474448,474450,479880,479884,479885,479886,479887,
                479888,479889,479890,479895,479897,479898,479899,479900,479901,479903,479918,479919,479920,479923,
                479924,479925,479926,479927,479929,479930,479931,479932,479936,479937,479938,479939,479940,479941,
                479942,479943,479944,479945,479946,479947,479948,479949,479950,479951,479952,479959,479960,479961,
                479962,479963,479964,479965,479966,479967,479968,479999,480000,480001,480002,480003,480004,480005,
                480006,480007,480008,480009,480010,480011,480012,480013,480014,480015,480016,480017,480018,480019,
                480020,480021,480022,480023,480024,480025,480026,480027,480028,480029,480892,480895,480896,480897,
                480900,480901,480902,480903,480904,480905,480906,480907,480908,480909,480910,480911,480916,480928,
                480929,480930,480931,480932,480933,480934,480935,480936,480948,484333],
                [465519,465520,465521,465522,465523,465524,465525,465526,465527,465528,465529,465540,465541,465542,
                465543,465544,465545,465546,470074,473551,473552,473553,473554,473555,473556,473557,473558,473559,
                473560,473561,474013,474014,474015,474027,474028,474029,474164,474165,474166,474167,474168,474169,
                474170,474171,474172,474173,474174,474175,474176,474427,474428,474429,474430,474431,474432,474433,
                474434,474435,474436,479881,479882,479891,479892,479893,479894,479896,479904,479905,479906,479911,
                479912,479933,479934,479953,479954,479955,479969,479970,479971,479972,479973,479974,479975,479976,
                479977,479978,479979,479980,480899,480912,480913,480914,480915,480917,480918,480919,480922,480939]),
    'p2_ufs': ([487962,487963,487964,487967,487968,487971,487972,487973,487974,487975,487976,487977,487979,487982,
                487983,487985,487986,487987,487989,487991,487992,487993,487994,487995,487997,487999,488000,488001,
                488002,488003,488004,488005,488006,488007,488009,488010,488012,488013,488015,488017,488018,488024,
                488025,488026,488029,488030,488032,488033,488034,488035,488037,488038,488041,488045,488049,488050,
                488052,488058,488060,488061,488065,488067,488069,488071,488072,488074,488075,488078,488079,488080,
                488081,488083,488084,488090,488091,488092,488094,488095,488097,488099,488100,488101,488102,488103,
                488106,488107,488108,488110,488111,488113,488116,488119,488120,488122,488123,488126,488128,488129,
                488130,488132,488133,488135,488136,488137,488138,488139,488140,488144,488149,488150,488151,488154,
                488155,488157,488159,488161,488162,488163,488164,488168,488171,488172,488173,488174,488175,488176,
                488178,488179,488180,488181,488185,488186,488190,488191,488192,488194,488196,488197,488199,488200,
                488201,488202,488206,488207,488208,488211,488213,488214,488215,488218,488222,488225,488227,488231,
                488233,488235,488236,488237,488238,488241,488242,488244,488245,488246,488247,488250,488252,488255,
                488256,488259,488260,488261,488263,488266,488267,488270,488271,488274,488275,488278,488279,488282,
                488284,488288,488290,488291,488293,488294,488297,488298,488299,488300,488302,488304,488306,488307,
                488309,488310,488311,488312,488313,488314,488315,488321,488323,488328,488329,488332,488333,488336,
                488340,488344,488348,488349,488352,488353,488357,488358,488359,488361,488362,488365,488369,488371,
                488375,488377,488379,488380,488382,488384,488385,488386,488392,488395,488397,488400,488402,488403,
                488405,488409,488411,488412,488413,488415,488418,488420,488422,488425,488430,488432,488434,488437,
                488438,488439,488441,488443,488445,488446,488448,488453,488454,488455,488457,488462,488464,488465,
                488468,488469,488470,488471,488472,488473,488474,488475,488479,488480,488481,488482,488483,488484,
                488489,488492,488493,488495,488496], 
                [487960,487961,487966,487969,487970,487984,487988,487990,487996,488008,488016,488019,488021,488023,
                488027,488040,488042,488046,488048,488051,488054,488055,488057,488059,488063,488064,488068,488073,
                488076,488089,488104,488105,488114,488115,488124,488127,488131,488143,488148,488156,488158,488166,
                488170,488184,488187,488188,488189,488198,488204,488205,488209,488216,488219,488221,488232,488234,
                488249,488254,488262,488268,488272,488277,488283,488285,488289,488316,488319,488322,488324,488325,
                488326,488339,488347,488350,488351,488364,488374,488376,488388,488391,488394,488398,488416,488424,
                488426,488428,488440,488442,488444,488449,488452,488463,488485,488486,488490,488491,488494]),
    'tmas': ([530691, 237449], [])
}

In [None]:
max_num_samples_per_domain = 100000
train_domain_names = ['daphne', 'p2_gt450', 'p2_ufs']
val_domain_names = ['p2_at2', 'p2_dp200']

In [None]:
print(list(len(datasets_by_name[x][0]) for x in val_domain_names))

In [None]:
set_jabba()
train_lists: List[List[Samples]] = []
for name in train_domain_names:
    training_slide_ids, validation_slide_ids = datasets_by_name[name]
    temp_samples_from_all_slides_of_one_domain = jmap(
        generate_wsi_sampling_points, 
        training_slide_ids,
        cache_key=f"breast_cancer_generate_training_samples_for_{name}_dataset_80_pct_patch_size_512").wait()
    all_samples = np.random.permutation(
            list(itertools.chain.from_iterable(temp_samples_from_all_slides_of_one_domain))
        )[:max_num_samples_per_domain]
    train_lists.append(all_samples)

In [None]:
set_jabba()
val_lists: List[List[Samples]] = []
for name in val_domain_names:
    training_slide_ids, validation_slide_ids = datasets_by_name[name]
    task_notes = f"breast_cancer_generate_training_samples_for_{name}_dataset_80_pct_patch_size_512"
    print(task_notes)
    temp_samples_from_all_slides_of_one_domain = jmap(
        generate_wsi_sampling_points, 
        training_slide_ids,
        cache_key=task_notes).wait()
    all_samples = np.random.permutation(
            list(itertools.chain.from_iterable(temp_samples_from_all_slides_of_one_domain))
        )[:max_num_samples_per_domain]
    val_lists.append(all_samples)

## 2. Dataset preparation

In [None]:
import torch
from skimage.transform import rotate
from typing import Dict
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from torchvision.transforms import Normalize

Define transformations including RGB to transmittance and transmittance to absorbance.

In [None]:
class RandomFlip(object):
    """Flips the images up/down, left/right randomly with probability of 0.5"""
    def __call__(self, sample: Tuple[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.Tensor]:
        im, domain_vect = sample
        if np.random.uniform(low=0.0, high=1.0) > 0.5:
            im = np.fliplr(im)
        return im, domain_vect
    
class RandomRotate(object):
    """Randomly rotates the image."""

    def __call__(self, sample: Tuple[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.Tensor]:
        rotation_angle_degs = float(np.random.randint(4) * 90.0)
        im, domain_vect = sample
        im = rotate(im.astype(np.float32), rotation_angle_degs)
        return im, domain_vect

class RGBToTransmittance(object):
    """Converts an RGB image to a transmittance image."""
    def __init__(self, min_transmittance: float = 1e-4, clip_max_transmittance_to_one: bool = True):
        self._min_transmittance: float = min_transmittance
        self._max_transmittance = 1.0 if clip_max_transmittance_to_one else None
        
    def __call__(self, sample: Tuple[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.Tensor]:
        rgb_im, domain_vect = sample
        trans_im = rgb_im/np.percentile(rgb_im, q=99.0, axis=(0,1))
        trans_im = np.maximum(trans_im, self._min_transmittance)
        trans_im = np.minimum(trans_im, self._max_transmittance)
        return trans_im, domain_vect
    
class TransmittanceToRGB(object):
    """Converts the transmittance RGB assuming the 0 transmittance is 1.0 in RGB."""
    def __call__(self, sample: Tuple[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.Tensor]:
        trans_im, domain_vect = sample
        return 10**(-trans_im), domain_vect
        
class TransmittanceToAbsorbance(object):
    """Converts a transmittance image to the absorbance image.
    
    Args:
        min_transmittance (optional): The minimum image transmittance. Defaults to 1e-4.
        clip_max_transmittance_to_one (optional): If True, the maximum transmittance will be clipped at 1.0.
    """
    
    def __call__(self, sample: Tuple[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.Tensor]:
        transmittance, domain_vect = sample
        return -np.log10(transmittance), domain_vect
        
    
class ToTensor(object):
    """Converts to tensor, changes the order from (H, W, C) to (C, H, W)"""
    def __call__(self, sample: Tuple[np.ndarray, np.ndarray]) -> Dict[str, torch.Tensor]:
        im, domain_vect = sample
        return torch.tensor(im.astype(np.float32).copy().transpose(2, 0, 1)), domain_vect

In [None]:
import torch
import h5py
from tempfile import TemporaryDirectory
from torch.utils.data import Dataset

class MultiDomainDataset(Dataset):
    """A custom dataset that contains images from different domains.
    
    Each time, this dataset returns a random image from each domain.
    Args:
        multi_domain_samples: A list of samples from different domains. One item in this list contains 1 sample from each domain. The zero-th index
            list is for the TMA dataset.
        transforms: A list of transforms to be applied on the sample. 
        input_patch_size_pixels (optional): The size of the input patch in pixels. This is the patch size. Defaults to INPUT_IMAGE_SIZE_PIXELS.
        domain_indices (optional): A list that specifies which domains to sample from. Default to None, in which case, samples will be obtained from all domains.
    """
    TMA_MPP = 0.2522 
    _HDF5_CHUNK_SIZE = 256
    _TMA_DOMAIN_INDEX = 0
    def __init__(self, 
                 multi_domain_samples: List[List[Samples]], 
                 transforms: List[object],
                 input_patch_size_pixels: int = INPUT_IMAGE_SIZE_PIXELS,
                 domain_indices: Optional[List[int]] = None,
                ) -> None:
        super().__init__()
        self._transforms = Compose(transforms)
        self._multi_domain_samples = multi_domain_samples
        self._h5_folder = TemporaryDirectory()
        self._input_patch_size_pixels = input_patch_size_pixels
        self._num_domains = len(multi_domain_samples)
        self._domain_indices = domain_indices
        if self._domain_indices is None:
            self._domain_indices = list(range(len(multi_domain_samples)))
        

    def __len__(self):
        return max(len(x) for x in self._multi_domain_samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns: a tuple of torch.Tensor and the one-hot domain encoded vector."""
        domain_idx = np.random.choice(self._domain_indices)
        return self.get_item_with_domain_idx(idx = idx, domain_idx = domain_idx)
    
    def get_item_with_domain_idx(self, idx: int, domain_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        all_samples_in_selected_domain = self._multi_domain_samples[domain_idx]
        sample = all_samples_in_selected_domain[idx]
        im = self._get_slide_platform_patch(sample)
        
        domain_vector = np.zeros((self._num_domains,))
        domain_vector[domain_idx] = 1.0
        return self._transforms((im, torch.FloatTensor(domain_vector)))
         
    def _get_slide_platform_patch(self, sample: Samples) -> np.ndarray:
        """Gets a patch from the slide platform."""
        h5_file_name_for_slide_roi = self._h5_file_name_from_sample(sample)
        if not os.path.exists(h5_file_name_for_slide_roi):
            slide_id = sample.slide_id
            row_idx, col_idx = sample.row_idx, sample.col_idx
            slide_reference = SlideReference(int(slide_id))
            with slide_reference.read_object() as slide, h5py.File(h5_file_name_for_slide_roi, "w") as file:
                slide_num_rows, slide_num_cols, _ = slide.shape
                im = slide.view_at_mpp(slide.mpp, mpp_tolerance=0.05)

                # Make sure that we cover the same distance.
                row_slice = self._get_valid_slice_within_range(slide_num_rows, row_idx, self._input_patch_size_pixels)
                col_slice = self._get_valid_slice_within_range(slide_num_cols, col_idx, self._input_patch_size_pixels)
                patch = cv2.resize(im[row_slice, col_slice, :], (self._input_patch_size_pixels, self._input_patch_size_pixels)) / 255.0
                he_dataset = file.create_dataset(
                    name='he_roi',
                    shape=patch.shape,
                    dtype=patch.dtype)
                he_dataset[:,:,:] = patch
        else:
            with h5py.File(h5_file_name_for_slide_roi, "r") as file:
                patch = file['he_roi'][:,:,:]
        return patch
            
    def _h5_file_name_from_sample(self, sample: Samples) -> str:
        """Generates an h5 file name from the sample."""
        return os.path.join(self._h5_folder.name, f"slide_{sample.slide_id}_roi_row_{sample.row_idx}_col_{sample.col_idx}.h5")
    
    @staticmethod
    def _get_valid_slice_within_range(max_dimension_pixels: int, center_idx: int, slice_length: int) -> slice:
        first_idx = np.clip(center_idx - slice_length // 2, 0, max_dimension_pixels - slice_length)
        return slice(first_idx, first_idx + slice_length)

In [None]:
# Test the data loader.
dataset = MultiDomainDataset(
                multi_domain_samples=train_lists,
                transforms = [RandomFlip(), RandomRotate(), RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],
                domain_indices = [0],
            )
print(f"Dataset len = {len(dataset)}")
im, domain_vect = dataset[4]
print(f"Image size = {im.shape}")
print(f"Domain one-hot = {domain_vect}")
im = np.transpose(im.numpy().astype(np.float),(1,2,0))
plt.imshow(im)
plt.grid(None)
plt.show()

In [None]:
# Test the data loader.
dataset = MultiDomainDataset(
                multi_domain_samples=val_lists,
                transforms = [RandomFlip(), RandomRotate(), RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],
                domain_indices = [1],
            )
print(f"Dataset len = {len(dataset)}")
im, domain_vect = dataset[40]
print(f"Image size = {im.shape}")
print(f"Domain one-hot = {domain_vect}")
im = np.transpose(im.numpy().astype(np.float),(1,2,0))
plt.imshow(im)
plt.grid(None)
plt.show()

## 3. Initializer definition

In [None]:
from enum import Enum, auto
import torch.nn as nn
class InitilizationType(Enum):
    KAIMING = auto()
    XAVIER = auto()
    NORMAL = auto()
    
class Initializer:
    """A class that initializes the model weights.
    
    Args:
        init_type (optional): The type of the initialization. Defaults to InitilizationType.NORMAL.
        
    """
    def __init__(self, init_type: InitilizationType = InitilizationType.NORMAL, init_gain: float = 0.02) -> None:
        self._init_type = init_type
        self._init_gain = init_gain
        
    def __call__(self, m: nn.Module):
        m.apply(self._initialize_module)
        return m
    
    def _initialize_module(self, m: object) -> None:
        class_name = m.__class__.__name__
        if self._has_weights(m):
            if self._is_conv_layer(class_name) or self._is_linear_layer(m):
                if self._init_type == InitilizationType.KAIMING:
                    nn.init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
                elif self._init_type == InitilizationType.XAVIER:
                    nn.init.xavier_normal_(m.weight.data, gain=self._init_gain)
                elif self._init_type == InitilizationType.NORMAL:
                    nn.init.normal_(m.weight.data, mean=0, std=self._init_gain)
                else:
                    raise ValueError(f"Unknown initialization type!")
                
                if m.bias is not None:
                    nn.init.constant_(m.bias.data, val=0)
            if self._is_batchnorm2d_layer(class_name):
                #TODO: investigate why the mean of this is set to 1.0
                nn.init.normal_(m.weight.data, mean = 1.0, std = self._init_gain)
                nn.init.constant_(m.bias.data, val = 0)
                
    @staticmethod
    def _has_weights(m: object) -> bool:
        return hasattr(m, 'weights')
    
    @staticmethod
    def _is_conv_layer(cls_name: str) -> bool:
        return cls_name.find('Conv') != -1
    
    @staticmethod
    def _is_linear_layer(cls_name: str) -> bool:
        return cls_name.find('Linear') != -1
    
    @staticmethod
    def _is_batchnorm2d_layer(cls_name: str) -> bool:
        return cls_name.find('BatchNorm2d') != -1
    

## 4. Network definition

In [None]:
import torch.nn as nn
from torch.nn.utils import spectral_norm

In [None]:
class LeakyReLUConv2d(nn.Module):
    """A class that defines a Conv2d follows by Leaky ReLU.
    
        [Conv2d with reflection pad with optional spectral normalization] -> [LeakyReLU].
        
    Args:
        in_channels: The number of input channels.
        out_channels: The number of output channels.
        kernel_size: The size of the 2D convolutional kernel.
        stride: The stride of the 2D convolutional kernel.
        padding: The size of the padding.
        enable_spectral_normalization (optional): If True, the spectral normalization will be perform. Defaults to False.
            See https://arxiv.org/pdf/1802.05957.pdf for relating details.
        enable_instance_norm (optional): If True, instance normalization will be used. Defaults to False.
        
    """
    def __init__(self, in_channels: int, 
                 out_channels: int, 
                 kernel_size: int, 
                 stride: int, 
                 padding: int, 
                 enable_spectral_normalization: bool = False,
                 enable_instance_norm: bool = False,
                ) -> None:
        super().__init__()
        conv2d_layer = nn.Conv2d(in_channels=in_channels, 
                                 out_channels=out_channels, 
                                 kernel_size=kernel_size, 
                                 stride=stride, 
                                 padding=padding, 
                                 padding_mode='reflect',
                                 bias=True)
        
        if enable_spectral_normalization:
            conv2d_layer = spectral_norm(conv2d_layer)
        
        layers = [conv2d_layer,]
        
        if enable_instance_norm:
            layers.append(nn.InstanceNorm2d(out_channels, affine=False))
        layers.append(nn.LeakyReLU(inplace=True))
            
        self._model = nn.Sequential(*layers)
        self._model = Initializer(init_type=InitilizationType.NORMAL)(self._model)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._model(x)

In [None]:
class ReLUInstNorm2dConv2d(nn.Module):
    """A class that defines the following transformation.
    
        [Conv2d with reflection padding] -> [InstanceNorm2d] -> [ReLU]
        
    Args;
        in_channels: The number of input channels.
        out_channels: The number of output channels.
        kernel_size: The size of the 2D convolutional kernel.
        stride: The stride of the 2D convolutional kernel.
        padding: The size of the padding.
    """
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int) -> None:
        super().__init__()
        self._model = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                padding_mode='reflect',
                bias=True,
            ),
            
            nn.InstanceNorm2d(
                num_features=out_channels,
                affine=False,
            ),
            nn.LeakyReLU(inplace=True),
        )
        self._model = Initializer(init_type=InitilizationType.NORMAL)(self._model)
       
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._model(x)

In [None]:
class ResInstNorm2dConv2d(nn.Module):
    """A class that defines the residulal condition block that has the following architecture.

        - -> [Conv2d(3x3) -> [InstanceNorm2d] -> [ReLU()] -> [Conv2d(3x3)] -> [InstanceNorm2d] -> + ->
         |                                                                                        ^
         |----------------------------------------------------------------------------------------|
         
     The numberd of input and output channels are the same.
     
     Args;
        in_channels: The number of input channels.
    """
    def __init__(self, in_channels: int) -> None:
        super().__init__()
        self._model = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                padding_mode='reflect',
                bias=True,
            ),
            nn.InstanceNorm2d(in_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                padding_mode='reflect',
                bias=True,
            ),
            nn.InstanceNorm2d(in_channels),
        )
        self._model = Initializer(init_type=InitilizationType.NORMAL)(self._model)
        
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self._model(x)

In [None]:
class ContentEncoder(nn.Module):
    """A class that defines the architecture for the content encoder.
    
    This architecture of this model is:
        [LeakyReLUConv2d] -> 3 x [ReLUInstanceNormConv2d] -> 3 x [ResInstNorm2dConv2d]
    Args:
        in_channels: The number of input channels.
        out_channels (optional): The number of output channels. Defaults to NUM_STAIN_VECTORS.
        
    Returns:
        For each minibatch with N samples, it returns a tensor of size N x out_channels x H x W in which 
        H and W is a downsampled version of the input image.
    """
    def __init__(self, in_channels: int, num_stain_vectors: int=NUM_STAIN_VECTORS) -> None:
        super().__init__()
        self._model = nn.Sequential(
            LeakyReLUConv2d(in_channels=in_channels, out_channels = 64, kernel_size = 7, stride = 1, padding = 3),
            ReLUInstNorm2dConv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 2, padding = 1),
            ReLUInstNorm2dConv2d(in_channels = 128, out_channels = num_stain_vectors, kernel_size = 3, stride = 2, padding = 1),
            ResInstNorm2dConv2d(in_channels = num_stain_vectors),
            ResInstNorm2dConv2d(in_channels = num_stain_vectors),
            ResInstNorm2dConv2d(in_channels = num_stain_vectors),
            ResInstNorm2dConv2d(in_channels = num_stain_vectors),
            nn.LeakyReLU(inplace=True),
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._model(x)

In [None]:
# Test ContentEncoder.
m = ContentEncoder(in_channels=3, num_stain_vectors = NUM_STAIN_VECTORS)
x = torch.randn(1, 3, INPUT_IMAGE_SIZE_PIXELS, INPUT_IMAGE_SIZE_PIXELS)
y = m(x)
print(f"Input shape = {x.shape}, output shape =  {y.shape}")

### Attribute encoder.
This network aims to extract the stain vector.

In [None]:
import torch.nn.functional as F
class Downsampling2xkWithSkipConnection(nn.Module):
    """A class that is a basic block to build the convolutional layers to estimate the parameters of the conditional attribute distribution.
    
    This block implements the following architecture
    x -> -> LeakyReLU -> Conv2d -> LeakyReLU -> Conv2d -> AvgPool2d -> + -> output
        |                                                              ^
        |-------------> AvgPool2d -> Conv2d ---------------------------|
    Args:
        in_channels: The number of input channels.
        out_channels: The number of output channels.
    Reference:
        https://github.com/HsinYingLee/MDMM/blob/18360fe3fa37dde28c70c5a945ec783e44eb72ed/networks.py#L334
    """
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self._forward_block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=True),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=True),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )
        self._skip_block = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=True)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._forward_block(x) + self._skip_block(x)


class StainVectorEstimator(nn.Module):
    """A class that estimate the stain matrix.
    
    The networks takes an input image returns a stain vector matrix of size N x (3k^2) x num_stain_vectors.
    To generate an output absorbance image, we need to resphae the stain vector into size
    (3k^2) * num_stain_vectors and multiply it to a content matrix of num_stain_vectors x H x W to obtain an output image of size
    3k^2 * H x W. Then, using the pixel shuffling, we can reduce it to 3 x (kW) x (kH)
    
    Args:
        in_channels: The number of input channels.
        out_channels (optional): The number of output channels for the attribute vector. Defaults to NUM_ATTRIBUTES.
        downsampling_factor (optional): A factor that describe how much of the image is downsampled. Defaults to 4.
    Reference:
        https://github.com/HsinYingLee/MDMM/blob/master/networks.py#L64
    """
    def __init__(self, in_channels: int, num_stain_vectors: int = NUM_STAIN_VECTORS, downsampling_factor: int = 4) -> None:
        super().__init__()
        self._num_stain_vectors = num_stain_vectors
        self._three_times_k_sqr = (3 * downsampling_factor**2)
        self._model = Initializer()(
            nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, padding_mode='reflect', bias=True),
                Downsampling2xkWithSkipConnection(in_channels=64, out_channels=128),
                Downsampling2xkWithSkipConnection(in_channels=128, out_channels=256),
                Downsampling2xkWithSkipConnection(in_channels=256, out_channels=512),
                Downsampling2xkWithSkipConnection(in_channels=512, out_channels=1024),
                nn.AdaptiveAvgPool2d(output_size=(1, 1)), # Condense all the X, Y dimensions to 1 pixels. 
                nn.Conv2d(
                    in_channels=1024,
                    out_channels=self._num_stain_vectors * self._three_times_k_sqr,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    bias=True),
                nn.LeakyReLU(inplace=False),
            )
        )
        
    def forward(self, ims: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns the mean and log of the variance for the conditional distribution z^a|ims."""
        x1 = self._model(ims)
        x2 = x1.view(x1.size(0), self._three_times_k_sqr, self._num_stain_vectors)
        # Prevent the attribute from collapsing by normaling it withrespect to the band dimension
        return x2

In [None]:
class AbsorbanceImGenerator(nn.Module):
    """A class that computes the product between the content and the attribute.
    
    This class generates a synthetic image G(z_c, z_a) from the content tensor z_c, attribute tensor z_a.
    Args:
        out_channels: The number of output channels for the generator.
        downsampling_factor (optional): A factor that describe how much of the image is downsampled. Defaults to 2.
    """
    _NUM_FEATURES_PER_FRACTION = 256
    def __init__(self, 
                 downsampling_factor: int = 4,
                ) -> None:
        super().__init__()
        self._shuffle_layer = torch.nn.PixelShuffle(upscale_factor=downsampling_factor)
        
    def forward(self, z_c: torch.Tensor, z_a: torch.Tensor) -> torch.Tensor:
        """Generates an image based by the content and the attribute tensor.
        
        Args:
            z_c: The content image tensor, which contains abundance information of the stain. The size should be of size (N, num_stain_vectors, H, W).
            z_a: The attribute image tensor. The size should be of size (N, (3k^2), num_stain_vectors)
        
        Returns:
            The generated image tensor of size (N, 3, k*H, k*W)
        """
        if z_c.size(1) != z_a.size(2):
            raise ValueError(f"The number of elements in first dimension of z_c must match the number of elements in the 2nd dimension of z_a")
        num_rows, num_cols = z_c.size(2), z_c.size(3)
        x = torch.bmm(z_a, z_c.view(z_c.size(0), z_c.size(1), -1))
        x = x.view(x.size(0), x.size(1), num_rows, num_cols)
        return self._shuffle_layer (x)

In [None]:
class RealFakeDiscriminator(nn.Module):
    """A discriminator that aims to classify if an image is real or generated image and predicts the class encoded logits of the input image.
    
    Args:
        in_channels: The number of channels for the input image.
        
    """
    def __init__(
        self,
        in_channels: int,
    ) -> None:
        super().__init__()
        self._downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
        self._encoder = self._make_network(in_channels=in_channels)
        self._real_fake_predictor = Initializer()(nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False))
        
    @staticmethod
    def _make_network(in_channels: int) -> nn.Module:
        """Make a single scale discriminator."""
        return nn.Sequential(
            LeakyReLUConv2d(in_channels=in_channels, out_channels=64, kernel_size=3, padding=1, stride=2, enable_spectral_normalization=True, enable_instance_norm=False),
            LeakyReLUConv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, enable_spectral_normalization=True, enable_instance_norm=False),
            LeakyReLUConv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=2, enable_spectral_normalization=True, enable_instance_norm=False),
            LeakyReLUConv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1, stride=2, enable_spectral_normalization=True, enable_instance_norm=False),
            LeakyReLUConv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1, stride=2, enable_spectral_normalization=True, enable_instance_norm=False),
            LeakyReLUConv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1, stride=2, enable_spectral_normalization=True, enable_instance_norm=False),
         )
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self._encoder(x)
        return self._real_fake_predictor(x)
    

## 5. Datamodule definition

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader

In [None]:
_DOMAIN_INDICES = None

In [None]:
class MultiDomainDataModule(pl.LightningDataModule):
    """A class that defines the data module for DRIT training and validation.
        multi_domain_train_samples: A list of lists that contains the training samples. One inner list for 1 domain.
        multi_domain_val_samples: A list of lists that contains the validation samples. One inner list for 1 domain.
        batch_size (optional): The number of samples per batch. Defaults to 32.
        num_dataloading_workers (optional): The number of worker for data loading. Defaults to 4.
        use_pin_memory (optional): If True, pinned memory will be used. Defaults to True.

    """
    def __init__(
        self, 
        multi_domain_train_samples: List[List[Samples]],
        multi_domain_val_samples: List[List[Samples]],
        batch_size: int = 32,
        num_dataloading_workers: int = 4,
        use_pin_memory: bool = True,
    ) -> None:
        super().__init__()
        self._multi_domain_train_samples = multi_domain_train_samples
        self._multi_domain_val_samples = multi_domain_val_samples
        self._batch_size = batch_size
        self._num_dataloading_workers = num_dataloading_workers
        self._use_pin_memory = use_pin_memory
        
    def setup(self, stage: str) -> None:
        if stage == "fit":
            self._train_dataset = MultiDomainDataset(
                multi_domain_samples=self._multi_domain_train_samples, 
                transforms = [RandomFlip(), RandomRotate(), RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],
                domain_indices=_DOMAIN_INDICES,
            )
            self._val_dataset = MultiDomainDataset(
                multi_domain_samples=self._multi_domain_val_samples,
                transforms = [RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],
                domain_indices=_DOMAIN_INDICES,
            )
        else:
            raise ValueError("MultiDomainDataModule datamodule is not defined for non-fit stages!")
    
    
    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self._train_dataset,
            batch_size=self._batch_size,
            shuffle=True,
            num_workers = self._num_dataloading_workers,
            pin_memory=self._use_pin_memory
        )
    
    
    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self._val_dataset,
            batch_size=self._batch_size,
            shuffle=False,
            num_workers = self._num_dataloading_workers,
            pin_memory=self._use_pin_memory
        )

## 6. Training module definition

In [None]:
import os
from collections import OrderedDict
import pytorch_lightning as pl
from pathlib import Path
from contextlib import ExitStack
from tempfile import TemporaryDirectory
from pathai.dynamic import NodeFamilies
from pathai.research_dev.modules.module_factory import torch_factory

In [None]:
class Loss:
    """A class that defines the loss for the DRIT.
    
    Args:
        real_fake_weight (optional): The weight for penalizing a real vs. fake image. Defaults to 1.0.
        recon_weight (optional): The weight for the reconstruction losses. Defaults to 10.0.
        content_consistency_weight (optional): The weight for L2 regularization for content consistency. Defaults to 1.0.
        attr_consistency_weight (optional): The weight for the latent regression loss. Defaults to 10.0.
        mode_seeking_loss_weight (optional): The weight for the mode seeking loss. Defaults to 1.0.
    """
    def __init__(self, 
                 real_fake_weight: float = 1.0,
                 recon_weight: float = 1.0,
                 content_consistency_weight: float = 1.0,
                 attr_consistency_weight: float = 10.0,
                 mode_seeking_loss_weight: float = 1.0,
                ) -> None:
    
        self._real_fake_weight: float = real_fake_weight
        self._recon_weight: float = recon_weight
        self._content_consistency_weight: float = content_consistency_weight
        self._attr_consistency_weight: float = attr_consistency_weight
        self._mode_seeking_loss_weight: float = mode_seeking_loss_weight
        
        
    def compute_generator_and_encoder_losses(self, 
                                             forward_outputs: Dict[str, torch.Tensor], 
                                             real_vs_cross_translation_disc: nn.Module,
                                             num_permutations_for_fake_images: int
                                            ) -> Dict[str, torch.Tensor]:
        """Computes a dictionary of the losses, keyed by the loss types.
        
        Args:
            forward_outputs: A dictionary of the output tensors from the forward pass, keyed by the name of the outputs.
            real_vs_cross_translation_disc: The discriminator that discrimates between the real and the fake cross-translation images.
            num_permutations_for_fake_images: The number of permutation used to generate fake images.
        """
        fake_one_time_ims_pred_logits_with_real_vs_cross_trans_disc = real_vs_cross_translation_disc.forward(forward_outputs['fake_one_time_cross_translation_ims'])
        
        real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc = self._real_fake_loss_for_images(
            predicted_logits=fake_one_time_ims_pred_logits_with_real_vs_cross_trans_disc,
            target_disc_label=1  # Ref: https://github.com/HsinYingLee/MDMM/blob/master/model.py#L292.
        ) / num_permutations_for_fake_images
        
        self_recon_loss = self._consistency_loss(
            real_ims=forward_outputs['real_ims'],
            target_ims=forward_outputs['self_recon_ims'],
        )
        
        cont_consistency_loss = (self._latent_regression_loss(forward_outputs['swapped_zc'], forward_outputs['z_c_from_fake_ims']) + \
            self._latent_regression_loss(forward_outputs['z_cs'], forward_outputs['z_c_from_recon_ims'])) / (num_permutations_for_fake_images + 1)
       
        attribute_consistency_loss = (
            self._latent_regression_loss(forward_outputs['z_as'].repeat(num_permutations_for_fake_images, 1, 1), forward_outputs['z_a_from_fake_ims']) + \
            self._latent_regression_loss(forward_outputs['z_as'], forward_outputs['z_a_from_recon_ims'])
        ) / (num_permutations_for_fake_images + 1)
        
        mode_seeking_loss = self._mode_seeking_regularization_loss(zas=forward_outputs['z_as'], ims=forward_outputs['real_ims'])
        
        total_loss = self._real_fake_weight * real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc + \
            self._recon_weight * self_recon_loss + \
            self._content_consistency_weight * cont_consistency_loss + \
            self._attr_consistency_weight * attribute_consistency_loss + self._mode_seeking_loss_weight * mode_seeking_loss
        return {
            'loss': total_loss,
            'real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc': real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc,
            'self_recon_loss': self_recon_loss,
            'cont_consistency_loss': cont_consistency_loss,
            'attribute_consistency_loss': attribute_consistency_loss,
            'mode_seeking_loss': mode_seeking_loss,
            'batch_size': len(forward_outputs['real_ims']),
        }
        
        
    @staticmethod
    def _consistency_loss(real_ims: torch.Tensor, target_ims: torch.Tensor) -> float:
        """Computes the  consistency loss between the real images and the cycle reconstructed images.
        
        Args:
            real_ims: The real images.
            target_ims: The target images to compare to.
        
        References:
            https://github.com/HsinYingLee/MDMM/blob/master/model.py#L254
        """
        return nn.L1Loss(reduction='sum')(real_ims, target_ims) / (real_ims.size(1) * real_ims.size(2) * real_ims.size(3))
         
    @staticmethod
    def _latent_regression_loss(z1: torch.Tensor, z2: torch.Tensor) -> float:
        """Computes the latent regression loss, which is an L1 distance between the known encoding tensor 'known_zas' to the mean of the conditional p(z_a|x)."""
        return nn.L1Loss(reduction='mean')(z1, z2) * z1.size(1)
        
    @staticmethod
    def _mode_seeking_regularization_loss(zas: torch.Tensor, ims: torch.Tensor) -> float:
        half_batch_size = zas.size(0) // 2
        za_dist = nn.L1Loss(reduction='sum')(zas[:half_batch_size], zas[half_batch_size:]) /  zas.size(1)
        im_dist = nn.L1Loss(reduction='sum')(ims[:half_batch_size], ims[half_batch_size:]) / (ims.size(1) * ims.size(2) * ims.size(3))
        return im_dist / za_dist
    
    def compute_real_fake_discriminator_losses(self, 
                                               forward_outputs: Dict[str, torch.Tensor], 
                                               real_vs_cross_translation_disc: nn.Module,
                                               num_permutations_for_fake_images: int,
                                              ) -> Dict[str, torch.Tensor]:
        """Computes a dictionary of the losses for the real vs. fake discriminators, keyed by the loss types.
        
        Args:
            forward_outputs: A dictionary of the output tensors from the forward pass, keyed by the name of the outputs.
            real_vs_cross_translation_disc: The discriminator that discrimates between the real and the fake cross-translation images.
            num_permutations_for_fake_images: The number of permutation used to generate fake images.
        """
        # The detach() command below makes sure that we have the tensor available after the backward pass of the discriminator so that we can update the generator.
        pred_logits_real_vs_cross_trans_disc = real_vs_cross_translation_disc.forward(
            torch.cat([forward_outputs['real_ims'].detach(), forward_outputs['fake_one_time_cross_translation_ims'].detach()], dim = 0)
        )
        
        num_real_samples = forward_outputs['real_ims'].size(0)
        
        real_fake_loss_real_ims_with_real_vs_cross_trans_disc = self._real_fake_loss_for_images(
            predicted_logits=pred_logits_real_vs_cross_trans_disc[:num_real_samples],
            target_disc_label=1
        )
        
        real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc = 0
        for fake_image_batch_idx in range(num_permutations_for_fake_images):
            real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc += self._real_fake_loss_for_images(
                predicted_logits=pred_logits_real_vs_cross_trans_disc[num_real_samples * (fake_image_batch_idx + 1):num_real_samples * (fake_image_batch_idx + 2)],
                target_disc_label=0
            )
        # We need to account for the fake that we have more fake images.
        real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc = real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc / num_permutations_for_fake_images
        
        total_loss = (real_fake_loss_real_ims_with_real_vs_cross_trans_disc + real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc)
        return {
            'loss': total_loss,
            'real_fake_loss_real_ims_with_real_vs_cross_trans_disc': real_fake_loss_real_ims_with_real_vs_cross_trans_disc,
            'real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc': real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc,
            'batch_size': len(forward_outputs['real_ims']),
        }    
        
            
    @staticmethod
    def _real_fake_loss_for_images(predicted_logits: torch.Tensor, target_disc_label: int) -> float:
        """Computes the adversarial loss when the discriminator is trying to discrimate between the reals and the fake images.
        
        The loss is given as -{Sum_over_real_images [log(sigmoid(disc(real)))] + Sum_over_fake_images [log(1 - sigmoid(disc(real)))]}
        The loss is per-pixel averaged and summed over instances in the mini-batch.
        
        Args:
            predicted_logits: The outputs logits predicted from a discriminator that is used to tell if an image is real or fake. 
            target_disc_label: The target label that we want the output of the discriminator to be. This can be 0 for Fake and 1 for real image.
        """
        if target_disc_label == 0:
            target_labels = torch.zeros_like(predicted_logits, device=predicted_logits.device)
        else:
            target_labels = torch.ones_like(predicted_logits, device=predicted_logits.device)
        return nn.BCEWithLogitsLoss(reduction='sum')(predicted_logits, target_labels) / (predicted_logits.size(2) * predicted_logits.size(3))

In [None]:
import torch.nn as nn
import itertools
import logging
import cv2
class MutliClassTrainingModule(pl.LightningModule):
    """A class that defines the training module for multi-class image translation.
    
    Args:
        num_input_channels: The number of the input channels.
        train_hyperparams: A dictionary that defines different hyperparameters for the training.
        test_dataset (optional): The dataset that contains images that is used to visualize the performance of the network over epochs of the training. Defaults to None.
    """
    _RANDOM_SEED = 1
    _NUM_DISC_UPDATE_PER_ITERATION = 1
    _NUM_PERMUTATION_FOR_FAKE_IMAGES = 3
    def __init__(self, 
                 num_input_channels: int,
                 train_hyperparams: Dict[str, Any],
                 test_dataset: Optional[MultiDomainDataset] = True,
                 **kwargs
                 ) -> None:
        super().__init__(**kwargs)
        self._num_input_channels: int = num_input_channels
        self._train_hyperparams = train_hyperparams
        self._number_gen_optimization_steps_to_update_disc = train_hyperparams['number_gen_optimization_steps_to_update_disc']
        
        if self._number_gen_optimization_steps_to_update_disc > 1:
            print(f"[WARNING] Please make sure that you have at least {self._number_gen_optimization_steps_to_update_disc} minibatches so that the discrimonator are updated")
        self._enc_c = ContentEncoder(in_channels=3, num_stain_vectors = NUM_STAIN_VECTORS)
        self._enc_a = StainVectorEstimator(in_channels=3, num_stain_vectors=NUM_STAIN_VECTORS)
        self._gen = AbsorbanceImGenerator()
        self._encoders_gen_params = itertools.chain([*self._enc_c.parameters(), *self._enc_a.parameters(), *self._gen.parameters()])
        self._disc1 = RealFakeDiscriminator(in_channels=num_input_channels)
        
        if train_hyperparams['pretrained_model_path'] is not None:
            self._enc_c = load_trained_model_from_checkpoint(train_hyperparams['pretrained_model_path'], network=self._enc_c, 
                                                            starts_str = "_enc_c.")
            self._enc_a = load_trained_model_from_checkpoint(train_hyperparams['pretrained_model_path'], network=self._enc_a, 
                                                            starts_str = "_enc_a.")
            self._gen = load_trained_model_from_checkpoint(train_hyperparams['pretrained_model_path'], network=self._gen, 
                                                            starts_str = "_gen.")
            self._disc1 = load_trained_model_from_checkpoint(train_hyperparams['pretrained_model_path'], network=self._disc1, 
                                                            starts_str = "_disc1.")
        self._discs_params = self._disc1.parameters()
        self._loss = Loss(**train_hyperparams['loss_weights_by_name'])
        
        self._test_dataset = test_dataset
        if test_dataset is not None:
            self._temp_save_folder: str = TemporaryDirectory()
            print(f"Training results will be saved to {self._temp_save_folder.name}")
            
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int) -> Dict[str, torch.Tensor]:
        outputs = self._compute_network_forward_outputs(batch)
        if optimizer_idx == 0:
            self._set_requires_gradients([self._enc_c, self._enc_a, self._gen], requires_grad=True)
            self._set_requires_gradients([self._disc1], requires_grad=False)
            return self._loss.compute_generator_and_encoder_losses(
                forward_outputs=outputs,
                real_vs_cross_translation_disc=self._disc1,
                num_permutations_for_fake_images=MutliClassTrainingModule._NUM_PERMUTATION_FOR_FAKE_IMAGES,
            )
        
        if optimizer_idx == 1:
            self._set_requires_gradients([self._enc_c, self._enc_a, self._gen], requires_grad=False)
            self._set_requires_gradients([self._disc1], requires_grad=True)
            return self._loss.compute_real_fake_discriminator_losses(
                forward_outputs=outputs,
                real_vs_cross_translation_disc=self._disc1,
                num_permutations_for_fake_images=MutliClassTrainingModule._NUM_PERMUTATION_FOR_FAKE_IMAGES,
            )
            
            
    def _compute_network_forward_outputs(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, torch.Tensor]:
        minibatch_size = batch[0].shape[0]
        self._ensure_minibatch_size_is_even(minibatch_size)
        real_ims, _ = batch
        z_cs = self._enc_c(real_ims)
        z_as = self._enc_a(real_ims)
        
        num_permutations = MutliClassTrainingModule._NUM_PERMUTATION_FOR_FAKE_IMAGES
        swapped_zc = self._generate_permutations(z_cs, num_permutations=num_permutations)
        # First time cross-translation
        input_z_cs = torch.cat([z_cs, swapped_zc], dim=0)
        input_z_as = torch.cat([z_as] * (num_permutations + 1), dim=0)
        
        out_fakes = self._gen(input_z_cs, input_z_as)
        self_recon_ims, fake_one_time_cross_translation_ims = torch.split(out_fakes, [minibatch_size, minibatch_size * num_permutations], dim=0)
        
        # Extract the content tensors.
        z_c_from_recon_ims, z_c_from_fake_ims = torch.split(self._enc_c(out_fakes), [minibatch_size, minibatch_size * num_permutations], dim=0)
        
        z_a_from_recon_ims, z_a_from_fake_ims = torch.split(self._enc_a(out_fakes), [minibatch_size, minibatch_size * num_permutations], dim=0)
        return {
            'real_ims': real_ims,
            'z_cs': z_cs,
            'swapped_zc': swapped_zc,
            'z_c_from_fake_ims': z_c_from_fake_ims,
            'z_c_from_recon_ims': z_c_from_recon_ims,
            
            'self_recon_ims': self_recon_ims,
            'fake_one_time_cross_translation_ims': fake_one_time_cross_translation_ims,
            
            'z_as': z_as,
            'z_a_from_recon_ims': z_a_from_recon_ims,
            'z_a_from_fake_ims': z_a_from_fake_ims,
        }
    
    @staticmethod
    def _generate_permutations(minibatch_samples: torch.Tensor, num_permutations: int = 1) -> torch.Tensor:
        num_samples = minibatch_samples.size(0)
        all_permutations: List[torch.Tensor]=[]
        for _ in range(num_permutations):
            sample_permutation = torch.randperm(num_samples, device = minibatch_samples.device)
            sample_permutation = sample_permutation.view(num_samples, 1, 1, 1)
            sample_permutation = sample_permutation.repeat(1, minibatch_samples.size(1), minibatch_samples.size(2), minibatch_samples.size(3))
            all_permutations.append(torch.gather(minibatch_samples, 0, sample_permutation))
        return torch.cat(all_permutations, dim=0)
    
   

    @staticmethod
    def _sample_attribute_vectors_from_gaussian_distribution(mus: torch.Tensor, logvars: torch.Tensor, over_sampling_factor: int = 1) -> torch.Tensor:
        std = logvars.mul(0.5).exp()
        z = MutliClassTrainingModule._sample_normal_distrbution(num_samples=mus.size(0) * over_sampling_factor, attr_dim=mus.size(1), target_device=mus.device)
        z = torch.randn(mus.shape, device=mus.device)
        z = z.mul(std).add(mus)
        return z.view(z.size(0), z.size(1), 1, 1)
    
    @staticmethod
    def _aggregate_losses_from_generator_and_discrimonator_loss_dicts(gen_loss_dict: Dict[str, torch.Tensor], disc_loss_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Returns a new dictionary of loss terms which is the combination of losses from two loss dictionaries."""
        losses_by_name: Dict[str, torch.Tensor] = {
            'gen_loss': gen_loss_dict['loss'],
            'disc_loss': disc_loss_dict['loss'],
        }
        
        for loss_dict in [gen_loss_dict, disc_loss_dict]:
            for k, v in loss_dict.items():
                if k != 'loss':
                    losses_by_name[k] = v
        if set(gen_loss_dict.keys()).intersection(disc_loss_dict.keys()) != {"loss", "batch_size"}:
            raise ValueError("Generator loss dictionary and discriminator loss dictionary have overlapping names!")
        return losses_by_name
                
        
    def training_step_end(self, workers_outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Aggregates the results from the training steps across multiple workers in the same batch.

        This is required for gradient descent to work. Otherwise, we would get "RuntimeError: grad can be implicitly
        created only for scalar outputs" because the loss is not a scalar (i.e., tensor with one element); instead it
        would be a tensor with size equal to the number of workers.
        
        Args:
            workers_outputs: A dictionary that contains the outputs from multiple workers.
        """
        return {k: v.sum() for k, v in workers_outputs.items()}
            
          
    def training_epoch_end(self, batch_outputs: List[List[Dict[str,float]]]) -> None:
        """Combines the loss from all batches.
        
        Args:
            batch_outputs: A list of list in which one item is for one optimizer idx. Each item in the inner
                lists is a Dictionary with key equals to 'loss'.
        """
        self._set_random_seed()
        total_num_samples = torch.stack([b['batch_size'] for b in batch_outputs[0]]).sum()
        avg_generator_and_encoder_loss = torch.stack([b['loss'] for b in batch_outputs[0]]).sum() / total_num_samples    
        avg_real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc = torch.stack([b['real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc'] for b in batch_outputs[0]]).sum() / total_num_samples
        avg_self_recon_loss = torch.stack([b['self_recon_loss'] for b in batch_outputs[0]]).sum() / total_num_samples
        avg_cont_consistency_loss = torch.stack([b['cont_consistency_loss'] for b in batch_outputs[0]]).sum() / total_num_samples
        avg_attribute_consistency_loss= torch.stack([b['attribute_consistency_loss'] for b in batch_outputs[0]]).sum() / total_num_samples       
        avg_mode_seeking_loss = torch.stack([b['mode_seeking_loss'] for b in batch_outputs[0]]).sum() / total_num_samples
        avg_train_disc_total_loss = torch.stack([b['loss'] for b in batch_outputs[1]]).sum() / total_num_samples
        avg_real_fake_loss_real_ims_with_real_vs_cross_trans_disc= torch.stack([b['real_fake_loss_real_ims_with_real_vs_cross_trans_disc'] for b in batch_outputs[1]]).sum() / total_num_samples
        avg_real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc= torch.stack([b['real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc'] for b in batch_outputs[1]]).sum() / total_num_samples
        
        print(f"\n-> Epoch {self.current_epoch}: train_encoders_generators_total_loss: {avg_generator_and_encoder_loss:.3f}, " + \
              f"\n        BCE[D_1(cross_tran), 'real'): {avg_real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc:.3f}" + "  -> Expected value: -log(0.5) = 0.6931." \
              f"\n        self_recon_loss: {avg_self_recon_loss:.3f} / cont_consistency_loss: {avg_cont_consistency_loss:.3f} / attribute_consistency_loss: {avg_attribute_consistency_loss:.3f} / mode_seeking_loss: {avg_mode_seeking_loss:.3f}"
             ) 
        
        print(f"-> train_disc_total_loss: {avg_train_disc_total_loss:.3f}" + \
              f"\n        BCE[D_1(real_image), 'real'): {avg_real_fake_loss_real_ims_with_real_vs_cross_trans_disc:.3f} / BCEL[D_1(cross_trans), 'fake']: {avg_real_fake_loss_cross_trans_ims_with_real_vs_cross_trans_disc:.3f}" + " -> Expected value: -log(0.5) = 0.6931.")
            
        self.log("train_encoders_generators_total_loss", avg_generator_and_encoder_loss)
        self.log("train_disc_total_loss", avg_train_disc_total_loss)
        if self._test_dataset is not None and self.current_epoch % 10 == 0:
            self._generate_inference_results()
    
    @staticmethod
    def _set_random_seed():
        # Make sure that the same set of random tensors are initialized at each epochs to maintain the same dataset for the optimization of the loss function.
        torch.manual_seed(MutliClassTrainingModule._RANDOM_SEED)
        torch.cuda.manual_seed(MutliClassTrainingModule._RANDOM_SEED)
                
    def _generate_inference_results(self) -> None:
        """Generates the inference results for debugging."""
        tensor_im_0, _ = self._test_dataset.get_item_with_domain_idx(idx=0, domain_idx=0)
        tensor_im_1, _ = self._test_dataset.get_item_with_domain_idx(idx=100, domain_idx=0)
        im = tensor_im_0.numpy().transpose(1, 2, 0)
        gpu_device = torch.device("cuda:0")
        tensor_im_0 = tensor_im_0[None, ...].to(gpu_device)
        tensor_im_1 = tensor_im_1[None, ...].to(gpu_device)
        
        
        with torch.no_grad():
            z_c = self._enc_c.to(gpu_device)(tensor_im_0)
            z_a_0 = self._enc_a.to(gpu_device)(tensor_im_0)
            z_a_1 = self._enc_a.to(gpu_device)(tensor_im_1)
            
            recon_im = self._gen.to(gpu_device)(z_c, z_a_0)
            cross_recon_im = self._gen.to(gpu_device)(z_c, z_a_1)
        
        num_total_images = 4
        num_rows, num_cols, num_chans = im.shape
        out_ims = np.zeros((num_rows, num_total_images * num_cols, num_chans), dtype=np.float)
        out_ims[:, :num_cols] = 10**(-im)
        out_ims[:, num_cols: 2 * num_cols] = 10**(-recon_im.cpu().numpy()[0].transpose(1,2,0))
        out_ims[:, 2 * num_cols: 3 * num_cols] = 10**(-tensor_im_1[0].cpu().numpy().transpose(1, 2, 0))
        out_ims[:, 3 * num_cols:] = 10**(-cross_recon_im.cpu().numpy()[0].transpose(1,2,0))
        out_ims = (np.clip(out_ims, 0.0, 1.0) * 255.0).astype(np.uint8)
        image_name = os.path.join(self._temp_save_folder.name, f"training_res_epoch_{self.current_epoch}.png")
        cv2.imwrite(image_name, out_ims)
        print(f"Save debug image {image_name}")
        
        
    @staticmethod
    def _set_requires_gradients(networks: List[nn.Module], requires_grad: bool) -> None:
        """Sets the status for gradient calculation for the networks.
        
        Args:
            networks: A list of neural networks to set the gradient.
            requires_grad: The value of required gradient to set. If False, gradient will not be calculated.
        """
        for net in networks:
            for param in net.parameters():
                param.requires_grad = requires_grad
        
    def configure_optimizers(self):
        # Optimizer configuration
        # See the doc at https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        beta_values = (0.5, 0.999)
        encoders_generator_opt = torch.optim.SGD(self._encoders_gen_params, lr=train_hyperparams['gen_learning_rate'], weight_decay=0.0001)
        discs_opt = torch.optim.SGD(self._discs_params, lr=train_hyperparams['disc_learning_rate'], weight_decay=0.0001)
        return {"optimizer": encoders_generator_opt, "lr_scheduler": {
            "scheduler": torch.optim.lr_scheduler.ExponentialLR(encoders_generator_opt, 0.95,  verbose=True),
            "interval": "epoch",
            "frequency": self._train_hyperparams["number_of_steps_to_update_lr"],
        }
        }, {"optimizer": discs_opt, "lr_scheduler": {
            "scheduler": torch.optim.lr_scheduler.ExponentialLR(discs_opt, 0.95,  verbose=True),
            "interval": "epoch",
            "frequency": self._train_hyperparams["number_of_steps_to_update_lr"],
        }}
    
    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure,
        on_tpu=False,
        using_native_amp=False,
        using_lbfgs=False,
    ) -> None:
        # See https://pytorch-lightning.readthedocs.io/en/latest/common/optimization.html for more information on how to optimize this correctly.
        if optimizer_idx == 0:
            if (batch_idx + 1) % self._number_gen_optimization_steps_to_update_disc == 0:
                optimizer.step(closure=optimizer_closure)
            else:
                optimizer_closure()
        
        # Update the discriminator for each iteration
        if optimizer_idx == 1:
            optimizer.step(closure=optimizer_closure)
            

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, float]:
        outputs = self._compute_network_forward_outputs(batch)
        return self._loss.compute_generator_and_encoder_losses(
                forward_outputs=outputs,
                real_vs_cross_translation_disc=self._disc1,
                num_permutations_for_fake_images=MutliClassTrainingModule._NUM_PERMUTATION_FOR_FAKE_IMAGES,
            )
    
    @staticmethod
    def _ensure_minibatch_size_is_even(minibatch_size: int) -> None:
        if not minibatch_size % 2 == 0:
            raise ValueError(f"The size of the minibatch must be even! The requested minibatch size is {minibatch_size}.")
            
    @staticmethod
    def _sample_normal_distrbution(num_samples: int, attr_dim: int, target_device: torch.device) -> torch.Tensor:
        return torch.randn((num_samples, attr_dim, 1, 1), device=target_device)
         
    def validation_step_end(self, workers_outputs: Dict[str, torch.Tensor]):
        """Combines the loss from all workers."""
        return {k: v.sum() for k, v in workers_outputs.items()}
        
    
    def validation_epoch_end(self, batch_outputs: List[Dict[str,float]]):
        """Combines the loss from all batches.
        
        Args:
            batch_outputs: A list in which each item is a Dictionary that contains information from each batch.
        """
        self._set_random_seed()
        total_num_samples = torch.stack([b['batch_size'] for b in batch_outputs]).sum()
        val_avg_generator_and_encoder_loss = torch.stack([b['loss'] for b in batch_outputs]).sum() / total_num_samples
        
        avg_real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc = torch.stack([b['real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc'] for b in batch_outputs]).sum() / total_num_samples
        
        avg_self_recon_loss = torch.stack([b['self_recon_loss'] for b in batch_outputs]).sum() / total_num_samples
        avg_cont_consistency_loss = torch.stack([b['cont_consistency_loss'] for b in batch_outputs]).sum() / total_num_samples
        avg_attribute_consistency_loss = torch.stack([b['attribute_consistency_loss'] for b in batch_outputs]).sum() / total_num_samples        
        avg_mode_seeking_loss = torch.stack([b['mode_seeking_loss'] for b in batch_outputs]).sum() / total_num_samples
        
        print(f"\n-> Val: train_encoders_generators_total_loss: {val_avg_generator_and_encoder_loss:.3f}, " + \
             f"\n        real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc: {avg_real_fake_adv_loss_cross_trans_ims_with_real_vs_cross_trans_disc:.3f}" + \
             f"\n        self_recon_loss: {avg_self_recon_loss:.3f} / cont_consistency_loss: {avg_cont_consistency_loss:.3f} / attribute_consistency_loss: {avg_attribute_consistency_loss:.3f} / mode_seeking_loss: {avg_mode_seeking_loss:.3f}." 
             ) 
        self.log("val_encoders_generators_total_loss", val_avg_generator_and_encoder_loss)

In [None]:
NUM_GPUS_ON_MLE_ENV = 8
def _get_num_gpu_for_pytorch_training() -> int:
    return NUM_GPUS_ON_MLE_ENV if torch.cuda.is_available() else 0

num_gpus = _get_num_gpu_for_pytorch_training()

## 7. Callback definition

In [None]:
def load_trained_model_from_checkpoint(checkpoint_url: str, network: nn.Module, starts_str: str) -> nn.Module:
    """Loads the model from the checkpoint on s3.

    Args:
        checkpoint_url: The path to the trained model on S3.
        network: A network to load the checkpoint parameters to.
        starts_str: A first few letters of the network variable to load the checkpoint from.

    Returns:
        The loaded model.
    """
    local_model_checkpoint = sync_files_from_s3_url([checkpoint_url])[0]
    state_dict = torch.load(local_model_checkpoint)["state_dict"]
    state_dict = OrderedDict([k[len(starts_str) :], v] for k, v in state_dict.items() if k.startswith(starts_str))
    network.load_state_dict(state_dict)
    return network


In [None]:
def load_trained_model_from_checkpoint(checkpoint_url: str, network: nn.Module, starts_str: str) -> nn.Module:
    """Loads the model from the checkpoint on s3.

    Args:
        checkpoint_url: The path to the trained model on S3.
        network: A network to load the checkpoint parameters to.
        starts_str: A first few letters of the network variable to load the checkpoint from.

    Returns:
        The loaded model.
    """
    local_model_checkpoint = sync_files_from_s3_url([checkpoint_url])[0]
    state_dict = torch.load(local_model_checkpoint)["state_dict"]
    state_dict = OrderedDict([k[len(starts_str) :], v] for k, v in state_dict.items() if k.startswith(starts_str))
    network.load_state_dict(state_dict)
    return network


In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import Callback
from pathai.handlers.s3 import sync_folder_to_s3
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pathai.handlers.s3 import sync_files_from_s3
from pathai.handlers.s3 import sync_files_from_s3_url

def every_n_checkpoint_callback(save_dir: str, every_n_train_epoch: int = 1) -> ModelCheckpoint:
    """Returns a model checkpoint callback, which saves the model every n epochs.

    Args:
        save_dir: The path to the saving directory.
        every_n_train_epoch (optional): A number that specifies the number of epoch interval that the checkpoint will be
            saved. Defaults to 1.
    """
    _make_dir_if_not_exist(save_dir)
    return ModelCheckpoint(
        dirpath=save_dir,
        period=every_n_train_epoch,
        filename="periodic_{epoch}_{val_encoders_generators_total_loss:.3f}_{train_encoders_generators_total_loss:.3f}_{train_disc_total_loss:.3f}",
        save_top_k=-1,
    )

def _make_dir_if_not_exist(dir_name: str) -> None:
    if not os.path.isdir(dir_name):
        os.mkdir(dir_name)

def top_n_checkpoint_callback(save_dir: str, num_best_models: int = 5) -> ModelCheckpoint:
    """Returns a model checkpoint callback, which saves the top-n performer.

    Args:
        save_dir: The path to the saving directory.
        num_best_models: The number of best performer to save.
    """
    _make_dir_if_not_exist(save_dir)
    return ModelCheckpoint(
        dirpath=save_dir,
        monitor="val_encoders_generators_total_loss",
        save_top_k=num_best_models,
        mode="min",
        filename="top_n_{epoch}_{val_encoders_generators_total_loss:.3f}_{train_encoders_generators_total_loss:.3f}_{train_disc_total_loss:.3f}",
    )

class SyncEveryNCheckpointsToS3Callback(Callback):
    """A PyTorch Lightning callback class for syncing a local folder to an S3 directory every n epochs.

    This callback is useful when we want to terminate a training process early (e.g., early stopping because model has
    converged and/or has taken too long to train) but still want to keep the model checkpoints.

    Args:
        local_dir: The local directory in which the checkpoints will be saved.
        every_n_train_epochs: A number that specifies the epoch interval to save the checkpoints. Defaults to 10.
    """

    def __init__(self, local_dir: str, s3_dir: str, every_n_train_epochs: int = 10) -> None:
        self._local_dir = local_dir
        self._s3_dir = s3_dir
        self._every_n_train_epochs = every_n_train_epochs

    def on_validation_epoch_end(self, trainner, pl_module: pl.LightningModule) -> None:
        if (pl_module.current_epoch + 1) % self._every_n_train_epochs == 0:
            sync_folder_to_s3(local_folder=self._local_dir, destination_key=self._s3_dir)


In [None]:
@task(node_family=NodeFamilies.GPU, slots=num_gpus)
def train_model(num_input_channels: int, 
                experiment_name: str, 
                train_hyperparams: Dict[str, Any],
                multi_domain_train_samples: List[List[Samples]],
                multi_domain_val_samples: List[Tuple[Samples]],
               ):
    """Trains the cycleGAN for the scanner transform.
    
    Args:
        num_input_channels: The number of the input channels.
        experiment_name: The name of the experiment to run.
        train_hyperparams: The hyperparameters for the training.
        multi_domain_train_samples: A list of lists that contains the training samples. One inner list for 1 domain.
        multi_domain_val_samples:A list of lists that contains the validation samples. One inner list for 1 domain.
    """
    S3_MODEL_FOLDER = "tnguyen/breast_cancer/"
    print("Training model...")
    with ExitStack() as stack:
        from pathai.parameters.config_manager import config_manager
        from pathai.research_dev.io.utils import get_tensorboard_logger
        
        config_manager.set_interactive_mode()
        config_manager.load_all_defaults()
        if not config_manager.get_task_configuration().env_name:
            config_manager.update_task_configuration(dict(env_name = "tan.nguyen/overlay_test"))
        
        tensorboard_logger = get_tensorboard_logger(experiment_name)
        print(f"Tensorboard save dir = {tensorboard_logger.save_dir}, log dir = {tensorboard_logger.log_dir}")
        
        with TemporaryDirectory() as save_dir:
            check_val_n_epochs = 5
            every_n_epochs_dir = os.path.join(save_dir, "every_n")
            top_n_dir = os.path.join(save_dir, "top_n")
            every_n_s3_folder_dir = "imaging-team/" + str(Path(S3_MODEL_FOLDER) / "trained_models" / experiment_name / "every_n")
            top_n_s3_folder_dir = "imaging-team/" + str(Path(S3_MODEL_FOLDER) / "trained_models" / experiment_name / "top_n")
            print(f"Checkpoint location on s3 for every_n: {every_n_s3_folder_dir}")
            
            callbacks=[
                every_n_checkpoint_callback(save_dir=every_n_epochs_dir),
                top_n_checkpoint_callback(save_dir=top_n_dir),
                EarlyStopping(monitor="val_encoders_generators_total_loss", mode="min", min_delta=0.0001, patience=10),
                SyncEveryNCheckpointsToS3Callback(local_dir=every_n_epochs_dir,
                                                  s3_dir=every_n_s3_folder_dir, 
                                                  every_n_train_epochs=check_val_n_epochs),
            ]
            
            trainer = pl.Trainer(
                callbacks=callbacks,
                max_epochs = train_hyperparams['num_epochs'],
                progress_bar_refresh_rate = 5,
                num_sanity_val_steps=1,
                precision=16,
                logger = [tensorboard_logger],
                resume_from_checkpoint=False,
                checkpoint_callback=True,
                check_val_every_n_epoch=20,
                log_every_n_steps=20,
                
                #profiler="advanced",
                
                profiler=None,
                
                # See: https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html for 'ddp' vs 'dp'
                accelerator="dp",
                # -1 to use all the available GPUs.
                gpus = -1,
                num_nodes = 1,
                accumulate_grad_batches=1,
                
                gradient_clip_val=5,
                
            )
            
            data_module = MultiDomainDataModule(
                multi_domain_train_samples=multi_domain_train_samples, 
                multi_domain_val_samples=multi_domain_val_samples,
                batch_size=train_hyperparams['batch_size'],
                num_dataloading_workers=train_hyperparams['num_dataloaders'],
            )
            
            data_module.setup("fit")
            
            training_module = torch_factory(MutliClassTrainingModule)(
                num_input_channels, 
                train_hyperparams,
                test_dataset = MultiDomainDataset(
                    multi_domain_samples=val_lists,
                    transforms = [RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],
                    input_patch_size_pixels=INPUT_IMAGE_SIZE_PIXELS,
                ) if train_hyperparams['periodically_save_training_results'] else None
            )
                
            trainer.fit(training_module, data_module)

In [None]:
def _adjust_num_samples_so_that_each_gpu_has_an_even_number_of_samples(org_num_samples: int, num_gpu: int, batch_size_per_gpu: int) -> int:
    return (org_num_samples // (num_gpus * batch_size_per_gpu)) * (num_gpus * batch_size_per_gpu)

def _adjust_minibatch_size(org_batch_size: int) -> int:
    """Makes sure that each minibatch size has an even number of samples."""
    return (org_batch_size // 2) * 2

def _loss_string_from_weight_dict(loss_weights_by_name: Dict[str, float]) -> int:
    return "_".join(f"{k}_{v}" for k, v in loss_weights_by_name.items())

#pretrain_model_path = "s3://imaging-team/tnguyen/breast_cancer/trained_models/Stain_separation_19888_bs_22_samples_w_real_fake_weight_1.0_recon_weight_20.0_content_consistency_weight_3.0_attr_consistency_weight_1.0_mode_seeking_loss_weight_1.0_patch_size_512_v1/every_n/periodic_epoch=39_val_encoders_generators_total_loss=1.468_train_encoders_generators_total_loss=1.421_train_disc_total_loss=1.329.ckpt" 
pretrain_model_path = None
batch_size_per_gpu = _adjust_minibatch_size(org_batch_size=16)  # 28 is the largest number that does not cause OOM.

loss_weights_by_name = {
    'real_fake_weight': 1.0,
    'recon_weight': 20.0,
    'content_consistency_weight': 3.0, 
    'attr_consistency_weight': 1.0,
    'mode_seeking_loss_weight': 1.0, 
}

train_params_jabba = {
    'weight_decay': 0.001,
    'gen_learning_rate': 1e-3,
    'disc_learning_rate': 5e-2,
    'num_epochs': 10000,
    'batch_size': batch_size_per_gpu * NUM_GPUS_ON_MLE_ENV,  
    'num_dataloaders': 20,
    'loss_weights_by_name': loss_weights_by_name,
    'pretrained_model_path': pretrain_model_path,
    'number_gen_optimization_steps_to_update_disc': 1,
    'number_of_steps_to_update_lr': 1,
    'periodically_save_training_results': False,
}

train_params_local = {
    'weight_decay': 0.001,
    'gen_learning_rate':1e-3,
    'disc_learning_rate': 5e-2,
    'num_epochs': 3000,
    'batch_size': 16,
    'loss_weights_by_name': loss_weights_by_name,
    'num_dataloaders': 20, 
    'pretrained_model_path': pretrain_model_path,
    'number_gen_optimization_steps_to_update_disc': 1,
    'number_of_steps_to_update_lr': 1,
    'periodically_save_training_results': True,
}

run_on_jabba = True
# False: run locally, True: run on jabba
if run_on_jabba:
    num_train_samples = _adjust_num_samples_so_that_each_gpu_has_an_even_number_of_samples(10000, num_gpus, batch_size_per_gpu)
    num_val_samples = _adjust_num_samples_so_that_each_gpu_has_an_even_number_of_samples(5000, num_gpus, batch_size_per_gpu)
    train_hyperparams = train_params_jabba
else:
    num_train_samples = 400
    num_val_samples = 16
    train_hyperparams = train_params_local

experiment_name = f"Stain_separation_{NUM_STAIN_VECTORS}_stain_vectors_{num_train_samples}_bs_{batch_size_per_gpu}_samples_w_{_loss_string_from_weight_dict(loss_weights_by_name)}_patch_size_512_v2"
print(f"experiment_name = {experiment_name}")

In [None]:
multi_domain_train_samples = [x[:num_train_samples] for x in train_lists]
multi_domain_val_samples = [x[:num_val_samples] for x in val_lists]

In [None]:
%matplotlib inline

In [None]:
if run_on_jabba:
    set_jabba()
    train_model.dispatch([
        3, 
        experiment_name, 
        train_hyperparams,
        multi_domain_train_samples,
        multi_domain_val_samples,
    ], cache_key=experiment_name).wait()
else:
    set_local()
    train_model.dispatch([
        3, 
        experiment_name, 
        train_hyperparams,
        multi_domain_train_samples,
        multi_domain_val_samples,
    ], cache_key=None).wait()

## 7. Inference testing

In [None]:
# 32 channels
NUM_STAIN_VECTORS = 32
pretrain_model_path = "s3://imaging-team/tnguyen/breast_cancer/trained_models/Stain_separation_19888_bs_22_samples_w_real_fake_weight_1.0_recon_weight_20.0_content_consistency_weight_3.0_attr_consistency_weight_1.0_mode_seeking_loss_weight_1.0_patch_size_512_v1/every_n/periodic_epoch=139_val_encoders_generators_total_loss=1.532_train_encoders_generators_total_loss=1.604_train_disc_total_loss=1.224.ckpt" 

# 8 channels
#NUM_STAIN_VECTORS = 8
#pretrain_model_path = "s3://imaging-team/tnguyen/breast_cancer/trained_models/Stain_separation_8_stain_vectors_19968_bs_16_samples_w_real_fake_weight_1.0_recon_weight_20.0_content_consistency_weight_3.0_attr_consistency_weight_1.0_mode_seeking_loss_weight_1.0_patch_size_512_v1/every_n/periodic_epoch=99_val_encoders_generators_total_loss=3.202_train_encoders_generators_total_loss=2.093_train_disc_total_loss=1.156.ckpt" 


In [None]:

enc_c = load_trained_model_from_checkpoint(pretrain_model_path, network=ContentEncoder(in_channels=3, num_stain_vectors = NUM_STAIN_VECTORS), starts_str = "_enc_c.").eval()
enc_a = load_trained_model_from_checkpoint(pretrain_model_path, network=StainVectorEstimator(in_channels=3, num_stain_vectors=NUM_STAIN_VECTORS), starts_str = "_enc_a.").eval()
gen = load_trained_model_from_checkpoint(pretrain_model_path, network=AbsorbanceImGenerator(), starts_str = "_gen.").eval()

In [None]:
TEST_IM_SIZE_PIXELS = 512

In [None]:
train_dataset = MultiDomainDataset(
                    multi_domain_samples=train_lists,
                    transforms = [RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],
                    input_patch_size_pixels=TEST_IM_SIZE_PIXELS,
                )
val_dataset = MultiDomainDataset(
                    multi_domain_samples=val_lists,
                    transforms = [RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],
                    input_patch_size_pixels=TEST_IM_SIZE_PIXELS,
                )

In [None]:
from typing import Generator
from pathai.research_dev.utilities.pytorch_utils import move_to_best_device

class StainNormalizationInferencer:
    """A class that performs the inferencing based on the trained model.
    
    Args:
        content_encoder: The trained model of the content encoder.
        gen: The trained model of the generator.
        max_tile_size_pixels (optional): The maximum title of each size that we ran the inference on, excluding the margin.
            The true size that the inferencer will work on will be tile_size_pixels + 2 * tile_margin_pixels.
            Defaults to 1024.
        tile_margin_pixels (optional): The size of the outer region surrounding the tile in pixels.
            We don't use the region of the in boundary region to avoid artifacts. Defaults to 0.
        max_polarization_val (optional): All polarization signal larger than this value will be mapped to 255 in the 
            output. Default to 0.8.
    """
    def __init__(
        self,
        content_encoder: ContentEncoder,
        gen: AbsorbanceImGenerator,
        max_tile_size_pixels: int = INPUT_IMAGE_SIZE_PIXELS, 
        tile_margin_pixels: int = 0,
        max_polarization_val: float = 0.8,
    ) -> None:
        self._content_encoder: ContentEncoder = move_to_best_device(content_encoder)
        self._gen: ImageGenerator = move_to_best_device(gen)
        self._max_tile_size_pixels: int = max_tile_size_pixels
        self._tile_margin_pixels: int = tile_margin_pixels
        self._max_polarization_val: float = max_polarization_val
        self._tile_size_pixels: int = self._max_tile_size_pixels
        self._inner_tile_size_pixels: int = self._tile_size_pixels - 2 * tile_margin_pixels
    
    def infer_one_image(self, image: Union[np.ndarray, torch.Tensor], 
                        z_a: torch.Tensor) -> np.ndarray:
        """Runs the inference on a single image.
        
        Args:
            image: A numpy or a Tensor image to run the inference on. This is the absorbance image It must of of dimension (H, W, C). 
            z_a: The attribute tensor for the domain to reconstruct the image.
        Returns:
            An 8-bit numpy image that contains the inference result of shape (H, W)
        """
        z_a = move_to_best_device(z_a)
        out_image = np.zeros_like(image)
        image_shape_2d = image.shape[:2]
        for r_slice, c_slice, r_ext_slice, c_ext_slice in self._tile_slices_iterator(image_shape_2d = image_shape_2d):
            print(f"Infering in row slice = {r_slice}, col slice = {c_slice}, r_ext_slice = {r_ext_slice}, c_ext_slice = {c_ext_slice}")
            tile = image[r_ext_slice, c_ext_slice, :]    
            tile = np.transpose(tile, (2, 0, 1))  # To (C, H, W)                
            if tile.dtype == np.uint8:
                tile = tile.astype(np.float32) / 255.0
            
            tile_output = self._infer_one_tile(tile, z_a = z_a)
            out_image[r_slice, c_slice] = tile_output.transpose((1,2,0))
        return out_image
        
    def _tile_slices_iterator(self, image_shape_2d: Tuple[int, int], ) -> Generator[Tuple[slice, slice, slice, slice], None, None]:
        """Returns a tuple of slices to run the inference and place the infered image.
        
        Returns:
            The row slice to extract the image.
            The column slice to extract the image.
            The row slice to extract the image.
            The column slice to extract the image.
        """
        nrows, ncols = image_shape_2d
        for tile_r in range(self._tile_margin_pixels, nrows - self._tile_margin_pixels, self._inner_tile_size_pixels):
            for tile_c in range(self._tile_margin_pixels, ncols - self._tile_margin_pixels, self._inner_tile_size_pixels):
                
                # Adjust the begining of the tile.
                if tile_r + self._inner_tile_size_pixels + self._tile_margin_pixels > nrows:
                    tile_r = nrows - (self._inner_tile_size_pixels + self._tile_margin_pixels)
                if tile_c + self._inner_tile_size_pixels + self._tile_margin_pixels > ncols:
                    tile_c = ncols - (self._inner_tile_size_pixels + self._tile_margin_pixels)
                
                tile_r_start = max(tile_r - self._tile_margin_pixels, 0)
                tile_c_start = max(tile_c - self._tile_margin_pixels, 0)
                
                r_slice, c_slice = slice(tile_r, tile_r + self._inner_tile_size_pixels), slice(tile_c, tile_c + self._inner_tile_size_pixels)
                r_ext_slice, c_ext_slice = self._create_extended_slices_to_cover_current_tile(
                    (r_slice, c_slice), 
                    image_shape_2d
                )
                yield (
                    r_slice, 
                    c_slice,
                    r_ext_slice,
                    c_ext_slice,
                )
    
    def _create_extended_slices_to_cover_current_tile(self, tile_slices: Tuple[slice, slice], image_shape: Tuple[int, int]):
        """Returns slices of extended row and columns to cover the current tile."""
        row_slice, col_slice = tile_slices
        nrows, ncols = image_shape
        left_limit_to_cover = max(0, col_slice.start - self._tile_margin_pixels)
        right_limit_to_cover = min(ncols, col_slice.stop + self._tile_margin_pixels)
        
        top_limit_to_cover = max(0, row_slice.start - self._tile_margin_pixels)
        bottom_limit_to_cover = min(nrows, row_slice.stop + self._tile_margin_pixels)
        
        return slice(top_limit_to_cover, bottom_limit_to_cover), slice(left_limit_to_cover, right_limit_to_cover)

        
    def _infer_one_tile(self, tile: np.ndarray, z_a: torch.Tensor) -> np.ndarray:
        """Performs the inference for a single tile.
        
        Args:
            tile: An input tile of dimensions (C, H, W) with all pixel values in [0.0, 1.0]. The shape should be (1, C, H, W).
            z_a: The attribute tensor for the domain to reconstruct the image. The shape should be (1, #attrs, 1, 1).
            
        Returns:
            The inference result of dimension (H, W)
        """
        tile = torch.from_numpy(tile)
        tile = tile.float()[None, :, :, :]
        tile = move_to_best_device(tile)
        with torch.no_grad():
            z_c = self._content_encoder(tile)
            return torch.squeeze(self._gen(z_c, z_a), axis=0).to("cpu").numpy()  # Single channel image

In [None]:
def absorbance_to_transmittance(im: np.ndarray) -> np.ndarray:
    im = np.clip(im, 0.0, None)
    return 10**(-im)

In [None]:
def get_slide_platform_patch(sample: Samples, input_patch_size_pixels: int) -> np.ndarray:
    """Gets a patch from the slide platform.
    
    Args:
        sample: A sample object that defines the sampling information of the patch.
        input_patch_size_pixels: The size of the patch to be extracted in pixels.
    Returns:
        An image of size H, W, C.
        
    """
    slide_id = sample.slide_id
    row_idx, col_idx = sample.row_idx, sample.col_idx
    slide_reference = SlideReference(int(slide_id))
    with slide_reference.read_object() as slide:
        slide_num_rows, slide_num_cols, _ = slide.shape
        im = slide.view_at_mpp(slide.mpp, mpp_tolerance=0.05)

        # Make sure that we cover the same distance.
        row_slice = slice(row_idx - input_patch_size_pixels // 2, row_idx + input_patch_size_pixels // 2)
        col_slice = slice(col_idx - input_patch_size_pixels // 2, col_idx + input_patch_size_pixels // 2)
        return im[row_slice, col_slice, :]

In [None]:
print(f"Dataset len = {len(val_dataset)}")
im_0_tensor, domain_vect_0 = val_dataset.get_item_with_domain_idx(idx = 12, domain_idx = 1)
im_0 = im_0_tensor.numpy().transpose(1, 2, 0)

im_1_tensor, domain_vect_0 = train_dataset.get_item_with_domain_idx(idx =28, domain_idx = 2)
im_1 = im_1_tensor.numpy().transpose(1, 2, 0)


In [None]:
# Build the attribute encoder from the 2nd image.
z_a0 = enc_a(im_0_tensor[None,:])
z_a1 = enc_a(im_1_tensor[None,:])

In [None]:
image_translation_inferencer = StainNormalizationInferencer(
        content_encoder = enc_c,
        gen = gen,
        max_tile_size_pixels = TEST_IM_SIZE_PIXELS)

self_reconstructed_im = image_translation_inferencer.infer_one_image(
    image = im_0, 
    z_a = z_a0)

cross_domain_reconstruction_image = image_translation_inferencer.infer_one_image(
    image = im_0, 
    z_a = z_a1)

In [None]:
print(f"Image size = {im_0.shape}")
print(f"Domain vector = {domain_vect_0}")

plt.figure(figsize=(20, 20))
plt.subplot(1, 3, 1)
plt.imshow(absorbance_to_transmittance(im_0))
plt.title(f'Source image')

plt.subplot(1, 3, 2)
plt.imshow(absorbance_to_transmittance(cross_domain_reconstruction_image))
plt.title(f'Translated image')

plt.subplot(1, 3, 3)
plt.imshow(absorbance_to_transmittance(im_1))
plt.title(f'Target image')
plt.show()


In [None]:
def visualize_content(im: torch.Tensor, content_encoder: torch.nn.Module) -> None:
    # Visualize different content channels in the images.
    zc = content_encoder(move_to_best_device(im))[0].detach().cpu().numpy()
    num_chans, im_num_rows, im_num_cols = zc.shape
    num_chan_sqrt = np.sqrt(num_chans)
    num_row_disp = int(num_chan_sqrt)
    num_col_disp = int(np.ceil(num_chans/num_row_disp))
    combined_num_rows, combined_num_cols = num_row_disp * im_num_rows, num_col_disp * im_num_cols
    combined_im = np.zeros((combined_num_rows, combined_num_cols), dtype=zc.dtype)
    for chan_idx, chan in enumerate(zc):
        row_idx = chan_idx // num_col_disp
        col_idx = chan_idx % num_col_disp
        combined_im[row_idx * im_num_rows : (row_idx + 1) * im_num_rows, col_idx * im_num_cols : (col_idx + 1) * im_num_cols] = chan / np.max(chan)
    
    rgb_im = absorbance_to_transmittance(im[0].numpy().transpose(1,2,0))
    #combined_im = cv2.resize(combined_im, (target_num_cols, target_num_rows))
    plt.figure(figsize=(8, 8))
    plt.imshow(rgb_im)
    plt.axis('off')
    plt.figure(figsize=(16, 16))
    plt.imshow(combined_im)
    plt.axis('off')
    plt.show()
    

In [None]:
visualize_content(im = im_0_tensor[None,:], content_encoder=enc_c)
visualize_content(im = im_1_tensor[None,:], content_encoder=enc_c)

### Illustration for presentation

In [None]:
print(f"Dataset len = {len(val_dataset)}")
im_0_tensor, domain_vect_0 = val_dataset.get_item_with_domain_idx(idx = 158, domain_idx = 1)
abs_im_0 = im_0_tensor.numpy().transpose(1, 2, 0)

num_target_im = 4

plt.figure(figsize=(20, 15))
    
for row_idx in range(num_target_im):
    im_1_tensor, domain_vect_0 = val_dataset.get_item_with_domain_idx(idx =row_idx + 10, domain_idx = 0)
    z_a1 = enc_a(im_1_tensor[None,:])

    cross_domain_reconstruction_image = image_translation_inferencer.infer_one_image(
        image = abs_im_0, 
        z_a = z_a1)

    print(f"Image size = {im_0.shape}")
    print(f"Domain vector = {domain_vect_0}")

    plt.subplot(3, num_target_im, 1 + row_idx)
    plt.imshow(absorbance_to_transmittance(abs_im_0))
    plt.axis('off')
   
    plt.subplot(3, num_target_im, 1 + row_idx + num_target_im)
    plt.imshow(absorbance_to_transmittance(cross_domain_reconstruction_image))
    plt.axis('off')
    
    plt.subplot(3, num_target_im, 1 + row_idx + 2 * num_target_im)
    plt.imshow(absorbance_to_transmittance(im_1_tensor.numpy().transpose(1, 2, 0)))
    plt.axis('off')

plt.show()


## Generate visualization of patch RGB data using t-SNE

### Raw color features before transformation.

In [None]:
# For a fair comparison, we will use normalize so that the maximum transmittance is 1.0.
vis_dataset_by_domain={
    0: MultiDomainDataset(multi_domain_samples=train_lists, transforms = [RGBToTransmittance(), TransmittanceToRGB()],domain_indices=[1]),
    1: MultiDomainDataset(multi_domain_samples=train_lists, transforms = [RGBToTransmittance(), TransmittanceToRGB()],domain_indices=[2]),
    2: MultiDomainDataset(multi_domain_samples=val_lists, transforms = [RGBToTransmittance(), TransmittanceToRGB()],domain_indices=[0]),
    3: MultiDomainDataset(multi_domain_samples=val_lists, transforms = [RGBToTransmittance(), TransmittanceToRGB()],domain_indices=[1]),
}

In [None]:
from collections import defaultdict
from skimage.color import rgb2lab, rgb2hsv, rgb2hed, rgb2gray


In [None]:
def color_feature_vector_from_patch_data(im: np.ndarray) -> np.ndarray:
    """Computes a color vector formed by the channel means after convert the input image, which is in RGB, into mulitple colorspaces RGB, HSV, LAB, HED, grayscale.
    
    Args:
        im: An input RGB image. This image is of size (C, H, W)
    
    Returns:
        A color vector of formed by the mean of the channel after transforming into other colorspaces.
    """
    im = im.astype(np.float32)
    combined_im = np.concatenate([im, rgb2lab(im), rgb2hsv(im), rgb2hed(im)], axis=0)
    return np.mean(combined_im, axis=(0,1))

In [None]:
num_points_per_domain = 500
point_vectors_by_domain = defaultdict(list)
for domain, dataset in vis_dataset_by_domain.items():
    print(f"Computing vectors for domain {domain}")
    point_vectors_by_domain[domain] = np.array([color_feature_vector_from_patch_data(dataset[sample_idx][0]) for sample_idx in range(num_points_per_domain)])

all_features = np.concatenate(list(point_vectors_by_domain.values()), axis=0)

In [None]:
label_name_by_index = {0: 'p2_gt450', 1: 'p2_ufs', 2: 'p2_at2', 3: 'dp200'}
labels = np.array([label_name_by_index[domain] for domain in point_vectors_by_domain.keys() for _ in range(num_points_per_domain)])

In [None]:
from sklearn.manifold import TSNE

In [None]:
tsne = TSNE(n_components=2, verbose=1, random_state=123)
projected_features = tsne.fit_transform(all_features) 

In [None]:
import pandas as pd
df = pd.DataFrame()
df["y"] = labels
df["feature 1"] = projected_features[:,0]
df["feature 2"] = projected_features[:,1]

In [None]:
import seaborn as sns
plt.figure(figsize=(10, 10))
sns.set(font_scale=1.6)
sns.axes_style("darkgrid")
sns.set_style("white")
sns.scatterplot(x="feature 1", y="feature 2", hue=df.y.tolist(),
                palette=sns.color_palette("hls", 4),
                data=df).set(title="Color statistics (t-SNE projection)") 
plt.grid(None)

### Perform pairwise translation for each point to.a reference point in the 1st domain.

In [None]:
to_transform_vis_dataset_by_domain={
    0: MultiDomainDataset(multi_domain_samples=train_lists, transforms = [RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],domain_indices=[1]),
    1: MultiDomainDataset(multi_domain_samples=train_lists, transforms = [RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],domain_indices=[2]),
    2: MultiDomainDataset(multi_domain_samples=val_lists, transforms = [RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],domain_indices=[0]),
    3: MultiDomainDataset(multi_domain_samples=val_lists, transforms = [RGBToTransmittance(), TransmittanceToAbsorbance(), ToTensor()],domain_indices=[1]),
}

In [None]:
point_vectors_of_transformed_image_by_domain = defaultdict(list)
for domain, dataset in to_transform_vis_dataset_by_domain.items():
    print(f"Computing vectors for domain {domain}")
    if domain == 0:
        all_images = []
        for idx in range(num_points_per_domain):
            im = absorbance_to_transmittance(dataset[idx][0].numpy().transpose(1,2,0))
            point_vectors_of_transformed_image_by_domain[domain].append(color_feature_vector_from_patch_data(im))
    else:
        _TARGET_DOMAIN_IDX = 0
        for sample_idx in range(num_points_per_domain):
            if sample_idx % 50 == 0:
                print(f"sample_idx = {sample_idx}")
            target_im_tensor = to_transform_vis_dataset_by_domain[_TARGET_DOMAIN_IDX][sample_idx][0]
            target_im = absorbance_to_transmittance(target_im_tensor.numpy().transpose(1,2,0))
            target_za = enc_a(target_im_tensor[None,:, :, :])
            source_absorbance_im = dataset[sample_idx][0].numpy().transpose(1,2,0)
            translated_im = image_translation_inferencer.infer_one_image(
                image = source_absorbance_im, 
                z_a = target_za)
            translated_im = absorbance_to_transmittance(translated_im)
            point_vectors_of_transformed_image_by_domain[domain].append(color_feature_vector_from_patch_data(translated_im))
            

all_features_after_transformed = np.concatenate(list(point_vectors_of_transformed_image_by_domain.values()), axis=0)

In [None]:
projected_features_after_transformed = tsne.fit_transform(all_features_after_transformed) 

In [None]:
df_after_transformed = pd.DataFrame()
df_after_transformed["y"] = labels
df_after_transformed["feature 1"] = projected_features_after_transformed[:,0]
df_after_transformed["feature 2"] = projected_features_after_transformed[:,1]

In [None]:
import seaborn as sns
plt.figure(figsize=(10, 10))
sns.set(font_scale=1.6)
sns.axes_style("darkgrid")
sns.set_style("white")
sns.scatterplot(x="feature 1", y="feature 2", hue=df_after_transformed.y.tolist(),
                palette=sns.color_palette("husl", 4),
                data=df_after_transformed).set(title="Color statistics (t-SNE projection)") 
plt.grid(None)