In [None]:
# 读取数据
import pickle
import os
import numpy as np

path = '/home/zw/Data/cfDNA_data/PRJNA574555process/CRC'
name = []
for f_path in os.listdir(path):
    if 'pkl' in f_path:
        file_name = path+ "/" + f_path
        name.append(file_name)
        
seq = np.zeros((len(name), 1500, 4, 75), dtype='int')
label = []
for i in range(len(name)):
    with open(name[i],'rb') as f:
        data = pickle.load(f)
        seq[i,:,:,:] = data[0]
        label.append(data[1])
labels = np.array(label,dtype='float')

## CLHCC model

In [None]:
import torch  
import torch.nn as nn  
import torch.nn.functional as F  

class CNN_LSTM(nn.Module):  
    def __init__(self, input_channels, hidden_size):  
        super(CNN_LSTM, self).__init__()  
        # CNN
        self.conv=nn.Sequential(nn.Conv2d(in_channels=1500, out_channels=256, kernel_size=(4,4), stride=2, padding=0))  
          
        # LSTM
        self.input_channels = input_channels
        self.lstm = nn.LSTM(256, hidden_size, num_layers=4, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, 1)  
        self.dropout = nn.Dropout(0.2)
        self.sigmoid = nn.Sigmoid()    
          
    def forward(self, x):          
        # CNN
        out = self.conv(x)
        out = out.squeeze()
        
        # LSTM
        x = out.permute(0, 2, 1)
        lstm_out, _ = self.lstm(x)  
        out = lstm_out[:, -1, :]
        
        out = self.dropout(out)
        out = self.fc(out)
        output = self.sigmoid(out)
        return output  
   
model = CNN_LSTM(input_channels=1500, hidden_size=800)  
print(model)

## train

In [None]:
class EarlyStopping:  
    """Early stops the training if validation loss doesn't improve after a given patience."""  
    def __init__(self, patience=20, verbose=False, delta=0):  
        """  
        Args:  
            patience (int): How long to wait after last time validation loss improved.  
                            Default: 7  
            verbose (bool): If True, prints a message for each validation loss improvement.   
                            Default: False  
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.  
                            Default: 0  
        """  
        self.patience = patience  
        self.verbose = verbose  
        self.counter = 0  
        self.best_score = None  
        self.early_stop = False  
        self.val_loss_min = np.inf  
        self.delta = delta  
  
    def __call__(self, val_loss, model):  
        score = -val_loss  
        if self.best_score is None:  
            self.best_score = score  
            self.save_checkpoint(val_loss, model)  
        elif score < self.best_score + self.delta:  
            self.counter += 1  
            print(f'Epoch {self.counter}, has worse validation loss than previous {self.best_score:.4f}')  
            if self.counter >= self.patience:  
                self.early_stop = True  
        else:  
            self.best_score = score  
            self.save_checkpoint(val_loss, model)  
            self.counter = 0  
  
    def save_checkpoint(self, val_loss, model):  
        '''Saves model when validation loss decrease.'''  
        if self.verbose:  
            print(f'Validation acc decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')  
        torch.save(model.state_dict(), 'best_model.pth')  
        self.val_loss_min = val_loss

In [None]:
from sklearn.model_selection import KFold
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR  
from torch.utils.data import DataLoader, TensorDataset  
from sklearn.model_selection import train_test_split  
from sklearn.metrics import confusion_matrix  
import seaborn as sns  
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report

kf = KFold(n_splits = 10, shuffle = True)

