In [1]:
# Import necessary libraries
import h5py
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
# Set the random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fce25fd1330>

In [32]:
# Data Preprocessing and HDF5 Creation
# Define nucleotide mapping
nucleotides = ['A', 'C', 'G', 'T']
token2idx = {nuc: idx + 2 for idx, nuc in enumerate(nucleotides)}  # 'A':2, 'C':3, 'G':4, 'T':5
token2idx['PAD'] = 0  # Padding token
token2idx['MASK'] = 1  # Mask token
idx2token = {idx: nuc for nuc, idx in token2idx.items()}
vocab_size = len(token2idx)  # Should be 6

In [33]:
# Function to encode sequences
def encode_sequence_to_array(seq):
    seq = seq.upper()
    encoded_seq = [token2idx.get(nuc, token2idx['PAD']) for nuc in seq]  # Map unknown nucleotides to 'PAD'
    return np.array(encoded_seq, dtype=np.uint8)

In [54]:
# Function to encode sequences
def encode_sequence(seq, max_length):
    seq = seq.upper()
    seq_encoded = [token2idx.get(nuc, token2idx['PAD']) for nuc in seq]
    # Pad sequences
    if len(seq_encoded) < max_length:
        seq_encoded += [token2idx['PAD']] * (max_length - len(seq_encoded))
    else:
        seq_encoded = seq_encoded[:max_length]
    return seq_encoded

In [34]:
# Paths to input and output files
input_file = '/mnt/f/hprc/segments_b.hdf5'

In [35]:
# Split probabilities
train_prob = 0.8
val_prob = 0.1
test_prob = 0.1
assert train_prob + val_prob + test_prob == 1.0, "Split probabilities must sum to 1"

In [36]:
# Custom Dataset and DataLoader
# Function to mask sequences for MLM
def mask_sequence(seq_encoded, mask_prob=0.15):
    input_ids = seq_encoded.copy()
    labels = [-100] * len(seq_encoded)  # Initialize labels with -100 (ignore index)
    for i in range(len(seq_encoded)):
        if seq_encoded[i] == token2idx['PAD']:
            continue  # Skip padding tokens
        if random.random() < mask_prob:
            labels[i] = seq_encoded[i]  # Save the original token id for loss calculation
            input_ids[i] = token2idx['MASK']  # Replace with [MASK] token
    return input_ids, labels

In [37]:
# Custom Dataset class
class HDF5Dataset(Dataset):
    def __init__(self, hdf5_file_path, split='train', mask_prob=0.15, max_seq_len=512):
        self.hdf5_file_path = hdf5_file_path
        self.split = split
        self.mask_prob = mask_prob
        self.max_seq_len = max_seq_len
        self.hdf5_file = None  # Will be opened lazily in __getitem__

    def __len__(self):
        if self.hdf5_file is None:
            with h5py.File(self.hdf5_file_path, 'r') as hdf5_file:
                self.length = len(hdf5_file[f'{self.split}_sequences'])
        return self.length

    def __getitem__(self, idx):
        if self.hdf5_file is None:
            # Each worker opens its own file handle
            self.hdf5_file = h5py.File(self.hdf5_file_path, 'r')
            self.dataset = self.hdf5_file[f'{self.split}_sequences']
        
        seq_encoded = self.dataset[idx].tolist()
        
        # Truncate the sequence to max_seq_len
        if len(seq_encoded) > self.max_seq_len:
            seq_encoded = seq_encoded[:self.max_seq_len]
        
        input_ids, labels = mask_sequence(seq_encoded, mask_prob=self.mask_prob)
        return input_ids, labels

    def __del__(self):
        if self.hdf5_file is not None:
            self.hdf5_file.close()

In [None]:
class DNADataset(Dataset):
    def __init__(self, hdf5_file_path, chunk_size=512, overlap=0, vocab_size=4):
        self.hdf5_file_path = hdf5_file_path
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.vocab_size = vocab_size
        self.hdf5_file = h5py.File(self.hdf5_file_path, 'r')
        self.sequences = self.hdf5_file['sequences']
        self.chunks = self._create_chunks()

    def __len__(self):
        if self.hdf5_file is None:
            with h5py.File(self.hdf5_file_path, 'r') as hdf5_file:
                self.length = len(hdf5_file[f'{self.split}_sequences'])
        return self.length

    def __getitem__(self, idx):
        seq_idx, start, end = self.chunks[idx]
        seq = self.sequences[seq_idx][start:end]
        seq_tensor = torch.from_numpy(seq).long()  # Shape: (seq_len,)

        # Convert to one-hot encoding
        one_hot_seq = F.one_hot(seq_tensor, num_classes=self.vocab_size).float()  # Shape: (seq_len, vocab_size)
        return one_hot_seq

    def __del__(self):
        self.hdf5_file.close()

In [38]:
max_seq_len = 512

