In [1]:
import argparse
import datetime
import os
import sys

In [2]:
import numpy as np

In [25]:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset

In [26]:
cd D:\Lung_Cancer_Diagnostic

D:\Lung_Cancer_Diagnostic


In [31]:
cd D:\Lung_Cancer_Diagnostic\LunaDataset

D:\Lung_Cancer_Diagnostic\LunaDataset


In [36]:
import import_ipynb
import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader

from util.util import enumerateWithEstimate
from dsets import LunaDataset
from util.logconf import logging
from model import LunaModel

In [40]:
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

In [41]:
METRICS_LABEL_NDX=0
METRICS_PRED_NDX=1
METRICS_LOSS_NDX=2
METRICS_SIZE = 3

In [None]:
class LunaTrainingApp:
    def __int__(self, sys_argv=None):
        if sys_argv is None:
            sys_argv = sys.argv[1:]
        
        parser = argparse.ArgumentParser()
        parser.add_argument('--num-workers',
                           help='Number of worker processes for background data loding',
                           default=8,
                           type=int,
                           )
        parser.add_argument('--batch-size',
                            help='Batch size to use for training',
                            default=32,
                            type=int,
                           )
        parser.add_argument('--epochs',
                           help='Number of epochs to train for',
                           default=1,
                           type=int,
                           )
        parser.add_argument('--tb-prefix',
                           default='p2ch11',
                           help="Data prefix to use for Tensorboard run. Defaults to chapter.",
                           )
        parser.add_argument('comment',
                           help='Comment suffix for Tensorboard run.',
                           nargs='?',
                           default='dwlpt',
                           )
        self.cli_args = parser.parse_args(sys_argv)
        self.time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H.%M.%S")
        
        self.trn_writer = None
        self.val_writer = None
        self.totalTrainingSamples_count = 0
        
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        
        self.model = self.initModel()
        self.optimizer = self.initOptimizer()
        
    def initModel(self):
        model = LunaModel()
        if self.use_cuda:
            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1:
                model = nn.DataParallel(model)
            model = model.to(self.device)
        return model
    
    def initOptimizer(self):
        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
    
    def initTrainDl(self):
        train_ds = LunaDataset(
            val_stride = 10,
            isValSet_bool = False,
        )
        
        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()
            
        train_dl = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )
        
        return train_dl
    
    def initValDl(self):
        val_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=True,
        )
        
        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()
            
        val_dl = DataLoader(
            val_ds,
            batch_size = batch_size,
            num_workers = self.cli_args.num_workers,
            pin_memory = self.use_cuda,
        )
        
        return val_dl
    
    def initTensorboardWriters(self):
        if self.trn_writer is None:
            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
            
            self.trn_writer = SummaryWriter(
                log_dir = log_dir + '-trn_cls-' + self.cli_args.comment)
            self.val_writer = SummaryWriter(
                log_dir = log_dir + '-val_cls-' + self.cli_args.comment)
            
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
        
        train_dl = self.initTrainDl()
        val_dl = self.initValDl()
        
        for epoch_ndx in range(1, self.cli_args.epochs + 1):
            
            log.info('Epoch {} of {}, {}/{} batches of size {}*{}'.format(
                epoch_ndx,
                self.cli_args.epochs,
                len(train_dl),
                len(val_dl),
                self.cli_args.batch_size,
                (torch.cuda.device_count() if self.use_cuda else 1),
            ))
            
            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
            
            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
            self.logMetrics(epoch_ndx, 'val', valMetrics_t)
            
        if hasattr(self, 'trn_writer'):
            self.trn_writer.close()
            self.val_writer.close()
            
    def doTraining(self, epoch_ndx, train_dl):
        self.model.train()
        trnMetrics_g = torch.zeros(
            METRICS_SIZE,
            len(train_dl.dataset),
            device=self.device,
        )
        
        batch_iter = enumerateWithEstimate(
            train_dl,
            "E{} Training".format(epoch_ndx),
            start_ndx = train_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            self.optimizer.zero_grad()
            
            loss_var = self.computeBatchLoss(
                batch_ndx,
                batch_tup,
                train_dl.batch_size,
                trnMetrics_g
            )
            
            loss_var.backward()
            self.optimizer.step()
            
            # This is for adding the model graph to TensorBoard.
            # if epoch_ndx == 1 and batch_ndx == 0:
            #      with torch.no_grad():
            #            model = LunaModel()
            #            self.trn_writer.add_graph(model, batch_tup[0], verbose=True)
            #            self.trn_writer.close()
            
        self.totalTrainingSamples_count += len(train_dl.dataset)
        
        return trnMetrics_g.to('cpu')
    
    def doValidation(self, epoch_ndx, val_dl):
        with torch.no_grad():
            self.model.eval()
            valMetrics_g = torch.zeros(
                METRICS_SIZE,
                len(val_dl.dataset),
                device=self.device,
            )
            
            batch_iter = enumerateWithEstimate(
                    val_dl,
                    "E{} Validation ".format(epoch_ndx),
                    start_ndx=val_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(
                    batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
                
        return valMetrics_g.to("cpu")

    
    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
        input_t, label_t, _series_list, _center_list = batch_tup
        
        input_g = input_t.to(self.device, non_blocking=True)
        label_g = label_t.to(self.device, non_blocking=True)
        
        logits_g, probability_g = self.model(input_g)
        
        loss_func = nn.CrossEntropyLoss(reduction='none')
        loss_g = loss_func(
            logits_g,
            label_g[:,1],
        )
        start_ndx = batch_ndx * batch_size
        end_ndx = start_ndx + label_t.size(0)
        
        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = \
            label_g[:,1].detach()
        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = \
            probability_g[:,1].detach()
        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = \
            loss_g.detach()
        
        return loss_g.mean()
    
    def logMetrics(
        ## 작성하기
    )