CS 4277: Deep Learning Group Project

## CS 4277: *Deep Learning* Group Project
### Members:
- Nicholas Hodge
- Joshua Peeples
- Jonathan Turner

### This project is our attempt at the Stanford RNA 3D Folding Challenge, found at:

https://www.kaggle.com/competitions/stanford-rna-3d-folding

**For this project to run:**

1. Install matplotlib in your Jupyter Kernel: Block [1]
2. Setup correct path files to your train dataset: Block [7] (there is a comment)

**Future work:**

1. Setup validation correctly
2. Test
3. Return Submission.csv as per requirements

In [1]:
# Uncomment and run if matplotlib not installed
# !  python -m pip install matplotlib

In [2]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# import matplotlib
# import matplotlib.pyplot as plt

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print (device)

cuda


# Training

In [4]:
NUC_TO_IDX = {
    "A": 0,
    "U": 1,
    "C": 2,
    "G": 3,
    "N": 4 # There are characters *other* than the above 4 sometimes. 'N' is standard for "unknown" (apparently)
}
PAD_IDX = 5
VOCAB_SIZE = len(NUC_TO_IDX) + 1

# Annotated to avoid later confusion - Nick
# This is inherits the Dataset class from pytorch to allow for the dataset to be an Iterable (i.e. work a LOT faster)
class RNADataset(Dataset):
    def __init__(self, seq_csv_path, coords_csv_path):
        # Read both train CSVs (the 'labels' csv -> 'coords')
        self.sequences_df = pd.read_csv(seq_csv_path)
        coords_df_raw = pd.read_csv(coords_csv_path)

        # We are going to get the base_id from each row in coords to associate them with the correct sequence.
        # 1SCL_A_5 becomes 1SCL_A
        coords_df_raw["base_id"] = coords_df_raw["ID"].apply(lambda x: "_".join(x.split("_")[:2]))

        # Now we are going to create groups of coords, where each group corresponds with the same sequence
        # Unfortunately some sequences have missing coord values, but I am going to assume that there are potential
        # sequences that have some missing and some not. So:

        # Method to remove entire groups where *any* row has missing coords
        def is_group_valid(group):
            return group[["x_1", "y_1", "z_1"]].notna().all().all() # returns only rows where all columns are good

        valid_groups = [
            group for _, group in coords_df_raw.groupby("base_id") if is_group_valid(group)
        ]

        # Concatenate all valid groups into a new coords_df
        self.coords_df = pd.concat(valid_groups, ignore_index=True)

        # Build groups and valid sequence IDs list
        self.coord_groups = self.coords_df.groupby("base_id")
        self.valid_ids = set(self.coord_groups.groups.keys())

        # Filter sequences to only include those with clean coordinate groups (prevents later tensors from being mishaped)
        self.sequences_df = self.sequences_df[self.sequences_df["target_id"].isin(self.valid_ids)]

    # Optional but Pytorch docs suggest this for 'Sampler' implmentations (might need that?)
    def __len__(self):
        return len(self.sequences_df)


    def __getitem__(self, idx):
        row = self.sequences_df.iloc[idx]
        seq_id = row["target_id"]
        sequence = row["sequence"]

        token_ids = [NUC_TO_IDX.get(nuc, NUC_TO_IDX["N"]) for nuc in sequence]
        token_ids = torch.tensor(token_ids, dtype=torch.long)

        # Here we introduce standardization to the coordinates

        # TODO: calculate the following values somewhere in the document in case the dataset changes:
        # Currently precalculated values
        mean_x = 80.44731529117061
        std_x = 147.42231938515297
        mean_y = 84.04072703411182
        std_y = 114.92890150429712
        mean_z = 98.61122565112208
        std_z = 119.41066506340083

        coords_standardized = self.coord_groups.get_group(seq_id)[["x_1", "y_1", "z_1"]].values
        coords_standardized[:, 0] = (coords_standardized[:, 0] - mean_x) / std_x
        coords_standardized[:, 1] = (coords_standardized[:, 1] - mean_y) / std_y
        coords_standardized[:, 2] = (coords_standardized[:, 2] - mean_z) / std_z

        coords = torch.tensor(coords_standardized, dtype=torch.float32)

        return token_ids, coords

