In [None]:
# Trained on variable length
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import Sequence
from functools import partial

def set_seed(seed=13):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(13)

def rand_sequence(n_seqs: int, min_len: int=50, max_len: int=200) -> Sequence[int]:
    for _ in range(n_seqs):
        seq_len = random.randint(min_len, max_len)
        yield [random.randint(0, 4) for _ in range(seq_len)]

def count_cpgs(seq: str) -> int:
    return sum(seq[i:i+2] == "CG" for i in range(len(seq) - 1))

alphabet = 'NACGT'
dna2int = {a: i for i, a in enumerate(alphabet)}
int2dna = {i: a for i, a in enumerate(alphabet)}
intseq_to_dnaseq = partial(map, int2dna.get)
dnaseq_to_intseq = partial(map, dna2int.get)

def prepare_data(num_samples=100, min_len=50, max_len=200):
    X_dna_seqs = list(rand_sequence(num_samples, min_len, max_len))
    temp = [''.join(intseq_to_dnaseq(seq)) for seq in X_dna_seqs]
    y_dna_seqs = [count_cpgs(seq) for seq in temp]
    return X_dna_seqs, y_dna_seqs

def one_hot_encode(sequence, num_classes=5):
    sequence_tensor = torch.tensor(sequence, dtype=torch.long)
    one_hot = F.one_hot(sequence_tensor, num_classes=num_classes)
    return one_hot.float()

def normalize_data(data):
    return (data - data.mean()) / data.std()

def collate_fn(batch):
    sequences, labels = zip(*batch)
    lengths = [len(seq) for seq in sequences]
    padded_seqs = nn.utils.rnn.pad_sequence(sequences, batch_first=True)
    return padded_seqs, torch.tensor(labels, dtype=torch.float32), torch.tensor(lengths)

train_x, train_y = prepare_data(4096)
test_x, test_y = prepare_data(1024)

train_x_one_hot = [one_hot_encode(seq) for seq in train_x]
test_x_one_hot = [one_hot_encode(seq) for seq in test_x]

train_dataset = list(zip(train_x_one_hot, train_y))
test_dataset = list(zip(test_x_one_hot, test_y))

batch_size = 128
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

class LSTMCpGPredictor(nn.Module):
    def __init__(self, input_size=5, hidden_size=64, num_layers=2, output_size=1):
        super(LSTMCpGPredictor, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, lengths):
        packed_input = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        last_output = output[torch.arange(output.size(0)), lengths - 1]
        return self.fc(last_output)

model = LSTMCpGPredictor()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, train_loader, loss_fn, optimizer, epochs=50):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for inputs, labels, lengths in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs, lengths)
            loss = loss_fn(outputs, labels.unsqueeze(1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}')

def evaluate_model(model, test_loader):
    model.eval()
    res_gs, res_pred = [], []
    with torch.no_grad():
        for inputs, labels, lengths in test_loader:
            outputs = model(inputs, lengths)
            res_gs.extend(labels.numpy())
            res_pred.extend(outputs.numpy().flatten())
    return np.array(res_gs), np.array(res_pred)

train_model(model, train_data_loader, loss_fn, optimizer)

res_gs, res_pred = evaluate_model(model, test_data_loader)

mse = np.mean((res_gs - res_pred) ** 2)
mae = np.mean(np.abs(res_gs - res_pred))
r2 = 1 - np.sum((res_gs - res_pred)**2) / np.sum((res_gs - np.mean(res_gs))**2)

print(f'Mean Squared Error: {mse:.4f}')
print(f'Mean Absolute Error: {mae:.4f}')
print(f'R-squared: {r2:.4f}')

torch.save(model.state_dict(), 'cpg_detector_model.pth')

Epoch 1/50, Loss: 20.3555
Epoch 2/50, Loss: 6.9104
Epoch 3/50, Loss: 6.7933
Epoch 4/50, Loss: 6.8048
Epoch 5/50, Loss: 6.7984
Epoch 6/50, Loss: 6.7950
Epoch 7/50, Loss: 6.7990
Epoch 8/50, Loss: 6.8015
Epoch 9/50, Loss: 6.7959
Epoch 10/50, Loss: 6.7962
Epoch 11/50, Loss: 6.7947
Epoch 12/50, Loss: 6.7976
Epoch 13/50, Loss: 6.7936
Epoch 14/50, Loss: 6.7945
Epoch 15/50, Loss: 6.7995
Epoch 16/50, Loss: 6.7956
Epoch 17/50, Loss: 6.7969
Epoch 18/50, Loss: 6.7960
Epoch 19/50, Loss: 6.7943
Epoch 20/50, Loss: 6.7943
Epoch 21/50, Loss: 6.7979
Epoch 22/50, Loss: 6.7972
Epoch 23/50, Loss: 6.7943
Epoch 24/50, Loss: 6.8079
Epoch 25/50, Loss: 6.8068
Epoch 26/50, Loss: 6.7952
Epoch 27/50, Loss: 6.7965
Epoch 28/50, Loss: 6.8246
Epoch 29/50, Loss: 6.8015
Epoch 30/50, Loss: 6.8081
Epoch 31/50, Loss: 6.8018
Epoch 32/50, Loss: 6.8038
Epoch 33/50, Loss: 6.8026
Epoch 34/50, Loss: 6.7972
Epoch 35/50, Loss: 6.7946
Epoch 36/50, Loss: 6.8016
Epoch 37/50, Loss: 6.7908
Epoch 38/50, Loss: 6.8035
Epoch 39/50, Loss: 6