<a href="https://colab.research.google.com/github/antonpirhonen/rna-folding-nn/blob/main/main_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Testing RNA folding with a Neural Network

In [1]:
import numpy as np
import pandas as pd
import torch

In [2]:
print(torch.cuda.is_available())

True


In [3]:
import torch.nn as nn

class RNAFoldingNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNAFoldingNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        return out

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# Read the training data
df = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/data_with_folded.csv")

# Convert RNA sequence to one-hot encoding
nucleotides = ['A', 'C', 'G', 'U']
def one_hot_nucl(seq):
    x = np.zeros((len(seq), 4))
    for i, nuc in enumerate(seq):
        x[i, nucleotides.index(nuc)] = 1
    return x

# Convert the RNA sequence to a one-hot encoding
train_data = np.array([one_hot_nucl(seq) for seq in df['seq']])
train_data = torch.tensor(train_data.reshape(train_data.shape[0], -1), dtype=torch.float)

# Convert the folded sequence to a one-hot encoding
hbonds = ['.', '(', ')']
def one_hot_hbond(seq):
    x = np.zeros((len(seq), 3))
    for i, hbond in enumerate(seq):
        x[i, hbonds.index(hbond)] = 1
    return x

train_labels = np.array([one_hot_hbond(seq) for seq in df['folded_seq']])
train_labels = torch.tensor(train_labels.reshape(train_labels.shape[0], -1), dtype=torch.float)

# Split the data into training and validation sets
from sklearn.model_selection import train_test_split

train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.2)

In [6]:
N = 30
# Create the network
input_size = 4*N
hidden_size = 128
output_size = 3*N

net = RNAFoldingNet(input_size, hidden_size, output_size)

if torch.cuda.is_available():
  net.cuda()
  train_data = torch.Tensor(train_data).cuda()
  train_labels = torch.Tensor(train_labels).cuda()
  # val_data = torch.Tensor(val_data).cuda()
  # val_labels = torch.Tensor(val_labels).cuda()


# Create the optimizer
import torch.optim as optim
import torch.nn as nn

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Train the network
num_epochs = 100
batch_size = 32

for epoch in range(num_epochs):
    running_loss = 0.0
    for i in range(0, len(train_data), batch_size):
        inputs = train_data[i:i+batch_size]
        labels = train_labels[i:i+batch_size]
        inputs = torch.Tensor(inputs).cuda()
        labels = torch.Tensor(labels).cuda()

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print('Epoch %d, loss: %.3f' % (epoch+1, running_loss/len(train_data)))


Epoch 1, loss: 0.014
Epoch 2, loss: 0.013
Epoch 3, loss: 0.013
Epoch 4, loss: 0.013
Epoch 5, loss: 0.013
Epoch 6, loss: 0.013
Epoch 7, loss: 0.013
Epoch 8, loss: 0.013
Epoch 9, loss: 0.013
Epoch 10, loss: 0.013
Epoch 11, loss: 0.013
Epoch 12, loss: 0.013
Epoch 13, loss: 0.013
Epoch 14, loss: 0.013
Epoch 15, loss: 0.013
Epoch 16, loss: 0.013
Epoch 17, loss: 0.013
Epoch 18, loss: 0.012
Epoch 19, loss: 0.012
Epoch 20, loss: 0.012
Epoch 21, loss: 0.012
Epoch 22, loss: 0.012
Epoch 23, loss: 0.012
Epoch 24, loss: 0.012
Epoch 25, loss: 0.012
Epoch 26, loss: 0.012
Epoch 27, loss: 0.012
Epoch 28, loss: 0.012
Epoch 29, loss: 0.012
Epoch 30, loss: 0.012
Epoch 31, loss: 0.012
Epoch 32, loss: 0.012
Epoch 33, loss: 0.012
Epoch 34, loss: 0.012
Epoch 35, loss: 0.012
Epoch 36, loss: 0.012
Epoch 37, loss: 0.012
Epoch 38, loss: 0.012
Epoch 39, loss: 0.012
Epoch 40, loss: 0.012
Epoch 41, loss: 0.012
Epoch 42, loss: 0.012
Epoch 43, loss: 0.012
Epoch 44, loss: 0.012
Epoch 45, loss: 0.012
Epoch 46, loss: 0.0

In [10]:
# Use the network to make predictions
if torch.cuda.is_available():
    net.cpu()

with torch.no_grad():
    outputs = net(val_data)
    outputs = outputs.numpy()
    val_labels = val_labels

# Convert the one-hot encoding back to a folded sequence
def one_hot_to_hbond(x):
    return hbonds[np.argmax(x)]

preds = []
for i in range(len(outputs)):
    pred = []
    for j in range(0, len(outputs[i]), 3):
        pred.append(one_hot_to_hbond(outputs[i][j:j+3]))
    preds.append(''.join(pred))

# Display the predictions and the actual folded sequences
for i in range(100):
    print(preds[i])
    # Reshape the labels to show secondary structure with hydrogen bonds using one-hot to hbond
    print(''.join([one_hot_to_hbond(val_labels[i][j:j+3]) for j in range(0, len(val_labels[i]), 3)]))
    print()

....(((..(((((......))))))))).
....((.((((((.(.....).))))))))

......(((..........)))))......
.................(((......))).

............(((((.......))))).
..........((((.(((...))).)))).

.....((((.(..((....)))........
.(((((.(((......))))))))......

............(((((((.....))))).
..............((.(((....))))).

..((..(((((.......))))........
..(((.(((...))).)))...........

..........(((((......)))).....
.((.(((......)))))....((....))

..............................
..............................

.........(((((.....)....))))).
..........((((....))))........

..((((.((........)))))..)))...
((((..((.((.....)))).)))).....

............(.................
...........(((((......)))))...

........(................))...
..............................

..((((.....)().......))))))...
......(((((.(((...)))..)))))..

.........((((..(..........)...
((((....((((((.....)))).))))))

..(((............(()......))).
.(((.((((((.(......))))))).)))

..(((((......))))))...........
..((((((....))))))......