In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torchvision.models as models

import os
import numpy as np
from sklearn import metrics
from tqdm import trange, tqdm

import matplotlib.pyplot as plt
import torch.nn.functional as F
import utilities as UT
from ranksvm import get_dynamic_image

def prep_data(LABEL_PATH ,TEST_NUM):
    # This function is used to prepare train/test labels for 5-fold cross-validation
    TEST_LABEL = LABEL_PATH + '/fold_All_' + str(TEST_NUM) +'.csv'

    # combine train labels
    filenames = [LABEL_PATH + '/fold_All_0.csv', 
                LABEL_PATH + '/fold_All_1.csv', 
                LABEL_PATH + '/fold_All_2.csv', 
                LABEL_PATH + '/fold_All_3.csv', 
                LABEL_PATH + '/fold_All_4.csv', ]

    filenames.remove(TEST_LABEL)

    with open(LABEL_PATH + '/combined_train_list_All.csv', 'w') as combined_train_list:
        for fold in filenames:
            for line in open(fold, 'r'):                
                combined_train_list.write(line)
    TRAIN_LABEL = LABEL_PATH + '/combined_train_list_All.csv'
    
    return TRAIN_LABEL, TEST_LABEL
    
class Dataset_Early_Fusion(Dataset):
    def __init__(self, 
                 label_file='/data/scratch/xxing/adni_dl/Preprocessed/ADNI2_MRItrain_list.csv'):         
        self.files = UT.read_csv(label_file)
    def __len__(self):
        return len(self.files)
    def __getitem__(self,idx):
        temp = self.files[idx]        
        full_path = temp[0]        
        
        label = full_path.split('/')[-2]
        if(label=='CN'):
            label=0
        elif(label=='AD'):
            label=1
        elif(label=='EMCI'):
            label=2
        elif(label=='LMCI'):
            label=3
        else:
            print('Label Error')
        
        im = np.load(full_path) #input image [A,S,Co,C]
        im = np.transpose(im, (3,2,0,1)) #[C,Co,A,S]
        #print(im.shape)
        im = im/im.max()
        return im, int(label), full_path # output image shape [C,W,H,T]


class double_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )

    def forward(self, x):
        x = self.conv(x)
        #print(x.shape)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


## Q-Net
class Qnet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, apply_softmax=True):
        super(Qnet, self).__init__()

        self.inc = inconv(in_channels, out_channels)
              
        self.apply_softmax = apply_softmax
        
    def forward(self, x):
        x = self.inc(x)
        if self.apply_softmax:
            x = F.softmax(x, dim = 1) # softmax across channel dimension
        return x

    
class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        
        
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8, kernel_size= 1) #original out=in//8
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8, kernel_size= 1)  #original out=in//8
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) 
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X C X (N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        
        out = self.gamma*out + x
        return out#,attention

class eca_layer(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel):
        super(eca_layer, self).__init__()
        k_size=5
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: input features with shape [b, c, h, w]
        b, c, h, w = x.size()

        # feature descriptor on the global spatial information
        y = self.avg_pool(x)

        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)

class CNN(nn.Module):
    def __init__(self, 
                 num_classes=4, 
                 feature='Vgg11', 
                 pretrained=True, 
                 requires_grad=True):         
        
        super(CNN, self).__init__()

         # Feature Extraction
        if(feature=='Alex'):
            self.ft_ext = models.alexnet(pretrained=pretrained) 
            self.ft_ext_modules = list(self.ft_ext.children())[:-2]            
            
        elif(feature=='Res34'):
            self.ft_ext = models.resnet34(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0:-2] # remove the Maxpooling layer
            
        elif(feature=='Res101'):
            self.ft_ext = models.resnet101(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0:-2] # remove the Maxpooling layer
            
        elif(feature=='Vgg16'):
            self.ft_ext = models.vgg16(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0] # remove the Maxpooling layer
            
        elif(feature=='Vgg11'):
            self.ft_ext = models.vgg11(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0] # remove the Maxpooling layer
            
        elif(feature=='Mobile'):
            self.ft_ext = models.mobilenet_v2(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0] # remove the Maxpooling layer
            
        self.ft_ext=nn.Sequential(*self.ft_ext_modules)                
        for p in self.ft_ext.parameters():
            p.requires_grad = requires_grad
            
        # Classifier
        if(feature=='Alex'):
            feature_shape=(256,5,5)
        elif(feature=='Res34'):
            feature_shape=(512,3,5)
        elif(feature=='Res101'):
            feature_shape=(512,5,5)
        elif(feature=='Vgg16'):
            feature_shape=(512,6,6)
        elif(feature=='Vgg11'):
            feature_shape=(512,6,6)
        elif(feature=='Mobile'):
            feature_shape=(1280,4,4)
            
        conv1_output_features = int(feature_shape[0])
        
        fc1_input_features = int(conv1_output_features*feature_shape[1]*feature_shape[2])
        fc1_output_features = int(conv1_output_features*2)
        fc2_output_features = int(fc1_output_features/4)
        
        self.sattn=Self_Attn(conv1_output_features)
        self.eca = eca_layer(conv1_output_features)
                
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=feature_shape[0],      
                out_channels=conv1_output_features,    
                kernel_size=1,       
            ),
            nn.BatchNorm2d(conv1_output_features),
            nn.ReLU()
        )                    
        self.fc1 = nn.Sequential(
             nn.Linear(fc1_input_features, fc1_output_features),
             nn.BatchNorm1d(fc1_output_features),            
             nn.ReLU()
         )

        self.fc2 = nn.Sequential(
             nn.Linear(fc1_output_features, fc2_output_features),
             nn.BatchNorm1d(fc2_output_features),
             nn.ReLU()
         )
        
        self.out = nn.Linear(fc2_output_features, num_classes)
        
    def forward(self, x, drop_prob=0.5):
        x = self.ft_ext(x)
        xa = self.sattn(x)
        xc = self.eca(x)
        x = xa + xc
        #x = xc
        x = x.view(x.size(0), -1) 
        x = self.fc1(x)
        x = nn.Dropout(drop_prob)(x)
        x = self.fc2(x)
        x = nn.Dropout(drop_prob)(x)        
        prob = self.out(x) 
        
        return prob

