In [1]:
#utils
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
#preprocess
from sklearn.model_selection import train_test_split
import numpy as np
import os
import torch.nn.functional as F

import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import Levenshtein

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import os
import h5py
import numpy as np
from numpy import ndarray
from tqdm import tqdm
from itertools import islice as take
import copy

path = "/data/yuxin/eligos_IVT/pseudoU_single_fast5/"
output_directory = "/data/feiyang/output/"


def chunk_dataset(path, chunk_len, num_chunks=None):
    def all_chunks():
        for file_ in os.listdir(path):
            if not os.path.isdir(os.path.join(path, file_)):
                continue
            
            for file in os.listdir(os.path.join(path, file_)):
                if file.endswith('.fast5'):
                    path_ = os.path.join(path, file_)
                    with h5py.File(os.path.join(path_, file), 'r') as fast5_file:
                        data_path = fast5_file['Raw/Reads']
                        try:
                            for data in data_path:
                                if data.startswith("Read"):
                                    # Access the Signal dataset and convert to npy array
                                    signal = data_path[data]["Signal"][()]
                            for chunk, target in get_chunks(fast5_file, regular_break_points(len(signal), chunk_len)):
                                yield (chunk, target)
                        except KeyError:
                            continue

    all_chunks_gen = all_chunks()
    chunks, targets = zip(*tqdm(take(all_chunks_gen, num_chunks), total=num_chunks))
    targets, target_lens = pad_lengths(targets)  # convert refs from ragged arrray
    return ChunkDataSet(chunks, targets, target_lens)


def get_chunks(fast5_file, break_points):
    global array_start, array_target
    sample = scale(fast5_file)
    tmps = fast5_file["Analyses"].keys()
    Ref_to_signal = []
    Reference = []
    for tmp in tmps:
        if tmp.startswith('RawGenomeCorrected'):
            events = fast5_file["Analyses/" + tmp + "/BaseCalled_template/"]["Events"][()]
            for _, i in enumerate(events):
                Ref_to_signal.append(i[2])
                Reference.append(ACGT_2_num(i[4].decode()) + 1)

            array_start = np.stack(Ref_to_signal, axis=0)
            array_target = np.stack(Reference, axis=0)

    pointers = array_start
    target = array_target  # CTC convention
    return (
        (sample[i:j], target[ti:tj]) for (i, j), (ti, tj)
        in zip(break_points, np.searchsorted(pointers, break_points))
    )


def scale(fast5_file, normalise=True):
    """ scale and normalise a read """

    global scaled
    reads_group = fast5_file["Raw/Reads"]

    # Find the sample
    for group in reads_group:
        if group.startswith("Read"):
            # Access the Signal dataset and convert to npy array
            samples = reads_group[group]["Signal"][()]
            # scaled = (scaling * (samples + offset)).astype(np.float32)

    if normalise:
        tmps = fast5_file["Analyses"].keys()

        for tmp in tmps:

            if tmp.startswith('RawGenomeCorrected'):
                scale = fast5_file["Analyses/" + tmp + "/BaseCalled_template"].attrs["scale"]
                shift = fast5_file["Analyses/" + tmp + "/BaseCalled_template"].attrs["shift"]

                return (samples - shift) / scale

    return scaled


def ACGT_2_num(char):
    if char == 'A':
        return 0
    elif char == 'C':
        return 1
    elif char == 'G':
        return 2
    elif char == 'T':
        return 3
    else:
        raise ValueError('Invalid input')


# **********************************************************************************************************************************************************************
class ChunkDataSet:
    def __init__(self, chunks, targets, lengths):
        self.chunks = np.expand_dims(chunks, axis=1)
        self.targets = targets
        self.lengths = lengths

    def __getitem__(self, i):
        return (
            self.chunks[i].astype(np.float32),
            self.targets[i].astype(np.int64),
            self.lengths[i].astype(np.int64),
        )

    def __len__(self):
        return len(self.lengths)


