In [None]:
import numpy as np
from datetime import datetime
import pandas as pd
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
import h5py
from torch.utils.data.dataset import Dataset
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score
import shutil

In [None]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision
from torchvision import datasets, models, transforms
from torch.nn.utils import weight_norm
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import os,sys
import copy
from math import ceil
import math
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
torch.cuda.is_available()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.get_device_name(0))

In [None]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        weight_norm(nn.Conv3d(in_channels, out_channels, 3, padding=1)),
        nn.ReLU(inplace=True),
#         weight_norm(nn.Conv3d(out_channels, out_channels, 3, padding=1)),
#         nn.ReLU(inplace=True)
    )   
    
class VAE(nn.Module):

    def __init__(self, n_class, in_channels,dropout=0.2,filters=[16,32,64,128], latent_channels=512):
        super().__init__()
                
        self.dconv_down1 = double_conv(in_channels, filters[0])
        self.dconv_down2 = double_conv(filters[0], filters[1])
        self.dconv_down3 = double_conv(filters[1], filters[2])
        self.dconv_down4 = double_conv(filters[2], filters[3])
#         self.dconv_down5 = double_conv(filters[3], filters[4])

        self.maxpool = nn.MaxPool3d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)   
        
#         self.dconv_up4 = double_conv(filters[4], filters[3])
#         self.dconv_up3 = double_conv(filters[3], filters[2])
        self.dconv_up2 = double_conv(filters[2], filters[1])
        self.dconv_up1 = double_conv(filters[1], filters[0])
        
        self.bn3d_1 = nn.BatchNorm3d(filters[0])
        self.bn3d_2 = nn.BatchNorm3d(filters[1])
        self.bn3d_3 = nn.BatchNorm3d(filters[2])
        self.bn3d_4 = nn.BatchNorm3d(filters[3])
#         self.bn3d_5 = nn.BatchNorm3d(filters[4])
        
        self.bn3d_1_1 = nn.BatchNorm3d(filters[0])
        self.bn3d_2_1 = nn.BatchNorm3d(filters[1])
#         self.bn3d_3_1 = nn.BatchNorm3d(filters[2])
        
        
        self.conv_last = nn.Conv3d(filters[0], in_channels, 1)
        self.fc1 = nn.Linear(filters[3]*20*20*20,latent_channels)
        self.bn_fc1 = nn.BatchNorm1d(latent_channels)
        self.dropout1 = nn.Dropout(dropout)
        self.out = nn.Linear(latent_channels, n_class)
        self.relu = nn.ReLU()
        
        self.relu_out1 = nn.ReLU()
        
    def forward(self, x):
        x = self.dconv_down1(x) #16 80^3
        x = self.maxpool(x) #16 40^3
        x = self.bn3d_1(x)
        
        x = self.dconv_down2(x) #32 40^3
        x = self.maxpool(x) #32 20^3
        x = self.bn3d_2(x)
        
        x = self.dconv_down3(x) #64 20^3
        x = self.maxpool(x)    #64 10^3
        x = self.bn3d_3(x)
        
#         x = self.dconv_down4(x) #128 10^3
#         x = self.maxpool(x)     #128 5^3
#         x = self.bn3d_4(x)
        
        #Classifier
        x1 = self.dconv_down4(x)  #256 5^3
#         x1 = self.maxpool(x1) #256 2^3
        x1 = self.bn3d_4(x1)
        
        flatten = x1.view(x1.size(0),-1)
        y_SSC = self.fc1(flatten)
        y_SSC = self.bn_fc1(y_SSC)
        y_SSC = self.relu(y_SSC)
        y_SSC = self.dropout1(y_SSC)
        
        
        y = y_SSC
        y = self.out(y)
        prediction = nn.functional.log_softmax(y,dim=1)
    
#         x = self.upsample(x) #256 10^3
#         x = self.dconv_up4(x) #128 10^3
        
#         x = self.upsample(x)     #128 10^3    
#         x = self.dconv_up3(x) #64 10^3
#         x = self.bn3d_3_1(x)
        
        x = self.upsample(x) #64 20^3    
        x = self.dconv_up2(x) #32 20^3
        x = self.bn3d_2_1(x)
        
        x = self.upsample(x)  #32 40^3       
        x = self.dconv_up1(x) #16 40^3
        x = self.bn3d_1_1(x)
        
        x = self.upsample(x)  #16 80^3 
        x = self.conv_last(x) #1 80^3
        
        out = self.relu_out1(x)
        
        return out, prediction

