In [56]:
import os
import torch.nn.functional as F
import random
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

In [57]:
def get_seq_length(filepath):
    max_length = 0
    with open(filepath, "r") as file:
        for line in file:
            _, seq = line.split(maxsplit=1)
            seq = seq.strip()
            if len(seq) > max_length:
                max_length = len(seq)
    return max_length

In [58]:
def one_hot_encode(sequence):
    mapping = {
        "A": [1, 0, 0, 0],
        "T": [0, 1, 0, 0],
        "G": [0, 0, 1, 0],
        "C": [0, 0, 0, 1],
    }
    return np.array(
        [mapping[char.upper()] for char in sequence.strip() if char.upper() in mapping],
        dtype=float,
    )

In [59]:
def merge_files(pos_dir, neg_dir, output_file, limit=None):
    with open(output_file, "w") as f:
        pos_files = sorted(os.listdir(pos_dir))
        neg_files = sorted(os.listdir(neg_dir))
        random.shuffle(pos_files)
        random.shuffle(neg_files)

        half_limit = len(pos_files) if limit is None else limit // 2
        half_limit = min(half_limit, len(pos_files), len(neg_files))

        for pos_file in pos_files[:half_limit]:
            with open(os.path.join(pos_dir, pos_file), "r") as pf:
                f.writelines(["1 " + line for line in pf])

        for neg_file in neg_files[:half_limit]:
            with open(os.path.join(neg_dir, neg_file), "r") as nf:
                f.writelines(["0 " + line for line in nf])

In [60]:
def split_data(input_file, limit=None):
    with open(input_file, "r") as f:
        lines = f.readlines()
        random.shuffle(lines)

        n = len(lines) if limit is None else limit

        train_end = int(n * 0.8)
        train_lines = lines[:train_end]
        test_lines = lines[train_end:]

    with open("train.txt", "w") as f:
        f.writelines(train_lines)
    with open("test.txt", "w") as f:
        f.writelines(test_lines)

In [61]:
class MyDataset(Dataset):
    def __init__(self, filepath, max_samples=None):
        self.data = []
        self.labels = []
        count = 0
        with open(filepath, "r") as f:
            for line in f:
                if max_samples and count >= max_samples:
                    break
                label, seq = line.split(maxsplit=1)
                encoded_seq = one_hot_encode(seq.strip())
                tensor_seq = torch.tensor(encoded_seq, dtype=torch.float).transpose(
                    0, 1
                )
                self.data.append(tensor_seq)
                self.labels.append(int(label))
                count += 1

        self.data = torch.stack(self.data)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


def create_dataloader(filepath, batch_size, max_samples=None):
    dataset = MyDataset(filepath, max_samples)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [62]:
