In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader , random_split,SubsetRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from tqdm import tqdm
import pickle
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
import warnings

In [2]:
exp = pd.read_csv("/home/data/sdb/wt/model_data/cell_media_pre.csv")

In [3]:
exp

Unnamed: 0,cell,media,TSPAN6,TNMD,FGR,CFH,FUCA2,GCLC,NIPAL3,ENPP4,...,CCL3,RDM1,SSTR3,ZNF229,PAGR1,TMEM265,IQCJ-SCHIP1,FAM95C,POLR2J3,CDR1
0,ACH-001339,1,3.150560,0.0,0.056584,1.310340,6.687201,3.682573,3.494416,0.790772,...,1.035624,1.948601,0.000000,0.536053,3.005400,0.454176,4.912650,0.000000,4.807355,0.042644
1,ACH-001538,0,5.085340,0.0,0.000000,5.868390,6.165309,4.489928,4.418865,3.485427,...,0.722466,2.759156,0.000000,2.094236,3.075533,4.426265,4.541019,0.604071,3.702658,0.000000
2,ACH-000327,1,3.337711,0.0,0.014355,3.090853,6.011451,3.642702,3.780310,3.280956,...,0.000000,1.014355,0.000000,1.871844,3.758090,0.084064,1.807355,1.669027,4.986411,0.250962
3,ACH-000233,1,0.056584,0.0,0.028569,6.093602,3.033863,3.422233,3.821710,3.207893,...,2.253989,0.056584,0.028569,0.014355,3.330558,1.028569,3.102658,2.448901,6.766330,0.000000
4,ACH-000461,1,4.017031,0.0,0.028569,0.084064,5.588565,6.380937,3.053111,2.392317,...,0.070389,2.792855,0.042644,0.014355,3.414136,4.792335,3.168321,1.104337,5.464995,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
905,ACH-000285,1,0.056584,0.0,0.411426,0.097611,0.704872,4.829850,3.326250,2.321928,...,1.070389,2.451541,0.028569,0.014355,3.835924,0.000000,0.111031,0.014355,6.325530,0.042644
906,ACH-002669,1,3.111031,0.0,0.014355,3.624101,6.805421,4.472488,3.665620,2.330558,...,1.490570,2.972693,0.028569,0.014355,2.408712,4.935460,4.458119,0.516015,5.073392,0.000000
907,ACH-001858,1,4.390943,0.0,0.028569,3.286881,6.902194,5.410748,4.047015,2.757023,...,0.000000,0.887525,0.014355,1.111031,2.805292,3.238787,4.646739,0.111031,4.486714,0.526069
908,ACH-001997,1,5.057450,0.0,0.028569,4.079805,6.971659,4.469886,4.933100,3.275007,...,0.263034,1.879706,0.000000,1.275007,3.547203,4.915999,4.104337,0.799087,4.595146,0.000000


In [4]:
class media_dt(Dataset):
    def __init__(self, exp_dt):
        self.exp = exp_dt
    def __len__(self):
        return len(self.exp)
    def __getitem__(self, idx): 
        media = torch.tensor(self.exp.loc[idx,].media)
        exp = torch.tensor(list(self.exp.loc[idx,][2:7995]))
        cell = self.exp.loc[idx,].cell
        return exp, media, cell

In [5]:
dt = media_dt(exp)

In [6]:
from sklearn.model_selection import KFold
import random
splits = KFold(n_splits=5,shuffle=True,random_state=2024052001)

In [7]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(7993, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1)
        )

    def forward(self, x):
        x = self.encoder(x)
        return x

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_op = torch.nn.BCEWithLogitsLoss()

In [9]:
from tqdm import tqdm

def train():
    model.train()

    total_loss = 0
    step = 0 
    for _, data in enumerate(tqdm(train_loader)):
        input_x = data[0].to(device)
        ys = data[1].to(device)
        optimizer.zero_grad()
        pred = model(input_x.float())
        loss = loss_op(
            pred,
            ys.float().reshape(-1, 1),
        )
        total_loss += loss.item() 
        loss.backward()
        optimizer.step()
        step = step + 1
    return total_loss / step


