In [1]:
from volume_dataloader import CTScanDataModule, CTDataSet
from unet import UNet
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import ct_utils

DATA_PATH = '/media/gaetano/DATA/DATA_NIFTI_JAWS/'

In [2]:
dataset = CTScanDataModule(DATA_PATH, batch_size=5)

model = UNet(in_channels=1,
             out_channels=4,
             n_blocks=4,
             start_filters=32,
             activation='relu',
             normalization='batch',
             conv_mode='same',
             dim=3,
             loss_alpha=.7,
             loss_beta=.3,
             loss_gamma=3/4,
             learning_rate=1e-3)

lr_monitor = LearningRateMonitor(logging_interval='epoch')
checkpoint = ModelCheckpoint(monitor='val_loss')
wandb_logger = WandbLogger()
trainer = Trainer(gpus=-1,
                  log_every_n_steps=1,
                  max_epochs=50,
                  auto_lr_find=False,
                  callbacks=[checkpoint],
                  logger=wandb_logger)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [4]:
trainer.fit(model, dataset)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | down_convs | ModuleList | 3.5 M 
1 | up_convs   | ModuleList | 2.1 M 
2 | conv_final | Conv3d     | 132   
------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.412    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  f"DataModule.{name} has already been called, so it will not be called again. "


In [5]:
check_model_path = checkpoint.best_model_path
model.eval()
x, y = dataset.ct_val[0]
y_hat = model(x.unsqueeze(0))
x = x.squeeze(0).numpy()
y_hat = y_hat.squeeze(0).sum(axis=0).detach().numpy()

In [6]:
%matplotlib qt
ct_utils.plot_3d_with_labels(x, y_hat, threshold=1400, transpose=[2, 1, 0], step_size=2)

ValueError: Surface level must be within volume data range.

In [3]:
# trainer.tune(model, dataset)
# Run learning rate finder
lr_finder = trainer.tuner.lr_find(model, dataset)

# Results can be found in
lr_finder.results

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()

# update hparams of the model
model.hparams.learning_rate = new_lr