CS 4277: Deep Learning Group Project

## CS 4277: *Deep Learning* Group Project
### Members:
- Nicholas Hodge
- Joshua Peeples
- Jonathan Turner

### This project is our attempt at the Stanford RNA 3D Folding Challenge, found at:

https://www.kaggle.com/competitions/stanford-rna-3d-folding

**For this project to run:**

1. Install matplotlib in your Jupyter Kernel: Block [1]
2. Setup correct path files to your train dataset: Block [7] (there is a comment)

**Future work:**

1. Setup validation correctly
2. Test
3. Return Submission.csv as per requirements

In [1]:
# Uncomment and run if matplotlib not installed
# !  python -m pip install matplotlib

In [2]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib
import matplotlib.pyplot as plt

In [3]:
NUC_TO_IDX = {
    "A": 0,
    "U": 1,
    "C": 2,
    "G": 3,
    "N": 4 # There are characters *other* than the above 4 sometimes. 'N' is standard for "unknown" (apparently)
}
VOCAB_SIZE = len(NUC_TO_IDX)

class RNADataset(Dataset):
    def __init__(self, seq_csv_path, coords_csv_path):
        self.sequences_df = pd.read_csv(seq_csv_path)
        self.coords_df = pd.read_csv(coords_csv_path)

        # Extract base ID from "ID" column in coords
        self.coords_df["base_id"] = self.coords_df["ID"].apply(lambda x: "_".join(x.split("_")[:2]))

        # Group by base ID and filter sequences that match
        self.coord_groups = self.coords_df.groupby("base_id")
        self.valid_ids = set(self.coord_groups.groups.keys())

        self.sequences_df = self.sequences_df[self.sequences_df["target_id"].isin(self.valid_ids)]

    def __len__(self):
        return len(self.sequences_df)

    def __getitem__(self, idx):
        row = self.sequences_df.iloc[idx]
        seq_id = row["target_id"]
        sequence = row["sequence"]

        token_ids = []
        for nuc in sequence:
            token_ids.append(NUC_TO_IDX.get(nuc, NUC_TO_IDX["N"]))  # Use "N" for unknown

        token_ids = torch.tensor(token_ids, dtype=torch.long)

        coords = self.coord_groups.get_group(seq_id)[["x_1", "y_1", "z_1"]].values
        coords = torch.tensor(coords, dtype=torch.float32)

        return token_ids, coords

In [4]:
# Pad sequences in collate_fn
def collate_fn(batch):
    sequences, coords = zip(*batch)

    # Pad sequences and coordinates
    seq_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=0)
    coord_padded = torch.nn.utils.rnn.pad_sequence(coords, batch_first=True, padding_value=0.0)

    # Create mask (1 for real tokens, 0 for padded)
    mask = (seq_padded != 0).unsqueeze(1).unsqueeze(2)

    return seq_padded, coord_padded, mask

In [5]:
# Source: Aladdin Persson on YouTube (then modified to have an encoder-only architecture)

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads, dropout):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, values, keys, query, mask):
        N, value_len, _ = values.shape
        _, key_len, _ = keys.shape
        _, query_len, _ = query.shape

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)
        attention = self.dropout(attention)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions
        
        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads, dropout)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        attention = self.attention(x, x, x, mask)

        x = self.dropout(self.norm1(attention + x))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out
    
class RNA3DFoldPredictor(nn.Module):
    def __init__(self,
                 vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 forward_expansion,
                 dropout,
                 max_length):
        super().__init__()
        self.embed_size = embed_size
        self.token_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
            TransformerBlock(embed_size, heads, dropout, forward_expansion)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_size, 3)  # Predict (x, y, z)

    def forward(self, x, mask=None):
        N, seq_len = x.shape

        positions = torch.arange(0, seq_len).unsqueeze(0).expand(N, seq_len).to(x.device)

        out = self.token_embedding(x) + self.position_embedding(positions)
        for layer in self.layers:
            out = layer(out, mask)

        coords = self.fc_out(out)
        return coords
    
    def predict_multiple(self, x, n_samples=5):
        self.train()  # Activate dropout during inference
        with torch.no_grad():
            outputs = [self(x) for _ in range(n_samples)]
        return torch.stack(outputs)  # Shape: (n_samples, batch_size, seq_len, 3)

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print (device)

cpu


In [None]:
dataset = RNADataset("./data/train_sequences.csv", "./data/train_labels.csv") # replace with your *actual* path
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

model = RNA3DFoldPredictor(
    vocab_size=VOCAB_SIZE,
    embed_size=256,
    num_layers=6,
    heads=8,
    forward_expansion=4,
    dropout=0.2,
    max_length=8192, # nearest multiple of 2...actual max is 4298
).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    test_batch = next(iter(loader))
    seqs, coords, mask = [x.to(device) for x in test_batch]

    print("Max token ID:", torch.max(seqs))  # Should be <= 3
    print("Embedding size:", model.token_embedding.num_embeddings)  # Should be 4

    outputs = model(seqs, mask)
    print("Output shape:", outputs.shape)

    for seqs, coords, mask in loader:
        print("Max token ID in batch:", torch.max(seqs)) # check max value
        seqs, coords, mask = seqs.to(device), coords.to(device), mask.to(device)

        optimizer.zero_grad()
        outputs = model(seqs, mask)

        loss = criterion(outputs, coords)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss / len(loader):.4f}")

Max token ID: tensor(3)
Embedding size: 5


KeyboardInterrupt: 