In [5]:
# Pad sequences in train_collate_fn
def train_collate_fn(batch):
    sequences, coords = zip(*batch)

    # Pad sequences with PAD_IDX
    seq_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=PAD_IDX)
    coord_padded = torch.nn.utils.rnn.pad_sequence(coords, batch_first=True, padding_value=0.0)

    # Mask should check against PAD_IDX
    mask = (seq_padded != PAD_IDX).unsqueeze(1).unsqueeze(2)

    return seq_padded, coord_padded, mask

In [6]:
# Source: Aladdin Persson on YouTube (then modified to have an encoder-only architecture)

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads, dropout):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, values, keys, query, mask):
        N, value_len, _ = values.shape
        _, key_len, _ = keys.shape
        _, query_len, _ = query.shape

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            # mask: (batch, 1, 1, seq_len) -> broadcastable to (batch, heads, query_len, key_len)
            energy = energy.masked_fill(mask == 0, float("-1e9"))

        attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)
        attention = self.dropout(attention)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads, dropout)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attention = self.attention(x, x, x, mask)

        x = self.dropout(self.norm1(attention + x))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class RNA3DFoldPredictor(nn.Module):
    def __init__(self,
                 vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 forward_expansion,
                 dropout,
                 max_length):
        super().__init__()
        self.embed_size = embed_size
        self.token_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
            TransformerBlock(embed_size, heads, dropout, forward_expansion)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_size, 3)  # Predict (x, y, z)

    def forward(self, x, mask=None):
        N, seq_len = x.shape

        positions = torch.arange(0, seq_len).unsqueeze(0).expand(N, seq_len).to(x.device)

        out = self.token_embedding(x) + self.position_embedding(positions)
        for layer in self.layers:
            out = layer(out, mask)

        coords = self.fc_out(out)
        return coords

    def predict_multiple(self, x, n_samples=5):
        self.train()  # Activate dropout during inference
        with torch.no_grad():
            outputs = [self(x) for _ in range(n_samples)]
        return torch.stack(outputs)  # Shape: (n_samples, batch_size, seq_len, 3)

In [7]:
def compute_tm_score(pred_coords, true_coords):
    # pred_coords and true_coords: shape (seq_len, 3)
    # Mask out invalid coordinates (e.g. from padding or -1e18)
    valid_mask = ~(torch.isclose(true_coords, torch.tensor(-1e18)).any(dim=-1))

    pred_coords = pred_coords[valid_mask]
    true_coords = true_coords[valid_mask]

    if len(pred_coords) < 3 or len(true_coords) < 3:
        return 0.0  # Not enough points to compare

    # Convert to numpy
    pred_coords = pred_coords.detach().cpu().numpy()
    true_coords = true_coords.detach().cpu().numpy()

    # Superimpose using Procrustes
    mu_pred = np.mean(pred_coords, axis=0)
    mu_true = np.mean(true_coords, axis=0)
    pred_centered = pred_coords - mu_pred
    true_centered = true_coords - mu_true

    H = pred_centered.T @ true_centered
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    pred_aligned = pred_centered @ R

    rmsd = np.sqrt(np.mean(np.sum((pred_aligned - true_centered) ** 2, axis=1)))
    tm_score = 1 / (1 + (rmsd / 1.24))  # Simplified TM-score-like metric

    return tm_score

In [8]:
dataset = RNADataset("./data/train_sequences.csv", "./data/train_labels.csv") # replace with your *actual* path
loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=train_collate_fn)


