In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
%matplotlib inline
import numpy as np
import pandas as pd
import os, shutil, glob, sys, math, cv2, re

import albumentations as albu
from torchsummary import summary

from tqdm import tqdm
# import normalizeStaining

AttributeError: module 'torch.nn' has no attribute 'MultiheadAttention'

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List
from torchvision import models

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

from sklearn.model_selection import train_test_split

# data process

In [None]:
tumor_record_folder = '/nfs/Shared/data/tcga/tumor'
cohort_tumor_npy = sorted(os.listdir(tumor_record_folder))

In [None]:
cohort_images_dict = {i[:12]: [] for i in cohort_tumor_npy}    
for npy in cohort_tumor_npy:
    cohort_name = npy[:12]
    image_names = np.load(os.path.join(tumor_record_folder, npy))
    cohort_images_dict[cohort_name].extend(image_names)

In [None]:
# 數量500多，有些沒有images
import pandas as pd
df = pd.read_csv('./data/coad_Mutation_Count.txt', delimiter="\t")
cohort_count_dict = {row[1]: row[3] for index,row in df.iterrows()}

In [None]:
train_cohorts, valid_cohorts = train_test_split(
    list(cohort_images_dict.keys()), test_size=0.33, random_state=42)
print(len(train_cohorts), len(valid_cohorts))

In [None]:
def get_images_labels(cohorts, cohort_images_dict, cohort_count_dict):
    all_images = []
    all_labels = []
    no_expand_labels = []
    for cohort in cohorts:
        images = cohort_images_dict[cohort]
        counts = [cohort_count_dict[cohort]]*len(images)
        all_images.extend(images)
        all_labels.extend(counts)
        no_expand_labels.append(cohort_count_dict[cohort])
    return np.array(all_images), np.array(all_labels), np.array(no_expand_labels)

In [None]:
def cvt2percentile(counts, labels):
    percentile = np.percentile(counts, [25,75])
    print(percentile)
    percentile_labels = np.zeros_like(labels)
    
    for idx, label in enumerate(labels):
        _foo = (label < percentile)*1
        if np.sum(_foo) == 0:
            percentile_labels[idx] = 2
        else:
            for _temp, _i in enumerate(_foo):
                if _i == 1:
                    percentile_labels[idx] = _temp
                    break
#         print(label, percentile_labels[idx])
    return percentile_labels

In [None]:
train_images, train_labels, train_counts = get_images_labels(train_cohorts, cohort_images_dict, cohort_count_dict)
valid_images, valid_labels, valid_counts = get_images_labels(valid_cohorts, cohort_images_dict, cohort_count_dict)

print(len(train_images), len(train_labels), len(train_counts), len(valid_images), len(valid_labels), len(valid_counts))

In [None]:
train_labels = cvt2percentile(train_counts, train_labels)
valid_labels = cvt2percentile(valid_counts, valid_labels)

# dataset

In [None]:
def get_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(height = 256, width = 265, always_apply=True),
        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(p=0.5),
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

# https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py, to_tensor     
def to0_1(x, **kwargs):
    return x/255

def get_preprocessing():
    _transform = [
        albu.Lambda(image=to_tensor, mask=to_tensor),
        albu.Lambda(image=to0_1, mask=to0_1),
    ]
    return albu.Compose(_transform)

In [None]:
class Dataset(BaseDataset):
    
    def __init__(self, image_array, label_array, augmentation=None, preprocessing=None):
        self.image_array = image_array
        self.label_array = label_array
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        fp = self.image_array[i]
        
        image = cv2.imread(fp)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         imgNorm= normalizeStaining.normalizeStaining(img = im_rgb)
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image)
            image = sample['image']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']
        
        label = self.label_array[i]
        return image, label
    
    def __len__(self):
        return len(self.image_array)

# model setting

In [None]:
import utils

In [None]:
bs = 64
epoch = 60
model_arch = 'resnet18'
train_loss = 'ce'
train_metric = 'fscore'
use_sampler = True
init_lr = 1e-4
nclass = 3

In [None]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
    lr = init_lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(in_features=512, out_features=nclass, bias=True)

In [None]:
train_dataset = Dataset(
    train_images, 
    train_labels, 
    augmentation=get_augmentation(), 
    preprocessing=get_preprocessing(),
)

valid_dataset = Dataset(
    valid_images, 
    valid_labels, 
    preprocessing=get_preprocessing(),
)

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=4)

In [None]:
loss = utils.metrics.CrossEntropy()
metrics = [
    utils.metrics.Fscore(),
    utils.metrics.Accuracy(),
]


optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=init_lr),
])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max=epoch, eta_min=1e-7)

In [None]:
use_cuda = torch.cuda.is_available()
DEVICE = torch.device("cuda" if use_cuda else "cpu")

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
import time
current_time = time.strftime("%Y_%m_%d_%H_%M", time.localtime())

model_name = './weight/' + "{}_MCnt-centile3-loss:{}_bs:{}".format(
    current_time, train_loss, bs)

cur_metric = 0

train_history = []
valid_history = []
for i in range(0, epoch):
    print('\nEpoch: {}, batch: {}'.format(i, bs))
    
    # lr_scheduler.step()
    adjust_learning_rate(optimizer, i)
    for param_group in optimizer.param_groups:
        print(param_group['lr'])
    
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    train_history.append(train_logs)
    valid_history.append(valid_logs)
    
    if cur_metric < valid_logs[metrics[0].__name__]:
        cur_metric = valid_logs[metrics[0].__name__]
        torch.save(model.state_dict(), model_name+"_epoch{}_{}:{:.4f}".format(i,train_metric, cur_metric)+".h5")
        print('Model saved!', model_name+"_epoch{}_{}:{:.4f}".format(i,train_metric, cur_metric)+".h5")

#     if cur_metric > valid_logs[metrics[0].__name__]:
#         cur_metric = valid_logs[metrics[0].__name__]
#         torch.save(model.state_dict(), model_name+"_epoch{}_{}:{:.4f}".format(i,train_metric, cur_metric)+".h5")
#         print('Model saved!', model_name+"_epoch{}_{}:{:.4f}".format(i,train_metric, cur_metric)+".h5")