@torch.no_grad()
def test():
    model.eval()

    ys, preds, preds_raw, cells = [], [], [], []
    total_loss = 0
    step = 0 
    for _, data in enumerate(tqdm(test_loader)):
        input_x = data[0].to(device)
        label = data[1].to(device)
        ys.append(label.cpu().detach().numpy())
        cells.append(data[2])
        out = model(input_x.float())
        loss = loss_op(
            out,
            label.float().reshape(-1, 1),
        )
        total_loss += loss.item() 
        step = step + 1
        preds.append(np.rint(torch.sigmoid(out).cpu().detach().numpy()))
        preds_raw.append(torch.sigmoid(out).cpu().detach().numpy())

    all_preds = np.concatenate(preds).ravel()
    all_labels = np.concatenate(ys).ravel()
    all_preds_raw = np.concatenate(preds_raw).ravel()
    all_cells = np.concatenate(cells).ravel()
    
    res = pd.DataFrame({"preds":all_preds,"preds_raw":all_preds_raw,"label":all_labels,"cells":all_cells})
    calculate_metrics(all_preds, all_labels, epoch, "test")
    return total_loss / step, res


def calculate_metrics(y_pred, y_true, epoch, type):
    print(f"\n Confusion matrix: \n {confusion_matrix(y_true, y_pred)}")
    print(f"F1 Score: {f1_score(y_true, y_pred)}")
    print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")
    try:
        roc = roc_auc_score(y_true, y_pred)
        print(f"ROC AUC: {roc}")
    except:
        print(f"ROC AUC: notdefined")

In [10]:
for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(exp)))):

    print('Fold {}'.format(fold + 1))
    
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(val_idx)
    train_loader = DataLoader(dt, batch_size=16, sampler=train_sampler, num_workers = 12)
    test_loader = DataLoader(dt, batch_size=16, sampler=test_sampler, num_workers = 12)
    
    model = Net().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(15):
        loss_train = train()
        print(f"Epoch: {epoch:03d}, Train Loss: {loss_train:.4f}")
        
    loss_test, test_pre = test()
    print(f"Epoch: {epoch:03d}, Test Loss: {loss_test:.4f}")
    test_pre.to_csv("/home/data/sdb/wt/model_data/cv_fold_"+str(fold)+"_media.csv")

Fold 1


100%|██████████| 46/46 [00:01<00:00, 26.05it/s]


Epoch: 000, Train Loss: 1.1085


100%|██████████| 46/46 [00:01<00:00, 34.36it/s]


Epoch: 001, Train Loss: 0.4642


100%|██████████| 46/46 [00:01<00:00, 31.49it/s]


Epoch: 002, Train Loss: 0.4445


100%|██████████| 46/46 [00:01<00:00, 29.28it/s]


Epoch: 003, Train Loss: 0.4175


100%|██████████| 46/46 [00:01<00:00, 33.14it/s]


Epoch: 004, Train Loss: 0.3943


100%|██████████| 46/46 [00:01<00:00, 26.34it/s]


Epoch: 005, Train Loss: 0.4024


100%|██████████| 46/46 [00:01<00:00, 28.37it/s]


Epoch: 006, Train Loss: 0.3489


100%|██████████| 46/46 [00:01<00:00, 36.52it/s]


Epoch: 007, Train Loss: 0.3004


100%|██████████| 46/46 [00:01<00:00, 24.15it/s]


Epoch: 008, Train Loss: 0.2182


100%|██████████| 46/46 [00:01<00:00, 28.78it/s]


Epoch: 009, Train Loss: 0.2349


100%|██████████| 46/46 [00:01<00:00, 29.86it/s]


Epoch: 010, Train Loss: 0.1409


100%|██████████| 46/46 [00:01<00:00, 26.97it/s]


Epoch: 011, Train Loss: 0.0713


100%|██████████| 46/46 [00:01<00:00, 34.39it/s]


Epoch: 012, Train Loss: 0.0941


100%|██████████| 46/46 [00:01<00:00, 31.12it/s]


Epoch: 013, Train Loss: 0.0842


100%|██████████| 46/46 [00:01<00:00, 30.83it/s]


Epoch: 014, Train Loss: 0.0309


100%|██████████| 12/12 [00:00<00:00, 18.36it/s]



 Confusion matrix: 
 [[ 15  32]
 [  1 134]]
F1 Score: 0.8903654485049833
Accuracy: 0.8186813186813187
Precision: 0.8072289156626506
Recall: 0.9925925925925926
ROC AUC: 0.6558707643814027
Epoch: 014, Test Loss: 1.0759
Fold 2


100%|██████████| 46/46 [00:01<00:00, 29.19it/s]


Epoch: 000, Train Loss: 1.1654


100%|██████████| 46/46 [00:01<00:00, 27.02it/s]


