In [1]:
import os
import yaml
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import torch
from torchvision.utils import save_image
# from torchvision import transforms
from monai import transforms as mt

from data_loaders_l import SynthesisDataModule, NiftiSynthesisDataset
from model_architectures import DDPM, UNet, MonaiDDPM 

torch.set_float32_matmul_precision('medium')
# Load configuration
with open('config_l.yaml', 'r') as f:
    config = yaml.safe_load(f)

pl.seed_everything(config.get('seed', 42), workers=True)
# Extract parameters from config
batch_size = config['batch_size']
learning_rate = config['learning_rate']
num_epochs = config['num_epochs']
label_dim = config.get('label_dim', 4)
experiment_name = config['experiment_name']
model_type = config['model_type']
resize_dim = config.get('resize_dim', False) #set false for no resizing
# Prepare output directories
experiment_path = os.path.join('experiments', experiment_name)
os.makedirs(experiment_path, exist_ok=True)
# Save a copy of the config for reproducibility
with open(os.path.join(experiment_path, 'config.yaml'), 'w') as out_f:
    yaml.dump(config, out_f)

# Define the root directory
root_dir = config['root_dir']
data_dir = config['data_dir']
full_data_path = os.path.join(root_dir, data_dir)

Error importing huggingface_hub.file_download: 'Version'


Seed set to 42


Loading data: embed


# Load transforms and the dataset

In [2]:

train_transforms = mt.Compose(
    [
        mt.LoadImaged(keys=["image"]),
        mt.SqueezeDimd(keys=["image"], dim=-1), # (H,W,1) → (H,W)
        mt.EnsureChannelFirstd(keys=["image"]), # (1,H,W)
        mt.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
        mt.ToTensord(keys=["image"]),

    ]
)

dataset = NiftiSynthesisDataset(full_data_path, transform=train_transforms)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
img_batch, _ = next(iter(train_loader))
print(img_batch.shape, img_batch.min().item(), img_batch.max().item())

torch.Size([8, 1, 256, 256]) 0.0 1.0


In [7]:
train_transforms = mt.Compose(
    [
        mt.LoadImaged(keys=["image"]),
        mt.SqueezeDimd(keys=["image"], dim=-1), # (H,W,1) → (H,W)
        mt.Resized(keys=["image"], spatial_size=[64, 64], mode="bilinear"),
        mt.EnsureChannelFirstd(keys=["image"]), # (1,H,W)
        mt.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
        mt.ToTensord(keys=["image"]),

    ]
)


dataset = NiftiSynthesisDataset(full_data_path, transform=train_transforms)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
img_batch, _ = next(iter(train_loader))
print(img_batch.shape, img_batch.min().item(), img_batch.max().item())

torch.Size([8, 1, 256, 64, 64]) 0.0 1.0


In [8]:
train_transforms = mt.Compose(
    [
        mt.LoadImaged(keys=["image"]),
        mt.SqueezeDimd(keys=["image"], dim=-1), # (H,W,1) → (H,W)
        mt.EnsureChannelFirstd(keys=["image"]), # (1,H,W)
        mt.Resized(keys=["image"], spatial_size=[64, 64], mode="bilinear"),
        mt.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
        mt.ToTensord(keys=["image"]),

    ]
)


dataset = NiftiSynthesisDataset(full_data_path, transform=train_transforms)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
img_batch, _ = next(iter(train_loader))
print(img_batch.shape, img_batch.min().item(), img_batch.max().item())

torch.Size([8, 1, 64, 64]) 0.0 1.0


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(17, 17))
for i in range(16):
    plt.subplot(9, 9, 1 + i)
    plt.axis('off')
    plt.imshow(samples[i].squeeze(0).clip(0, 1).data.cpu().numpy(),
               cmap='gray')

# Load the model!!

In [3]:
model = MonaiDDPM(lr=learning_rate, T=1000)

# Set up lightning functionalities

In [4]:
tb_logger = pl_loggers.TensorBoardLogger('logs/', name=experiment_name)

# Set up callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(experiment_path, "checkpoints"),
    filename=f"{model_type}-{{epoch:02d}}-{{step}}",
    auto_insert_metric_name=True,
    save_top_k=1,
    monitor="train_loss",
    mode="min",
)
early_stopping = EarlyStopping(
    monitor="train_loss",
    patience=15,
    mode="min",
    check_on_train_epoch_end=True,
)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Set up Trainer
trainer = pl.Trainer(
    fast_dev_run=True, #set to true for tests
    max_epochs=num_epochs,
    accelerator="auto",
    precision=16,
    logger=tb_logger,
    callbacks=[checkpoint_callback, early_stopping, lr_monitor],
    enable_progress_bar=True,
    num_sanity_val_steps=0,
    gradient_clip_val=1.0,
)


/home/locolinux2/miniconda3/envs/U24/lib/python3.9/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


# Do the training

In [5]:
trainer.fit(model, train_dataloaders=train_loader)
# trainer.fit(model, data_module) # train_dataloaders=train_loader)

print('Training complete!')


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | unet      | DiffusionModelUNet | 18.5 M | train
1 | scheduler | DDPMScheduler      | 0      | train
---------------------------------------------------------
18.5 M    Trainable params
0         Non-trainable params
18.5 M    Total params
74.004    Total estimated model params size (MB)
203       Modules in train mode
0         Modules in eval mode
/home/locolinux2/miniconda3/envs/U24/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

`Trainer.fit` stopped: `max_steps=1` reached.


Training complete!
