In [1]:
import os,sys,cv2,json
import matplotlib.pyplot as plt
import numpy as np
import albumentations as albu
import torch,torchvision
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import json
device = 'cuda:0'
sys.path.append('/rdfs/fast/home/chenyixin/')

In [2]:
def get_data(png_path):
    img = cv2.imread(png_path)[...,::-1]
    return {'img':img,'septum':0,'asd':0}        

In [3]:
def postprocess(pred):
    pred_ = torch.sigmoid(pred[0,0]).detach().cpu().numpy()
    pred_ = np.uint8(pred_ * 255)
    # septum_[septum != 0] = 255
    pred_ = cv2.dilate(pred_,np.ones((3,3),dtype=np.uint8),3)
    pred_ = cv2.erode(pred_,np.ones((3,3),dtype=np.uint8),3)
    pred_ = cv2.erode(pred_,np.ones((3,3),dtype=np.uint8),3)
    pred_ = cv2.dilate(pred_,np.ones((3,3),dtype=np.uint8),3)
    pred_[pred_< 25] = 0
    pred_[pred_!= 0] = 255
    contours,_ = cv2.findContours(pred_,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    # find biggest contours
    if len(contours) == 0:
        return 0
    area = []
    for c in contours:
        area.append(cv2.contourArea(contours[0]))
    max_index = np.argmax(area)
    max_cont = contours[max_index][:,0]
    # cal biggest dis
    return np.array([np.max(max_cont[:,0]),np.min(max_cont[:,0]),np.max(max_cont[:,1]),np.min(max_cont[:,1])])
def create_box(arr):
    a,b,c,d = np.max(arr[:,0]),np.min(arr[:,1]),np.max(arr[:,2]),np.min(arr[:,3])
    width,height = a-b,c-d
    length = max(width,height) * 1
    ct_x = (a + b) / 2
    ct_y = (c+d) / 2
    a = ct_x + length / 2
    b = ct_x - length / 2    
    c = ct_y + length / 2    
    d = ct_y - length / 2    
    return a,b,c,d

In [18]:
# # create val_list:
# path = '/rdfs/data/echo/chenyixin/ASD-yixin/Train_SC-2A//'
# i = 0
# train_data_x = []
# train_data_y = []
# val_data_x = []
# val_data_y = []
# val_list = []
# for basepath,dirnames,files in os.walk(path):
#     for dirname in dirnames:
#         tmp_path = os.path.join(basepath,dirname)
#         files = os.listdir(tmp_path)
#         png_files = [os.path.join(basepath,dirname,i) for i in files if '.png' in i]
        
#         if len(png_files) == 0:
#             continue
#         if np.random.rand() > 0.8:
#             val_list.append(tmp_path)
#         i += 1
# with open('./val_list.txt','w') as f:
#     f.write('/n'.join(val_list))

In [19]:
path = '/rdfs/data/echo/chenyixin/ASD-yixin/Train_SC-2A//'
with open('./val_list.txt','r') as f:
    val_list = f.readlines()[0]
    val_list = val_list.split('/n')
i = 0
train_data_x = []
train_data_y = []
val_data_x = []
val_data_y = []
for basepath,dirnames,files in os.walk(path):
    for dirname in dirnames:
        tmp_path = os.path.join(basepath,dirname)
        files = os.listdir(tmp_path)
        png_files = [os.path.join(basepath,dirname,i) for i in files if '.png' in i]
        
        if len(png_files) == 0:
            continue
        png_files = sorted(png_files)
        asd_label,asd_pred = [],[]
        septum_label,septum_pred = [],[]
        imgs = []
        for png_file in png_files:
            d = get_data(png_file)
            input_shape = d['img'].shape[:2]
            img = torch.tensor(d['img']/255).float().permute(2,0,1).unsqueeze(1)
            img = F.interpolate(img,(240,320))
            img = img[:,:,:,40:280]
            imgs.append(img)
        imgs = torch.cat(imgs,dim=1)
        print('\r' + f'{i}',end='',flush=True)
        if tmp_path in val_list:
            val_data_x.append(imgs)
            if 'control' in tmp_path:
                val_data_y.append(0)
            else:
                val_data_y.append(1)
        else:
            train_data_x.append(imgs)
            if 'control' in tmp_path:
                train_data_y.append(0)
            else:
                train_data_y.append(1)
        i += 1

859

In [20]:
def data_generator(data_x,data_y,mode='train'):
    num = len(data_x)
    ite = 0
    if mode == 'train':
        max_iteration = 99999
        index = np.random.randint(num)
    if mode == 'val':
        max_iteration = len(data_x)
    while ite < max_iteration:
        if mode == 'train':
            index = np.random.randint(num)
        if mode == 'val':
            index = ite
        x = data_x[index]
        y = data_y[index]
        y = torch.tensor(y).float().unsqueeze(0).unsqueeze(0)
        if mode == 'train':
            if x.shape[1] <= 16:
                yield x,y
            else:
                start_frame = np.random.randint(x.shape[1]-16)
                x = x[:,start_frame:start_frame + 16]
                yield x,y
        if mode == 'val':
            yield x,y
        ite += 1
train_generator = data_generator(train_data_x,train_data_y,mode='train')


In [21]:
class MYMODEL(nn.Module):
    def __init__(self):
        super(MYMODEL,self).__init__()
        self.bb = torchvision.models.resnet18()
        self.bb.avgpool = nn.Identity()
        self.bb.fc = nn.Identity()
        self.bb.conv1 = nn.Conv2d(3,64,7,1,3)
        
        self.frame_max_pooling = nn.AdaptiveMaxPool2d(output_size=(1,15))
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(output_size=(1,1))
        
        self.fc = nn.Sequential(
            nn.Linear(512,512),
            nn.LeakyReLU(),
            nn.Linear(512,1)
        )
    
    def forward(self,x):
        x = x.permute(1,0,2,3)
        x = self.bb.conv1(x)
        x = self.bb.bn1(x)        
        x = self.bb.relu(x)
        x = self.bb.maxpool(x)
        
        x = self.bb.layer1(x)        
        x = self.bb.layer2(x)        
        x = self.bb.layer3(x)        
        x = self.bb.layer4(x)
        
        x = x.permute(1,2,0,3)
        x = self.frame_max_pooling(x)
        x = x.permute(2,0,1,3)
        x = self.global_avg_pooling(x)
        
        x = x.squeeze(2).squeeze(2)
        x = self.fc(x)
        return torch.sigmoid(x)

In [22]:
model = MYMODEL().to(device)
# model = torch.load('./cls_model.pth')
opt = torch.optim.Adam(model.parameters(),lr=1e-5,weight_decay=1e-4)
bce = nn.BCELoss()

In [23]:
from sklearn import metrics
def find_best_sentivity_specificity(gtA,A):
    fpr,tpr,thresholds = metrics.roc_curve(gtA,A)
#     plt.plot(fpr,tpr)
#     plt.show()
    smallest_dis = 1
    for i in range(len(fpr)):
        tmp1 = fpr[i]
        tmp2 = tpr[i]
        dis = np.power(np.power(0-tmp1,2)+np.power(1-tmp2,2),0.5)
        if smallest_dis > dis:
            smallest_dis = dis
            sen = tmp2
            spe = 1-tmp1
    return sen,spe,i
def plot_roc(gt,pred):
    auc = np.round(metrics.roc_auc_score(gt,pred),4)
    sen,spe,i = find_best_sentivity_specificity(gt,pred)
    fpr,tpr,thresholds = metrics.roc_curve(gt,pred)
    plt.plot(fpr,tpr,c= 'red',label=f'{0} ROC curve (AUROC:{0})')
    # plt.savefig('./forpaper/20211202/ROC_817_sen709_spe771.png',dpi=300)
    print(sen,spe,auc)

In [24]:
def val():
    val_generator = data_generator(val_data_x,val_data_y,mode='val')  
    LOSS = 0
    ite = 0
#     model.eval()
    ys = []
    preds = []
    for x,y in val_generator:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            pred = model(x)
            loss = bce(pred,y)
            LOSS += loss.item() / len(val_data_x)
            print('\r' + f'{ite}/{len(val_data_x)}',end='',flush=True)
            ys.append(y.item())
            preds.append(pred.item())
            ite += 1
#     print(ys,preds)
#     plot_roc(ys,preds)
    return LOSS
val()

136/137

0.5951385097782108

In [26]:
!gpustat

[1m[37mnode3.ib.com[m  Mon Dec 27 13:55:46 2021
[36m[0][m [34mTesla P100-PCIE-16GB[m |[31m 34'C[m, [32m  0 %[m | [36m[1m[33m11818[m / [33m16280[m MB | [1m[30mchenyixin[m([33m7295M[m) [1m[30mchenyixin[m([33m4513M[m)
[36m[1][m [34mTesla P100-PCIE-16GB[m |[31m 32'C[m, [32m  0 %[m | [36m[1m[33m   10[m / [33m16280[m MB |
[36m[2][m [34mTesla P100-PCIE-16GB[m |[31m 27'C[m, [32m  0 %[m | [36m[1m[33m   10[m / [33m16280[m MB |
[36m[3][m [34mTesla P100-PCIE-16GB[m |[31m 31'C[m, [32m  0 %[m | [36m[1m[33m15691[m / [33m16280[m MB | [1m[30mrenyike[m([33m15681M[m)
[36m[4][m [34mTesla P100-PCIE-16GB[m |[31m 28'C[m, [32m  0 %[m | [36m[1m[33m  773[m / [33m16280[m MB | [1m[30msongzihao[m([33m763M[m)
[36m[5][m [34mTesla P100-PCIE-16GB[m |[1m[31m 55'C[m, [1m[32m 78 %[m | [36m[1m[33m15691[m / [33m16280[m MB | [1m[30mlihui[m([33m15681M[m)
[36m[6][m [34mTesla P100-PCIE-16GB[m |[31m 27'C

In [25]:
ite = 0
best_val = val()
for x,y in train_generator:
    model.train()
    x,y = x.to(device),y.to(device)
    pred = model(x)
    loss = bce(pred,y)
    print('\r' + f'{ite},train_loss:{loss.item()}',end='',flush=True)
    
    opt.zero_grad();
    loss.backward()-*
    opt.step()
    
    ite += 1
    if ite % 200 == 0:
        val_loss = val()
        print(f'{ite},val_loss:{val_loss}')
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model,'./cls_model.pth')
            print('==> saved model')


136/137in_loss:0.33774024248123178200,val_loss:0.4962995949670347


  "type " + obj.__name__ + ". It won't be checked "


==> saved model
136/137in_loss:0.136990755796432544400,val_loss:0.4749891007468648
==> saved model
136/137in_loss:0.366361409425735535600,val_loss:0.44252341825270325
==> saved model
136/137in_loss:0.128110840916633677800,val_loss:0.4055574450805023
==> saved model
136/137in_loss:0.2367957681417465271000,val_loss:0.3841601309169385
==> saved model
136/137ain_loss:0.0300341080874204641200,val_loss:0.36731695803902004
==> saved model
136/137ain_loss:0.2262320965528488281400,val_loss:0.3412637718320981
==> saved model
136/137ain_loss:0.0105934180319309231600,val_loss:0.38462866106016197
136/137ain_loss:0.04324645549058914741800,val_loss:0.3398097587960099
==> saved model
136/137ain_loss:0.21394146978855133122000,val_loss:0.32129069133389754
==> saved model
136/137ain_loss:0.53179073333740236882200,val_loss:0.2818719882256331
==> saved model
136/137ain_loss:0.32662492990493774332400,val_loss:0.33799810388999013
136/137ain_loss:0.07117658853530884582600,val_loss:0.2693580028296782
==> saved

KeyboardInterrupt: 