# A minimal example to demonstrate how the trainer for FNet and wGAN GP plus the callbacks works along with patched dataset

Is dependent on the files produced by 1.illumination_correction/0.create_loaddata_csvs ALSF pilot data repo https://github.com/WayScience/pediatric_cancer_atlas_profiling

In [1]:
import sys
import pathlib

import pandas as pd
import torch.nn as nn
import torch.optim as optim

sys.path.append(str(pathlib.Path('.').absolute().parent.parent))
print(str(pathlib.Path('.').absolute().parent.parent))

## Dataset
from virtual_stain_flow.datasets.PatchDataset import PatchDataset
from virtual_stain_flow.datasets.CachedDataset import CachedDataset

## FNet training
from virtual_stain_flow.models.fnet import FNet
from virtual_stain_flow.trainers.Trainer import Trainer

## wGaN training
from virtual_stain_flow.models.unet import UNet
from virtual_stain_flow.models.discriminator import GlobalDiscriminator
from virtual_stain_flow.trainers.WGANTrainer import WGANTrainer

## wGaN losses
from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss
from virtual_stain_flow.losses.DiscriminatorLoss import WassersteinLoss
from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss

from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize

## Metrics
from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper
from virtual_stain_flow.metrics.PSNR import PSNR
from virtual_stain_flow.metrics.SSIM import SSIM

## callback
from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger
from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePlot


/home/weishanli/Waylab


  check_for_updates()


## Specify train output paths

In [2]:
EXAMPLE_DIR = pathlib.Path('.').absolute() / 'example_train'
EXAMPLE_DIR.mkdir(exist_ok=True)

