# A minimal example to demonstrate how the trainer for FNet and wGaN GP plus the callbacks works along with patched dataset

Is dependent on the files produced by 1.illumination_correction/0.create_loaddata_csvs ALSF pilot data repo https://github.com/WayScience/pediatric_cancer_atlas_profiling

In [1]:
import sys
import pathlib

import pandas as pd
import torch.nn as nn
import torch.optim as optim

sys.path.append(str(pathlib.Path('.').absolute().parent.parent))
print(str(pathlib.Path('.').absolute().parent.parent))

## Dataset
from virtual_stain_flow.datasets.PatchDataset import PatchDataset
from virtual_stain_flow.datasets.CachedDataset import CachedDataset

## FNet training
from virtual_stain_flow.models.fnet import FNet
from virtual_stain_flow.trainers.Trainer import Trainer

## wGaN training
from virtual_stain_flow.models.unet import UNet
from virtual_stain_flow.models.discriminator import GlobalDiscriminator
from virtual_stain_flow.trainers.WGaNTrainer import WGaNTrainer

## wGaN losses
from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss
from virtual_stain_flow.losses.DiscriminatorLoss import DiscriminatorLoss
from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss

from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize

## Metrics
from virtual_stain_flow.metrics.PSNR import PSNR
from virtual_stain_flow.metrics.SSIM import SSIM

## callback
from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger
from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePatchPlot


/home/weishanli/Waylab


  check_for_updates()


## Specify train output paths

In [2]:
EXAMPLE_DIR = pathlib.Path('.').absolute() / 'example_train'
EXAMPLE_DIR.mkdir(exist_ok=True)

