# General Setup & Imports 

In [1]:
!pip install trackio -qq

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.9/887.9 kB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.0/23.0 MB[0m [31m90.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.4/55.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m128.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from kaggle_secrets import UserSecretsClient
import os

user_secrets = UserSecretsClient()

os.environ["HF_TOKEN"] = user_secrets.get_secret("hf_access_token")
os.environ["PL_DISABLE_TENSORBOARD"] = "1"

In [3]:
import pandas as pd
import numpy as np

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

from torchmetrics.classification import MulticlassF1Score

import trackio
import kagglehub

In [4]:
pl.seed_everything(1596)

Seed set to 1596


1596

In [5]:
# trackio.init(
#     project="25-t3-nppe2",
#     space_id="nikshhiremath/25-t3-nppe2",
# )

# Load Dataset

In [6]:
train_df = pd.read_csv("/kaggle/input/sep-25-dl-gen-ai-nppe-2/train.csv")
test_df  = pd.read_csv("/kaggle/input/sep-25-dl-gen-ai-nppe-2/test.csv")

print("Train rows:", len(train_df))
print("Test rows :", len(test_df))

Train rows: 7262
Test rows : 1816


# Tokenizers / Label Maps

In [7]:
# Allowed amino acids (with '*' for masked residues)
AMINO_ACIDS = list("ACDEFGHIKLMNPQRSTVWY*")

aa2idx = {aa: i+1 for i, aa in enumerate(AMINO_ACIDS)}  # 0 = padding
idx2aa = {i+1: aa for i, aa in enumerate(AMINO_ACIDS)}

# Q8 labels
Q8_LABELS = ["B", "C", "E", "G", "H", "I", "S", "T"]
q8_to_idx = {c: i for i, c in enumerate(Q8_LABELS)}
idx_to_q8 = {i: c for c, i in q8_to_idx.items()}

# Q3 labels
Q3_LABELS = ["H", "E", "C"]
q3_to_idx = {c: i for i, c in enumerate(Q3_LABELS)}
idx_to_q3 = {i: c for c, i in q3_to_idx.items()}

# Dataset

In [8]:
class ProteinDataset(Dataset):
    def __init__(self, df, train=True):
        self.seqs = df["seq"].values
        self.train = train

        if train:
            self.sst8 = df["sst8"].values
            self.sst3 = df["sst3"].values

    def encode(self, seq, table):
        return torch.tensor([table[c] for c in seq], dtype=torch.long)

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq = self.encode(self.seqs[idx], aa2idx)

        if self.train:
            q8 = self.encode(self.sst8[idx], q8_to_idx)
            q3 = self.encode(self.sst3[idx], q3_to_idx)
            return seq, q8, q3

        return seq

In [9]:
def pad_collate(batch):
    # Training: (seq, q8, q3)
    if len(batch[0]) == 3:
        seqs, q8s, q3s = zip(*batch)
        
        lengths = [len(s) for s in seqs]
        max_len = max(lengths)

        padded_seqs = []
        padded_q8   = []
        padded_q3   = []

        for s, q8, q3 in zip(seqs, q8s, q3s):
            pad_len = max_len - len(s)

            padded_seqs.append(torch.cat([s, torch.zeros(pad_len, dtype=torch.long)]))
            padded_q8.append(torch.cat([q8, torch.zeros(pad_len, dtype=torch.long)]))
            padded_q3.append(torch.cat([q3, torch.zeros(pad_len, dtype=torch.long)]))

        return torch.stack(padded_seqs), torch.stack(padded_q8), torch.stack(padded_q3)

    # Test set: only sequences
    else:
        seqs = batch
        lengths = [len(s) for s in seqs]
        max_len = max(lengths)

        padded_seqs = [
            torch.cat([s, torch.zeros(max_len - len(s), dtype=torch.long)])
            for s in seqs
        ]
        return torch.stack(padded_seqs)

# Lightning DataModule

In [10]:
class ProteinDM(pl.LightningDataModule):
    def __init__(self, train_df, batch_size=32):
        super().__init__()
        self.train_df = train_df
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = ProteinDataset(self.train_df, train=True)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=pad_collate
        )

# Bi-Directional RNN

In [11]:
class BiRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)

        self.rnn = nn.RNN(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True,
        )

        self.q8_head = nn.Linear(hidden_dim * 2, len(Q8_LABELS))
        self.q3_head = nn.Linear(hidden_dim * 2, len(Q3_LABELS))

    def forward(self, x):
        x = self.embedding(x)               # (B, L, E)
        out, _ = self.rnn(x)                # (B, L, 2H)

        q8_logits = self.q8_head(out)       # (B, L, 8)
        q3_logits = self.q3_head(out)       # (B, L, 3)
        return q8_logits, q3_logits