In [3]:
!rm -rf example_train/*

In [4]:
PLOT_DIR = EXAMPLE_DIR / 'plot'
PLOT_DIR.mkdir(parents=True, exist_ok=True)

MLFLOW_DIR =EXAMPLE_DIR / 'mlflow'
MLFLOW_DIR.mkdir(parents=True, exist_ok=True)

## Specify paths to loaddata and read single cell features

In [None]:
## REPLACE WITH YOUR OWN PATHS
loaddata_csv_path = pathlib.Path(
    '/REPLACE/WITH/YOUR/PATH'
    )
sc_features_parquet_path = pathlib.Path(
    '/REPLACE/WITH/YOUR/PATH'
    )

if loaddata_csv_path.exists():
    try:
        loaddata_csv = next(loaddata_csv_path.glob('*.csv'))
    except:
        raise FileNotFoundError("No loaddata csv found")
else:
    raise ValueError("Incorrect loaddata csv path")

loaddata_df = pd.read_csv(loaddata_csv)
# subsample to reduce runtime
loaddata_df = loaddata_df.sample(n=100, random_state=42)

sc_features = pd.DataFrame()
for plate in loaddata_df['Metadata_Plate'].unique():
    sc_features_parquet = sc_features_parquet_path / f'{plate}_sc_normalized.parquet'
    if not sc_features_parquet.exists():
        print(f'{sc_features_parquet} does not exist, skipping...')
        continue 
    else:
        sc_features = pd.concat([
            sc_features, 
            pd.read_parquet(
                sc_features_parquet,
                columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']
            )
        ])

print(loaddata_df.head())
print(sc_features.head())

            FileName_OrigBrightfield  \
2079  r06c22f01p01-ch1sk1fk1fl1.tiff   
668   r05c09f03p01-ch1sk1fk1fl1.tiff   
2073  r05c22f04p01-ch1sk1fk1fl1.tiff   
1113  r06c13f07p01-ch1sk1fk1fl1.tiff   
788   r06c10f06p01-ch1sk1fk1fl1.tiff   

                               PathName_OrigBrightfield  \
2079  /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
668   /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
2073  /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
1113  /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
788   /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   

                     FileName_OrigER  \
2079  r06c22f01p01-ch2sk1fk1fl1.tiff   
668   r05c09f03p01-ch2sk1fk1fl1.tiff   
2073  r05c22f04p01-ch2sk1fk1fl1.tiff   
1113  r06c13f07p01-ch2sk1fk1fl1.tiff   
788   r06c10f06p01-ch2sk1fk1fl1.tiff   

                                        PathName_OrigER  \
2079  /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
668   /home/weishanli/Waylab/ALSF_pilot/data/

## Configure Patch size and channels

In [6]:
PATCH_SIZE = 256

channel_names = [
    "OrigBrightfield",
    "OrigDNA",
    "OrigER",
    "OrigMito",
    "OrigRNA",
    "OrigAGP",
]
input_channel_name = "OrigBrightfield"
target_channel_names = [ch for ch in channel_names if ch != input_channel_name]

## Prep Patch dataset and Cache

In [7]:
pds = PatchDataset(
    _loaddata_csv=loaddata_df,
    _sc_feature=sc_features,
    _input_channel_keys=None,
    _target_channel_keys=None,
    _input_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),
    _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),
    patch_size=PATCH_SIZE,
    verbose=True,
    patch_generation_method="random_cell",
    patch_generation_random_seed=42
)

## Set input and target channels
pds.set_input_channel_keys([input_channel_name])
pds.set_target_channel_keys('OrigDNA')

## Cache for faster training 
cds = CachedDataset(
    pds,
    prefill_cache=True
)

2025-03-03 20:31:52,602 - DEBUG - Dataframe supplied for loaddata_csv, using as is
2025-03-03 20:31:52,603 - DEBUG - Dataframe supplied for sc_feature, using as is
2025-03-03 20:31:52,604 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers
2025-03-03 20:31:52,605 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes
2025-03-03 20:31:52,605 - DEBUG - Merge fields inferred: ['Metadata_Plate', 'Metadata_Well', 'Metadata_Site']
2025-03-03 20:31:52,605 - DEBUG - Dataframe supplied for sc_feature, using as is
2025-03-03 20:31:52,673 - DEBUG - Inferring channel keys from loaddata csv
2025-03-03 20:31:52,674 - DEBUG - Channel keys: {'OrigBrightfield', 'OrigMito', 'OrigAGP', 'OrigER', 'OrigRNA', 'OrigDNA'} inferred from loaddata csv
2025-03-03 20:31:52,674 - DEBUG - Setting input channel(s) ...
2025-03-03 20:31:52,674 - DEBUG 

2025-03-03 20:31:53,018 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells
2025-03-03 20:31:53,018 - DEBUG - Generating patches that contain cells
2025-03-03 20:31:53,052 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False
2025-03-03 20:31:53,681 - DEBUG - Generated 461 patches for 93 site/view
2025-03-03 20:31:53,682 - DEBUG - Set input channel(s) as ['OrigBrightfield']
2025-03-03 20:31:53,682 - DEBUG - Set target channel(s) as ['OrigDNA']


# FNet trainer

## Train model without callback and check logs

In [8]:
model = FNet(depth=4)
lr = 3e-4
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))

trainer = Trainer(
    model = model,
    optimizer = optimizer,
    backprop_loss = nn.L1Loss(),
    dataset = cds,
    batch_size = 16,
    epochs = 10,
    patience = 5,
    callbacks=None,
    early_termination_metric = 'L1Loss',
    metrics={'psnr': PSNR(_metric_name="psnr"), 'ssim': SSIM(_metric_name="ssim")},
    device = 'cuda'
)

trainer.train()

In [9]:
pd.DataFrame(trainer.log)

Unnamed: 0,epoch,L1Loss,val_L1Loss,psnr,ssim,val_psnr,val_ssim
0,1,0.307582,0.367112,10.074805,0.026604,8.649314,0.035401
1,2,0.158626,0.15945,14.711766,0.044222,15.713264,0.067737
2,3,0.099793,0.092746,17.329845,0.07211,20.035673,0.118198
3,4,0.068925,0.050483,19.144039,0.105142,24.395761,0.256631
4,5,0.048414,0.042193,24.528973,0.305629,25.807829,0.34059
5,6,0.033416,0.027642,27.669621,0.455102,28.555059,0.503599
6,7,0.026435,0.025302,28.809948,0.520874,29.219507,0.563964
7,8,0.021219,0.017991,29.607342,0.582654,30.165979,0.627546
8,9,0.018385,0.01797,29.834793,0.600629,29.763931,0.582465
9,10,0.015869,0.018145,30.26502,0.622541,29.08075,0.488713


## Train model with alternative early termination metric

In [10]:
model = FNet(depth=4)
lr = 3e-4
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))

trainer = Trainer(
    model = model,
    optimizer = optimizer,
    backprop_loss = nn.L1Loss(),
    dataset = cds,
    batch_size = 16,
    epochs = 10,
    patience = 5,
    callbacks=None,
    metrics={'psnr': PSNR(_metric_name="psnr"), 'ssim': SSIM(_metric_name="ssim")},
    device = 'cuda',
    early_termination_metric = 'psnr' # set early termination metric as psnr for the sake of demonstration
)

trainer.train()

Early termination at epoch 6 with best validation metric 9.887849807739258


## Train with mlflow logger callbacks

In [11]:
mlflow_logger_callback = MlflowLogger(
        name='mlflow_logger',
        mlflow_uri=MLFLOW_DIR / 'mlruns',
        mlflow_experiment_name='Default',
        mlflow_start_run_args={'run_name': 'example_train', 'nested': True},
        mlflow_log_params_args={
            'lr': 3e-4
        },
    )

del trainer

trainer = Trainer(
    model = model,
    optimizer = optimizer,
    backprop_loss = nn.L1Loss(),
    dataset = cds,
    batch_size = 16,
    epochs = 10,
    patience = 5,
    callbacks=[mlflow_logger_callback],
    metrics={'psnr': PSNR(_metric_name="psnr"), 'ssim': SSIM(_metric_name="ssim")},
    device = 'cuda',
    early_termination_metric = 'L1Loss'
)

trainer.train()

# wGaN GP example with mlflow logger callback and plot callback

In [None]:
generator = UNet(
    n_channels=1,
    n_classes=1
)

discriminator = GlobalDiscriminator(
    n_in_channels = 2,
    n_in_filters = 64,
    _conv_depth = 4,
    _pool_before_fc = True,
    batch_norm=True
)

generator_optimizer = optim.Adam(generator.parameters(), 
                                 lr=0.0002, 
                                 betas=(0., 0.9))
discriminator_optimizer = optim.Adam(discriminator.parameters(), 
                                     lr=0.00002, 
                                     betas=(0., 0.9),
                                     weight_decay=0.001)

gp_loss = GradientPenaltyLoss(
    _metric_name='gp_loss',
    discriminator=discriminator,
    weight=10.0,
)

gen_loss = GeneratorLoss(
    _metric_name='gen_loss'
)

disc_loss = WassersteinLoss(
    _metric_name='disc_loss'
)

mlflow_logger_callback = MlflowLogger(
        name='mlflow_logger',
        mlflow_uri=MLFLOW_DIR / 'mlruns',
        mlflow_experiment_name='Default',
        mlflow_start_run_args={'run_name': 'example_train_wgan', 'nested': True},
        mlflow_log_params_args={
            'gen_lr': 0.0002,
            'disc_lr': 0.00002
        },
    )

plot_callback = IntermediatePlot(
    name='plotter',
    path=PLOT_DIR,
    dataset=pds, # give it the patch dataset as opposed to the cached dataset
    indices=[1,3,5,7,9], # plot 5 selected patches images from the dataset
    plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],
    figsize=(20, 25),
    show_plot=False,
)

wgan_trainer = WGANTrainer(
    dataset=cds,
    batch_size=16,
    epochs=20,
    patience=20, # setting this to prevent unwanted early termination here
    device='cuda',
    generator=generator,
    discriminator=discriminator,
    gen_optimizer=generator_optimizer,
    disc_optimizer=discriminator_optimizer,
    generator_loss_fn=gen_loss,
    discriminator_loss_fn=disc_loss,
    gradient_penalty_fn=gp_loss,
    discriminator_update_freq=1,
    generator_update_freq=2,
    callbacks=[mlflow_logger_callback, plot_callback],
    metrics={'ssim': SSIM(_metric_name='ssim'), 
             'psnr': PSNR(_metric_name='psnr')
             },
    early_termination_metric = 'GeneratorLoss'
)

wgan_trainer.train()

del generator
del wgan_trainer

## # wGaN GP example with mlflow logger callback and alternative early termination loss

In [None]:
generator = UNet(
    n_channels=1,
    n_classes=1
)

discriminator = GlobalDiscriminator(
    n_in_channels = 2,
    n_in_filters = 64,
    _conv_depth = 4,
    _pool_before_fc = True,
    batch_norm=True
)

generator_optimizer = optim.Adam(generator.parameters(), 
                                 lr=0.0002, 
                                 betas=(0., 0.9))
discriminator_optimizer = optim.Adam(discriminator.parameters(), 
                                     lr=0.00002, 
                                     betas=(0., 0.9),
                                     weight_decay=0.001)

gp_loss = GradientPenaltyLoss(
    _metric_name='gp_loss',
    discriminator=discriminator,
    weight=10.0,
)

gen_loss = GeneratorLoss(
    _metric_name='gen_loss'
)

disc_loss = WassersteinLoss(
    _metric_name='disc_loss'
)

mlflow_logger_callback = MlflowLogger(
        name='mlflow_logger',
        mlflow_uri=MLFLOW_DIR / 'mlruns',
        mlflow_experiment_name='Default',
        mlflow_start_run_args={'run_name': 'example_train_wgan_mae_early_term', 'nested': True},
        mlflow_log_params_args={
            'gen_lr': 0.0002,
            'disc_lr': 0.00002
        },
    )

wgan_trainer = WGANTrainer(
    dataset=cds,
    batch_size=16,
    epochs=20,
    patience=5, # lower patience here
    device='cuda',
    generator=generator,
    discriminator=discriminator,
    gen_optimizer=generator_optimizer,
    disc_optimizer=discriminator_optimizer,
    generator_loss_fn=gen_loss,
    discriminator_loss_fn=disc_loss,
    gradient_penalty_fn=gp_loss,
    discriminator_update_freq=1,
    generator_update_freq=2,
    callbacks=[mlflow_logger_callback],
    metrics={'ssim': SSIM(_metric_name='ssim'), 
             'psnr': PSNR(_metric_name='psnr'),
             'mae': MetricsWrapper(_metric_name='mae', module=nn.L1Loss()) # use a wrapper for torch nn L1Loss
             },
    early_termination_metric = 'mae' # update early temrination loss with the supplied L1Loss/mae metric instead of the default GaN generator loss
)

wgan_trainer.train()