In [None]:
def trainexp(train_dataloader, val_dataloader, feature='Res34', batch_size=16):
    
    net = CNN(feature=feature).to(device)
    q_net = Qnet(in_channels=1, out_channels=1, apply_softmax=False).to(device)
    
    param = list(net.parameters()) + list(q_net.parameters())
    opt = torch.optim.Adam(param, lr=0.0001, weight_decay=0.001)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma= 0.985)

    loss_fcn = torch.nn.CrossEntropyLoss(weight=LOSS_WEIGHTS.to(device))
        
    t = trange(EPOCHS, desc=' ', leave=True)

    train_hist = []
    val_hist = []
    
    pred_result = []
    old_acc = 0
    old_auc = 0
    test_acc = 0
    best_epoch = 0
    for e in t:    
        y_true = []
        y_pred = []
        
        val_y_true = []
        val_y_pred = []                
        
        train_loss = 0
        val_loss = 0
        
        val_score =[]
        val_score_f =[]
        val_img = []
        val_label = []
        val_fuse = []
        
        opt.zero_grad()
        # training
        net.train()
        q_net.train()
        
        for step, (img, label, _) in enumerate(train_dataloader):
            img = img.float().to(device)
            
            label = label.long().to(device)
            opt.zero_grad()
            
            # initialize the quality score tensor
            quality_score = torch.zeros(img.shape[0], 160, 1,1).to(device)
            # Compute logits of quality scores
            for i in range(img.shape[2]):  # for every slice
                slice_i = img[:, :, i, :, :]
                quality_score[:, i, :, :] = q_net(slice_i).squeeze(1)
  
            # Normalize quality scores: apply Softmax across slices dimension. We are using slices in dim=1.
            quality_final = torch.softmax(quality_score, dim=1)
            #print(quality_final)
            # initialize the fused image
            fused_image = torch.zeros(img.shape[0], 1, 96, 160).to(device)
            #print('initializing fused image of shape: ', fused_image.shape)

            # Compute the fused image
            for i in range(img.shape[2]):  # for every slice
                slice_i = img[:, :, i, :, :]
                fused_image += quality_final[:, i, :, :].unsqueeze(1) * slice_i
                
            fused_image = torch.cat((fused_image,fused_image,fused_image), dim=1)
            #print('initializing fused image of shape: ', fused_image.shape)
            
            out = net(fused_image)
            loss = loss_fcn(out, label)

            loss.backward()
            opt.step()
            
            label = label.cpu().detach()
            out = out.cpu().detach()
            y_true, y_pred = UT.assemble_labels(step, y_true, y_pred, label, out)        

            train_loss += loss.item()

        train_loss = train_loss/(step+1)
        
        acc = float(torch.sum(torch.max(y_pred, 1)[1]==y_true))/ float(len(y_pred))
        #auc = metrics.roc_auc_score(y_true, y_pred[:,1])
        f1 = metrics.f1_score(y_true, torch.max(y_pred, 1)[1], average='micro')
        precision = metrics.precision_score(y_true, torch.max(y_pred, 1)[1], average='micro')
        recall = metrics.recall_score(y_true, torch.max(y_pred, 1)[1], average='micro')
        #ap = metrics.average_precision_score(y_true, torch.max(y_pred, 1)[1], average='micro') #average_precision

        scheduler.step()

        # val
        net.eval()
        q_net.eval()
        full_path = []
        with torch.no_grad():
            for step, (img, label, _) in enumerate(val_dataloader):
                
                
                val_img.append(img)
                val_label.append(label)
                
                img = img.float().to(device)
                label = label.long().to(device)
                
                # initialize the quality score tensor
                quality_score = torch.zeros(img.shape[0], 160, 1, 1).to(device)
                # Compute logits of quality scores
                for i in range(img.shape[2]):  # for every slice
                    slice_i = img[:, :, i, :, :]
                    quality_score[:, i, :, :] = q_net(slice_i).squeeze(1)
  
                # Normalize quality scores: apply Softmax across slices dimension. We are using slices in dim=1.
                quality_final = torch.softmax(quality_score, dim=1)
        
                val_score.append(quality_score.cpu())
                val_score_f.append(quality_final.cpu())
                # initialize the fused image
                fused_image = torch.zeros(img.shape[0], 1, 96, 160).to(device)
                #print('initializing fused image of shape: ', fused_image.shape)
                
                # Compute the fused image
                for i in range(img.shape[2]):  # for every slice
                    slice_i = img[:, :, i, :, :]
                    fused_image += quality_final[:, i, :, :].unsqueeze(1) * slice_i
                
                fused_image = torch.cat((fused_image,fused_image,fused_image), dim=1)
                val_fuse.append(fused_image.cpu())
                
                out = net(fused_image)
                loss = loss_fcn(out, label)
                val_loss += loss.item()

                label = label.cpu().detach()
                out = out.cpu().detach()
                val_y_true, val_y_pred = UT.assemble_labels(step, val_y_true, val_y_pred, label, out)
                
                for item in _:
                    full_path.append(item)
                
        val_loss = val_loss/(step+1)
        #print(val_y_pred)
        
        val_acc = float(torch.sum(torch.max(val_y_pred, 1)[1]==val_y_true))/ float(len(val_y_pred))
        #val_auc = metrics.roc_auc_score(val_y_true, val_y_pred[:,1])
        val_f1 = metrics.f1_score(val_y_true, torch.max(val_y_pred, 1)[1], average='micro')
        val_precision = metrics.precision_score(val_y_true, torch.max(val_y_pred, 1)[1], average='micro')
        val_recall = metrics.recall_score(val_y_true, torch.max(val_y_pred, 1)[1], average='micro')
        #val_ap = metrics.average_precision_score(val_y_true, torch.max(val_y_pred, 1)[1]) #average_precision


        train_hist.append([train_loss, acc, f1, precision, recall])
        val_hist.append([val_loss, val_acc,  val_f1, val_precision, val_recall])             

        t.set_description("Epoch: %i, train loss: %.4f, train acc: %.4f, val loss: %.4f, val acc: %.4f, test acc: %.4f" 
                          %(e, train_loss, acc, val_loss, val_acc, test_acc))


        if(old_acc<val_acc):
            old_acc = val_acc
            best_epoch = e
            test_loss = 0
            test_y_true = val_y_true
            test_y_pred = val_y_pred            

            test_loss = val_loss
            test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true))/ float(len(test_y_pred))
            #test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
            test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1], average='micro')
            test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1], average='micro')
            test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1], average='micro')
            #test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1], average='micro') #average_precision
            
            test_performance = [best_epoch, test_loss, test_acc, test_f1, test_precision, test_recall]
            val_score_b= val_score
            val_fuse_b = val_fuse
            val_score_fb= val_score_f

    return train_hist, val_hist, test_performance, test_y_true, test_y_pred, full_path, val_img, val_label, val_score_b, val_fuse_b,val_score_fb
    