Epoch: 001, Train Loss: 0.4347


100%|██████████| 46/46 [00:01<00:00, 23.27it/s]


Epoch: 002, Train Loss: 0.4429


100%|██████████| 46/46 [00:01<00:00, 25.38it/s]


Epoch: 003, Train Loss: 0.4408


100%|██████████| 46/46 [00:01<00:00, 24.07it/s]


Epoch: 004, Train Loss: 0.3452


100%|██████████| 46/46 [00:01<00:00, 32.29it/s]


Epoch: 005, Train Loss: 0.3471


100%|██████████| 46/46 [00:01<00:00, 26.38it/s]


Epoch: 006, Train Loss: 0.2124


100%|██████████| 46/46 [00:01<00:00, 26.08it/s]


Epoch: 007, Train Loss: 0.2495


100%|██████████| 46/46 [00:01<00:00, 32.82it/s]


Epoch: 008, Train Loss: 0.1586


100%|██████████| 46/46 [00:01<00:00, 29.90it/s]


Epoch: 009, Train Loss: 0.1616


100%|██████████| 46/46 [00:01<00:00, 27.10it/s]


Epoch: 010, Train Loss: 0.2282


100%|██████████| 46/46 [00:01<00:00, 29.67it/s]


Epoch: 011, Train Loss: 0.0647


100%|██████████| 46/46 [00:01<00:00, 31.82it/s]


Epoch: 012, Train Loss: 0.0491


100%|██████████| 46/46 [00:01<00:00, 32.72it/s]


Epoch: 013, Train Loss: 0.2275


100%|██████████| 46/46 [00:01<00:00, 30.63it/s]


Epoch: 014, Train Loss: 0.0556


100%|██████████| 12/12 [00:00<00:00, 15.33it/s]



 Confusion matrix: 
 [[ 27  17]
 [ 22 116]]
F1 Score: 0.8560885608856088
Accuracy: 0.7857142857142857
Precision: 0.8721804511278195
Recall: 0.8405797101449275
ROC AUC: 0.7271080368906456
Epoch: 014, Test Loss: 0.8016
Fold 3


100%|██████████| 46/46 [00:01<00:00, 29.10it/s]


Epoch: 000, Train Loss: 0.8934


100%|██████████| 46/46 [00:01<00:00, 24.94it/s]


Epoch: 001, Train Loss: 0.4515


100%|██████████| 46/46 [00:01<00:00, 28.67it/s]


Epoch: 002, Train Loss: 0.4305


100%|██████████| 46/46 [00:01<00:00, 29.15it/s]


Epoch: 003, Train Loss: 0.3843


100%|██████████| 46/46 [00:01<00:00, 27.43it/s]


Epoch: 004, Train Loss: 0.2910


100%|██████████| 46/46 [00:01<00:00, 31.84it/s]


Epoch: 005, Train Loss: 0.2865


100%|██████████| 46/46 [00:01<00:00, 26.05it/s]


Epoch: 006, Train Loss: 0.2339


100%|██████████| 46/46 [00:01<00:00, 27.82it/s]


Epoch: 007, Train Loss: 0.1741


100%|██████████| 46/46 [00:01<00:00, 26.90it/s]


Epoch: 008, Train Loss: 0.1894


100%|██████████| 46/46 [00:01<00:00, 30.94it/s]


Epoch: 009, Train Loss: 0.1293


100%|██████████| 46/46 [00:01<00:00, 30.04it/s]


Epoch: 010, Train Loss: 0.2195


100%|██████████| 46/46 [00:01<00:00, 31.72it/s]


Epoch: 011, Train Loss: 0.0617


100%|██████████| 46/46 [00:01<00:00, 35.22it/s]


Epoch: 012, Train Loss: 0.0653


100%|██████████| 46/46 [00:01<00:00, 34.43it/s]


Epoch: 013, Train Loss: 0.1628


100%|██████████| 46/46 [00:01<00:00, 28.95it/s]


Epoch: 014, Train Loss: 0.0578


100%|██████████| 12/12 [00:00<00:00, 22.36it/s]



 Confusion matrix: 
 [[ 17  24]
 [ 12 129]]
F1 Score: 0.8775510204081632
Accuracy: 0.8021978021978022
Precision: 0.8431372549019608
Recall: 0.9148936170212766
ROC AUC: 0.66476388168137
Epoch: 014, Test Loss: 0.9159
Fold 4


100%|██████████| 46/46 [00:01<00:00, 34.16it/s]


