Code for Urinary Bladder Cancer Screening with Electronic Noses Based on Few-Shot Contrastive Representation Learning and Open-Set Recognition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from typing import Optional, Tuple
from enum import IntEnum
import numpy as np
import pandas as pd
import torch.utils.data as Data
import torch.optim as optim
from torch.autograd import Variable
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, \
    confusion_matrix
import matplotlib.pyplot as plt
from keras.utils import to_categorical
from sklearn.metrics import roc_curve
import libmr
from sklearn.preprocessing import StandardScaler, label_binarize
import os
from torchsummary import summary
import torch.optim as optim
import torch.utils.data as Data
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device is %s' % device)

In [None]:
# Supervised contrastive learning (SuCL)

class SuCL(nn.Module):
    def __init__(self, temperature):
        super(SuCL, self).__init__()
        self.register_buffer("temperature", torch.tensor(temperature))


    def forward(self, emb_i, label):
        representations = F.normalize(emb_i, dim=1)
        n = representations.shape[0]
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
        mask = torch.ones_like(similarity_matrix) * (label.expand(n, n).eq(label.expand(n, n).t()))
        mask_no_sim = torch.ones_like(mask) - mask
        mask_trace_0 = torch.ones(n, n) - torch.eye(n, n)
        similarity_matrix = torch.exp(similarity_matrix / self.temperature)
        similarity_matrix = similarity_matrix * mask_trace_0.to(device)
        sim = mask * similarity_matrix
        no_sim = similarity_matrix - sim
        no_sim_sum = torch.sum(no_sim, dim=1)
        no_sim_sum_expend = no_sim_sum.repeat(n, 1).T
        sim_sum = sim + no_sim_sum_expend

        loss = torch.div(sim, sim_sum)
        loss = mask_no_sim + loss + torch.eye(n, n).to(device)
        loss = -torch.log(loss)  
        loss = torch.sum(torch.sum(loss, dim=1)) / (len(torch.nonzero(loss)))

        return loss

In [None]:
## Deep residual neural network (ResNet)

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv1d(inchannel, outchannel, 3, stride, 1, bias=False),   
            nn.BatchNorm1d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv1d(outchannel, outchannel, 3, 1, 1, bias=False),
            nn.BatchNorm1d(outchannel)
        )
        self.right = shortcut


    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)


class ResNet(nn.Module):
    def __init__(self, blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.model_name = 'resnet34'

        self.pre = nn.Sequential(
            nn.Conv1d(6, 64, 7, 2, 3, bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(3, 2, 1))

        self.layer1 = self._make_layer(64, 64, blocks[0])
        self.layer2 = self._make_layer(64, 128, blocks[1], stride=2)
        self.layer3 = self._make_layer(128, 256, blocks[2], stride=2)
        self.layer4 = self._make_layer(256, 512, blocks[3], stride=2)

        self.fc = nn.Linear(512, num_classes)


    def _make_layer(self, inchannel, outchannel, block_num, stride=1):
        
        shortcut = nn.Sequential(
            nn.Conv1d(inchannel, outchannel, 1, stride, bias=False),
            nn.BatchNorm1d(outchannel),
            nn.ReLU()
        )

        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))

        for i in range(1, block_num):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)


    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.pre(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = nn.AdaptiveAvgPool1d(1)(x)
        x = x.view(x.size(0), -1)
        return self.fc(x), x
    

def Resnet18(num_class=10):
    return ResNet([2, 2, 2, 2], num_class)


def Resnet34(num_class=10):
    return ResNet([3, 4, 6, 3], num_class)

In [None]:
from datas import load_data

LOADPATH = "./data/"
picked_sensors = [4, 6, 7, 9, 11, 12] # picked 6 sensors
# origin sample size: (n, 65, 13)
# sample size: (n, 65, 6)
# label size: (n,)
train_dataset, val_dataset, test_dataset, open_test_dataset, weibull_dataset = load_data(LOADPATH, picked_sensors)

Batch_size = 128
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=Batch_size, shuffle=True)
valid_loader = Data.DataLoader(dataset=val_dataset, batch_size=Batch_size, shuffle=False)
test_loader = Data.DataLoader(dataset=test_dataset, batch_size=Batch_size, shuffle=False)
open_test_loader = Data.DataLoader(dataset=open_test_dataset, batch_size=1, shuffle=False)
weibull_loader = Data.DataLoader(dataset=weibull_dataset, batch_size=1, shuffle=False)