all_acc = []
matrix = []
ACC = []
PRE = []
RECALL = []
F1 = []
for train_index,val_index in kf.split(seq):
    labels = np.array(label,dtype='float')
    train_data = np.zeros((len(train_index), 1500, 4, 75), dtype='int')
    val_data = np.zeros((len(val_index), 1500, 4, 75), dtype='int')
    for i in range(len(train_index)):
        train_data[i,:,:,:]= seq[train_index[i]]
    for i in range(len(val_index)):
        val_data[i,:,:,:] =  seq[val_index[i]]
    train_labels, val_labels =labels[train_index], labels[val_index]

    # NumPy to Tensor  
    data_train = torch.from_numpy(train_data).float()  
    targets_train = torch.from_numpy(train_labels).long()  
    data_test = torch.from_numpy(val_data).float()  
    targets_test = torch.from_numpy(val_labels).long()  
      
    dataset_train = TensorDataset(data_train, targets_train)  
    dataset_test = TensorDataset(data_test, targets_test)  
      
    dataloader_train = DataLoader(dataset_train, batch_size=128, shuffle=True)  
    dataloader_test = DataLoader(dataset_test, batch_size=128, shuffle=False)  
    
    model = CNN_LSTM(input_channels=1500, hidden_size=800)  
    
    criterion = nn.BCELoss()
    early_stopping = EarlyStopping(patience=20, verbose=True)
    initial_lr = 0.0007 
    optimizer = optim.Adam(model.parameters(), lr=initial_lr)  
    scheduler = ExponentialLR(optimizer, gamma=0.9, last_epoch=-1)  
    # GPU
    if torch.cuda.is_available():
        model = model.cuda()
        criterion = criterion.cuda()

    print("train begin!") 
    num_epochs = 100
    train_loss = []
    train_acc = []
    test_acc = []
    test_loss = []
    pred_score = []
    best_acc = 0
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        n = len(train_index)
        scheduler.step()  
        model.train()
        for inputs, labels in dataloader_train:
            inputs = inputs.cuda()
            labels = labels.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            pred = outputs.squeeze()  
            pred = pred.float()
            labels = labels.float()
            loss = criterion(pred, labels)            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pred = (pred > 0.5).float() 
            correct += (labels==pred).sum()
    
        acc = correct/n*100
        train_loss.append(total_loss)
        train_acc.append(float(acc))
    
        model.eval() 
        y_pred_list = []  
        y_true_list = []
        with torch.no_grad(): 
            correct = 0  
            total_loss = 0
            total = len(val_index)
            for inputs, labels in dataloader_test:  
                inputs = inputs.cuda()
                labels = labels.cuda()
                outputs = model(inputs)
                predicted = (outputs > 0.5).float()
                y_pred_list.extend(predicted.cpu().numpy())
                y_true_list.extend(labels.cpu().numpy())
                correct += (predicted.squeeze() == labels.squeeze()).sum()
               
                outputs = outputs.squeeze().float()
                labels = labels.float()
                loss = criterion(outputs.squeeze(), labels)
                total_loss += loss.item()
            acc = correct / total*100
        print(f'the test data: Loss: {total_loss:.4f}, accuracy:{acc}%')
        test_acc.append(float(acc))
        test_loss.append(total_loss)
        # save the best model
        if acc > best_acc:  
            best_acc = acc  
            best_model_wts = model.state_dict() 
        early_stopping(total_loss, model)
        if early_stopping.early_stop:  
            print("Early stopping")
            
            cm = confusion_matrix(y_true_list, y_pred_list)  
            matrix.append(cm)
        
            y_pred = np.array(y_pred_list)  
            y_true = np.array(y_true_list)
            accuracy = accuracy_score(y_true, y_pred) 
            precision = precision_score(y_true, y_pred, average='binary')  # 或者 'micro', 'macro', 'weighted' 等  
            recall = recall_score(y_true, y_pred, average='binary')  # 类似地选择平均方式  
            f1 = f1_score(y_true, y_pred, average='binary')  # 类似地选择平均方式  
            ACC.append(accuracy)
            PRE.append(precision)
            RECALL.append(recall)
            F1.append(f1)
            break 
    
    torch.save(model.state_dict(), 'model.pth')
    all_acc.append(best_acc)

## result

In [None]:
data = np.array(ACC)
mean = np.mean(data) 
variance = np.var(data)  
print("acc均值:", mean)  
print("方差:", variance)

data = np.array(PRE)
mean = np.mean(data) 
variance = np.var(data)  
print("precision均值:", mean)  
print("方差:", variance)

data = np.array(RECALL)
mean = np.mean(data) 
variance = np.var(data)  
print("recall均值:", mean)  
print("方差:", variance)

data = np.array(F1)
mean = np.mean(data) 
variance = np.var(data)  
print("f1均值:", mean)  
print("方差:", variance)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
model = CNN_LSTM(input_channels=1500, hidden_size=800)
weight = torch.load('./model.pth', map_location=device) 
model.load_state_dict(weight,strict=False) 
model = model.to(device)
model.eval()
  
y_pred_list = []  
y_true_list = []  
y_score = []
  
with torch.no_grad(): 
    for data, target in dataloader_test:  
        data = data.to(device) 
        output = model(data)
        y_score.extend(output.cpu().numpy())
        predicted = (output > 0.5).float()
        y_pred_list.extend(predicted.cpu().numpy()) 
        y_true_list.extend(target.cpu().numpy())  
  
cm = confusion_matrix(y_true_list, y_pred_list)  
print("Confusion Matrix:")  
print(cm)  
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')  
plt.xlabel('Predicted')  
plt.ylabel('Truth')  
plt.savefig('hot.png')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

y_test = y_true_list
y_score = y_score
fpr, tpr, thread = roc_curve(y_true_list, y_score)
roc_auc = auc(fpr, tpr
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkred',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('CLHCC in CRC ROC')
plt.legend(loc="lower right")
plt.savefig('roc.png',)
plt.show()