In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import timm
import argparse

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import skimage.transform
import scipy.ndimage as nd
import scipy.ndimage.morphology as mp
import random
import time
from tqdm import tqdm

parser = argparse.ArgumentParser(description = "step1 Cancer Segmentation EfficientNet-b0")
parser.add_argument('--batch_size', default = 100, type = int, help='batch size')
parser.add_argument('--num_epochs', default = 50, type = int, help='training epochs')
parser.add_argument('--lr', default = 0.02, type = float, help = 'learning rate')
parser.add_argument('--num_workers', type = int, default = 6)
parser.add_argument('--level', type = int, default = 1)
parser.add_argument('--valid_s', type = int, default = 36)
parser.add_argument('--valid_e', type = int, default = 41)
parser.add_argument('--resume', default = False)
parser.add_argument('--resume_epoch', type = int, default = 0)

args = parser.parse_args()

root_dir = '/mnt/hsyoo/data/patch/train'
# batch_size = 100
# learning_rate = 3e-4
# num_epochs = 50
# num_workers = 6
# level = 3
# valid_s = 36
# valid_e = 41
# resume = False
# resume_epoch = 0

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# In[3]:


def random_seed(seed_value, use_cuda):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    random.seed(seed_value)
    if use_cuda:
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True  #needed
        torch.backends.cudnn.benchmark = False
seed = 42
random_seed(seed, True)

# In[4]:


def path_list(root_dir, mode, valid_num, level):
    data_list = []
    num = 0
    level_dim = str(level)
    meta_labels = ["NOLN_metastasis", "LN_metastasis"]
    cancer_labels = ["normal", "cancer"]
    
    for meta_label in sorted(os.listdir(os.path.join(root_dir))):
        if meta_label not in meta_labels:
            continue
        for patient in sorted(os.listdir(os.path.join(root_dir, meta_label))):
            if mode == 'train':
                if 'LN' not in patient:
                    continue
                if patient.split('_')[-1] in valid_num:
                    continue
                else:
                    for cancer_label in sorted(os.listdir(os.path.join(root_dir, meta_label, patient, level_dim))):
                        if cancer_label not in cancer_labels:
                            continue
                        if cancer_label == 'normal':
                            label_value = 0
                        elif cancer_label == 'cancer':
                            label_value = 1
                        for image in sorted(os.listdir(os.path.join(root_dir, meta_label, patient, level_dim, cancer_label, "img"))):
                            if image.split('.')[-1] != 'png':
                                continue
                            else:
                                case = {
                                    'image' : os.path.join(root_dir, meta_label, patient, level_dim, cancer_label, "img", image),
                                    'label' : label_value
                                }
                                data_list.append(case)
            else:
                if patient.split('_')[-1] in valid_num:
                    for cancer_label in sorted(os.listdir(os.path.join(root_dir, meta_label, patient, level_dim))):
                        if cancer_label not in cancer_labels:
                            continue
                        if cancer_label == 'normal':
                            label_value = 0
                        elif cancer_label == 'cancer':
                            label_value = 1
                        for image in sorted(os.listdir(os.path.join(root_dir, meta_label, patient, level_dim, cancer_label, "img"))):
                            if image.split('.')[-1] != 'png':
                                continue
                            else:
                                case = {
                                    'image' : os.path.join(root_dir, meta_label, patient, level_dim, cancer_label, "img", image),
                                    'label' : label_value
                                }
                                data_list.append(case)
    return data_list
                            

start = time.perf_counter()
valid_num = list()
for i in range(args.valid_s,args.valid_e):
    val_num = format(i, '03')
    valid_num.append(val_num)
train_list = path_list(root_dir, "train", valid_num, args.level)
valid_list = path_list(root_dir, "valid", valid_num, args.level)
print(f"training Level : {args.level}")
print('* Time: %.3f' %(time.perf_counter() - start))
print(f"train_list : {len(train_list)}, valid_list : {len(valid_list)}")

class MyDataset(Dataset):
    def __init__(self, path_list, transform = None):
        self.path_list = path_list
        self.transform = transform
        
    def __getitem__(self, index):
        image = Image.open(self.path_list[index]['image'])
        image = image.convert("RGB")
        label = torch.tensor(self.path_list[index]['label']).type(torch.uint8)

        if self.transform:
            image = self.transform(image)
        data = {'image' : image, 'label' : label.item()}

        return data

    def __len__(self):
        return len(self.path_list)