In [None]:
from pytorch_model_summary import summary
print(summary(VAE(2,1),torch.zeros((1,1,160, 160, 160)), show_input=False, show_hierarchical=True))

In [None]:
def init_weights(m):
# print(m)
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        m.weight.data.normal_(0, 0.01)

In [None]:
import seaborn as sns
def make_confusion_matrix(cf,
                          group_names=None,
                          categories='auto',
                          count=True,
                          percent=True,
                          cbar=True,
                          xyticks=True,
                          xyplotlabels=True,
                          sum_stats=True,
                          figsize=None,
                          cmap='Blues',
                          my_dpi=100,
                          title=None,
                          saved=True,
                          save_name='Unsupervised Learning.png'):
    '''
    This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization.
    Arguments
    ---------
    cf:            confusion matrix to be passed in
    group_names:   List of strings that represent the labels row by row to be shown in each square.
    categories:    List of strings containing the categories to be displayed on the x,y axis. Default is 'auto'
    count:         If True, show the raw number in the confusion matrix. Default is True.
    normalize:     If True, show the proportions for each category. Default is True.
    cbar:          If True, show the color bar. The cbar values are based off the values in the confusion matrix.
                   Default is True.
    xyticks:       If True, show x and y ticks. Default is True.
    xyplotlabels:  If True, show 'True Label' and 'Predicted Label' on the figure. Default is True.
    sum_stats:     If True, display summary statistics below the figure. Default is True.
    figsize:       Tuple representing the figure size. Default will be the matplotlib rcParams value.
    cmap:          Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues'
                   See http://matplotlib.org/examples/color/colormaps_reference.html
                   
    title:         Title for the heatmap. Default is None.
    '''


    # CODE TO GENERATE TEXT INSIDE EACH SQUARE
    blanks = ['' for i in range(cf.size)]

    if group_names and len(group_names)==cf.size:
        group_labels = ["{}\n".format(value) for value in group_names]
    else:
        group_labels = blanks

    if count:
        group_counts = ["{0:0.1%}\n".format(value) for value in cf.flatten()]
    else:
        group_counts = blanks

    if percent:
        group_percentages = ["{0:.2%}".format(value) for value in cf.flatten()/np.tile(np.sum(cf,axis = 0),(3,))]
    else:
        group_percentages = blanks

    box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels,group_counts,group_percentages)]
    box_labels = np.asarray(box_labels).reshape(cf.shape[0],cf.shape[1])


    # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS
    if sum_stats:
        #Accuracy is sum of diagonal divided by total observations
        accuracy  = np.trace(cf) / float(np.sum(cf))
#         accuracy  = np.trace(cf) / 3
        #if it is a binary confusion matrix, show some more stats
        if len(cf)==2:
            #Metrics for Binary Confusion Matrices
            precision = cf[1,1] / sum(cf[:,1])
            recall    = cf[1,1] / sum(cf[1,:])
            f1_score  = 2*precision*recall / (precision + recall)
            stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(
                accuracy,precision,recall,f1_score)
        else:
            stats_text = "\n\nBalanced Accuracy={:0.3f}".format(accuracy)
    else:
        stats_text = ""


    # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS
    if figsize==None:
        #Get default figure size if not set
        figsize = plt.rcParams.get('figure.figsize')

    if xyticks==False:
        #Do not show categories if xyticks is False
        categories=False


    # MAKE THE HEATMAP VISUALIZATION
    plt.figure(figsize=figsize, dpi=my_dpi)
    sns.heatmap(cf,annot=box_labels,fmt="",cmap=cmap,cbar=cbar,xticklabels=categories,yticklabels=categories)

    if xyplotlabels:
        plt.ylabel('True label')
        plt.xlabel('Predicted label' + stats_text)
    else:
        plt.xlabel(stats_text)
    
    if title:
        plt.title(title)
    if saved:
        plt.savefig(save_name, dpi=my_dpi*10, bbox_inches='tight')

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
#     print(output.size())
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def cal_consistency_weight(epoch, init_ep=0, end_ep=150, init_w=0.0, end_w=20.0):
    """Sets the weights for the consistency loss"""
    if epoch > end_ep:
        weight_cl = end_w
    elif epoch < init_ep:
        weight_cl = init_w
    else:
        T = float(epoch - init_ep)/float(end_ep - init_ep)
        #weight_mse = T * (end_w - init_w) + init_w #linear
        weight_cl = (math.exp(-5.0 * (1.0 - T) * (1.0 - T))) * (end_w - init_w) + init_w #exp
    #print('Consistency weight: %f'%weight_cl)
    return weight_cl

