In [None]:
import os
import numpy as np
import pandas as pd
import PIL.Image as Image
import glob

FILE_EXTENSION = ['.png', '.PNG', '.jpg', '.JPG', '.dcm', '.DCM', '.raw', '.RAW', '.svs', '.SVS']
IMG_EXTENSION = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG']
DCM_EXTENSION = ['.dcm', '.DCM']
RAW_EXTENSION = ['.raw', '.RAW']
NIFTI_EXTENSION = ['.nii']
NP_EXTENSION = ['.npy']


def check_extension(filename, extension_ls=FILE_EXTENSION):
    return any(filename.endswith(extension) for extension in extension_ls)


def load_file_path(folder_path, extension_ls=FILE_EXTENSION, all_sub_folders=False):
    """find 'IMG_EXTENSION' file paths in folder.
    
    Parameters:
        folder_path (str) -- folder directory
        extension_ls (list) -- list of extensions
    
    Return:
        file_paths (list) -- list of 'extension_ls' file paths
    """
    
    file_paths = []
    assert os.path.isdir(folder_path), f'{folder_path} is not a valid directory'

    for root, _, fnames in sorted(os.walk(folder_path)):
        for fname in fnames:
            if check_extension(fname, extension_ls):
                path = os.path.join(root, fname)
                file_paths.append(path)
        if not all_sub_folders:
            break

    return file_paths[:]


def gen_new_dir(new_dir):
    try: 
        if not os.path.exists(new_dir): 
            os.makedirs(new_dir) 
            #print(f"New directory!: {new_dir}")
    except OSError: 
        print("Error: Failed to create the directory.")
        

def get_data_fname_label_in_split(data_df, mode='train'):
    return data_df[data_df['split']==mode][['name', 'good_outcome_3m']].values


def get_dataset(data_df, data_dir, mode='train'):
    data_fname_label_arr = get_data_fname_label_in_split(data_df, mode=mode)
    np_path_ls = sorted(load_file_path(data_dir, NP_EXTENSION))
    np_path_dict = {os.path.splitext(os.path.basename(p))[0]:p for p in np_path_ls}
    return [[np_path_dict.get(fname), label] for fname, label in data_label_arr if np_path_dict.get(fname)]

In [None]:
def normalize(arr):
    tmp = (arr - arr.min())/(arr.max()-arr.min())
    return tmp.astype(np.float32)

def img_loader(imgpath):
    img = np.load(imgpath).transpose(2,1,0)
    return normalize(img)

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2


def get_training_augmentation(params=None):
    transform_list = []
    
    #transform_list.append(A.HorizontalFlip(p=.5))
    #transform_list.append(A.VerticalFlip(p=.5))
    #transform_list.append(A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=5, shift_limit=0.2, border_mode=0, p=.5))
    #transform_list.append(A.ShiftScaleRotate(scale_limit=0.01, rotate_limit=5, shift_limit=0., border_mode=0, p=.5))
    
    return A.Compose(transform_list)


def get_preprocessing(params=None,convert=True):
    transform_list = []
    if convert:
        #transform_list.append(A.Normalize(mean=(0.5,),  std=(0.5,)))
        #transform_list.append(A.Normalize(mean=(0.485, 0.456, 0.406),  std=(0.229, 0.224, 0.225)))
        transform_list.append(ToTensorV2(transpose_mask=True))
    return A.Compose(transform_list)


class AIHUB_GoodOutcomPredDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 dataset_dir, 
                 dataset_df,
                 img_loader=img_loader, 
                 augmentation=None, 
                 preprocessing=None,
                 mode='train'
    ):
        self.dataset_dir = dataset_dir
        self.dataset_df = pd.read_csv(dataset_df)
        self.img_loader = img_loader
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.mode = mode
        self.dataset = get_dataset(self.dataset_df, self.dataset_dir, self.mode)
        if self.mode != 'train':
            self.augmentation = None
        
    def __getitem__(self, index):
        image_path, label = self.dataset[index]
        image = img_loader(image_path)
        if self.augmentation:
            sample = self.augmentation(image=image)
            image = sample['image']
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']
            image = torch.unsqueeze(image, 0).permute(0, 1, 2, 3)
        
        return image, label
    
    def __len__(self):
        return len(self.dataset)

In [None]:
train_dataset = AIHUB_GoodOutcomPredDataset(
    dataset_dir='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np_resampled',
    dataset_df='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/aihub_df.csv',
    augmentation=None,
    preprocessing=get_preprocessing(),
    mode='train'
)

In [None]:
val_dataset = AIHUB_GoodOutcomPredDataset(
    dataset_dir='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np_resampled',
    dataset_df='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/aihub_df.csv',
    augmentation=None,
    preprocessing=get_preprocessing(),
    mode='val'
)

In [None]:
test_dataset = AIHUB_GoodOutcomPredDataset(
    dataset_dir='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np_resampled',
    dataset_df='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/aihub_df.csv',
    augmentation=None,
    preprocessing=get_preprocessing(),
    mode='test'
)