print("created data loader successfully...")

In [None]:
# utils
def get_confusion_matrix(trues, preds):
    labels = []
    for i in range(len(set(labels))):
        labels.append(i)
    conf_matrix = confusion_matrix(trues, preds)
    return conf_matrix


def plot_confusion_matrix(conf_matrix):
    plt.imshow(conf_matrix.T, cmap=plt.cm.Greens)
    indices = range(conf_matrix.shape[0])
    labels = []
    for i in range(6):
        labels.append(i)

    plt.xticks(indices, labels)
    plt.yticks(indices, labels)
    plt.colorbar()
    plt.xlabel('y_true')
    plt.ylabel('y_pred')
   
    for first_index in range(conf_matrix.shape[0]):
        for second_index in range(conf_matrix.shape[1]):
            plt.text(first_index, second_index, conf_matrix[first_index, second_index])
    plt.savefig('heatmap_confusion_matrix_18.jpg')
    plt.show()


def plot_confusion_matrix_open(conf_matrix):
    plt.imshow(conf_matrix.T, cmap=plt.cm.Greens)
    indices = range(conf_matrix.shape[0])
    labels = []
    for i in range(len(conf_matrix)):
        labels.append(i)

    plt.xticks(indices, labels)
    plt.yticks(indices, labels)
    plt.colorbar()
    plt.xlabel('y_true')
    plt.ylabel('y_pred')
    
    for first_index in range(conf_matrix.shape[0]):
        for second_index in range(conf_matrix.shape[1]):
            plt.text(first_index, second_index, conf_matrix[first_index, second_index])
    plt.savefig('heatmap_confusion_matrix_18_open.jpg')
    plt.show()


In [None]:
training = False
load_model = True  
load_model_epoch = 700 
save_model = True  
save_each_epoch = 100 
EPOCH = 0

# hyperparameters
w_ce = 1
w_cl = 0.5

