In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

class DeepSignal(nn.Module):
    def __init__(self, sequence_feature_size=128, signal_feature_size=512, num_classes=2):
        super(DeepSignal, self).__init__()

        # Sequence Feature Module (BRNN)
        self.sequence_brnn = nn.LSTM(
            input_size=4,  # nucleotide type, mean, std, num_signals
            hidden_size=sequence_feature_size // 2,  # //2 for bidirectional
            num_layers=3,
            bidirectional=True,
            batch_first=True
        )

        # Signal Feature Module (CNN with Inception blocks)
        self.signal_cnn = SignalCNN()

        # Classification Module
        self.classifier = nn.Sequential(
            nn.Linear(sequence_feature_size + signal_feature_size, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, sequence_features, signal_features):
        # Sequence features: (batch_size, seq_len=17, features=4)
        sequence_out, (h_n, c_n) = self.sequence_brnn(sequence_features)
        # Use the last hidden state from both directions
        sequence_features = torch.cat([h_n[-2], h_n[-1]], dim=1)  # (batch_size, sequence_feature_size)

        # Signal features: (batch_size, 1, 360)
        signal_features = self.signal_cnn(signal_features)

        # Concatenate features
        combined_features = torch.cat([sequence_features, signal_features], dim=1)

        # Classification
        output = self.classifier(combined_features)

        return output

class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InceptionBlock, self).__init__()

        # 1x1 convolution branch
        self.branch1 = nn.Conv1d(in_channels, out_channels//5, kernel_size=1)

        # 1x3 convolution branch
        self.branch2 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels//5, kernel_size=1),
            nn.Conv1d(out_channels//5, out_channels//5, kernel_size=3, padding=1)
        )

        # 1x5 convolution branch
        self.branch3 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels//5, kernel_size=1),
            nn.Conv1d(out_channels//5, out_channels//5, kernel_size=5, padding=2)
        )

        # Residual 1x3 convolution branch
        self.branch4 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels//5, kernel_size=3, padding=1),
        )

        # 1x3 maxpool branch
        self.branch5 = nn.Sequential(
            nn.MaxPool1d(kernel_size=3, stride=1, padding=1),
            nn.Conv1d(in_channels, out_channels//5, kernel_size=1)
        )

        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x) + x  # residual connection
        branch5 = self.branch5(x)

        out = torch.cat([branch1, branch2, branch3, branch4, branch5], dim=1)
        out = self.bn(out)
        out = self.relu(out)

        return out

class SignalCNN(nn.Module):
    def __init__(self):
        super(SignalCNN, self).__init__()

        # Initial layers
        self.initial_conv = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        )

        # Inception blocks (11 total as mentioned in paper)
        self.inception_blocks = nn.Sequential(
            InceptionBlock(64, 64),
            InceptionBlock(64, 64),
            InceptionBlock(64, 64),  # 3x inception

            InceptionBlock(64, 128),
            InceptionBlock(128, 128),
            InceptionBlock(128, 128),
            InceptionBlock(128, 128),
            InceptionBlock(128, 128),  # 5x inception

            InceptionBlock(128, 256),
            InceptionBlock(256, 256),
            InceptionBlock(256, 256),
            InceptionBlock(256, 256),
            InceptionBlock(256, 512),  # 5x inception
        )

        # Final layers
        self.final_pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        # x shape: (batch_size, 1, 360)
        x = self.initial_conv(x)
        x = self.inception_blocks(x)
        x = self.final_pool(x)
        x = x.view(x.size(0), -1)

        return x

class MethylationDataset(Dataset):
    def __init__(self, sequence_features, signal_features, labels):
        """
        sequence_features: numpy array of shape (num_samples, 17, 4)
        signal_features: numpy array of shape (num_samples, 360)
        labels: numpy array of shape (num_samples,) with 0/1 labels
        """
        self.sequence_features = torch.FloatTensor(sequence_features)
        self.signal_features = torch.FloatTensor(signal_features).unsqueeze(1)  # Add channel dimension
        self.labels = torch.FloatTensor(labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.sequence_features[idx], self.signal_features[idx], self.labels[idx]

class DeepSignalTrainer:
    def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.device = device

        # Loss function and optimizer as described in paper
        self.criterion = nn.BCEWithLogitsLoss()
        self.optimizer = optim.Adam(model.parameters(), lr=0.001)

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=2, gamma=0.1)

    def train_epoch(self, dataloader):
        self.model.train()
        total_loss = 0

        for batch_idx, (seq_features, sig_features, labels) in enumerate(dataloader):
            seq_features = seq_features.to(self.device)
            sig_features = sig_features.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()

            outputs = self.model(seq_features, sig_features)

            # Convert to probabilities using sigmoid and normalize
            prob_methylated = torch.sigmoid(outputs[:, 0])
            prob_unmethylated = torch.sigmoid(outputs[:, 1])

            total_prob = prob_methylated + prob_unmethylated
            final_output_methylated = prob_methylated / total_prob

            loss = self.criterion(final_output_methylated, labels)
            loss.backward()

            self.optimizer.step()

            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f'Batch {batch_idx}, Loss: {loss.item():.4f}')

        return total_loss / len(dataloader)

    def train(self, train_loader, val_loader, epochs=50, early_stopping_patience=5):
        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(epochs):
            train_loss = self.train_epoch(train_loader)
            val_loss = self.validate(val_loader)

            self.scheduler.step()

            print(f'Epoch {epoch+1}/{epochs}:')
            print(f'  Train Loss: {train_loss:.4f}')
            print(f'  Val Loss: {val_loss:.4f}')
            print(f'  LR: {self.optimizer.param_groups[0]["lr"]:.6f}')

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                # Save best model
                torch.save(self.model.state_dict(), 'deepsignal_best_model.pth')
            else:
                patience_counter += 1

            if patience_counter >= early_stopping_patience:
                print(f'Early stopping at epoch {epoch+1}')
                break

    def validate(self, dataloader):
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for seq_features, sig_features, labels in dataloader:
                seq_features = seq_features.to(self.device)
                sig_features = sig_features.to(self.device)
                labels = labels.to(self.device)

                outputs = self.model(seq_features, sig_features)

                prob_methylated = torch.sigmoid(outputs[:, 0])
                prob_unmethylated = torch.sigmoid(outputs[:, 1])

                total_prob = prob_methylated + prob_unmethylated
                final_output_methylated = prob_methylated / total_prob

                loss = self.criterion(final_output_methylated, labels)
                total_loss += loss.item()

        return total_loss / len(dataloader)

def main():

    train_dataset = MethylationDataset(sequence_train, signal_train, labels_train)
    val_dataset = MethylationDataset(sequence_val, signal_val, labels_val)

    train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

    model = DeepSignal()

    trainer = DeepSignalTrainer(model)

    trainer.train(train_loader, val_loader, epochs=50)

if __name__ == "__main__":
    main()

NO Testing cause limited data, although  can use similar signals to test after passing through the pipeline. Like download another raw signals and test on the pipeline