model = RNA3DFoldPredictor(
    vocab_size=VOCAB_SIZE,
    embed_size=64,
    num_layers=4,
    heads=4,
    forward_expansion=4,
    dropout=0.2,
    max_length=4298, # nearest multiple of 2 is 8192...actual max is 4298
).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    batch_num = 0

    test_batch = next(iter(loader))
    seqs, coords, mask = [x.to(device) for x in test_batch]

    print("Max token ID:", torch.max(seqs))  # Should be <= 3
    print("Embedding size:", model.token_embedding.num_embeddings)  # Should be 4

    with torch.no_grad():
        outputs = model(seqs, mask)
    print("Output shape:", outputs.shape)

    for seqs, coords, mask in loader:
        batch_num += 1
        # print("Max token ID in batch:", torch.max(seqs))
        seqs, coords, mask_attention = seqs.to(device), coords.to(device), mask.to(device)

        # check for any NaN values
        if torch.isnan(coords).any() or torch.isinf(coords).any():
            print(f"WARNING: NaN/Inf found in target coordinates in batch {batch_num}! Skipping batch.")
            continue

        optimizer.zero_grad()
        outputs = model(seqs, mask_attention)


        if torch.isnan(outputs).any() or torch.isinf(outputs).any():
             print(f"WARNING: NaN/Inf found in model outputs BEFORE loss calculation in batch {batch_num}!")

        non_pad_mask = (seqs != PAD_IDX) # Shape: (batch_size, seq_len)

        # Flatten outputs and coords, then apply the mask
        outputs_flat = outputs.view(-1, 3) # Shape: (batch * seq_len, 3)
        coords_flat = coords.view(-1, 3)   # Shape: (batch * seq_len, 3)
        non_pad_mask_flat = non_pad_mask.view(-1) # Shape: (batch * seq_len)

        outputs_masked = outputs_flat[non_pad_mask_flat]
        coords_masked = coords_flat[non_pad_mask_flat]

        # Calculate loss ONLY on non-padded elements
        if outputs_masked.nelement() > 0: # Check if there are any non-padded elements
            loss = criterion(outputs_masked, coords_masked)

            if torch.isnan(loss):
               print(f"WARNING: NaN detected in loss for batch {batch_num}!")
               # Add more debugging here if needed: print outputs_masked, coords_masked
               continue # Skip optimization step for this batch

            print(f"Batch {batch_num} Loss: {loss.item():.6f}") # Print loss *before* backward

            loss.backward()
            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
        else:
            print(f"Skipping batch {batch_num} due to only padding elements.")

    # Avoid division by zero if loader is empty or all batches were skipped
    if len(loader) > 0 and total_loss > 0:
         print(f"Epoch {epoch+1} Loss: {total_loss / len(loader):.4f}") # Or divide by number of valid batches processed
    else:
         print(f"Epoch {epoch+1} had no valid batches or zero total loss.")


Max token ID: tensor(5, device='cuda:0')
Embedding size: 6
Output shape: torch.Size([4, 134, 3])
Batch 1 Loss: 1.075750
Batch 2 Loss: 1.697220
Batch 3 Loss: 1.186286
Batch 4 Loss: 0.785351
Batch 5 Loss: 1.685617
Batch 6 Loss: 1.633016
Batch 7 Loss: 0.988838
Batch 8 Loss: 1.577247
Batch 9 Loss: 1.084945
Batch 10 Loss: 0.918133
Batch 11 Loss: 0.986582
Batch 12 Loss: 0.663105
Batch 13 Loss: 0.656975
Batch 14 Loss: 1.483961
Batch 15 Loss: 1.234468
Batch 16 Loss: 1.436586
Batch 17 Loss: 1.311145
Batch 18 Loss: 0.646610
Batch 19 Loss: 0.819395
Batch 20 Loss: 1.119095
Batch 21 Loss: 1.187753
Batch 22 Loss: 1.205928
Batch 23 Loss: 1.253905
Batch 24 Loss: 1.017939
Batch 25 Loss: 1.844937
Batch 26 Loss: 0.880586
Batch 27 Loss: 1.777552
Batch 28 Loss: 0.467045
Batch 29 Loss: 0.917429
Batch 30 Loss: 0.934313
Batch 31 Loss: 1.509476
Batch 32 Loss: 1.638286
Batch 33 Loss: 1.050958
Batch 34 Loss: 0.854604
Batch 35 Loss: 1.048206
Batch 36 Loss: 0.920672
Batch 37 Loss: 0.824061
Batch 38 Loss: 1.088969


In [9]:
model_path = './model.pth'
torch.save(model.state_dict(), model_path)

# Testing

