In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Note: this cell can only be run once, since we're changing the directory.
import sys
import os

path = "D:/testings/Python/TestingPython/"
if path not in sys.path:
    sys.path.append(path)

os.chdir("../")
os.getcwd()

'D:\\testings\\Python\\TestingPython\\ModelFusion'

In [3]:
import torch
import pytorch_lightning as pl

from torch.utils.data import Subset
from ModelFusion.fusion.fuse_u_net import fuse_u_nets
from ModelFusion.helpers.load_data import load_data
from ModelFusion.helpers.load_model import reload_model
from ModelFusion.helpers.pl_helpers import TrainSeg, get_logger
from ModelFusion.helpers.pl_callbacks import ValVisualizationSeg
from pytorch_lightning.callbacks import ModelCheckpoint
from monai.utils import CommonKeys

In [4]:
args = {}
# input args
args["num_retrain_samples"] = 5
args["num_retrain_epochs"] = 1
args["save_dir"] = "./"
args["vendor"] = "A"

# fixed args
args["ds_name"] = "MNMS"
args["model_name"] = "UNet"
args["log_name"] = "u_net_eval_logs"
args["accelerator"] = "gpu" if torch.cuda.is_available() else "cpu"

## Fuse 2 UNets

In [5]:
%%capture
model1_path = "./seg_logs/2022_11_22_02_40_15_864295/"
model2_path = "./seg_logs/2022_11_22_02_48_26_706789/"
model1 = reload_model(args["model_name"], model1_path)
model2 = reload_model(args["model_name"], model2_path)
model_fused = fuse_u_nets(model1_path, model2_path)

In [6]:
# model_fused

## Load Data

In [7]:
train_ds = load_data(args["ds_name"], "train", vendor=args["vendor"])
val_ds = load_data(args["ds_name"], "val", vendor=args["vendor"])

Loading dataset: 100%|███████████████████████████████████████| 1740/1740 [00:09<00:00, 180.03it/s]
Loading dataset: 100%|█████████████████████████████████████████| 104/104 [00:00<00:00, 145.94it/s]


In [8]:
train_indices = torch.arange(args["num_retrain_samples"])
train_ds_subset = Subset(train_ds, train_indices)

In [9]:
len(train_ds_subset), len(train_ds)

(5, 1740)

In [10]:
len(val_ds)

104

In [11]:
ds_dict = {
    "train": train_ds_subset,
    "val": val_ds
}

## Models

In [12]:
model = model_fused
# model = model1
# model = model2
lit_model = TrainSeg(ds_dict, model.to(torch.device("cpu")))

In [13]:
logger, time_stamp = get_logger(args["save_dir"], args["log_name"])
callbacks = [
    ValVisualizationSeg(save_interval=1),
    ModelCheckpoint(os.path.join(args["save_dir"], args["log_name"], time_stamp, "checkpoints/"),
                    monitor="val_dsc", mode="max")
]

In [14]:
# trainer = pl.Trainer(
#     accelerator="gpu",
#     devices=1,
#     fast_dev_run=2
# )

In [15]:
# trainer.fit(lit_model)

In [16]:
trainer = pl.Trainer(
    accelerator=args["accelerator"],
    devices=1,
    logger=logger,
    callbacks=callbacks,
    num_sanity_val_steps=-1,
    # precision=16,
    max_epochs=args["num_retrain_epochs"]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [17]:
trainer.fit(lit_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | model   | UNet       | 41.5 K
1 | loss_fn | DiceCELoss | 0     
---------------------------------------
41.5 K    Trainable params
0         Non-trainable params
41.5 K    Total params
0.166     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [18]:
trainer.validate(lit_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

───────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
───────────────────────────────────────────────────────────────────────────────────────────────────
         val_dsc            0.5675495266914368
        val_loss            0.3568689227104187
───────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.3568689227104187, 'val_dsc': 0.5675495266914368}]