class MiniCNN(nn.Module):
    def __init__(self, input_length):
        super(MiniCNN, self).__init__()
        self.conv1 = nn.Conv1d(4, 6, kernel_size=5)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(6, 16, kernel_size=3)

        def conv_output_size(L, K, S=1, P=0):
            return (L - K + 2 * P) // S + 1

        L1 = conv_output_size(input_length, 5)
        L2 = conv_output_size(L1, 2, S=2)
        L3 = conv_output_size(L2, 3)
        L4 = conv_output_size(L3, 2, S=2)

        fc_input_features = 16 * L4
        self.fc1 = nn.Linear(fc_input_features, 120)
        self.fc2 = nn.Linear(120, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [63]:
class AlexNet(nn.Module):
    def __init__(self, input_length):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv1d(4, 96, kernel_size=11, stride=4)
        self.local_response1 = nn.LocalResponseNorm(
            size=5, alpha=0.0001, beta=0.75, k=2
        )
        self.pool1 = nn.MaxPool1d(kernel_size=3, stride=2)
        self.conv2 = nn.Conv1d(96, 256, kernel_size=5, padding=2)
        self.local_response2 = nn.LocalResponseNorm(
            size=5, alpha=0.0001, beta=0.75, k=2
        )
        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=2)
        self.conv3 = nn.Conv1d(256, 384, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(384, 384, kernel_size=3, padding=1)
        self.conv5 = nn.Conv1d(384, 256, kernel_size=3, padding=1)
        self.pool5 = nn.MaxPool1d(kernel_size=3, stride=2)

        self._to_linear = self.calculate_to_linear(input_length)

        self.fc1 = nn.Linear(self._to_linear, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 4)

    def calculate_to_linear(self, length):
        length = (length - 11) // 4 + 1
        length = (length - 3) // 2 + 1
        length = (length + 2 * 2 - 5) // 1 + 1
        length = (length - 3) // 2 + 1
        length = (length + 2 * 1 - 3) // 1 + 1
        length = (length + 2 * 1 - 3) // 1 + 1
        length = (length + 2 * 1 - 3) // 1 + 1
        length = (length - 3) // 2 + 1

        return length * 256

    def forward(self, x):
        x = self.pool1(F.relu(self.local_response1(self.conv1(x))))
        x = self.pool2(F.relu(self.local_response2(self.conv2(x))))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool5(F.relu(self.conv5(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.5)
        x = self.fc3(x)
        return x

In [64]:
def train(model, device, train_dataloader, optimizer, epochs):
    model.train()
    for batch_ids, (inputs, labels) in enumerate(train_dataloader):
        labels = labels.type(torch.LongTensor)
        inputs, labels = inputs.to(device), labels.to(device)
        torch.autograd.set_detect_anomaly(True)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        if (batch_ids + 1) % 2 == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epochs,
                    batch_ids * len(inputs),
                    len(train_dataloader.dataset),
                    100.0 * batch_ids / len(train_dataloader),
                    loss.item(),
                )
            )


def test(model, device, test_dataloader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            test_loss += F.nll_loss(outputs, labels, reduction="sum").item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()

        test_loss /= len(test_dataloader)
        res = "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            len(test_dataloader.dataset),
            100.0 * correct / len(test_dataloader.dataset),
        )
        print(res)
        print("=" * 30)
        return res

In [65]:
pos_dir = "output/positive/cropped"
neg_dir = "output/negative/cropped"


output_file = "merged_data.txt"

merge_files(pos_dir, neg_dir, output_file)
split_data(output_file)
input_length = get_seq_length(output_file)

In [66]:
epochs = 5
batch_size = 1024
max_samples = None

if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

NVIDIA GeForce RTX 2070 SUPER
cuda


In [67]:
train_dataloader = create_dataloader(
    filepath="train.txt", batch_size=batch_size, max_samples=max_samples
)
test_dataloader = create_dataloader(
    filepath="test.txt", batch_size=batch_size, max_samples=max_samples
)

In [68]:
alex_net = AlexNet(input_length).to(device)
alex_optimizer = torch.optim.Adam(alex_net.parameters(), lr=0.0001)
for epoch in range(1, epochs + 1):
    train(alex_net, device, train_dataloader, alex_optimizer, epoch)
    test(alex_net, device, test_dataloader)


Test set: Average loss: -4047.1676, Accuracy: 31262/49128 (64%)


Test set: Average loss: -3433.3346, Accuracy: 31818/49128 (65%)


Test set: Average loss: -3540.0957, Accuracy: 32055/49128 (65%)


Test set: Average loss: -3595.2825, Accuracy: 32458/49128 (66%)


Test set: Average loss: -3605.4940, Accuracy: 32798/49128 (67%)



In [69]:
mini_cnn = MiniCNN(input_length).to(device)
mini_optimizer = torch.optim.Adam(mini_cnn.parameters(), lr=0.0001)
for epoch in range(1, epochs + 1):
    train(mini_cnn, device, train_dataloader, mini_optimizer, epoch)
    test(mini_cnn, device, test_dataloader)


Test set: Average loss: -58.6825, Accuracy: 30964/49128 (63%)


Test set: Average loss: -101.8478, Accuracy: 31311/49128 (64%)


Test set: Average loss: -131.5284, Accuracy: 31514/49128 (64%)


Test set: Average loss: -156.3552, Accuracy: 31698/49128 (65%)


Test set: Average loss: -164.4649, Accuracy: 31824/49128 (65%)