def regular_break_points(n, chunk_len, overlap=0, align='mid'):
    num_chunks, remainder = divmod(n - overlap, chunk_len - overlap)
    start = {'left': 0, 'mid': remainder // 2, 'right': remainder}[align]
    starts = np.arange(start, start + num_chunks * (chunk_len - overlap), (chunk_len - overlap))
    return np.vstack([starts, starts + chunk_len]).T


def pad_lengths(ragged_array, max_len=None):
    lengths: ndarray = np.array([len(x) for x in ragged_array], dtype=np.uint16)
    padded = np.zeros((len(ragged_array), max_len or np.max(lengths)), dtype=ragged_array[0].dtype)
    for x, y in zip(ragged_array, padded):
        y[:len(x)] = x
    return padded, lengths


def typical_indices(x, n=2.5):
    mu, sd = np.mean(x), np.std(x)
    idx, = np.where((mu - n * sd < x) & (x < mu + n * sd))
    return idx


def filter_chunks(ds, idx):
    filtered = ChunkDataSet(ds.chunks.squeeze(1)[idx], ds.targets[idx], ds.lengths[idx])
    filtered.targets = filtered.targets[:, :filtered.lengths.max()]
    return filtered


def save_chunks(chunks, output_directory):
    a = chunks.chunks.squeeze(1)
    b = chunks.targets
    c = chunks.lengths
    indices = c != 0
    aa = np.compress(indices, a, axis=0)
    bb = np.compress(indices, b, axis=0)
    cc = np.compress(indices, c, axis=0)
    os.makedirs(output_directory, exist_ok=True)
    np.save(os.path.join(output_directory, "chunks.npy"), aa)
    np.save(os.path.join(output_directory, "references.npy"), bb)
    np.save(os.path.join(output_directory, "reference_lengths.npy"), cc)
    print()
    print("> data written to %s:" % output_directory)
    print("  - chunks.npy with shape", aa.shape)
    print("  - references.npy with shape", bb.shape)
    print("  - reference_lengths.npy shape", cc.shape)




training_chunks = chunk_dataset(path, 3600)
training_indices = typical_indices(training_chunks.lengths)
training_chunks = filter_chunks(training_chunks, np.random.permutation(training_indices))
save_chunks(training_chunks, output_directory)



print("ok")


115645it [1:24:15, 22.88it/s] 



> data written to /data/feiyang/output/:
  - chunks.npy with shape (55769, 3600)
  - references.npy with shape (55769, 186)
  - reference_lengths.npy shape (55769,)
ok


In [2]:
from sklearn.model_selection import train_test_split
import numpy as np
import os
import torch.nn.functional as F
import torch
chunks = np.load('/data/feiyang/output/chunks.npy')
chunks = (chunks - np.min(chunks)) / (np.max(chunks) - np.min(chunks))

references = np.load('/data/feiyang/output/references.npy')
reference_lengths = np.load('/data/feiyang/output/reference_lengths.npy')

#num_classes = 5  # 四个碱基类型和一个空白标签
#references = F.one_hot(torch.LongTensor(references), num_classes=num_classes).numpy()


train_chunks, valid_chunks, train_references, valid_references, train_reference_lengths, valid_reference_lengths = train_test_split(chunks, references, reference_lengths, test_size=0.1)
# valid_chunks, test_chunks, valid_references, test_references, valid_reference_lengths, test_reference_lengths = train_test_split(valid_chunks, valid_references, valid_reference_lengths, test_size=0.5, random_state=20)
output_directory="/data/feiyang/preprocessed/"
np.save(os.path.join(output_directory, "train_chunks.npy"), train_chunks)
np.save(os.path.join(output_directory, "train_references.npy"), train_references)
np.save(os.path.join(output_directory, "train_reference_lengths.npy"), train_reference_lengths)

np.save(os.path.join(output_directory, "valid_chunks.npy"), valid_chunks)
np.save(os.path.join(output_directory, "valid_references.npy"), valid_references)
np.save(os.path.join(output_directory, "valid_reference_lengths.npy"), valid_reference_lengths)

In [3]:
class TransformerLSTM(nn.Module):
    """
    Transformer-LSTM模型
    """
    def __init__(self, input_size, hidden_size, num_layers, num_heads, lstm_hidden_size, lstm_layers, lstm_dropout, output_size):
        super(TransformerLSTM, self).__init__()
        
        # Transformer编码器
        self.transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(self.transformer_encoder_layer, num_layers=num_layers)
        
        # LSTM解码器
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers, dropout=lstm_dropout, bidirectional=True)
        
        # 全连接层
        self.fc = nn.Linear(lstm_hidden_size * 2, output_size)
    
    def forward(self, x):
        # Transformer编码器
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, input_size)
        x = self.transformer_encoder(x)
        
        # LSTM解码器
        x = x.permute(1, 0, 2)  # (batch_size, seq_len, input_size)
        x, _ = self.lstm(x)
        x = self.fc(x)
        
        return x.permute(0, 2, 1)  # (batch_size, output_size, seq_len)