if training:
    model = Resnet34(num_class=6).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001)
    loss_function = nn.CrossEntropyLoss()  
    scl_loss = SuCL(temperature=0.7)

    if load_model:
        checkpoint = torch.load('./checkpoints/model_checkpoint_' + str(load_model_epoch) + '.tar')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        load_epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print('load model over ... ')
        start_epoch = load_epoch + 1
    else:
        start_epoch = 0

    print('start training ... ')
    hist_loss = np.zeros(EPOCH)
    hist_loss_val = np.zeros(EPOCH)
    hist_loss_test = np.zeros(EPOCH)
    model.train()

    train_dict = {}
    valid_dict = {}
    test_dict = {}

    # test_loader = None
    for epoch in range(start_epoch, start_epoch + EPOCH):
        tol_loss = 0.0
        train_preds = []
        train_trues = []
        for x, y in train_loader:
            batch_x = Variable(x)
            batch_y = Variable(y)
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.long().to(device)
            output, outscl = model(batch_x)
            
            loss_ce = loss_function(output, batch_y.squeeze(dim=1))
            loss_cl = scl_loss(outscl, batch_y.squeeze(dim=1))
            loss = w_ce * loss_ce + w_cl * loss_cl

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            tol_loss += loss.item()
            train_outputs = output.argmax(dim=1)

            train_preds.extend(train_outputs.detach().cpu().numpy())
            train_trues.extend(batch_y.detach().cpu().numpy())

        sklearn_accuracy = accuracy_score(train_trues, train_preds)
        sklearn_precision = precision_score(train_trues, train_preds, average='micro')
        sklearn_recall = recall_score(train_trues, train_preds, average='micro')
        sklearn_f1 = f1_score(train_trues, train_preds, average='micro')
        hist_loss[epoch - start_epoch] = tol_loss
        print(
            "[sklearn_metrics] Epoch:{} loss:{:.4f} accuracy:{:.4f} precision:{:.4f} recall:{:.4f} f1:{:.4f}".format(
                epoch,
                tol_loss,
                sklearn_accuracy,
                sklearn_precision,
                sklearn_recall,
                sklearn_f1))

        model.eval()
        val_pres = []
        val_trues = []
        tol_val_loss = 0
        for val_x, val_y in valid_loader:
            # Propagate input
            val_netout, outscl = model(val_x.to(device))
            val_y = val_y.long().to(device)

            # Comupte loss
            loss_ce = loss_function(val_netout, val_y.squeeze(dim=1))
            loss_cl = scl_loss(outscl, val_y.squeeze(dim=1))
            val_loss = w_ce * loss_ce + w_cl * loss_cl

            tol_val_loss += val_loss.data
            val_outputs = val_netout.argmax(dim=1)

            val_pres.extend(val_outputs.detach().cpu().numpy())
            val_trues.extend(val_y.detach().cpu().numpy())

        sklearn_val_accuracy = accuracy_score(val_trues, val_pres)
        sklearn_val_precision = precision_score(val_trues, val_pres, average='micro')
        sklearn_val_recall = recall_score(val_trues, val_pres, average='micro')
        sklearn_val_f1 = f1_score(val_trues, val_pres, average='micro')

        hist_loss_val[epoch - start_epoch] = tol_val_loss
        print(
            "[sklearn_val_metrics] Epoch:{} val_loss:{:.4f} val_accuracy:{:.4f} val_precision:{:.4f} val_recall:{:.4f} val_f1:{:.4f}"
                .format(epoch, tol_val_loss, sklearn_val_accuracy, sklearn_val_precision, sklearn_val_recall,
                        sklearn_val_f1))

        # close set test
        model.eval()
        test_pres = []
        test_trues = []
        scores_pred = []
        tol_test_loss = 0
        test = True
        if test:
            for test_x, test_y in test_loader:
                # Propagate input
                test_netout, outscl = model(test_x.to(device))
                test_y = test_y.long().to(device)
                scores_pred.append(F.softmax(test_netout))
                # Comupte loss
                loss_ce = loss_function(test_netout, test_y.squeeze(dim=1))
                loss_cl = scl_loss(outscl, test_y.squeeze(dim=1))
                test_loss = w_ce * loss_ce + w_cl * loss_cl
                tol_test_loss += test_loss.data
                test_outputs = test_netout.argmax(dim=1)
                test_pres.extend(test_outputs.detach().cpu().numpy())
                test_trues.extend(test_y.detach().cpu().numpy())

            sklearn_test_accuracy = accuracy_score(test_trues, test_pres)
            sklearn_test_precision = precision_score(test_trues, test_pres, average='micro')
            sklearn_test_recall = recall_score(test_trues, test_pres, average='micro')
            sklearn_test_f1 = f1_score(test_trues, test_pres, average='micro')
            hist_loss_test[epoch - start_epoch] = tol_test_loss
            print(
                "[sklearn_closeset_test_metrics] Epoch:{} test_loss:{:.4f} test_accuracy:{:.4f} test_precision:{:.4f} test_recall:{:.4f} test_f1:{:.4f}"
                    .format(epoch, tol_test_loss, sklearn_test_accuracy, sklearn_test_precision, sklearn_test_recall,
                            sklearn_test_f1))
            
            print("print confusuionMatrix of closeSet classification:\n")
            conf_matrix = get_confusion_matrix(train_trues, train_preds)
            plot_confusion_matrix(conf_matrix)
            model.train()
            
            test_trues_one = label_binarize(test_trues, classes=[0, 1, 2, 3, 4, 5])
        if epoch % save_each_epoch == 0 and save_model:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': tol_loss
            }, './checkpoints/model_checkpoint_' + str(epoch) + '.tar'  
            )
    plt.plot(hist_loss, 'o-', label='train')
    plt.plot(hist_loss_val, 'o-', label='val_loss')
    plt.legend()
    plt.savefig('loss_18.jpg')
    plt.show()


