In [11]:
from Bio import SeqIO

def read_and_pad_fasta(file_path, max_length=512):
    sequences = []
    for record in SeqIO.parse(file_path, "fasta"):
        seq = str(record.seq)
        # 填充到指定长度
        if len(seq) < max_length:
            seq += 'N' * (max_length - len(seq))
        sequences.append(seq[:max_length])  # 确保不超过max_length
    return sequences

sequences = read_and_pad_fasta("../../../data/5utr_95.fasta")

sequences[0]

'CCACGCGTCCGGGAAGCGAGCGGCTGAGTTGCTGCGGGGAAAAAATAAAAAATAAATAAAAAGCCAAATTAGTTGTGTCTTGCGGGAAGTGGAAGCCTCTGGTTGTTGTGTCTGCGGTTAAACAGCGCTTCTATTCGCGGCTTGTCTTGTTCCCAACTGTGTGAGATTTGCTAGTGACCCGGCCTGTGTACTCCCCTGCCAGGCATACATANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN'

In [12]:
import numpy as np

def one_hot_encode(sequences, max_length=512):
    encoding = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4}
    one_hot_encoded = np.zeros((len(sequences), max_length, 5), dtype=np.int8)
    
    for i, seq in enumerate(sequences):
        for j, nucleotide in enumerate(seq):
            if nucleotide in encoding:
                one_hot_encoded[i, j, encoding[nucleotide]] = 1
    
    one_hot_encoded = np.transpose(one_hot_encoded, (0, 2, 1))
    
    return one_hot_encoded

one_hot_sequences = one_hot_encode(sequences)

In [13]:
print(one_hot_sequences[0])
print(one_hot_sequences[0].shape)

[[0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [1 1 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 1 1]]
(5, 512)


In [14]:
import torch
from torch.utils.data import DataLoader, random_split, TensorDataset

sequences_tensor = torch.tensor(one_hot_sequences, dtype=torch.float32)
dataset = TensorDataset(sequences_tensor)

train_size = int(0.8 * len(dataset))  # 80% 用于训练
val_size = len(dataset) - train_size  # 剩余 20% 用于验证

torch.save(sequences_tensor, '5utr_95.pt')

print(f"data saved in utr_95.pt")

data saved in utr_95.pt


In [15]:
# example usage
loaded_sequences = torch.load('5utr_95.pt', weights_only=True)

dataloader = DataLoader(loaded_sequences, batch_size=32)

for batch in dataloader:
    print(batch.shape)  # 输出形状为 (batch_size, 512, 5)
    break

torch.Size([32, 5, 512])


In [16]:
import numpy as np


# ===== delete later ====
import torch
from torch.utils.data import DataLoader, random_split, TensorDataset

sequences_tensor = torch.tensor(one_hot_sequences[0:1000], dtype=torch.float32)
dataset = TensorDataset(sequences_tensor)

train_size = int(0.8 * len(dataset))  # 80% 用于训练
val_size = len(dataset) - train_size  # 剩余 20% 用于验证

torch.save(sequences_tensor, '5utr_95_tmp.pt')

print(f"data saved in utr_95_tmp.pt")

data saved in utr_95_tmp.pt