LABEL_PATH = '/u/amo-d0/grad/xxi242/Preprocessed/ADNI_AV45_AF'

GPU = 1
BATCH_SIZE = 16
EPOCHS = 150

LR = 0.0001
#LOSS_WEIGHTS = torch.tensor([1., 1.28, 1, 1.29 ]) 

device = torch.device('cuda:'+str(GPU) if torch.cuda.is_available() else 'cpu')

train_hist = []
val_hist = []
test_performance = []
test_y_true = np.asarray([])
test_y_pred = np.asarray([])
full_path = np.asarray([])
for i in range(0, 5):
    print('Train Fold', i)
    
    TEST_NUM = i
    TRAIN_LABEL, TEST_LABEL = prep_data(LABEL_PATH, TEST_NUM)
    
    train_dataset = Dataset_Early_Fusion(label_file=TRAIN_LABEL)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    val_dataset = Dataset_Early_Fusion(label_file=TEST_LABEL)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
        
    cur_result = trainexp(train_dataloader, val_dataloader, batch_size=16)
    
    train_hist.append(cur_result[0])
    val_hist.append(cur_result[1]) 
    test_performance.append(cur_result[2]) 
    test_y_true = np.concatenate((test_y_true, cur_result[3].numpy()))
    if(len(test_y_pred) == 0):
        test_y_pred = cur_result[4].numpy()
    else:
        test_y_pred = np.vstack((test_y_pred, cur_result[4].numpy()))
    full_path = np.concatenate((full_path, np.asarray(cur_result[5])))

print(test_performance)

test_y_true = torch.tensor(test_y_true)
test_y_pred = torch.tensor(test_y_pred)
test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true.long()))/ float(len(test_y_pred))
#test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
#test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1])

print('ACC %.4f, F1 %.4f, Prec %.4f, Recall %.4f' 
      %(test_acc, test_f1, test_precision, test_recall))
