In [None]:
#@title Download data from GCP bucket
import sys

if 'google.colab' in sys.modules:
  !gsutil -m cp -r gs://indaba-data .
else:
  !mkdir -p indaba-data/train
  !wget -P indaba-data/train https://storage.googleapis.com/indaba-data/train/train.csv --continue
  !wget -P indaba-data/train https://storage.googleapis.com/indaba-data/train/train_mut.pt --continue
  !wget -P indaba-data/train https://storage.googleapis.com/indaba-data/train/train_wt.pt --continue

  !mkdir -p indaba-data/test
  !wget -P indaba-data/test https://storage.googleapis.com/indaba-data/test/test.csv --continue
  !wget -P indaba-data/test https://storage.googleapis.com/indaba-data/test/test_mut.pt --continue
  !wget -P indaba-data/test https://storage.googleapis.com/indaba-data/test/test_wt.pt --continue

Copying gs://indaba-data/README.txt...
/ [0 files][    0.0 B/   33.0 B]                                                Copying gs://indaba-data/test/test.csv...
/ [0 files][    0.0 B/290.0 KiB]                                                Copying gs://indaba-data/test/test_mut.pt...
/ [0 files][    0.0 B/  9.6 MiB]                                                Copying gs://indaba-data/train/train.csv...
/ [0/9 files][    0.0 B/  3.3 GiB]   0% Done                                    Copying gs://indaba-data/test/test_wt.pt...
/ [0/9 files][    0.0 B/  3.3 GiB]   0% Done                                    Copying gs://indaba-data/train/train_mut.pt...
/ [0/9 files][    0.0 B/  3.3 GiB]   0% Done                                    Copying gs://indaba-data/train/train_wt.pt...
/ [0/9 files][    0.0 B/  3.3 GiB]   0% Done                                    ==> NOTE: You are downloading one or more large file(s), which would
run significantly faster if you enabled sliced object dow

In [None]:
#@title Imports and moving to working directory
import torch 
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader


# move to data folder
%cd indaba-data

/content/indaba-data


In [None]:
# Load Embedding tensors & Traing csv
# Embeddings were calculated using the ESM 650M pretrained model 
# Tensor shape of embedded data:  [data_len,1280] 
# There are no sequences in the Embedding tensors as we've performed an average of it (torch.mean(embed, dim=1))
# More details in https://huggingface.co/facebook/esm2_t33_650M_UR50D

wt_emb = torch.load("train/train_wt.pt")
mut_emb = torch.load("train/train_mut.pt")
df = pd.read_csv("train/train.csv")

In [None]:
len(df)

339778

In [None]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

# Reset the index of the dataframe
df.reset_index(drop=True, inplace=True)

# Split data into train and validation
wt_emb_train, wt_emb_val, mut_emb_train, mut_emb_val, df_train, df_val = train_test_split(wt_emb, mut_emb, df, test_size=0.21, random_state=42)

# Define the dataset class
class EmbeddingDataset(Dataset):
  def __init__(self, wt_pt, mut_pt, data_df):
    self.pt_wt = wt_pt
    self.pt_mut = mut_pt
    self.df = data_df

  def __len__(self):
    return len(self.pt_wt)

  def __getitem__(self, index):
    if "ddg" in self.df.columns:
      df_out = torch.Tensor([self.df.iloc[index]["ddg"]])
    else:
      df_out = torch.Tensor([self.df.iloc[index]["ID"]])

    return self.pt_wt[index,:], self.pt_mut[index,:], df_out

# Create separate datasets for the training and validation sets
train_dataset = EmbeddingDataset(wt_emb_train, mut_emb_train, df_train.reset_index(drop=True))
val_dataset = EmbeddingDataset(wt_emb_val, mut_emb_val, df_val.reset_index(drop=True))

# Create dataloaders for the training and validation sets
train_dataloader = DataLoader(train_dataset, batch_size=26, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=26, shuffle=False, num_workers=2)


In [None]:
len(train_dataloader)

8389

In [None]:
len(val_dataloader)


2230

In [None]:
for i, (wt_emb, mut_emb, ddg) in enumerate(train_dataloader):
    print(f'wt_emb shape: {wt_emb.shape}, mut_emb shape: {mut_emb.shape}')
    if i > 5:  # Print shapes for first 5 batches only
        break


wt_emb shape: torch.Size([32, 1280]), mut_emb shape: torch.Size([32, 1280])
wt_emb shape: torch.Size([32, 1280]), mut_emb shape: torch.Size([32, 1280])
wt_emb shape: torch.Size([32, 1280]), mut_emb shape: torch.Size([32, 1280])
wt_emb shape: torch.Size([32, 1280]), mut_emb shape: torch.Size([32, 1280])
wt_emb shape: torch.Size([32, 1280]), mut_emb shape: torch.Size([32, 1280])
wt_emb shape: torch.Size([32, 1280]), mut_emb shape: torch.Size([32, 1280])
wt_emb shape: torch.Size([32, 1280]), mut_emb shape: torch.Size([32, 1280])


