In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import random

def read_sequences(filename):
    try:
        with open(filename, 'r') as f:
            return [line.strip() for line in f if line.strip()]
    except FileNotFoundError:
        print(f"File {filename} not found.")
        #return generate_sample_data()

positive_sequences = read_sequences("dnase_seq_output_positive_trimmed.txt")
negative_sequences = read_sequences("negative_file.txt")

labeled_data = [(seq, 1) for seq in positive_sequences] + [(seq, 0) for seq in negative_sequences]
random.shuffle(labeled_data)

def one_hot_encode(sequence):
    mapping = {'A': [1,0,0,0], 'C': [0,1,0,0], 'G': [0,0,1,0], 'T': [0,0,0,1]}
    return np.array([mapping[base] for base in sequence.upper()])

class DNADataset(Dataset):
    def __init__(self, data):
        self.sequences = []
        self.labels = []
        for seq, label in data:
            self.sequences.append(one_hot_encode(seq))
            self.labels.append(label)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return torch.FloatTensor(self.sequences[idx]), torch.tensor(self.labels[idx])
    
class AlexNet1D(nn.Module):
    def __init__(self):
        super(AlexNet1D, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(4, 96, kernel_size=11, stride=4),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2),
            nn.Conv1d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2),
            nn.Conv1d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 2, 256),  # Flattened size: 256 * 2
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class NiN1D(nn.Module):
    def __init__(self):
        super(NiN1D, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(4, 96, kernel_size=11, stride=4),
            nn.ReLU(),
            nn.Conv1d(96, 96, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(96, 96, kernel_size=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2),
            nn.Conv1d(96, 256, kernel_size=5, padding=2, stride=2),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2),
            nn.Conv1d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(384, 384, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(384, 384, kernel_size=1),
            nn.ReLU(),
        )
        self.classifier = nn.Sequential(
            nn.Linear(384 * 3, 256),  # Flattened size: 384 * 3
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    
def train_model(model, train_loader, test_loader, optimizer, criterion, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    accuracies = []
    all_preds = []
    all_labels = []

    for epoch in range(epochs):
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        accuracy = 100. * correct / total
        accuracies.append(accuracy)
        print(f'Epoch {epoch+1}, Test Accuracy: {accuracy:.2f}%')

    return accuracies, all_preds, all_labels

def main():
    dataset = DNADataset(labeled_data)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    configs = [
        ("AlexNet-0.01-16", AlexNet1D(), 0.01, 16),
        ("AlexNet-0.0001-64", AlexNet1D(), 0.0001, 64),
        ("NiN-0.0001-16", NiN1D(), 0.0001, 16),
        ("NiN-0.01-64", NiN1D(), 0.01, 64)
    ]

    results = {}
    all_conf_matrices = {}
    for name, model, lr, batch_size in configs:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)

        print(f"\nTraining {name} with batch size {batch_size} and learning rate {lr}...")
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        accuracies, all_preds, all_labels = train_model(model, train_loader, test_loader, optimizer, criterion)
        results[name] = accuracies

        # Compute confusion matrix
        conf_matrix = confusion_matrix(all_labels, all_preds)
        all_conf_matrices[name] = conf_matrix

    # Plot accuracy vs. epochs
    plt.figure(figsize=(10, 6))
    for name, accuracies in results.items():
        plt.plot(range(1, len(accuracies) + 1), accuracies, label=name)
    plt.title("Accuracy vs. Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Display confusion matrices
    for name, conf_matrix in all_conf_matrices.items():
        disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=[0, 1])
        disp.plot(cmap='Blues', values_format='d')
        plt.title(f"Confusion Matrix for {name}")
        plt.show()

# Run the main function
results = main()