class CNNLSTM(nn.Module):
    """
    (B, 1, S) -> (B, C, S_reducted)
    """
    def __init__(self, line, num_classes, stride, winlen,hidden_size,lstm_layers,lstm_dropout):
        super(CNNLSTM, self).__init__()
        
        self.conv1 = nn.Conv1d(line, 4, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(4, 16, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(16, num_classes, kernel_size=winlen, stride=stride, padding=winlen//2)
        self.lstm = nn.LSTM(
            input_size=num_classes,
            hidden_size=hidden_size,
            num_layers=lstm_layers,
            dropout=lstm_dropout,
            bidirectional=True,
        )
        self.fc = nn.Linear(hidden_size, num_classes)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.permute(2, 0, 1) #(S_reducted, B, C)
        S_reducted, B, _ = x.shape
        x, _n = self.lstm(x)# (S_reducted, B, 2 * H)
        x = x.view(S_reducted, B, 2, -1).sum(dim=2) # (S_reducted, B, H)
        x = self.fc(x)  #(S_reducted, B, C)
        return x.permute(1, 2, 0)#(B, C, S_reducted)
    
# class CNN(nn.Module):
#     def __init__(self, input_size, num_classes):
#         super(CNN, self).__init__()
        
#         self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
#         self.relu1 = nn.ReLU()
#         self.maxpool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        
#         self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
#         self.relu2 = nn.ReLU()
#         self.maxpool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        
#         self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
#         self.relu3 = nn.ReLU()
#         self.maxpool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        
#         self.fc = nn.Linear(in_features=64 * (input_size // 8), out_features=num_classes)
        
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.relu1(x)
#         x = self.maxpool1(x)
        
#         x = self.conv2(x)
#         x = self.relu2(x)
#         x = self.maxpool2(x)
        
#         x = self.conv3(x)
#         x = self.relu3(x)
#         x = self.maxpool3(x)
        
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)
        
#         return x


In [4]:
class RNADataset(Dataset):
    def __init__(self, chunks, references, reference_lengths):
        self.chunks = chunks
        self.references = references
        self.reference_lengths = reference_lengths

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, index):
        return self.chunks[index], self.references[index], self.reference_lengths[index]


In [5]:



def compute_ctc_loss(outputs, targets, target_lengths):
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
    outputs = outputs.permute(2, 0, 1) #(S_reducted, B, C)
    S_reducted, B, C =outputs.shape
    outputs_lengths = torch.ones(B).type_as(outputs).int() * S_reducted
    
    
    loss = ctc_loss(outputs, targets,outputs_lengths, target_lengths)
    return loss

In [6]:
def train(batch_size,model, optimizer, criterion, dataloader, device):
    model.train()
    total_loss = 0.0
    for i, (inputs, targets, target_lengths) in enumerate(dataloader):
        inputs, targets,target_lengths = inputs.to(torch.float32).to(device), targets.to(device),target_lengths.int().to(device)
        #inputs torch.float32 torch.Size([B, 1, S])
        #targets torch.int64 torch.Size([B, base_seq_len])
        #target_lengths torch.float32  torch.Size([B])
        optimizer.zero_grad()
        outputs = model(inputs)
        #outputs torch.float32 torch.Size([B, C, S_reducted])
        loss = criterion(outputs, targets, target_lengths)
        if i%100==0:
            print(loss)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

class SeqErrorRate:
    """Using Levenshtein distance."""

    def __init__(self, ignore_tokens=None):
        if ignore_tokens is not None:
            self.ignore_tokens = set(ignore_tokens)
        else:
            self.ignore_tokens = set()

    def __call__(self, preds, targets):
        B = preds.shape[0]
        total_error = 0.0
        total_length = 0
        for i in range(B):
            pred = [p for p in preds[i].tolist() if p not in self.ignore_tokens]
            target = [t for t in targets[i].tolist() if t not in self.ignore_tokens]
            distance = Levenshtein.distance(''.join([chr(c) for c in pred]), ''.join([chr(c) for c in target]))
            error = distance / max(len(pred), len(target))
            total_error += error
            total_length += 1
        return total_error / total_length
def evaluate(batch_size,model, criterion, dataloader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_err = 0
    true_labels = []
    pred_probs = []
    cal_err= SeqErrorRate(ignore_tokens=[0])
    with torch.no_grad():
        for i, (inputs, targets, target_lengths) in enumerate(dataloader):
            inputs, targets,target_lengths = inputs.to(torch.float32).to(device), targets.to(device),target_lengths.int().to(device)
            outputs = model(inputs)#(B, C, S_reducted)
            loss = criterion(outputs, targets, target_lengths)
            total_loss += loss.item()
#             print(outputs.shape, targets.shape)
            outputs = outputs.argmax(dim=1)
            error = cal_err(outputs, targets)
            total_err += error
            
            
    return total_loss / len(dataloader), total_err/ len(dataloader)

In [7]:
import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

batch_size = 32
input_size = 3600  # 输入信号段长度


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_chunks = np.load('/data/feiyang/preprocessed/train_chunks.npy')
train_references = np.load('/data/feiyang/preprocessed/train_references.npy')
train_reference_lengths = np.load('/data/feiyang/preprocessed/train_reference_lengths.npy')
train_reference_lengths =train_reference_lengths.astype(np.int32)

valid_chunks = np.load('/data/feiyang/preprocessed/valid_chunks.npy')
valid_references = np.load('/data/feiyang/preprocessed/valid_references.npy')
valid_reference_lengths = np.load('/data/feiyang/preprocessed/valid_reference_lengths.npy')
valid_reference_lengths =valid_reference_lengths.astype(np.int32)
train_dataset = RNADataset(torch.tensor(train_chunks).unsqueeze(1).to(torch.float32), torch.tensor(train_references).to(torch.int32), torch.tensor(train_reference_lengths))
#inputs targets target_lengths torch.Size([32, 1, 3600]) torch.float32 torch.Size([32, 130]) torch.int32 torch.Size([32]) torch.int32

valid_dataset = RNADataset(torch.tensor(valid_chunks).unsqueeze(1).to(torch.float32), torch.tensor(valid_references).to(torch.int32),
                           torch.tensor(valid_reference_lengths))
# test_dataset = RNADataset(torch.tensor(test_chunks), torch.tensor(test_references), torch.tensor(test_reference_lengths))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)



In [7]:

# model=CNNLSTM(line=line,num_classes=num_classes,stride=5,winlen=19,
#               hidden_size=hidden_size,lstm_layers=lstm_layers,lstm_dropout=lstm_dropout).to(device)


# for i, (inputs, targets, target_lengths) in enumerate(train_loader):
#     print(inputs.shape,inputs.dtype, targets.shape,targets.dtype, target_lengths.shape,target_lengths.dtype)
#     inputs, targets,target_lengths = inputs.to(device), targets.to(device),target_lengths.int().to(device)
#     print(target_lengths.dtype)
    
#     outputs = model(inputs) 
    
#     outputs = outputs.permute(2, 0, 1) #(S_reducted, B, C)
#     S_reducted, B, C =outputs.shape
#     outputs_lengths = torch.ones(B).type_as(outputs).int() * S_reducted
    
#     print(outputs_lengths.shape,outputs_lengths.dtype)
#     loss = ctc_loss(outputs, targets,outputs_lengths, target_lengths)
#     print(loss)
#     break

In [8]:
# Model

num_epochs = 10
line=1
num_classes = 5  # 类别数（四个碱基和一个空白标签）
input_size = 4  # 输入大小（A，C，G，T四个碱基）
hidden_size = 512  # 隐状态大小
lstm_layers = 2  # LSTM层数
lstm_dropout=0.2
learning_rate = 1e-4

model=CNNLSTM(line=line,num_classes=num_classes,stride=5,winlen=19,
              hidden_size=hidden_size,lstm_layers=lstm_layers,lstm_dropout=lstm_dropout).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=1, verbose=True)


model_dir = './model'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'clstm_model.pth')


if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print(f'Loaded model from {model_path}')

# Train

for epoch in range(num_epochs):
    train_loss = train(batch_size, model, optimizer, compute_ctc_loss, train_loader, device)

    valid_loss, valid_error = evaluate(batch_size, model, compute_ctc_loss, valid_loader, device)
    scheduler.step(valid_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.3f}, Valid Loss: {valid_loss:.3f}, Valid Err: {valid_error:.3f}')

    # 保存模型
    torch.save(model.state_dict(), model_path)
    print(f'Saved model to {model_path}')
print('ok')

Loaded model from ./model/clstm_model.pth
tensor(2.6122, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.6930, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.4994, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.5452, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.7493, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.4969, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.6958, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.5014, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.6815, device='cuda:0', grad_fn=<MeanBackward0>)