In [None]:
# open set test
testing = True
start_ind = 0
ts_percent = 0.1
ts_limit = 15

# decide checkpoints which need to test
start_ckeckpoint = 700
end_checkpoint = 701
checkpoint_step = 100

if testing:
    model = Resnet34(num_class=6).to(device)
    model.eval()

    result = []
    for cp in range(start_ckeckpoint, end_checkpoint, checkpoint_step):
        
        checkpoint = torch.load(f'./checkpoints/model_checkpoint_{cp}.tar')
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        class_v = [[], [], [], [], [], []]
        
        for x, y in weibull_loader:
            batch_x = Variable(x)  
            batch_y = Variable(y)
            batch_x = batch_x.float().to(device) 
            batch_y = batch_y.long().to(device)
            output, outscl = model(batch_x)
            score = output.argmax(axis=1).to('cpu')
            yy = y.reshape(-1).int()
            b = score.eq(yy)
            
            if b[0]:
                class_v[int(yy[0].item())].append(output)
        center = []
        mr_all = []
        
        for i in range(6):
            assert (len(class_v[i]) != 0)
            class_va = torch.cat(class_v[i])
            class_mean = torch.mean(class_va, dim=0)
            center.append(class_mean)
            d_s = torch.pow(class_va - class_mean, 2).sum(dim=1)
            out = torch.sqrt(d_s)
            out, indx = torch.sort(out)
            mr = libmr.MR()
            tailSize = int(out.shape[0] * ts_percent)
            tailSize = max(tailSize, ts_limit)
            mr.fit_high(out, tailSize)
            mr_all.append(mr)

        all_score = []
        test_pres = []
        test_trues = []
        scores_pred = []
        pred_list = []

        for x, y in open_test_loader:
            batch_x = Variable(x)  
            batch_x = batch_x.float().to(device)  
            batch_y = int(y.item()) 
            output, outscl = model(batch_x)
            scores = []
            unknownList = []
            
            for ii, cent in enumerate(center):
                d_s = torch.pow(output - cent, 2).sum(dim=1)
                out = torch.sqrt(d_s)
                score_ = mr_all[ii].w_score(out)
                modified_score = output[0, ii] * (1 - score_)
                scores.append(modified_score)
                unkown = output[0, ii] - modified_score
                unknownList.append(unkown)
            
            unkown_score = sum(unknownList)
            scores.append(unkown_score)
            scores = torch.tensor(scores)
            scores_pred.append(F.softmax(scores))
            scores_ = F.softmax(scores)
            max_s = torch.max(scores_)
            thresd = 0.8  

            if max_s > thresd:
                test_outputs = scores.argmax()
                test_pres.append(test_outputs.item())
            else:
                test_pres.append(6)
            
            scores_pred.append(F.softmax(scores))
            test_trues.append(batch_y)
            pred_score = scores.numpy()
            pred_prob = scores_.numpy()
            pred_result = np.append(pred_score, np.append(pred_prob, batch_y))
            pred_list.append(pred_result)
        
        pred_list = np.array(pred_list)
        df = pd.DataFrame(pred_list)
        df.to_csv('OSCresult.csv', index=False, header=False)
        print("saved OSC result successfully... ")
        test_trues = np.array(test_trues)
        test_pres = np.array(test_pres)
        test_trues = (test_trues == 5).astype(int)
        test_pres = (test_pres == 5).astype(int)
        conf_matrix = get_confusion_matrix(test_trues, test_pres)
        precision = conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[1, 0])
        recall = conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[0, 1])
        f1 = 2 * precision * recall / (precision + recall)
        fpr = conf_matrix[1, 0] / (conf_matrix[0, 0] + conf_matrix[1, 0])
        
        print(f'f1: {f1 :.4f}, precision: {precision :.4f}, recall: {recall :.4f}, fpr: {fpr :.4f}')
        conf_matrix = get_confusion_matrix(test_trues, test_pres)
        plot_confusion_matrix_open(conf_matrix)
        result.append((f1, precision, recall, fpr))