In [3]:
!rm -rf example_train/*

In [4]:
PLOT_DIR = EXAMPLE_DIR / 'plot'
PLOT_DIR.mkdir(parents=True, exist_ok=True)

MLFLOW_DIR =EXAMPLE_DIR / 'mlflow'
MLFLOW_DIR.mkdir(parents=True, exist_ok=True)

## Specify paths to loaddata and read a single

In [5]:
## REPLACE WITH YOUR OWN PATHS
analysis_home_path = pathlib.Path('/home/weishanli/Waylab/ALSF_pilot/ALSF_img2img_prototyping')
sc_features_parquet_path = pathlib.Path(
    '/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/preprocessed_profiles_SN0313537/single_cell_profiles'
)

In [6]:
loaddata_csv_path = analysis_home_path \
    / '0.data_analysis_and_preprocessing' / 'loaddata_csvs'

if loaddata_csv_path.exists():
    try:
        loaddata_csv = next(loaddata_csv_path.glob('*.csv'))
    except:
        raise FileNotFoundError("No loaddata csv found")
else:
    raise ValueError("Incorrect loaddata csv path")

loaddata_df = pd.read_csv(loaddata_csv)
# subsample to reduce runtime
loaddata_df = loaddata_df.sample(n=100, random_state=42)

sc_features = pd.DataFrame()
for plate in loaddata_df['Metadata_Plate'].unique():
    sc_features_parquet = sc_features_parquet_path / f'{plate}_sc_normalized.parquet'
    if not sc_features_parquet.exists():
        print(f'{sc_features_parquet} does not exist, skipping...')
        continue 
    else:
        sc_features = pd.concat([
            sc_features, 
            pd.read_parquet(
                sc_features_parquet,
                columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']
            )
        ])

print(loaddata_df.head())
print(sc_features.head())

            FileName_OrigBrightfield  \
2079  r06c22f01p01-ch1sk1fk1fl1.tiff   
668   r05c09f03p01-ch1sk1fk1fl1.tiff   
2073  r05c22f04p01-ch1sk1fk1fl1.tiff   
1113  r06c13f07p01-ch1sk1fk1fl1.tiff   
788   r06c10f06p01-ch1sk1fk1fl1.tiff   

                               PathName_OrigBrightfield  \
2079  /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
668   /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
2073  /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
1113  /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
788   /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   

                     FileName_OrigER  \
2079  r06c22f01p01-ch2sk1fk1fl1.tiff   
668   r05c09f03p01-ch2sk1fk1fl1.tiff   
2073  r05c22f04p01-ch2sk1fk1fl1.tiff   
1113  r06c13f07p01-ch2sk1fk1fl1.tiff   
788   r06c10f06p01-ch2sk1fk1fl1.tiff   

                                        PathName_OrigER  \
2079  /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...   
668   /home/weishanli/Waylab/ALSF_pilot/data/

## Configure Patch size and channels

In [7]:
PATCH_SIZE = 256

channel_names = [
    "OrigBrightfield",
    "OrigDNA",
    "OrigER",
    "OrigMito",
    "OrigRNA",
    "OrigAGP",
]
input_channel_name = "OrigBrightfield"
target_channel_names = [ch for ch in channel_names if ch != input_channel_name]

## Prep Patch dataset and Cache

In [8]:
pds = PatchDataset(
    _loaddata_csv=loaddata_df,
    _sc_feature=sc_features,
    _input_channel_keys=None,
    _target_channel_keys=None,
    _input_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),
    _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),
    patch_size=PATCH_SIZE,
    verbose=True,
    patch_generation_method="random_cell",
    patch_generation_random_seed=42
)

## Set input and target channels
pds.set_input_channel_keys([input_channel_name])
pds.set_target_channel_keys('OrigDNA')

## Cache for faster training 
cds = CachedDataset(
    pds,
    prefill_cache=True
)

2025-02-16 19:32:37,676 - DEBUG - Dataframe supplied for loaddata_csv, using as is
2025-02-16 19:32:37,676 - DEBUG - Dataframe supplied for sc_feature, using as is
2025-02-16 19:32:37,676 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers
2025-02-16 19:32:37,676 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes
2025-02-16 19:32:37,676 - DEBUG - Merge fields inferred: ['Metadata_Plate', 'Metadata_Site', 'Metadata_Well']
2025-02-16 19:32:37,676 - DEBUG - Dataframe supplied for sc_feature, using as is
2025-02-16 19:32:37,699 - DEBUG - Inferring channel keys from loaddata csv
2025-02-16 19:32:37,699 - DEBUG - Channel keys: {'OrigRNA', 'OrigMito', 'OrigER', 'OrigBrightfield', 'OrigAGP', 'OrigDNA'} inferred from loaddata csv
2025-02-16 19:32:37,699 - DEBUG - Setting input channel(s) ...
2025-02-16 19:32:37,699 - DEBUG 

# FNet trainer

## Train model without callback and check logs

In [9]:
model = FNet(depth=4)
lr = 3e-4
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))

trainer = Trainer(
    model = model,
    optimizer = optimizer,
    backprop_loss = nn.L1Loss(),
    dataset = cds,
    batch_size = 16,
    epochs = 10,
    patience = 5,
    callbacks=None,
    metrics={'psnr': PSNR(_metric_name="psnr"), 'ssim': SSIM(_metric_name="ssim")},
    device = 'cuda'
)

trainer.train()

In [10]:
pd.DataFrame(trainer.log)

Unnamed: 0,epoch,L1Loss,val_L1Loss,psnr,ssim,val_psnr,val_ssim
0,1,0.511466,0.470702,5.717022,0.025058,6.503,0.033165
1,2,0.422323,0.440381,7.361481,0.060039,7.077345,0.035204
2,3,0.387546,0.398983,8.151957,0.03616,7.929383,0.038308
3,4,0.336513,0.346264,9.332579,0.042898,9.15672,0.043285
4,5,0.296291,0.301103,10.399074,0.047979,10.390893,0.049095
5,6,0.262401,0.257461,11.397511,0.057491,11.808893,0.057339
6,7,0.240838,0.23097,12.086881,0.066947,12.89053,0.064594
7,8,0.213393,0.196471,12.998819,0.071901,14.243113,0.072355
8,9,0.183583,0.170808,14.289798,0.113711,15.495139,0.083101
9,10,0.163348,0.174614,15.204335,0.106222,14.943835,0.077652


## Train with mlflow logger callbacks

In [11]:
mlflow_logger_callback = MlflowLogger(
        name='mlflow_logger',
        mlflow_uri=MLFLOW_DIR / 'mlruns',
        mlflow_experiment_name='Default',
        mlflow_start_run_args={'run_name': 'example_train', 'nested': True},
        mlflow_log_params_args={
            'lr': 3e-4
        },
    )

del trainer

trainer = Trainer(
    model = model,
    optimizer = optimizer,
    backprop_loss = nn.L1Loss(),
    dataset = cds,
    batch_size = 16,
    epochs = 10,
    patience = 5,
    callbacks=[mlflow_logger_callback],
    metrics={'psnr': PSNR(_metric_name="psnr"), 'ssim': SSIM(_metric_name="ssim")},
    device = 'cuda'
)

trainer.train()

# wGaN GP example with mlflow logger callback and plot callback

In [12]:
generator = UNet(
    n_channels=1,
    n_classes=1
)

discriminator = GlobalDiscriminator(
    n_in_channels = 2,
    n_in_filters = 64,
    _conv_depth = 4,
    _pool_before_fc = True
)

generator_optimizer = optim.Adam(generator.parameters(), 
                                 lr=0.0002, 
                                 betas=(0., 0.9))
discriminator_optimizer = optim.Adam(discriminator.parameters(), 
                                     lr=0.00002, 
                                     betas=(0., 0.9),
                                     weight_decay=0.001)

gp_loss = GradientPenaltyLoss(
    _metric_name='gp_loss',
    discriminator=discriminator,
    weight=10.0,
)

gen_loss = GeneratorLoss(
    _metric_name='gen_loss'
)

disc_loss = DiscriminatorLoss(
    _metric_name='disc_loss'
)

mlflow_logger_callback = MlflowLogger(
        name='mlflow_logger',
        mlflow_uri=MLFLOW_DIR / 'mlruns',
        mlflow_experiment_name='Default',
        mlflow_start_run_args={'run_name': 'example_train_wgan', 'nested': True},
        mlflow_log_params_args={
            'gen_lr': 0.0002,
            'disc_lr': 0.00002
        },
    )

plot_callback = IntermediatePatchPlot(
    name='plotter',
    path=PLOT_DIR,
    dataset=pds, # give it the patch dataset as opposed to the cached dataset
    plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],
    figsize=(20, 25),
    show_plot=False,
)

wgan_trainer = WGaNTrainer(
    dataset=cds,
    batch_size=16,
    epochs=20,
    patience=20,
    device='cuda',
    generator=generator,
    discriminator=discriminator,
    gen_optimizer=generator_optimizer,
    disc_optimizer=discriminator_optimizer,
    generator_loss_fn=gen_loss,
    discriminator_loss_fn=disc_loss,
    gradient_penalty_fn=gp_loss,
    discriminator_update_freq=1,
    generator_update_freq=2,
    callbacks=[mlflow_logger_callback, plot_callback],
    metrics={'ssim': SSIM(_metric_name='ssim'), 
             'psnr': PSNR(_metric_name='psnr')},
)

wgan_trainer.train()