tra = [
    transforms.RandomHorizontalFlip(), 
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
    transforms.RandomAffine((-10,10), shear=10, scale=(0.9, 1.2)),
    transforms.RandomRotation(90),
#     transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]

val = [
#     transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

trainset = MyDataset(train_list, transform = transforms.Compose(tra))
validset = MyDataset(valid_list, transform = transforms.Compose(val))


# In[9]:


train_loader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size, shuffle = True, num_workers = args.num_workers, pin_memory = True)
valid_loader = torch.utils.data.DataLoader(validset,batch_size=args.batch_size, shuffle = False, num_workers = args.num_workers, pin_memory = True)


net = timm.create_model("tf_efficientnet_b0_ns", pretrained = True, num_classes = 2)
net = net.to(device)

loss = torch.nn.CrossEntropyLoss() # loss
alg = torch.optim.SGD(net.parameters(),lr=args.lr)


# In[ ]:


if args.resume:
    net.load_state_dict(torch.load(f'/mnt/hsyoo/pth/pth_80/level_{args.level}/eff_b0_pth_{args.resume_epoch - 1}.pth'))

loss_train = np.array([])
loss_valid = np.array([])
accs_train = np.array([])
accs_valid = np.array([])
best_metric = -1
best_metric_epoch = -1

for epoch in range(args.num_epochs):
    stime = time.time()
    net.train()
    i=0
    l_epoch = 0
    correct = 0
    l_epoch_val = 0
    for item in tqdm(train_loader):
        i=i+1
        image, y = item['image'].to(device), item['label'].type(torch.long).to(device)
        y_hat=net(image)
        y_hat= F.softmax(y_hat, dim = 1)
        l=loss(y_hat,y)
        correct += (y_hat.argmax(dim=1)==y).sum()
        l_epoch+=l
        alg.zero_grad()
        l.backward()
        alg.step()
    loss_train = np.append(loss_train,l_epoch.cpu().detach().numpy()/i)
    accs_train = np.append(accs_train,correct.cpu()/np.float(len(trainset)))

    correct = 0
    i = 0
    net.eval()
    with torch.no_grad():
        for item in tqdm(valid_loader):
            i +=1
            image, y = item['image'].to(device), item['label'].to(device)
            y_hat=net(image)
            y_hat= F.softmax(y_hat, dim = 1)
            l = loss(y_hat, y)
            correct += (y_hat.argmax(dim=1)==y).sum()
            l_epoch_val += l
    accs_valid = np.append(accs_valid,correct.cpu()/np.float(len(validset)))
    loss_valid = np.append(loss_valid, l_epoch_val.cpu().detach().numpy()/i)
    if (correct.cpu()/np.float(len(validset))) > best_metric:
        best_metric = correct.cpu()/np.float(len(validset))
        best_metric_epoch = epoch
        torch.save(net.state_dict(), f"/mnt/hsyoo/pth/pth_80/level_{args.level}/best_model_level{args.level}.pth")
        print("saved new best metric model")
    
#     torch.save(net.state_dict(), f'/nfs/paip2021/paip2021/data/pth/multi_clf_aug/level_{level}/eff_b0_pth_{epoch}.pth')
    
    if True:
        fig = plt.figure(figsize = (12, 6))
        ax = fig.add_subplot(1,2,1)
        plt.plot(loss_train,label='train loss')
        plt.plot(loss_valid, label='valid loss')
        plt.legend(loc='lower left')
        plt.title('epoch: %d '%(epoch+1))

        ax = fig.add_subplot(1,2,2)
        plt.plot(accs_train,label='train accuracy')
        plt.plot(accs_valid,label='valid accuracy')
        plt.legend(loc='lower left')
        plt.pause(.0001)
        plt.show()
        fig.savefig(f"/mnt/pathology/hsyoo/result/loss/loss_80/level_{args.level}/eff_b0_step1_cla_loss.png")

        print('train loss: ',loss_train[-1])
        print('valid loss: ', loss_valid[-1])
        print('train accuracy: ',accs_train[-1])
        print('valid accuracy: ',accs_valid[-1])
        print(f"best metric epoch : {best_metric_epoch}, best metric accuracy : {best_metric}")
    print(f"1 epoch time : {(time.time() - stime) / 60} min")
