In [None]:
! pip install "monai==0.5.3"

In [17]:
import math

import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
import pytorch_lightning as pl
from filelock import FileLock
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
import torchvision
import os
import monai
from monai.networks.layers.factories import Act, Norm
from losses import *
from hyperopt import hp
from ray.tune.suggest.hyperopt import HyperOptSearch
import numpy as np

# from source.ray_utils import * # create_search_space, create_test_search_space

# import source.transforms as transforms
# import source.transforms.oral_cavity_transforms as transforms
# import source.losses as losses
# import deepgrow
from monai.metrics.meandice import compute_meandice
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())

In [3]:
# # Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir logs/ --port 7991

In [4]:
class MNISTDataModule(LightningDataModule):
    def __init__(
        self,
        batch_size =  256,
        data_dir=PATH_DATASETS
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((32,32)),
            ]
        )

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

In [7]:
class LightningVAE(pl.LightningModule):
    def __init__(self, config):
        super(LightningVAE, self).__init__()

        self.lr = config["lr"]
        self.batch_size = config["batch_size"]
        self.latent_dim = config["latent_dim"]

        self.model = monai.networks.nets.VarAutoEncoder(
            dimensions=2,  
            kernel_size=config["kernel_size"],
            in_shape=[1, 32,32],
            out_channels=1,
            channels=config["channel"],
            strides=config["stride"],
            latent_size=config["latent_dim"],
            norm=config["norm"],
            dropout=config["dropout_rate"],
            num_res_units=config["num_resnets"],
        )
 
        self.vae_loss = KLLoss(alpha=config["alpha"], beta=config["beta"])
        self.dice = Dice()

    def forward(self, x):
        return self.model(x)


    def training_step(self, train_batch, batch_idx):
        # calculate loss, dice and avg_kl by doing a forward of model


        return {"loss": loss, "dice": dice, "avg_kl": avg_kl}

    def training_epoch_end(self, outputs):
        # aggregate loss , dice and avg kl and loog them in tensorboard 


    def validation_step(self, val_batch, batch_idx):
        # calculate loss, dice and avg_kl
        

        return {
            "loss": loss,
            "dice": dice,
            "avg_kl": avg_kl,
        }

    def validation_epoch_end(self, outputs):
        # aggregate loss , dice and avg kl and loog them in tensorboard 
        

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    def on_epoch_end(self):

        # log sampled images
        if(self.current_epoch%5 == 0):
            sample_out = self.forward(self.sample_random_batch)
            sample_out = sample_out[0].detach().cpu().numpy()
            sample_in = self.sample_random_batch.cpu().numpy()
            data = []
            slice_index = sample_out.shape[2]//2
            for i in range(1):
                for j in range(sample_out.shape[0]):
                    data.append(sample_out[j,i])
                for j in range(sample_out.shape[0]):
                    data.append(sample_in[j,i])
            data_tensor = torch.from_numpy(np.array(data)).unsqueeze(1)
            grid = torchvision.utils.make_grid(data_tensor,
                                              normalize = True, 
                                             scale_each = True,
                                             nrow = sample_out.shape[0])
            self.logger.experiment.add_image("generated_images", grid, self.current_epoch)


In [8]:
from pytorch_lightning.loggers import TensorBoardLogger
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback,
    TuneReportCheckpointCallback,
)

In [10]:
# single run
def train_vae_single(config, num_epochs=1, num_gpus=1):
    model = LightningVAE(config)
    data_module = MNISTDataModule(
        batch_size=config["batch_size"]
    )

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        logger=TensorBoardLogger(save_dir="./logs"),
    )

    trainer.fit(model, data_module)

In [11]:
param = {
        "lr": 0.00001,
        "latent_dim": 256,
        "kernel_size": 3,
        "dropout_rate": 0.1,
        "alpha": 1,
        "beta": 0.01,
        "norm": Norm.INSTANCE,
         "batch_size": 256,
        
            "val": 3,
            "channel": (32, 64, 64),
            "stride": (1, 2, 4),
            # "resnet_units_batch" : hp.choice("res6", res_d6),

        "num_resnets":  0,
           
        
    }

train_vae_single(param, num_epochs=20 , num_gpus =1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type           | Params
--------------------------------------------
0 | model    | VarAutoEncoder | 899 K 
1 | vae_loss | KLLoss         | 0     
2 | dice     | Dice           | 0     
--------------------------------------------
899 K     Trainable params
0         Non-trainable params
899 K     Total params
3.598     Total estimated model params size (MB)


                                                                      

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Epoch 0:   1%|          | 2/235 [00:00<00:23,  9.85it/s, loss=1.75, v_num=28]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Epoch 0:  91%|█████████▏| 215/235 [00:13<00:01, 16.31it/s, loss=1.22, v_num=28]
Validating: 0it [00:00, ?it/s][A
Epoch 0:  92%|█████████▏| 217/235 [00:13<00:01, 16.37it/s, loss=1.22, v_num=28]
Validating:  10%|█         | 2/20 [00:00<00:00, 18.58it/s][A
Epoch 0:  94%|█████████▍| 221/235 [00:13<00:00, 16.49it/s, loss=1.22, v_num=28]
Epoch 0:  96%|█████████▌| 225/235 [00:13<00:00, 16.62it/s, loss=1.22, v_num=28]
Epoch 0:  97%|█████████▋| 229/235 [00:13<00:00, 16.74it/s, loss=1.22, v_num=28]
Validating:  70%|███████   | 14/20 [00:00<00:00, 27.50it/s][A
Epoch 0: 100%|██████████| 235/235 [00:13<00:00, 16.90it/s, loss=1.22, v_num=28]
Epoch 1:  30%|███       | 71/235 [00:04<00:10, 16.24it/s, loss=1.13, v_num=28] 

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
