In [4]:
#@title Download data from GCP bucket
import sys

if 'google.colab' in sys.modules:
  !gsutil -m cp -r gs://indaba-data .
else:
  !mkdir -p indaba-data/train
  !wget -P indaba-data/train https://storage.googleapis.com/indaba-data/train/train.csv --continue
  !wget -P indaba-data/train https://storage.googleapis.com/indaba-data/train/train_mut.pt --continue
  !wget -P indaba-data/train https://storage.googleapis.com/indaba-data/train/train_wt.pt --continue

  !mkdir -p indaba-data/test
  !wget -P indaba-data/test https://storage.googleapis.com/indaba-data/test/test.csv --continue
  !wget -P indaba-data/test https://storage.googleapis.com/indaba-data/test/test_mut.pt --continue
  !wget -P indaba-data/test https://storage.googleapis.com/indaba-data/test/test_wt.pt --continue

Copying gs://indaba-data/README.txt...
/ [0 files][    0.0 B/   33.0 B]                                                Copying gs://indaba-data/test/test.csv...
Copying gs://indaba-data/test/test_mut.pt...
Copying gs://indaba-data/test/test_wt.pt...
Copying gs://indaba-data/train/train_wt.pt...
Copying gs://indaba-data/train/train.csv...
Copying gs://indaba-data/train/train_mut.pt...
==> NOTE: You are downloading one or more large file(s), which would
run significantly faster if you enabled sliced object downloads. This
feature is enabled by default but requires that compiled crcmod be
installed (see "gsutil help crcmod").



In [5]:
#@title Imports and moving to working directory
import torch 
import pandas as pd
from tqdm import tqdm

# move to data folder
%cd indaba-data

/content/indaba-data/indaba-data


In [25]:
# Load Embedding tensors & Traing csv
# Embeddings were calculated using the ESM 650M pretrained model 
# Tensor shape of embedded data:  [data_len,1280] 
# There are no sequences in the Embedding tensors as we've performed an average of it (torch.mean(embed, dim=1))
# More details in https://huggingface.co/facebook/esm2_t33_650M_UR50D

wt_emb = torch.load("train/train_wt.pt")
mut_emb = torch.load("train/train_mut.pt")
df = pd.read_csv("train/train.csv")

In [26]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

# Reset the index of the dataframe
df.reset_index(drop=True, inplace=True)

# Split data into train and validation
wt_emb_train, wt_emb_val, mut_emb_train, mut_emb_val, df_train, df_val = train_test_split(wt_emb, mut_emb, df, test_size=0.2, random_state=42)

# Define the dataset class
class EmbeddingDataset(Dataset):
  def __init__(self, wt_pt, mut_pt, data_df):
    self.pt_wt = wt_pt
    self.pt_mut = mut_pt
    self.df = data_df

  def __len__(self):
    return len(self.pt_wt)

  def __getitem__(self, index):
    if "ddg" in self.df.columns:
      df_out = torch.Tensor([self.df.iloc[index]["ddg"]])
    else:
      df_out = torch.Tensor([self.df.iloc[index]["ID"]])

    return self.pt_wt[index,:], self.pt_mut[index,:], df_out

# Create separate datasets for the training and validation sets
train_dataset = EmbeddingDataset(wt_emb_train, mut_emb_train, df_train.reset_index(drop=True))
val_dataset = EmbeddingDataset(wt_emb_val, mut_emb_val, df_val.reset_index(drop=True))

# Create dataloaders for the training and validation sets
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

In [28]:
len(train_dataloader)

8495

In [29]:
len(val_dataloader)

2124

In [23]:
import torch
import torch.nn as nn

class LSTMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(LSTMWithAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # LSTM layer
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)

        # Attention mechanism
        self.attention = nn.Linear(hidden_dim, 1)

        # Fully connected layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, wt_emb, mut_emb):
        # Pass the embeddings through the LSTM layer
        _, (hidden_state, _) = self.lstm(wt_emb)

        if isinstance(hidden_state, tuple):
            # If hidden_state is a tuple (hidden_state, cell_state)
            hidden_state = hidden_state[0]

        # Apply attention mechanism
        attn_weights = torch.softmax(self.attention(hidden_state), dim=1)
        context_vector = torch.bmm(attn_weights.unsqueeze(2), hidden_state.unsqueeze(1)).squeeze(1)

        # Pass the context vector through the fully connected layer
        output = self.fc(context_vector)
        
        return output