def update_ema_variables(model, model_teacher, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1.0 - 1.0 / float(global_step + 1), alpha)
    for param_t, param in zip(model_teacher.parameters(), model.parameters()):
        param_t.data.mul_(alpha).add_(1 - alpha, param.data)


In [None]:
from collections import defaultdict
import torch.nn.functional as F

In [None]:
def calc_loss(out, prediction, target, original, MSE_weight=0.5):
    CE = F.cross_entropy(prediction, target)
#     BCE = F.binary_cross_entropy_with_logits(prediction, target)
#     out = torch.sigmoid(out)
    MSE = F.mse_loss(out*65535, original*65535)
    
    loss = MSE * MSE_weight + CE * (1 - MSE_weight)
    
    return loss

def print_metrics(metrics, epoch_samples, phase):    
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
        
    print("{}: {}".format(phase, ", ".join(outputs)))   

In [None]:
def train_sup(label_loader, model, device, criterions, optimizers, epoch, args, k=2):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    MSE_weight = args.MSE_weight
    
    # switch to train mode
    model.train()
    
    criterion_ce, criterion_mse = criterions
    optimizer1, optimizer2 = optimizers
    
    end = time.time()
    
    LabelList = torch.tensor([1]).to(device)
    PredList = torch.tensor([1]).to(device)
    epoch_samples = 0
    running_corrects = 0

    label_iter = iter(label_loader) 
    for i in range(len(label_iter)):
        inputs, target= next(label_iter)
        
        # measure data loading time
        data_time.update(time.time() - end)
        sl = inputs.shape
        batch_size = sl[0]
        inputs = inputs.to(device)
        target = target.to(device)
        epoch_samples += batch_size
        
        # compute output
        output, predictions = model(inputs)
        loss_ce = criterion_ce(predictions, target)
        loss_mse = criterion_mse(output*255, inputs*255)
        loss = loss_mse * MSE_weight + loss_ce * (1 - MSE_weight)
#         if epoch<20:
#             loss = criterion_mse(output*255, inputs*255)
#         else:
#             loss_ce = criterion_ce(predictions, target)
#             loss_mse = criterion_mse(output*255, inputs*255)
#             loss = loss_mse * MSE_weight + loss_ce * (1 - MSE_weight)
        
        _, preds = torch.max(predictions, 1)
        LabelList = torch.cat([LabelList, target.view(-1)], dim=0)
        PredList = torch.cat([PredList, preds.view(-1)], dim=0)
        running_corrects += torch.sum(preds == target.data)
        
        # measure accuracy and record loss
        prec1, prec5 = accuracy(predictions.data, target, topk=(1, k))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        if epoch < 5:
            optimizer1.zero_grad()
        else:
            optimizer2.zero_grad()
        
        loss.backward()
        
        if epoch < 5:
            optimizer1.step()
        else:
            optimizer2.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec @ 1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec @ {k} {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(label_iter), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, k=k, top5=top5))
    epoch_acc_balanced = balanced_accuracy_score(LabelList[1:].cpu(), PredList[1:].cpu())
    epoch_acc = running_corrects.double() / epoch_samples
    print('\n\nEpoch: {0}\t Balanced Accuracy: {1}\t Running Accuracy: {2}\n\n'.format(epoch, epoch_acc_balanced, epoch_acc))
    return top1.avg , losses.avg, epoch_acc_balanced, epoch_acc


