## train and test

In [1]:
import os
import torch
import random
import numpy as np
import torch.nn.functional as F
from scbasset import scBasset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from tqdm import tqdm

import done!


In [2]:
# 固定 Python 的随机数种子
random.seed(2023)

# 固定 Numpy 的随机数种子
np.random.seed(2023)

# 固定 PyTorch 的随机数种子
torch.manual_seed(2023)
if torch.cuda.is_available():
    torch.cuda.manual_seed(2023)
    torch.cuda.manual_seed_all(2023)


In [3]:
lr = 0.001
batch_size = 64

In [4]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model = scBasset(n_cells=2)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

### data procession

In [5]:
def load_data(pos_file, neg_file):
    # 读取正样本和负样本文件
    pos_data,neg_data = [],[]
    with open(pos_file, 'r') as f:
        for line in f.readlines():
            if not line.startswith('<'):
                pos_data.append(line.strip())
    with open(neg_file, 'r') as f:
        for line in f.readlines():
            if not line.startswith('<'):
                neg_data.append(line.strip())

    # 将所有序列合并为一个列表
    all_data = pos_data + neg_data

    # 初始化onehot编码和填充的结果列表
    onehot_data = []
    padded_data = []
    label = [1]*len(pos_data)+[0]*len(neg_data)
#     print(label)
    # 对每个序列进行处理
    for seq in tqdm(all_data):
        # 将碱基序列转换为数字编码，A: 0, C: 1, G: 2, T: 3, N: 4
        num_seq = [0 if base == 'A' else 1 if base == 'C' else 2 if base == 'G' else 3 if base == 'T' else 4 for base in seq]

        # 进行onehot编码
        onehot_seq = torch.zeros((5, len(num_seq)))
        for i, num in enumerate(num_seq):
            onehot_seq[num, i] = 1
        
        onehot_seq = onehot_seq[:4,:]
        # 统一序列长度为2000，并进行填充
        if len(num_seq) >= 2000:
            padded_seq = onehot_seq[:, :2000]
        else:
            padded_seq = torch.zeros((4, 2000))
            padded_seq[:, :len(num_seq)] = onehot_seq

        # 将onehot编码和填充后的序列添加到结果列表中
        onehot_data.append(onehot_seq)
        padded_data.append(padded_seq)

    # 将结果列表转换为PyTorch张量，并返回
    
    return torch.stack(padded_data),torch.tensor(np.array(label))


device(type='cuda', index=3)

In [6]:
dataset = load_data('data/negatives.fa','data/positives.fa')
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [7]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4,pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,num_workers=4,pin_memory=True)
len(train_loader),len(test_loader)

(32985, 8247)

## train and test

In [8]:
import matplotlib.pyplot as plt
from train import train
from train import test

# Training loop
num_epochs = 5
train_res = []
test_res = []
for epoch in tqdm(range(num_epochs)):
    train_res.append(train(model, train_loader, optimizer, criterion, device))
    test_res.append(test(model, test_loader, criterion, device))


  0%|                                                                                | 0/5 [00:00<?, ?it/s]

Train Loss: 0.4690, Acc: 0.8355, Precision: 0.8425, Recall: 0.8355, F1: 0.8359, AUC: 0.8390



 20%|█████████████▌                                                      | 1/5 [18:19<1:13:18, 1099.60s/it]

Test Loss: 0.4568, Acc: 0.8483, Precision: 0.8567, Recall: 0.8483, F1: 0.8478, AUC: 0.8484
Train Loss: 0.4429, Acc: 0.8637, Precision: 0.8697, Recall: 0.8637, F1: 0.8640, AUC: 0.8671



 40%|████████████████████████████                                          | 2/5 [36:36<54:54, 1098.26s/it]

Test Loss: 0.4405, Acc: 0.8662, Precision: 0.8713, Recall: 0.8662, F1: 0.8660, AUC: 0.8662
Train Loss: 0.4342, Acc: 0.8731, Precision: 0.8787, Recall: 0.8731, F1: 0.8734, AUC: 0.8763



 60%|██████████████████████████████████████████                            | 3/5 [54:48<36:30, 1095.16s/it]

Test Loss: 0.4321, Acc: 0.8750, Precision: 0.8784, Recall: 0.8750, F1: 0.8750, AUC: 0.8750
Train Loss: 0.4276, Acc: 0.8803, Precision: 0.8856, Recall: 0.8803, F1: 0.8806, AUC: 0.8834



 80%|██████████████████████████████████████████████████████▍             | 4/5 [1:12:52<18:10, 1090.73s/it]

Test Loss: 0.4309, Acc: 0.8761, Precision: 0.8807, Recall: 0.8761, F1: 0.8760, AUC: 0.8761
Train Loss: 0.4215, Acc: 0.8870, Precision: 0.8919, Recall: 0.8870, F1: 0.8873, AUC: 0.8899


100%|████████████████████████████████████████████████████████████████████| 5/5 [1:30:52<00:00, 1090.49s/it]

Test Loss: 0.4282, Acc: 0.8792, Precision: 0.8826, Recall: 0.8792, F1: 0.8792, AUC: 0.8793





In [9]:
torch.save(model, 'epoch5_20230419.pth')