In [24]:
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate the LSTM model with attention
input_dim = 1280  # Assuming the input dimensions of your embeddings are 1280
hidden_dim = 256
num_layers = 1
output_dim = 1
model = LSTMWithAttention(input_dim, hidden_dim, num_layers, output_dim).to(device)
# Training parameters
epochs = 10
learning_rate = 0.001
# Define the loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    model.train()
    for i, (wt, mut, ddg) in enumerate(train_dataloader):
        wt, mut, ddg = wt.to(device), mut.to(device), ddg.to(device)

        # Forward pass
        outputs = model(wt, mut)
        loss = criterion(outputs, ddg)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_losses = []
    with torch.no_grad():
        for i, (wt_val, mut_val, ddg_val) in enumerate(val_dataloader):
            wt_val, mut_val, ddg_val = wt_val.to(device), mut_val.to(device), ddg_val.to(device)

            # Forward pass and calculate loss
            outputs = model(wt_val, mut_val)
            val_loss = criterion(outputs, ddg_val)
            val_losses.append(val_loss.item())

    avg_val_loss = sum(val_losses) / len(val_losses)
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {loss.item()}, Validation Loss: {avg_val_loss}")

# Save the trained model
torch.save(model.state_dict(), 'lstm_with_attention.pth')


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/10, Train Loss: 1.4516680240631104, Validation Loss: 1.0927000908696718


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 2/10, Train Loss: 0.7721603512763977, Validation Loss: 1.0848169151812848
Epoch 3/10, Train Loss: 2.2502424716949463, Validation Loss: 1.0843591267302661


KeyboardInterrupt: ignored

In [None]:
# Example of training script
device = torch.device("cuda")
model =  StabilityModel().to(device)
optimizer = torch.optim.Adadelta(model.parameters(), lr=0.0001)
criterion = torch.nn.MSELoss()
epoch_loss = 0
for i in range(1):
  epoch_loss = 0
  for batch_idx, (data_mut,data_wt , target) in tqdm(enumerate(train_dataloader)):
      # extract input from datallader
      x1 = data_wt.to(device)
      x2 = data_mut.to(device)
      y = target.to(device)
      # make prediction
      y_pred = model(x1,x2)
      # calculate loss and run optimizer
      loss = torch.sqrt(criterion(y, y_pred))
      loss.backward()
      optimizer.step()
      epoch_loss += loss
  print("epoch_",i," = ", epoch_loss/len(train_dataloader))
  # [Recommended] Save trained models to select best checkpoint for prediction (or add prediction in the epochs loop)

## Prediction & submission

In [None]:
# load embedding tensors & traing csv
wt_test_emb = torch.load("test/test_wt.pt")
mut_test_emb = torch.load("test/test_mut.pt")
df_test = pd.read_csv("test/test.csv")

In [None]:
# creating testing dataset and loading the embedding
test_dataset = EmbeddingDataset(wt_test_emb,mut_test_emb,df_test)
# preparing a dataloader for the testing
test_dataloader = torch.utils.data.dataloader.DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=2,
    )

In [None]:
df_result = pd.DataFrame()
with torch.no_grad():
  for batch_idx, (data_mut,data_wt , target) in tqdm(enumerate(test_dataloader)):
    x1 = data_wt.to(device)
    x2 = data_mut.to(device)
    id = target.to(device)
    # make prediction
    y_pred = model(x1,x2)
    df_result = pd.concat([df_result, pd.DataFrame({"ID":id.squeeze().cpu().numpy().astype(int) , "ddg" : y_pred.squeeze().cpu().numpy()})])

In [None]:
# df_result.to_csv("submission.csv",index=False)