In [None]:
import time
import torch

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            #self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            #self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [None]:
def train_epoch(model, loader, optimizer, epoch, n_epochs, print_freq=100):
    batch_time = AverageMeter()
    losses = AverageMeter()
    error = AverageMeter()

    # Model on train mode
    model.cuda()
    model.train()

    end = time.time()
    for batch_idx, (input, target) in enumerate(loader):
        # Create vaiables
        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()

        # compute output
        output = model(input)
        #loss = criterion(output, target)
        loss = torch.nn.functional.cross_entropy(output, target)

        # measure accuracy and record loss
        batch_size = target.size(0)
        _, pred = output.data.cpu().topk(1, dim=1)
        error.update(torch.ne(pred.squeeze(), target.cpu()).float().sum().item() / batch_size, batch_size)
        losses.update(loss.item(), batch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print stats
        if batch_idx % print_freq == 0:
            res = '\t'.join([
                'Epoch: [%d/%d]' % (epoch + 1, n_epochs),
                'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                'Time %.3f (%.3f)' % (batch_time.val, batch_time.avg),
                'Loss %.4f (%.4f)' % (losses.val, losses.avg),
                'Error %.4f (%.4f)' % (error.val, error.avg),
            ])
            print(res)

    # Return summary statistics
    return batch_time.avg, losses.avg, error.avg

In [None]:
def test_epoch(model, loader, print_freq=10, is_test=True):
    batch_time = AverageMeter()
    losses = AverageMeter()
    error = AverageMeter()

    # Model on eval mode
    model.cuda()
    model.eval()

    end = time.time()
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
            # Create vaiables
            if torch.cuda.is_available():
                input = input.cuda()
                target = target.cuda()

            # compute output
            output = model(input)
            #loss = criterion(output, target)
            loss = torch.nn.functional.cross_entropy(output, target)

            # measure accuracy and record loss
            batch_size = target.size(0)
            _, pred = output.data.cpu().topk(1, dim=1)
            error.update(torch.ne(pred.squeeze(), target.cpu()).float().sum().item() / batch_size, batch_size)
            losses.update(loss.item(), batch_size)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # print stats
            if batch_idx % print_freq == 0:
                res = '\t'.join([
                    'Test' if is_test else 'Valid',
                    'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                    'Time %.3f (%.3f)' % (batch_time.val, batch_time.avg),
                    'Loss %.4f (%.4f)' % (losses.val, losses.avg),
                    'Error %.4f (%.4f)' % (error.val, error.avg),
                ])
                print(res)

    # Return summary statistics
    return batch_time.avg, losses.avg, error.avg

In [None]:
def train(model, train_set, valid_set, test_set, save, n_epochs=300,
          batch_size=64, lr=0.0001, patience=10, save_epoch=10, seed=None):
    cnt=0
    if seed is not None:
        torch.manual_seed(seed)

    # Data loaders
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size, drop_last=True, shuffle=True,
                                               pin_memory=(torch.cuda.is_available()), num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False,
                                              pin_memory=(torch.cuda.is_available()), num_workers=0)
    
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    if valid_set is None:
        valid_loader = None
    else:
        valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False,
                                                   pin_memory=(torch.cuda.is_available()), num_workers=0)
    # Model on cuda
    if torch.cuda.is_available():
        model = model.cuda()

    # Wrap model for multi-GPUs, if necessary
    model_wrapper = model
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        model_wrapper = torch.nn.DataParallel(model).cuda()

    # Optimizer
    optimizer = torch.optim.Adam(model_wrapper.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs],
                                                     gamma=0.1)

    # Start log
    with open(os.path.join(save, 'results.csv'), 'w') as f:
        f.write('epoch,train_loss,train_error,valid_loss,valid_error,test_error\n')

    # Train model
    best_error = 1
    for epoch in range(n_epochs):
        _, train_loss, train_error = train_epoch(
            model=model_wrapper,
            loader=train_loader,
            optimizer=optimizer,
            epoch=epoch,
            n_epochs=n_epochs,
        )
        scheduler.step()
        _, valid_loss, valid_error = test_epoch(
            model=model_wrapper,
            loader=valid_loader if valid_loader else test_loader,
            is_test=(not valid_loader)
        )

        # Determine if model is the best
        if valid_loader:
            if valid_error < best_error:
                best_error = valid_error
                print('New best error: %.4f' % best_error)
                torch.save(model.state_dict(), os.path.join(save, 'model_best.dat'))
        else:
            if (cnt%save_epoch==0):
                #torch.save(model.state_dict(), os.path.join(save, 'model_epoch'+str(cnt).zfill(3)+'.dat'))
                pass
        # Log results
        with open(os.path.join(save, 'results.csv'), 'a') as f:
            f.write('%03d,%0.6f,%0.6f,%0.5f,%0.5f,\n' % (
                (epoch + 1),
                train_loss,
                train_error,
                valid_loss,
                valid_error,
            ))
        cnt+=1
        
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    torch.save(model.state_dict(), os.path.join(save, 'model_final.dat'))

    # Final test of model on test set
    model.load_state_dict(torch.load(os.path.join(save, 'model_final.dat')))
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()
    test_results = test_epoch(
        model=model,
        loader=test_loader,
        is_test=True
    )
    _, _, test_error = test_results
    with open(os.path.join(save, 'results.csv'), 'a') as f:
        f.write(',,,,,%0.5f\n' % (test_error))
    print('Final test error: %.4f' % test_error)