In [None]:
for wt_emb, mut_emb, ddg in train_dataloader:
    print(wt_emb_train.shape, mut_emb_train.shape, ddg.shape)
    break
import torch.nn as nn
import torch.nn.functional as F


torch.Size([268424, 1280]) torch.Size([268424, 1280]) torch.Size([32, 1])


In [None]:
class OptimizedDeepFFNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(OptimizedDeepFFNN, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(),
            nn.Dropout(0.5)
        )
        
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.BatchNorm1d(hidden_dim//2),
            nn.LeakyReLU(),
            nn.Dropout(0.5)
        )
        
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_dim//2, hidden_dim//4),
            nn.BatchNorm1d(hidden_dim//4),
            nn.LeakyReLU(),
            nn.Dropout(0.5)
        )
        
        self.final_layer = nn.Linear(hidden_dim//4, output_dim)
        
    def forward(self, wt_emb, mut_emb):
        x = torch.cat((wt_emb, mut_emb), dim=1)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        out = self.final_layer(x)
        return out

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
input_dim = 1280 * 2  # Multiply by 2 since we're concatenating wt_emb and mut_emb
hidden_dim = 512
output_dim = 1  # ddG value
model = OptimizedDeepFFNN(input_dim, hidden_dim, output_dim)
model.apply(init_weights)

# Training loop
from torch.optim.lr_scheduler import ReduceLROnPlateau

num_epochs = 20  # Adjust as needed
patience = 10  # Patience for early stopping

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, 'min')

best_val_loss = None
counter = 0
for epoch in range(num_epochs):
    model.train()
    train_losses = []

    for wt_emb, mut_emb, ddg in train_dataloader:
        optimizer.zero_grad()
        outputs = model(wt_emb, mut_emb)
        loss = criterion(outputs, ddg)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
    
    avg_train_loss = sum(train_losses) / len(train_losses)

    model.eval()
    with torch.no_grad():
        val_losses = []
        for wt_emb, mut_emb, ddg in val_dataloader:
            outputs = model(wt_emb, mut_emb)
            loss = criterion(outputs, ddg)
            val_losses.append(loss.item())
    
    avg_val_loss = sum(val_losses) / len(val_losses)
    scheduler.step(avg_val_loss)

    print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}')

    # Early stopping
    if best_val_loss is None:
        best_val_loss = avg_val_loss
    elif avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0  # Reset counter if validation loss decreases
    else:
        counter += 1  # Increase counter if validation loss doesn't decrease
        print(f'EarlyStopping counter: {counter} out of {patience}')
        if counter >= patience:
            print('Early stopping triggered.')
            break



Epoch 1, Train Loss: 1.0051616868276194, Val Loss: 0.8386187197381606
Epoch 2, Train Loss: 0.8629007320329307, Val Loss: 0.80184475255654
Epoch 3, Train Loss: 0.8351355713877295, Val Loss: 0.7693133246805102
Epoch 4, Train Loss: 0.8197600219879663, Val Loss: 0.7918837036653484
EarlyStopping counter: 1 out of 10
Epoch 5, Train Loss: 0.8077201493584452, Val Loss: 0.7554841244969133
Epoch 6, Train Loss: 0.7960810999784897, Val Loss: 0.7501168418759188
Epoch 7, Train Loss: 0.7889809217217398, Val Loss: 0.7438798307570641
Epoch 8, Train Loss: 0.7828568889900361, Val Loss: 0.7405863737979812
Epoch 9, Train Loss: 0.7760382743661346, Val Loss: 0.7275966319535345
Epoch 10, Train Loss: 0.770937412668147, Val Loss: 0.7150539197996593
Epoch 11, Train Loss: 0.7670144723355947, Val Loss: 0.7312756195463942
EarlyStopping counter: 1 out of 10
Epoch 12, Train Loss: 0.7624106828444392, Val Loss: 0.7243300113977338
EarlyStopping counter: 2 out of 10
Epoch 13, Train Loss: 0.7568239919501133, Val Loss: 0.7

## Prediction & submission

In [None]:
# load embedding tensors & traing csv
wt_test_emb = torch.load("test/test_wt.pt")
mut_test_emb = torch.load("test/test_mut.pt")
df_test = pd.read_csv("test/test.csv")

In [None]:
# creating testing dataset and loading the embedding
test_dataset = EmbeddingDataset(wt_test_emb,mut_test_emb,df_test)
# preparing a dataloader for the testing
test_dataloader = torch.utils.data.dataloader.DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=2,
    )

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=model.to(device)
df_result = pd.DataFrame()
with torch.no_grad():
  for batch_idx, (data_mut,data_wt , target) in tqdm(enumerate(test_dataloader)):
    x1 = data_wt.to(device)
    x2 = data_mut.to(device)
    id = target.to(device)
    # make prediction
    y_pred = model(x1,x2)
    df_result = pd.concat([df_result, pd.DataFrame({"ID":id.squeeze().cpu().numpy().astype(int) , "ddg" : y_pred.squeeze().cpu().numpy()})])

60it [00:00, 65.83it/s]


In [None]:
df_result.to_csv("submission_deepffnn_LCODETWIL.csv",index=False)