# Bi-Directional LSTM

In [12]:
class BiLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
        )

        self.q8_head = nn.Linear(hidden_dim * 2, len(Q8_LABELS))
        self.q3_head = nn.Linear(hidden_dim * 2, len(Q3_LABELS))

    def forward(self, x):
        x = self.embedding(x)              # (B, L, E)
        out, _ = self.lstm(x)              # (B, L, 2H)

        q8_logits = self.q8_head(out)
        q3_logits = self.q3_head(out)
        return q8_logits, q3_logits

# LightningModule for Training

In [13]:
class ProteinLightning(pl.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.model = model
        self.lr = lr

        self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)   # ignore padding

        # Metrics
        self.f1_q8 = MulticlassF1Score(
            num_classes=len(Q8_LABELS),
            average="macro"
        )
        self.f1_q3 = MulticlassF1Score(
            num_classes=len(Q3_LABELS),
            average="macro"
        )

    def training_step(self, batch, batch_idx):
        seq, q8, q3 = batch

        pred8, pred3 = self.model(seq)   # logits: (B, L, C)

        # Compute loss (ignore_index=0 handles padding)
        loss8 = self.loss_fn(pred8.transpose(1, 2), q8)
        loss3 = self.loss_fn(pred3.transpose(1, 2), q3)
        loss = loss8 + loss3

        mask_q8 = (q8 != 0)
        mask_q3 = (q3 != 0)

        pred8_flat = pred8.argmax(-1)[mask_q8]
        true8_flat = q8[mask_q8]

        pred3_flat = pred3.argmax(-1)[mask_q3]
        true3_flat = q3[mask_q3]

        # Compute F1
        f1_q8_val = self.f1_q8(pred8_flat, true8_flat)
        f1_q3_val = self.f1_q3(pred3_flat, true3_flat)

        # Logging
        self.log("train_loss", loss, prog_bar=True)
        self.log("f1_q8", f1_q8_val, prog_bar=True, on_step=True, on_epoch=True)
        self.log("f1_q3", f1_q3_val, prog_bar=True, on_step=True, on_epoch=True)

        trackio.log({
            "loss": float(loss.item()),
            "f1_q8": float(f1_q8_val.item()),
            "f1_q3": float(f1_q3_val.item()),
        })

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# Training

In [14]:
# # Training Setup
# dm = ProteinDM(train_df, batch_size=32)
# dm.setup()

In [15]:
# # Train BiRNN
# print("Training BiRNN model...")

# rnn_model = BiRNN(
#     vocab_size=len(AMINO_ACIDS) + 1,
#     embed_dim=128,
#     hidden_dim=256
# )

# lit_rnn = ProteinLightning(rnn_model)

# trainer_rnn = pl.Trainer(
#     max_epochs=20,
#     accelerator="gpu" if torch.cuda.is_available() else "cpu",
#     devices=1
# )

# trainer_rnn.fit(lit_rnn, dm)

# # Save checkpoint
# trainer_rnn.save_checkpoint("birnn.ckpt")

In [16]:
# # Upload to KaggleHub
# kagglehub.model_upload(
#     "nikshhiremath/birnn/pytorch/lr1e-3",
#     "birnn.ckpt",
#     # license_name="apache-2.0"
# )

In [17]:
# # Train BiLSTM
# print("Training BiLSTM model...")

# lstm_model = BiLSTM(
#     vocab_size=len(AMINO_ACIDS) + 1,
#     embed_dim=128,
#     hidden_dim=256,
#     num_layers=2
# )

# lit_lstm = ProteinLightning(lstm_model)

# trainer_lstm = pl.Trainer(
#     max_epochs=20,
#     accelerator="gpu" if torch.cuda.is_available() else "cpu",
#     devices=1
# )

# trainer_lstm.fit(lit_lstm, dm)

# # Save checkpoint
# trainer_lstm.save_checkpoint("bilstm.ckpt")

In [18]:
# # Upload to KaggleHub
# kagglehub.model_upload(
#     "nikshhiremath/bilstm/pytorch/lr1e-3",
#     "bilstm.ckpt",
#     # license_name="apache-2.0"
# )

In [19]:
# # Finish TrackIO run
# trackio.finish()

# Inference

In [20]:
# Download model files
birnn_path = kagglehub.model_download("nikshhiremath/birnn/pytorch/lr1e-3")
bilstm_path = kagglehub.model_download("nikshhiremath/bilstm/pytorch/lr1e-3")