def validate(val_loader, model, device, criterions, args, mode='valid', k=2, weight_pi=20):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    MSE_weight = args.MSE_weight

    # switch to evaluate mode
    model.eval()
    
    criterion_ce, criterion_mse = criterions
    
    LabelList = torch.tensor([1]).to(device)
    PredList = torch.tensor([1]).to(device)
    epoch_samples = 0
    running_corrects = 0
    
    end = time.time()
    with torch.no_grad():
        for i, (inputs, target) in enumerate(val_loader):
            sl = inputs.shape
            batch_size = sl[0]
            target = target.to(device)
            inputs = inputs.to(device)
            epoch_samples += batch_size
            
            # compute output
            output, predictions = model(inputs)
        
            loss_ce = criterion_ce(predictions, target)
            loss_mse = criterion_mse(output*255, inputs*255)
            loss = loss_mse * MSE_weight + loss_ce * (1 - MSE_weight) # * weight_pi
            
            _, preds = torch.max(predictions, 1)
            LabelList = torch.cat([LabelList, target.view(-1)], dim=0)
            PredList = torch.cat([PredList, preds.view(-1)], dim=0)
            running_corrects += torch.sum(preds == target.data)
            # measure accuracy and record loss
            prec1, prec5 = accuracy(predictions.data, target, topk=(1, k))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))
 
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
 
            if i % args.print_freq == 0:
                if mode == 'test':
                    print('Test: [{0}/{1}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec @ 1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec @ {k} {top5.val:.3f} ({top5.avg:.3f})'.format(
                           i, len(val_loader), batch_time=batch_time, loss=losses,
                           top1=top1, top5=top5, k=k))
                else:
                    print('Valid: [{0}/{1}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec @ 1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec @ {k} {top5.val:.3f} ({top5.avg:.3f})'.format(
                           i, len(val_loader), batch_time=batch_time, loss=losses,
                           top1=top1, top5=top5, k=k))
    epoch_acc_balanced = balanced_accuracy_score(LabelList[1:].cpu(), PredList[1:].cpu())
    epoch_acc = running_corrects.double() / epoch_samples
    report = classification_report(LabelList[1:].cpu(), PredList[1:].cpu(), labels=[0,1])
    confusionmatrix = confusion_matrix(LabelList[1:].cpu(), PredList[1:].cpu(), labels=[0,1])
    print(' ****** Prec @ 1 {top1.avg:.3f} Prec @ {k} {top5.avg:.3f} Loss {loss.avg:.3f} '
          .format(top1=top1, top5=top5, loss=losses, k=k))
    print('\n\nBalanced Accuracy: {0}\t Running Accuracy: {1}\n\n'.format(epoch_acc_balanced, epoch_acc))

    return top1.avg, losses.avg, epoch_acc_balanced, epoch_acc, report, confusionmatrix

In [None]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', dirname='.'):
    fpath = os.path.join(dirname, filename + '_latest.pth.tar')
    torch.save(state, fpath)
    if is_best:
        bpath = os.path.join(dirname, filename + '_best.pth.tar')
        shutil.copyfile(fpath, bpath)