In [39]:
# Custom collate function for variable-length sequences
def collate_fn(batch):
    input_ids = [torch.tensor(item[0], dtype=torch.long) for item in batch]
    labels = [torch.tensor(item[1], dtype=torch.long) for item in batch]
    
    # Truncate sequences to max_seq_len
    input_ids = [seq[:max_seq_len] for seq in input_ids]
    labels = [lbl[:max_seq_len] for lbl in labels]
    
    # Pad sequences to the maximum length in the batch (which will be <= max_seq_len)
    input_ids_padded = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=token2idx['PAD']
    )
    labels_padded = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=-100  # -100 is the ignore index for loss
    )
    return input_ids_padded, labels_padded


In [40]:
# Create datasets for each split
train_dataset = HDF5Dataset(hdf5_file_path=output_file, split='train', mask_prob=0.15)
val_dataset = HDF5Dataset(hdf5_file_path=output_file, split='val', mask_prob=0.15)

In [41]:
# Create DataLoaders
#train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=collate_fn)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,  # Further reduce if necessary
    shuffle=True,
    num_workers=2,  # Reduce from 4 to 2 or 1
    collate_fn=collate_fn,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn,
    persistent_workers=True
)

In [42]:
# Model Definition
# Positional Encoding class
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        seq_len = x.size(1)
        device = x.device

        position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1)  # (seq_len, 1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float, device=device) * 
                             (-np.log(10000.0) / self.d_model))

        pe = torch.zeros(seq_len, self.d_model, device=device)  # (seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices

        pe = pe.unsqueeze(0)  # (1, seq_len, d_model)
        x = x + pe
        return x


In [68]:
# Transformer Encoder Model
class TransformerEncoderModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, dff, num_layers, dropout_rate=0.1):
        super(TransformerEncoderModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=token2idx['PAD'])
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, 
                                                   dim_feedforward=dff, dropout=dropout_rate)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, src):
        # src: (batch_size, seq_len)
        src_key_padding_mask = src == token2idx['PAD']  # (batch_size, seq_len)
        x = self.embedding(src) * np.sqrt(self.embedding.embedding_dim)  # (batch_size, seq_len, d_model)
        x = self.pos_encoder(x)
        x = x.transpose(0, 1)  # (seq_len, batch_size, d_model)
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        x = x.transpose(0, 1)  # (batch_size, seq_len, d_model)
        logits = self.fc_out(x)  # (batch_size, seq_len, vocab_size)
        return logits


In [44]:
# Model Training

In [45]:
# Hyperparameters
d_model = 64
num_heads = 4
dff = 256
num_layers = 2
dropout_rate = 0.1
max_seq_len = 520  # Adjust as needed

In [46]:
# Instantiate the model
model = TransformerEncoderModel(
    vocab_size=vocab_size, 
    d_model=d_model, 
    num_heads=num_heads,
    dff=dff, 
    num_layers=num_layers, 
    dropout_rate=dropout_rate)



In [66]:
# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

TransformerEncoderModel(
  (embedding): Embedding(6, 64, padding_idx=0)
  (pos_encoder): PositionalEncoding()
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc_out): Linear(in_features=64, out_features=6, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [48]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [49]:
# Training loop
epochs = 10

In [50]:
for epoch in range(epochs):
    # Training Phase
    model.train()
    total_train_loss = 0
    #for batch_input_ids, batch_labels in train_loader:
    for batch_input_ids, batch_labels in tqdm(train_loader):
        batch_input_ids = batch_input_ids.to(device)
        batch_labels = batch_labels.to(device)

        optimizer.zero_grad()
        outputs = model(batch_input_ids)  # (batch_size, seq_len, vocab_size)
        loss = criterion(outputs.view(-1, vocab_size), batch_labels.view(-1))
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    print(f'Epoch {epoch+1}/{epochs}, Training Loss: {avg_train_loss:.4f}')

    # Validation Phase
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch_input_ids, batch_labels in val_loader:
            batch_input_ids = batch_input_ids.to(device)
            batch_labels = batch_labels.to(device)

            outputs = model(batch_input_ids)
            loss = criterion(outputs.view(-1, vocab_size), batch_labels.view(-1))
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    print(f'Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss:.4f}')


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:19<00:00, 129.14it/s]

Epoch 1/10, Training Loss: 1.3578





Epoch 1/10, Validation Loss: 1.3443


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:16<00:00, 149.75it/s]


Epoch 2/10, Training Loss: 1.3485
Epoch 2/10, Validation Loss: 1.3440


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:18<00:00, 137.15it/s]


Epoch 3/10, Training Loss: 1.3477
Epoch 3/10, Validation Loss: 1.3339


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:17<00:00, 140.64it/s]


Epoch 4/10, Training Loss: 1.3464
Epoch 4/10, Validation Loss: 1.3328


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:20<00:00, 123.53it/s]


Epoch 5/10, Training Loss: 1.3455
Epoch 5/10, Validation Loss: 1.3354


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:20<00:00, 123.27it/s]