In [10]:
IDX_TO_NUC = {v: k for k, v in NUC_TO_IDX.items()}
# Testing dataset
class RNATestDataset(Dataset):
    def __init__(self, test_seq_path):
        # Read test csv
        self.data = pd.read_csv(test_seq_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        seq_id = row["target_id"]
        sequence = row["sequence"]

        token_ids = [NUC_TO_IDX.get(nuc, NUC_TO_IDX["N"]) for nuc in sequence]
        token_ids = torch.tensor(token_ids, dtype=torch.long)

        return seq_id, token_ids

In [11]:
test_dataset = RNATestDataset(test_seq_path='./data/test_sequences.csv')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model.load_state_dict(torch.load('model.pth')) # load trained model

model.to(device)
model.eval()

# enable dropout
for m in model.modules():
        if m.__class__.__name__.startswith('Dropout'):
            m.train()

submission_rows = []

with torch.no_grad():
    for seq_id, token_ids in test_loader:
        seq_id = seq_id[0]  # unpack from list
        token_ids = token_ids.to(device).squeeze(0)  # [seq_len]
        sequence = [IDX_TO_NUC[i.item()] for i in token_ids]

        predictions = []

        # Generate 5 predictions
        for _ in range(5):
            output = model(token_ids.unsqueeze(0))  # [1, seq_len, 3]
            coords = output.squeeze(0).cpu().numpy()  # [seq_len, 3]
            predictions.append(coords)

        predictions = np.stack(predictions, axis=0)  # [5, seq_len, 3]

        seq_len = len(sequence)

        # Loop over each nucleotide in the sequence
        for i in range(seq_len):
            row = {
                "ID": f"{seq_id}_{i+1}",
                "resname": sequence[i],
                "resid": i+1
            }
            for j in range(5):  # 5 predictions
                row[f"x_{j+1}"] = predictions[j, i, 0]
                row[f"y_{j+1}"] = predictions[j, i, 1]
                row[f"z_{j+1}"] = predictions[j, i, 2]

            submission_rows.append(row)

# Convert to DataFrame and save
submission_df = pd.DataFrame(submission_rows)
submission_df.to_csv("submission.csv", index=False)

In [12]:
display(pd.read_csv('submission.csv'))

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,R1107_1,G,1,-0.395682,0.026159,-0.460790,-0.494894,-0.650087,-0.418244,-0.416741,0.150414,-0.912710,-0.010037,-0.118914,-0.552068,-0.288310,-0.017140,-0.589925
1,R1107_2,G,2,-0.391142,-0.348778,-0.625145,-1.039486,-0.452833,-0.256414,-0.282264,-0.324641,-0.428678,0.153570,-0.090237,-0.447006,-0.015217,-0.057593,-0.473987
2,R1107_3,G,3,-0.406647,0.237234,-0.645078,-0.472471,-0.608401,-0.907962,-0.309058,-0.521215,-0.297594,0.273179,-0.576949,-0.514920,-0.529903,-0.336794,-0.105902
3,R1107_4,G,4,-0.255734,-0.193740,-0.450975,-1.073963,-0.451119,-0.545389,-0.265850,-0.288817,-0.768573,-0.674212,0.003626,-0.213552,0.183197,-0.327977,-0.753004
4,R1107_5,G,5,-0.275882,-0.376889,-0.640550,0.141876,-0.380515,-0.370196,-0.239424,-0.610117,-0.164057,-1.026388,0.407131,-0.269336,-0.650812,-1.035031,-0.041577
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2510,R1190_114,U,114,0.373998,-0.722683,-0.175997,-0.131089,-0.624439,-0.328461,-0.198958,-0.287816,0.287879,0.217500,-1.008544,-0.449152,-0.742779,-0.390416,-0.336146
2511,R1190_115,U,115,-0.256396,-0.835813,-0.537425,-0.499511,-0.407664,-0.626575,-0.321698,-0.102411,-0.289411,-0.508839,-0.105264,0.337541,-0.595533,0.091612,-0.496264
2512,R1190_116,U,116,0.756600,-0.402746,-0.086196,-0.124368,-0.887070,-0.260656,-0.721631,-0.328267,-0.258278,-0.448919,-0.096239,-0.546463,-1.014522,-0.884052,-0.059010
2513,R1190_117,U,117,-0.372658,-0.214195,-0.348880,-0.585410,-0.147604,-0.384459,-0.355588,-0.096559,-0.569185,0.057323,-0.432629,-0.723334,-0.958681,0.019751,-0.448138