In [None]:
import copy
def train_model(model, dataloaders, optimizers, schedulers, criterions, args, device):
    best_acc_balanced = 0
    best_test_acc_balanced = 0
    prec1s_tr, acc1_tr_balanced, acc1_tr, losses_tr = [], [], [], []
    losses_cl_tr = []
    prec1s_val, acc1_val_balanced, acc1_val, losses_val, losses_et_val = [], [], [], [], []
    prec1s_t_tr = []
    prec1s_t_val, acc1_t_val_balanced, acc1_t_val = [], [], []
    learning_rate, weights_cl = [], []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']

            model.load_state_dict(checkpoint['state_dict'])
            if args.model=='mt': model_teacher.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
 
        
    ckpt_dir = args.ckpt+'_'+args.arch+'_'+args.model+'_'+args.optim+'_ul%.3f'%(args.unlabel_percent)
    ckpt_dir = ckpt_dir + '_e%d'%(args.epochs)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    print(ckpt_dir)

    batch_size_label=args.batch_size

    label_loader, val_loader = dataloaders
    
    print("Batch size (label): ", batch_size_label)
    
    optimizer1, optimizer2 = optimizers
    scheduler1, scheduler2 = schedulers

    for epoch in range(args.start_epoch, args.epochs):
        epoch_starttime = time.time()
        if epoch < 5:
            for param_group in optimizer1.param_groups:
                print("LR:", param_group['lr'])
                lr = param_group['lr']
        else:
            for param_group in optimizer2.param_groups:
                print("LR:", param_group['lr'])
                lr = param_group['lr']
        
        # train for one epoch
        if args.model == 'baseline':
            print('Supervised Training')
            prec1_tr, loss_tr, train_acc_balanced, train_acc = train_sup(label_loader, model, device, criterions, optimizers, epoch, args)
        
        # evaluate on validation set        
        prec1_val, loss_val, val_acc_balanced, val_acc, val_report, val_confusionmatrix = validate(val_loader, model, device, criterions, args, mode='valid')
        
        # learning scheduler
        if epoch < 5:
            scheduler1.step(loss_val)
        elif epoch > args.epochs * (1/3):
            scheduler2.step(loss_val)
        
        # append values
        acc1_tr_balanced.append(train_acc_balanced)
        acc1_tr.append(train_acc)
        acc1_val_balanced.append(val_acc_balanced)
        acc1_val.append(val_acc)
        prec1s_tr.append(prec1_tr)
        losses_tr.append(loss_tr)
        prec1s_val.append(prec1_val)
        losses_val.append(loss_val)

        learning_rate.append(lr)

        # remember best prec@1 and save checkpoint
        is_best = val_acc_balanced > best_acc_balanced
        if is_best:
            best_test_acc_balanced = val_acc_balanced
            best_model_wts = copy.deepcopy(model.state_dict())
            val_report_best = val_report
            val_confusionmatrix_best = val_confusionmatrix
        print("Best test balanced accuracy: %.3f"%best_test_acc_balanced)
        best_acc_balanced = max(val_acc_balanced, best_acc_balanced)
        dict_checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_acc_balanced': best_acc_balanced,
            'best_test_acc_balanced' : best_test_acc_balanced,
            'prec1s_tr': prec1s_tr,
            'acc1_tr': acc1_tr,
            'acc1_tr_balanced': acc1_tr_balanced,
            'losses_tr': losses_tr,
            'losses_cl_tr': losses_cl_tr,
            'prec1s_val': prec1s_val,
            'acc1_val': acc1_val,
            'acc1_val_balanced':acc1_val_balanced,
            'losses_val': losses_val,
            'learning_rate' : learning_rate,
            'val_report': val_report, 
            'val_confusionmatrix': val_confusionmatrix,
        }
        
        save_checkpoint(dict_checkpoint, is_best, str(args.boundary), dirname=ckpt_dir)
        print('\nEpoch time:', time.time()-epoch_starttime, 's\n')
    model.load_state_dict(best_model_wts)
    return model, acc1_tr_balanced, acc1_val_balanced, acc1_tr, acc1_val, losses_tr, losses_val, learning_rate, val_report_best, val_confusionmatrix_best

In [None]:
import os
def datapreparation(Datadir, fold):
    allfold = ['Fold1', 'Fold2', 'Fold3', 'Fold4', 'Fold5']
    valfold = allfold.pop(fold)
    trainpath = []
    valpath = []
    
    fpath = []
    for trainfold in allfold:
        traindir = Datadir + "/High/"+trainfold+"/"
        for dirpath, dirnames, filenames in os.walk(traindir):
            for filename in [f for f in filenames if f.endswith(".mat")]:
                tempfpath =os.path.join(dirpath, filename)
                fpath.append(tempfpath)
    datalength = list(range(0,len(fpath)))
    train = list(range(0,len(fpath)))
    print('High Train data length: %d' %(len(train)))
    for i in train:
        trainpath.append(fpath[i])
        
    fpath = []
    for trainfold in allfold:
        traindir = Datadir + "/Low1/"+trainfold+"/"
        for dirpath, dirnames, filenames in os.walk(traindir):
            for filename in [f for f in filenames if f.endswith(".mat")]:
                tempfpath =os.path.join(dirpath, filename)
                fpath.append(tempfpath)
    datalength = list(range(0,len(fpath)))
    train = list(range(0,len(fpath)))
    print('Low Train data length: %d' %(len(train)))
    for i in train:
        trainpath.append(fpath[i])
        
    fpath = []
    traindir = Datadir + "/High/"+valfold+"/"
    for dirpath, dirnames, filenames in os.walk(traindir):
        for filename in [f for f in filenames if f.endswith(".mat")]:
            tempfpath =os.path.join(dirpath, filename)
            fpath.append(tempfpath)
    datalength = list(range(0,len(fpath)))
    val = list(range(0,len(fpath)))
    print('Val Train data length: %d' %(len(val)))
    for i in val:
        valpath.append(fpath[i])
        
    fpath = []
    traindir = Datadir + "/Low1/"+valfold+"/"
    for dirpath, dirnames, filenames in os.walk(traindir):
        for filename in [f for f in filenames if f.endswith(".mat")]:
            tempfpath =os.path.join(dirpath, filename)
            fpath.append(tempfpath)
    datalength = list(range(0,len(fpath)))
    val = list(range(0,len(fpath)))
    print('Val Train data length: %d' %(len(val)))
    for i in val:
        valpath.append(fpath[i])
        
    return trainpath,valpath

