<a href="https://colab.research.google.com/github/44REAM/radianet/blob/master/main_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [1]:

!pip install optuna
!pip install pytorch_lightning
!pip install efficientnet_pytorch

%cd '/content/drive/My Drive/deeplearning'
!rm -r radianet
!git clone https://github.com/44REAM/radianet.git
%cd radianet


/content/drive/My Drive/deeplearning
Cloning into 'radianet'...
remote: Enumerating objects: 270, done.[K
remote: Counting objects: 100% (270/270), done.[K
remote: Compressing objects: 100% (186/186), done.[K
remote: Total 270 (delta 115), reused 206 (delta 63), pack-reused 0[K
Receiving objects: 100% (270/270), 3.22 MiB | 4.56 MiB/s, done.
Resolving deltas: 100% (115/115), done.
/content/drive/My Drive/deeplearning/radianet


In [7]:
%load_ext tensorboard

import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
from torch.optim import Adam
import optuna
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping

from sklearn.metrics import confusion_matrix

from radianet import get_dataloader
from radianet.datasets import SampleDataset2D, LIDCDataset, Transforms, lidc_dataloader
from radianet.models import Simple3DCNN, MyEfficientNet
from radianet.callbacks import MetricsCallback
from radianet.metrics import binary_accuracy


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [0]:

class LightningNet(pl.LightningModule):
    def __init__(self, trial, config, model, dataloader, loss):
        super().__init__()
        self.model = model(trial, config)
        self.config = config
        self.loss = loss

        self.dataloader = dataloader
        self.hparams = self._get_hparams(trial)
        self.train_dataset = self.dataloader['train']
        self.val_dataset = self.dataloader['val']
        self.test_dataset = self.dataloader['test']

    def _get_hparams(self, trial):
        learning_rate = trial.suggest_loguniform('lr', self.config.MIN_LR, self.config.MAX_LR)
        hparams = {
            'lr': learning_rate
        }

        return hparams

    def prepare_data(self):

        try:
            for batch, _ in self.train_dataset:
                test_data = batch
                break
            self.model.calculate_linear_input(test_data)
        except AttributeError:
            print('linear input not have been calculate')

    def forward(self, data):

        return self.model(data)

    def training_step(self, batch, _):
        data, target = batch
        output = self.forward(data)
        output = output.reshape(-1)
        return {"loss": self.loss(output, target)}

    def validation_step(self, batch, _):
        data, target = batch

        output = self.forward(data)
        output = output.reshape(-1)

        acc = binary_accuracy(output, target, pos_weight=0.30769)

        return {"val_loss": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        # Pass the accuracy to the `DictLogger` via the `'log'` key.
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return Adam(self.model.parameters(), lr=self.hparams['lr'])

    def test_step(self, batch, _):
        data, target = batch
        output = self.forward(data)
        output = output.reshape(-1)

        acc = binary_accuracy(output, target, pos_weight=0.307, theshold=0.48)
        binary_confusion_matrix(output, target)

        return {"test_acc": acc}

    def test_epoch_end(self, outputs):
        test_acc_mean = torch.stack([x['test_acc'] for x in outputs]).mean()
        return {'test_acc': test_acc_mean}

    def train_dataloader(self):
        return self.train_dataset

    def val_dataloader(self):
        return self.val_dataset

    def test_dataloader(self):
        return self.test_dataset


In [0]:
class Config():
    BATCHSIZE = 32
    N_TRIALS = 1
    EPOCHS = 50
    N_CLASSES = 1
    BORDER_MODE = 4
    NAME = 'homework'
    IMAGE_SIZE = 224
    MIN_LR = 1e-6
    MAX_LR = 1e-4
    LIDC_PATH = '/content/drive/My Drive/dataset/radiology/TCIA_LIDC-IDRI/preprocess/'
    EFFICIENTNET_B0_LAYER = 0

    CHECKPOINT_DIR = '/content/drive/My Drive/deeplearning/checkpoints/'
    CHECKPOINT = CHECKPOINT_DIR + NAME
    TENSORBOARD = CHECKPOINT_DIR + NAME
    DB_NAME = 'sqlite:///' + NAME + '.db'


In [0]:
def get_logger():
    return TensorBoardLogger(Config.TENSORBOARD)

def binary_confusion_matrix(output, target, theshold = 0.48):
    output[output > theshold] = 1
    output[output <= theshold] = 0

    print(confusion_matrix(output.cpu(), target.cpu()))

def get_model_checkpoint(trial):
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        os.path.join(Config.CHECKPOINT, "trial_{}".format(trial.number), "{epoch}"), 
        monitor="val_loss", save_top_k = -1
    )
    return checkpoint_callback



In [37]:

def objective(trial):

    logger = get_logger()
    checkpoint_callback = get_model_checkpoint(trial)
    metrics_callback = MetricsCallback()

    loss = nn.BCEWithLogitsLoss(pos_weight  = torch.tensor([0.307]))
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        min_delta=0.00,
        patience=3,
        verbose=False,
        mode='max'
    )
    
    # callbacks DO NOT replace the explicit callbacks (loggers, EarlyStopping or ModelCheckpoint)
    trainer = pl.Trainer(
        max_epochs=Config.EPOCHS,
        
        # all callbacks
        logger=logger,
        checkpoint_callback=checkpoint_callback,
        early_stop_callback = None,
        callbacks=[metrics_callback],

        gpus=[0] if torch.cuda.is_available() else None
    )
    

    dataloader = lidc_dataloader(Config.LIDC_PATH, Config)

    lightning_model = LightningNet(trial, Config, MyEfficientNet, dataloader, loss)

    trainer.fit(lightning_model)
    trainer.test()

    return metrics_callback.metrics[-1]["val_loss"]

study = optuna.create_study(study_name = Config.NAME, storage = Config.DB_NAME, 
                            load_if_exists = True, direction="maximize")
study.optimize(objective, n_trials= Config.N_TRIALS)


[32m[I 2020-05-16 07:29:49,449][0m Using an existing study with name 'homework' instead of creating a new one.[0m
GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]

    | Name                                              | Type                    | Params
------------------------------------------------------------------------------------------
0   | model                                             | MyEfficientNet          | 5 M   
1   | model.head                                        | Sequential              | 928   
2   | model.head.0                                      | Conv2dStaticSamePadding | 864   
3   | model.head.0.static_padding                       | ZeroPad2d               | 0     
4   | model.head.1                                      | BatchNorm2d             | 64    
5   | model.body_no_grad                                | Sequential              | 0     
6   | model.body_grad                  

Loaded pretrained weights for efficientnet-b0
linear input not have been calculate



The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` in the `DataLoader` init to improve performance.



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…




The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` in the `DataLoader` init to improve performance.



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…





The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` in the `DataLoader` init to improve performance.



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

[[ 6 12]
 [ 9 38]]
--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': tensor(58.2076)}
--------------------------------------------------------------------------------



[32m[I 2020-05-16 07:32:24,168][0m Finished trial#11 with value: 71.3131332397461 with parameters: {'lr': 1.0012466029993057e-05}. Best is trial#7 with value: 74.65685272216797.[0m