In [None]:
def demo(save, model,
         n_epochs=300, 
         batch_size=64, 
         lr=0.0001, 
         patience=10, 
         seed=None):
    """

    """
    # Datasets
    train_dataset = AIHUB_GoodOutcomPredDataset(
        dataset_dir='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np_resampled',
        dataset_df='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/aihub_df.csv',
        augmentation=None,
        preprocessing=get_preprocessing(),
        mode='train')
    val_dataset = AIHUB_GoodOutcomPredDataset(
        dataset_dir='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np_resampled',
        dataset_df='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/aihub_df.csv',
        augmentation=None,
        preprocessing=get_preprocessing(),
        mode='val')
    test_dataset = AIHUB_GoodOutcomPredDataset(
        dataset_dir='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np_resampled',
        dataset_df='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/aihub_df.csv',
        augmentation=None,
        preprocessing=get_preprocessing(),
        mode='test')

    # Models
    #print(model)
    
    # Print number of parameters
    num_params = sum(p.numel() for p in model.parameters())
    print("Total parameters: ", num_params)

    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    # Train the model
    train(model=model, train_set=train_dataset, valid_set=val_dataset, test_set=test_dataset, save=save,
          n_epochs=n_epochs, batch_size=batch_size, lr=lr, patience=patience, seed=seed)
    print('Done!')

In [None]:
GPU_NUM = 1
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')

In [None]:
from models import resnet, wide_resnet, resnext, densenet

In [None]:
save_path = './3DResNet101'
gen_new_dir(save_path)
N_EPOCHS = 10000
BATCH_SIZE = 32
LR = 0.00001
PATIENCE = 10

model = resnet.resnet101(num_classes=2, 
                         shortcut_type='A', 
                         spatial_size=256, 
                         sample_duration=36)

In [None]:
demo(save=save_path, 
     model=model,
     n_epochs=N_EPOCHS, 
     batch_size=BATCH_SIZE, 
     lr=LR, 
     patience=PATIENCE, 
     seed=None)

In [None]:
from sklearn import metrics
import matplotlib.pyplot as plt
import seaborn
from tqdm import tqdm

In [None]:
def test_acc(testloader, model, threshold=0.5):
    correct = 0
    total = 0
    output_arr = np.ones((1, 2))
    label_arr = np.array([])
    pred_arr = np.array([])
    model.cuda()
    model.eval()
   

    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1) # argmax
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            output_arr = np.concatenate((output_arr, outputs.softmax(1).cpu().numpy()), axis=0)
            label_arr = np.concatenate((label_arr, labels.cpu().numpy()), axis=0)
            pred_arr = np.concatenate((pred_arr, predicted.cpu().numpy()), axis=0)

    output_arr = np.delete(output_arr, 0, axis=0)
    acc = correct/total
    print('Accuracy on the test images: ', (100*correct/total))
    return acc, output_arr, label_arr, pred_arr

In [None]:
save_path = './3DResNet101'

test_model = resnet.resnet101(num_classes=2, 
                         shortcut_type='A', 
                         spatial_size=256, 
                         sample_duration=36)
test_model.load_state_dict(torch.load(os.path.join(save_path, 'model_best.dat')))

In [None]:
test_dataset = AIHUB_GoodOutcomPredDataset(
        dataset_dir='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np_resampled',
        dataset_df='/home/ncp/workspace/blocks1/3D_CNN_for_PRED/aihub_df.csv',
        augmentation=None,
        preprocessing=get_preprocessing(),
        mode='test')

test_loader = torch.utils.data.DataLoader(intest_dataset, 
                                          batch_size=16, 
                                          shuffle=False,
                                          pin_memory=(torch.cuda.is_available()), 
                                          num_workers=0)

In [None]:
acc, output_arr, label_arr, pred_arr = test_acc(intest_loader, test_model)

In [None]:
tmp = output_arr.copy()
threshold = 0.962
tmp[:,0] -= (threshold-0.5) / 2
tmp[:,1] += (threshold-0.5) / 2

In [None]:
pred_ = []
for t in tmp:
    pred_.append(np.argmax(t))

In [None]:
print(metrics.classification_report(label_arr, pred_, target_names=['survived', 'not_survived']))

In [None]:
class_names = ['bad_outcome', 'good_outcome']

In [None]:
cm = metrics.confusion_matrix(label_arr, pred_)
df_cm = pd.DataFrame(cm, index = [i for i in class_names], columns = [i for i in class_names])
plt.figure(figsize = (10, 7))

#plt.ylabel('True label')
#plt.xlabel('Pred label')
sn.heatmap(df_cm, annot=True)

In [None]:
cm.T

In [None]:
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(2):
        fpr[i], tpr[i], _ = metrics.roc_curve(label_arr == i, output_arr[:, i])
        roc_auc[i] = metrics.auc(fpr[i], tpr[i])
roc_auc[0]

In [None]:
plt.plot(fpr[0], tpr[0])