Epoch 6/10, Training Loss: 1.3431
Epoch 6/10, Validation Loss: 1.3431


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:16<00:00, 150.82it/s]


Epoch 7/10, Training Loss: 1.3431
Epoch 7/10, Validation Loss: 1.3354


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:20<00:00, 124.29it/s]


Epoch 8/10, Training Loss: 1.3455
Epoch 8/10, Validation Loss: 1.3324


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:18<00:00, 138.62it/s]


Epoch 9/10, Training Loss: 1.3443
Epoch 9/10, Validation Loss: nan


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2506/2506 [00:20<00:00, 124.91it/s]


Epoch 10/10, Training Loss: 1.3416
Epoch 10/10, Validation Loss: 1.3395


In [51]:
# for epoch in range(epochs):
#     # Training Phase
#     model.train()
#     total_train_loss = 0
#     #for batch_input_ids, batch_labels in train_loader:
#     for batch_input_ids, batch_labels in tqdm(train_loader):
#         batch_input_ids = batch_input_ids.to(device)
#         batch_labels = batch_labels.to(device)

#         optimizer.zero_grad()
#         outputs = model(batch_input_ids)  # (batch_size, seq_len, vocab_size)
#         loss = criterion(outputs.view(-1, vocab_size), batch_labels.view(-1))
#         loss.backward()
#         optimizer.step()

#         total_train_loss += loss.item()

#     avg_train_loss = total_train_loss / len(train_loader)
#     print(f'Epoch {epoch+1}/{epochs}, Training Loss: {avg_train_loss:.4f}')

#     # Validation Phase
#     model.eval()
#     total_val_loss = 0
#     with torch.no_grad():
#         for batch_input_ids, batch_labels in val_loader:
#             batch_input_ids = batch_input_ids.to(device)
#             batch_labels = batch_labels.to(device)

#             outputs = model(batch_input_ids)
#             loss = criterion(outputs.view(-1, vocab_size), batch_labels.view(-1))
#             total_val_loss += loss.item()

#     avg_val_loss = total_val_loss / len(val_loader)
#     print(f'Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss:.4f}')

In [69]:
def sequence_to_feature_vector(sequence):
    model.eval()

    # Encode sequence and convert to tensor
    seq_encoded = encode_sequence(sequence, max_seq_len)
    seq_tensor = torch.tensor([seq_encoded], dtype=torch.long).to(device)  # Move to specified device

    with torch.no_grad():
        # Get embeddings
        batch_size, seq_len = seq_tensor.size()

        # Apply embeddings and positional encoding
        x = model.embedding(seq_tensor) * np.sqrt(model.embedding.embedding_dim)  # Embedding scaling
        x = model.pos_encoder(x)  # Apply positional encoding

        # Apply transformer encoder layers
        x = x.transpose(0, 1)  # (seq_len, batch_size, d_model)
        x = model.encoder(x)
        x = x.transpose(0, 1)  # (batch_size, seq_len, d_model)

        # Pooling to get feature vector
        feature_vector = x.mean(dim=1).squeeze(0).cpu().numpy()  # Move to CPU before converting to numpy
    return feature_vector


In [70]:
# Example usage
sequence = 'ACGTGCTAGC'
feature_vector = sequence_to_feature_vector(sequence)
print(f"Feature Vector (shape {feature_vector.shape}):\n{feature_vector}")

Feature Vector (shape (64,)):
[ 5.8099587e-02 -6.1848026e-02  6.0076017e-02 -4.3027827e-01
 -2.8560886e-01  3.1843016e-01  4.0122762e-02  5.3647733e-01
 -1.0901927e+00  1.6983964e-02  2.1897623e-02 -3.3709139e-02
 -5.7947643e-02 -4.0864423e-03  1.1620720e-02  3.5534676e-02
  4.3646872e-02  2.5381869e-02  3.2397246e-01 -1.0952059e+00
  2.0940056e-01 -4.1635954e-03 -1.0789967e-02  8.6933456e-02
 -2.3620518e-01 -4.8311031e-04  2.1125045e-02  4.4094608e-03
  3.2754294e-03 -2.8062798e-02 -1.9266700e-02  4.1266456e-02
 -6.1162031e-01 -5.6737955e-03 -9.3305456e-03  1.2559577e-02
 -1.8635046e-02  8.6001754e-03  9.2126913e-02 -1.1084046e-01
 -4.7628932e-02  2.2788800e-01  8.9139994e-03 -1.2620354e-01
  1.1468509e-01 -1.6168781e-01 -1.0524358e-02 -1.6406955e-01
 -3.4234583e-02 -3.6349267e-02  2.7449679e-02  7.0067993e-03
  1.5577413e-01 -1.5072285e-02 -1.5130562e-02 -2.5287632e-02
 -5.3113955e-03  5.4835291e-03 -2.5302169e-03  2.5189964e-02
 -4.1460890e-02 -9.7081564e-02  7.1518809e-02 -1.127862