In [None]:
import scipy.io as sio

class MyDataset(Dataset):
    def __init__(self, paths, transforms=None):
        self.paths = paths
        self.transforms = transforms
        
    def __getitem__(self, index):
        data = sio.loadmat(self.paths[index])['data']
        data[data<3000]=0
        x = torch.from_numpy(data.astype(np.float32)/65535)
        x = x.unsqueeze(dim = 0)
        x = x.unsqueeze(dim = 0)
        if 'Low' in self.paths[index]:
            label = int(0)
        else:
            label = int(1)
        if self.transforms:
            x = self.transforms(x)
#         print(x.shape)
        x = x.squeeze(dim = 0)
#         data = sio.loadmat(self.paths[index])['data']
#         data = data.astype(np.float32)/65535
#         data = np.transpose(data, (2,0,1))
#         x = torch.from_numpy(data)
#         x = torch.unsqueeze(x,0)
#         x = torch.unsqueeze(x,0)
#         print(x.shape)
#         if 'Low' in self.paths[index]:
#             label = int(0)
#         else:
#             label = int(1)
            
#         if self.transforms:
#             x = F.interpolate(x, (400,400,400), mode='bicubic', align_corners=True)
#         x = torch.squeeze(x,0)
        return x,label
    
    def __len__(self):
        return len(self.paths)


In [None]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import copy
import argparse
from sklearn.model_selection import StratifiedKFold
from datetime import date

parser = argparse.ArgumentParser(description='PyTorch Semi-supervised learning Training')
parser.add_argument('--model', metavar='MODEL', default='baseline', help='model: (default: baseline)', choices=['baseline', 'pi', 'mt'])
parser.add_argument('--optim', '-o', metavar='OPTIM', default='adam', help='optimizer: '+' (default: adam)', choices=['adam', 'sgd'])
parser.add_argument('--arch', metavar='ARCH', default='VAE', help='architecture: (default: VAE)', choices=['VAE', 'UNet'])
parser.add_argument('--epochs', default=150, type=int, help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
parser.add_argument('--batch_size', default=2, type=int, help='mini-batch size (default: 32)')
parser.add_argument('--lr', default=5e-5, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight-decay', default=1e-4, type=float, help='weight decay (default: 1e-4)')
parser.add_argument('--weight_l1', default=1e-3, type=float, help='l1 regularization (default: 1e-3)')
parser.add_argument('--print_freq', default=800, type=int, help='print frequency (default: 100)')
parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)')
parser.add_argument('--num_classes',default=3, type=int, help='number of classes in the model')
parser.add_argument('--ckpt', default='ckpt', type=str, help='path to save checkpoint (default: ckpt)')
parser.add_argument('--boundary',default=0, type=int, help='different label/unlabel division [0,9]')
parser.add_argument('--gpu',default=0, type=str, help='cuda_visible_devices')
parser.add_argument('--weight-pi', default=5, type=float, help='weight pi (default: 1)')
parser.add_argument('--weight-mt', default=8, type=float, help='weight mt (default: 8)')
parser.add_argument('--MSE_weight', default=0.5, type=float, help='MSe weight (default: 0.1)')


parser.add_argument('--unlabel_percent', default=0, type=float, help='unlabel percent (default: 0.95)')
args, unknown = parser.parse_known_args()
print(args)

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)
torch.set_num_threads(8)

# Data preparation - Straight KFold
Datadir = "../Data_Augumented"
args.ckpt = 'ckpt_'+ str(date.today())
# runs = 0
for runs in range(0,5):
    print("Fold-",str(runs),": Initializing Datasets and Dataloaders...")
    
    Fold = runs
    args.boundary = runs
    
    trainpath,valpath = datapreparation(Datadir,Fold)
    print('label train data vol.: ',len(trainpath))
    print('val data vol.: ',len(valpath))
    
    transforms_data = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) 
    label_dataset = MyDataset(trainpath, transforms=transforms_data)
    val_dataset = MyDataset(valpath, transforms=transforms_data)
    label_loader = torch.utils.data.DataLoader(label_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True, pin_memory=True)
    dataloaders = (label_loader, val_loader)
    
    # Detect if we have a GPU available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    
    num_class = 2
    in_channel = 1
