In [1]:
from pathlib import Path
from datetime import datetime

import torch 
from torch.utils.data import ConcatDataset
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
from datetime import datetime 
#
from medical_diffusion.data.datamodules import SimpleDataModule
from medical_diffusion.data.datasets import CXPDataset_Loader, CXPDataset

In [4]:
# --------------- Settings --------------------
current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S")
path_run_dir = Path.cwd() / 'runs' / str(current_time)
path_run_dir.mkdir(parents=True, exist_ok=True)
gpus = [0] if torch.cuda.is_available() else None


dataset = CXPDataset_Loader(root_dir = r"C:\Users\mhr_k\Data\CheXpert-Simp", windows=True)

train:134075|val:44134|test:44584


In [5]:
ds_train = CXPDataset(data_list = dataset.train_ds)
ds_val = CXPDataset(data_list = dataset.val_ds)
ds_test = CXPDataset(data_list = dataset.test_ds)

dm = SimpleDataModule(
    ds_train = ds_train,
    ds_val = ds_val,
    batch_size=8, 
    # num_workers=0,
    pin_memory=True
) 

In [5]:
ds_train = CXPDataset(image_resize = 64, data_list = dataset.train_ds[:100])
ds_val = CXPDataset(image_resize = 64, data_list = dataset.val_ds[:10])
ds_test = CXPDataset(image_resize = 64, data_list = dataset.test_ds[:10])

dm = SimpleDataModule(
    ds_train = ds_train,
    ds_val = ds_val,
    ds_test = ds_test,
    batch_size=8, 
    # num_workers=0,
    pin_memory=True
)

In [6]:
ds_train[0]

{'source': tensor([[[-0.9765, -0.5922, -0.6235,  ..., -0.9922, -0.9922, -0.9922],
          [-0.9451, -0.6078, -0.5686,  ..., -0.9922, -0.9922, -0.9922],
          [-1.0000, -0.6706, -0.6078,  ..., -1.0000, -1.0000, -1.0000],
          ...,
          [ 0.2706,  0.2941,  0.3333,  ...,  0.8745,  0.8588,  0.8431],
          [ 0.3333,  0.3569,  0.3882,  ...,  0.8588,  0.8510,  0.8353],
          [ 0.3804,  0.3961,  0.4275,  ...,  0.8431,  0.8353,  0.8275]]]),
 'target': 1}

In [7]:
len(ds_train)


100

In [8]:
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN

In [9]:
import numpy as np 
import torch
import torchvision.transforms.functional as tF
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.data import TensorDataset, Subset

from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.functional import multiscale_structural_similarity_index_measure as mmssim

from medical_diffusion.models.embedders.latent_embedders import VAE

In [10]:
model = VAE(
        in_channels=1, 
        out_channels=1, 
        emb_channels=8,
        spatial_dims=2,
        hid_chs =    [ 64, 128, 256,  512], 
        kernel_sizes=[ 3,  3,   3,    3],
        strides =    [ 1,  2,   2,    2],
        deep_supervision=1,
        use_attention= 'none',
        loss = torch.nn.MSELoss,
        # optimizer_kwargs={'lr':1e-6},
        embedding_loss_weight=1e-6
    )

Setting up [baseline] perceptual loss: trunk [vgg], v[0.1], spatial [off]


In [12]:
to_monitor = "train/L1"  # "val/loss" 
min_max = "min"
save_and_sample_every = 4

early_stopping = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement
    patience=30, # number of checks with no improvement
    mode=min_max
)
checkpointing = ModelCheckpoint(
    dirpath=str(path_run_dir), # dirpath
    monitor="val/loss",
    every_n_train_steps=save_and_sample_every,
    save_last=True,
    save_top_k=5,
    mode=min_max,
)
trainer = Trainer(
    accelerator='cpu',
    devices=1,
    # precision=16,
    # amp_backend='apex',
    # amp_level='O2',
    # gradient_clip_val=0.5,
    default_root_dir=str(path_run_dir),
    callbacks=[checkpointing],
    # callbacks=[checkpointing, early_stopping],
    enable_checkpointing=True,
    check_val_every_n_epoch=1,
    log_every_n_steps=save_and_sample_every, 
    auto_lr_find=False,
    # limit_train_batches=1000,
    limit_val_batches=50, # 0 = disable validation - Note: Early Stopping no longer available 
    min_epochs=2,
    max_epochs=5,
    num_sanity_val_steps=2,
)

# ---------------- Execute Training ----------------
trainer.fit(model, datamodule=dm)

# ------------- Save path to best model -------------
model.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name      | Type                         | Params
-----------------------------------------------------------
0 | loss_fct  | MSELoss                      | 0     
1 | perceiver | LPIPS                        | 14.7 M
2 | inc       | UnetResBlock                 | 38.0 K
3 | encoders  | ModuleList                   | 7.7 M 
4 | out_enc   | Sequential                   | 74.0 K
5 | quantizer | DiagonalGaussianDistribution | 0     
6 | inc_dec   | UnetResBlock                 | 2.4 M 
7 | decoders  | ModuleList                   | 3.1 M 
8 | outc      | BasicBlock                   | 65    
9 | outc_ver  | ModuleList                   | 129   
-----------------------------------------------------------
13.4 M    Trainable params
14.7 M    Non-trainable params
28.

Epoch 0:   0%|          | 0/12 [00:00<?, ?it/s]

Epoch 4: 100%|██████████| 12/12 [10:02<00:00, 50.24s/it, loss=2.2e+04, v_num=1]   