print("BiRNN loaded from:", birnn_path)
print("BiLSTM loaded from:", bilstm_path)

BiRNN loaded from: /kaggle/input/birnn/pytorch/lr1e-3/3
BiLSTM loaded from: /kaggle/input/bilstm/pytorch/lr1e-3/1


In [21]:
def strip_prefix(state_dict, prefix="model."):
    new_state = {}
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_state[k[len(prefix):]] = v
        else:
            new_state[k] = v
    return new_state

In [22]:
birnn_model = BiRNN(vocab_size=len(AMINO_ACIDS)+1)
birnn_state = strip_prefix(torch.load(f"{birnn_path}/birnn.ckpt", map_location="cpu")["state_dict"])
birnn_model.load_state_dict(birnn_state)
birnn_model.eval()

BiRNN(
  (embedding): Embedding(22, 128, padding_idx=0)
  (rnn): RNN(128, 256, batch_first=True, bidirectional=True)
  (q8_head): Linear(in_features=512, out_features=8, bias=True)
  (q3_head): Linear(in_features=512, out_features=3, bias=True)
)

In [23]:
bilstm_model = BiLSTM(vocab_size=len(AMINO_ACIDS)+1)
bilstm_state = strip_prefix(torch.load(f"{bilstm_path}/bilstm.ckpt", map_location="cpu")["state_dict"])
bilstm_model.load_state_dict(bilstm_state)
bilstm_model.eval()

BiLSTM(
  (embedding): Embedding(22, 128, padding_idx=0)
  (lstm): LSTM(128, 256, num_layers=2, batch_first=True, bidirectional=True)
  (q8_head): Linear(in_features=512, out_features=8, bias=True)
  (q3_head): Linear(in_features=512, out_features=3, bias=True)
)

In [24]:
def pad_collate_test(batch):
    max_len = max(len(s) for s in batch)
    padded = [
        torch.cat([s, torch.zeros(max_len - len(s), dtype=torch.long)])
        for s in batch
    ]
    return torch.stack(padded)

In [25]:
def encode(seq, table):
    return torch.tensor([table[c] for c in seq], dtype=torch.long)

test_encoded = [encode(seq, aa2idx) for seq in test_df["seq"]]

test_loader = DataLoader(
    test_encoded,
    batch_size=32,
    shuffle=False,
    collate_fn=pad_collate_test
)

In [26]:
preds_q8 = []
preds_q3 = []

with torch.no_grad():
    for batch in test_loader:
        # Use RNN predictions
        q8_logits, q3_logits = birnn_model(batch)

        # Get predictions
        pred8 = q8_logits.argmax(-1)
        pred3 = q3_logits.argmax(-1)

        for i in range(batch.size(0)):
            length = (batch[i] != 0).sum().item()

            seq_q8 = "".join(idx_to_q8[idx.item()] for idx in pred8[i][:length])
            seq_q3 = "".join(idx_to_q3[idx.item()] for idx in pred3[i][:length])

            preds_q8.append(seq_q8)
            preds_q3.append(seq_q3)

In [27]:
submission = pd.DataFrame({
    "id": test_df["id"],
    "sst8": preds_q8,
    "sst3": preds_q3
})

submission.to_csv("submission.csv", index=False)

print("submission.csv created!")
submission.head()

submission.csv created!


Unnamed: 0,id,sst8,sst3
0,0,CCCCTHHHCCHHHHHHHHHHHHHCSEEEEEESCCTCCCSEEEEEET...,CCCCCCCCCCCCEEEECCEECCCCCEEEEEECCCCCCCCCEEEECC...
1,1,CCCCCCCCCCEEEEEEEESTTEEEEEEECTTEHEHHHHCCCCTTCC...,CCCCCCCCCCEEEEEEECCCCEEEEEEECCCEEEECCCCCCCCCCC...
2,2,CCCHHEHHHHTHHHHHHHHHHHHHHHHHHHHHHHCCCTHHHHHHHH...,CCCEEEEECCCCCCCCCCCCCCCCCCCEEEECCCCCCCCCCCCCCC...
3,3,CCCCHHHHHHHHHCHHHHHHHHHHHHHHHHEEEECTHHHHHHHHHH...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCEEEECCCCCCCCCCCC...
4,4,CCEEEEEEETCTHHHHHHHHHHHHHHHCEEEEECSTTCEEEESSEE...,CCEEEEEECCCCCCCCCCCCEEECCCCCCCEECCCCCCEEEECCEE...