#     num_epochs = 100
    
    if args.arch == 'VAE':
        model = VAE(num_class, in_channel)
    elif args.arch == 'UNet':
        model = UNet(num_class, in_channel)
    else:
        raise('Architecture not implented')
    model.apply(init_weights)
    model = model.to(device)

    optimizer_ft1 = optim.Adam(model.parameters(), lr=args.lr,betas = (0.9, 0.999),eps=1e-08,weight_decay=0)
    optimizer_ft2 = optim.Adam(model.parameters(), lr=1e-2,betas = (0.9, 0.999),eps=1e-08,weight_decay=0)
    exp_lr_scheduler1 = lr_scheduler.ReduceLROnPlateau(optimizer_ft1, mode='min', factor=0.5, patience=5, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=1, min_lr=1e-7, eps=1e-08)
    exp_lr_scheduler2 = lr_scheduler.ReduceLROnPlateau(optimizer_ft2, mode='min', factor=0.5, patience=5, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=1, min_lr=1e-7, eps=1e-08)       
    optimizers = (optimizer_ft1, optimizer_ft2)
    schedulers = (exp_lr_scheduler1, exp_lr_scheduler2)
    criterion_ce = nn.CrossEntropyLoss(reduction='mean')
    criterion_mse = nn.MSELoss(reduction='mean')
    
    criterions = (criterion_ce, criterion_mse)
    
    model,acc1_tr_balanced,acc1_val_balanced,acc1_tr,acc1_val,losses_tr,losses_val,learning_rate,val_report_best,val_confusionmatrix_best = train_model(model, dataloaders, optimizers, schedulers, criterions, args, device)
    
    ckpt_dir = args.ckpt+'_'+args.arch+'_'+args.model+'_'+args.optim+'_ul%.3f'%(args.unlabel_percent)
    ckpt_dir = ckpt_dir + '_e%d'%(args.epochs)
    
    if args.model == 'baseline':
        savebasename = 'Baseline_' + str(date.today())
    fname1 = savebasename+'_'+str(runs)+'Loss'+'.png'
    fname2 = savebasename+'_'+str(runs)+'BA'+'.png'
    fname3 = savebasename+'_'+str(runs)+'LR'+'.png'
    fname4 = savebasename+'_'+str(runs)+'Val_CM'+'.png'
    
    my_dpi = 200
    
    ## Loss figure
    plt.figure(figsize=(5, 5), dpi=my_dpi)
    plt.semilogy(range(0,args.epochs), losses_tr,label='Training Loss')
    plt.semilogy(range(0,args.epochs), losses_val,label='Validation Loss')
    plt.legend()
    plt.xlabel('Epoch')
    plt.savefig(os.path.join(ckpt_dir, fname1), dpi=my_dpi * 10)
    plt.show()
    
    ## Balanced Accuracy figure
    plt.figure(figsize=(5, 5), dpi=my_dpi)
    plt.plot(range(0,args.epochs), acc1_tr_balanced,label='Training Accuracy')
    plt.plot(range(0,args.epochs), acc1_val_balanced,label='Validation Accuracy')
    plt.legend()
    plt.xlabel('Epoch')
    plt.savefig(os.path.join(ckpt_dir, fname2), dpi=my_dpi * 10)
    plt.show()
    
    ## Learning Rate figure
    plt.figure(figsize=(5, 5), dpi=my_dpi)
    plt.plot(range(0,args.epochs), learning_rate)
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.savefig(os.path.join(ckpt_dir, fname3), dpi=my_dpi * 10)
    plt.show()
    
    ## Validation confusion matrix
    confusionMat = np.asarray(val_confusionmatrix_best)
    sumconfusion = np.sum(confusionMat,axis = 1).T
    summat = np.tile(sumconfusion,(num_class,1)).T
    percentconfusion_val = np.divide(confusionMat,summat)
    
    print('Validation Report: \n', val_report_best)
    categories = ['Low','High']
    make_confusion_matrix(percentconfusion_val, 
                          #group_names=labels,
                          categories=categories,
                          percent=False,
                          cbar=False,
                          figsize=(4 ,4),
                          cmap='Greens',my_dpi=200,title = 'Supervised Learning',
                          saved=True, save_name=os.path.join(ckpt_dir, fname4))

    del label_dataset