Epoch: 000, Train Loss: 1.3426


100%|██████████| 46/46 [00:01<00:00, 23.23it/s]


Epoch: 001, Train Loss: 0.5037


100%|██████████| 46/46 [00:01<00:00, 27.84it/s]


Epoch: 002, Train Loss: 0.4390


100%|██████████| 46/46 [00:02<00:00, 22.70it/s]


Epoch: 003, Train Loss: 0.4119


100%|██████████| 46/46 [00:01<00:00, 31.15it/s]


Epoch: 004, Train Loss: 0.4319


100%|██████████| 46/46 [00:01<00:00, 29.18it/s]


Epoch: 005, Train Loss: 0.3798


100%|██████████| 46/46 [00:01<00:00, 30.66it/s]


Epoch: 006, Train Loss: 0.3770


100%|██████████| 46/46 [00:01<00:00, 24.92it/s]


Epoch: 007, Train Loss: 0.2964


100%|██████████| 46/46 [00:01<00:00, 35.17it/s]


Epoch: 008, Train Loss: 0.2614


100%|██████████| 46/46 [00:02<00:00, 22.43it/s]


Epoch: 009, Train Loss: 0.1723


100%|██████████| 46/46 [00:01<00:00, 26.05it/s]


Epoch: 010, Train Loss: 0.2342


100%|██████████| 46/46 [00:01<00:00, 34.59it/s]


Epoch: 011, Train Loss: 0.1150


100%|██████████| 46/46 [00:01<00:00, 30.40it/s]


Epoch: 012, Train Loss: 0.0684


100%|██████████| 46/46 [00:01<00:00, 28.56it/s]


Epoch: 013, Train Loss: 0.1840


100%|██████████| 46/46 [00:01<00:00, 33.58it/s]


Epoch: 014, Train Loss: 0.1482


100%|██████████| 12/12 [00:00<00:00, 23.03it/s]



 Confusion matrix: 
 [[ 22  21]
 [  5 134]]
F1 Score: 0.91156462585034
Accuracy: 0.8571428571428571
Precision: 0.864516129032258
Recall: 0.9640287769784173
ROC AUC: 0.7378283419775807
Epoch: 014, Test Loss: 0.4493
Fold 5


100%|██████████| 46/46 [00:01<00:00, 40.69it/s]


Epoch: 000, Train Loss: 0.8783


100%|██████████| 46/46 [00:01<00:00, 24.46it/s]


Epoch: 001, Train Loss: 0.4434


100%|██████████| 46/46 [00:01<00:00, 30.20it/s]


Epoch: 002, Train Loss: 0.4471


100%|██████████| 46/46 [00:01<00:00, 35.22it/s]


Epoch: 003, Train Loss: 0.3919


100%|██████████| 46/46 [00:01<00:00, 28.33it/s]


Epoch: 004, Train Loss: 0.3843


100%|██████████| 46/46 [00:01<00:00, 28.75it/s]


Epoch: 005, Train Loss: 0.3111


100%|██████████| 46/46 [00:01<00:00, 37.88it/s]


Epoch: 006, Train Loss: 0.3643


100%|██████████| 46/46 [00:01<00:00, 31.59it/s]


Epoch: 007, Train Loss: 0.2324


100%|██████████| 46/46 [00:01<00:00, 26.98it/s]


Epoch: 008, Train Loss: 0.1974


100%|██████████| 46/46 [00:01<00:00, 32.43it/s]


Epoch: 009, Train Loss: 0.1699


100%|██████████| 46/46 [00:01<00:00, 31.36it/s]


Epoch: 010, Train Loss: 0.1822


100%|██████████| 46/46 [00:01<00:00, 31.84it/s]


Epoch: 011, Train Loss: 0.1609


100%|██████████| 46/46 [00:01<00:00, 23.61it/s]


Epoch: 012, Train Loss: 0.0567


100%|██████████| 46/46 [00:01<00:00, 25.51it/s]


Epoch: 013, Train Loss: 0.0883


100%|██████████| 46/46 [00:01<00:00, 31.42it/s]


Epoch: 014, Train Loss: 0.0764


100%|██████████| 12/12 [00:00<00:00, 19.79it/s]


 Confusion matrix: 
 [[ 28  18]
 [ 26 110]]
F1 Score: 0.8333333333333333
Accuracy: 0.7582417582417582
Precision: 0.859375
Recall: 0.8088235294117647
ROC AUC: 0.7087595907928389
Epoch: 014, Test Loss: 0.6983



