In [None]:
!pip install unrar

In [None]:
!unrar x training_testing_datasets.rar

In [None]:
!pip install biopython

In [None]:
!pip install fair-esm

In [None]:
import torch

from pathlib import Path
from argparse import Namespace
from tqdm import tqdm

from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer

In [None]:
model_location = "esm2_t30_150M_UR50D"
include = "mean"
repr_layers = [-1]
truncation_seq_length = 1022
nogpu = False
toks_per_batch = 4096

config = Namespace(
    model_location = model_location,
    fasta_file = Path("/content/training.fasta"),
    repr_layers = repr_layers,
    include = include,
    output_dir = Path("/content/train_embeddings_150M/"),
    truncation_seq_length = truncation_seq_length,
    nogpu = nogpu,
    toks_per_batch = toks_per_batch
)

test_config = Namespace(
    model_location = model_location,
    fasta_file = Path("/content/testing.fasta"),
    repr_layers = repr_layers,
    include = include,
    output_dir = Path("/content/test_embeddings_150M/"),
    truncation_seq_length = truncation_seq_length,
    nogpu = nogpu,
    toks_per_batch = toks_per_batch
)


def run(args):
    model, alphabet = pretrained.load_model_and_alphabet(args.model_location)
    model.eval()
    if isinstance(model, MSATransformer):
        raise ValueError(
            "This script currently does not handle models with MSA input (MSA Transformer)."
        )
    if torch.cuda.is_available() and not args.nogpu:
        model = model.cuda()
        print("Transferred model to GPU")

    dataset = FastaBatchedDataset.from_file(args.fasta_file)
    batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches
    )
    print(f"Read {args.fasta_file} with {len(dataset)} sequences")

    args.output_dir.mkdir(parents=True, exist_ok=True)
    return_contacts = "contacts" in args.include

    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in tqdm(enumerate(data_loader)):

            if torch.cuda.is_available() and not args.nogpu:
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)

            logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")

            for i, label in enumerate(labels):
                args.output_file = args.output_dir / f"{label}.pt"
                args.output_file.parent.mkdir(parents=True, exist_ok=True)
                result = {"label": label}
                truncate_len = min(args.truncation_seq_length, len(strs[i]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in args.include:
                    result["representations"] = {
                        layer: t[i, 1 : truncate_len + 1].clone()
                        for layer, t in representations.items()
                    }
                if "mean" in args.include:
                    result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                if "bos" in args.include:
                    result["bos_representations"] = {
                        layer: t[i, 0].clone() for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()

                torch.save(
                    result,
                    args.output_file,
                )

In [None]:
run(config)
run(test_config)

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

class EmbeddingsDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = Path(root_dir)
        self.file_list = [f for f in os.listdir(self.root_dir) if f.endswith(".pt")]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        full_path = self.root_dir / file_name
        embeddings = torch.load(full_path)['mean_representations'][30]

        class_label = int("|1|" in file_name)

        return embeddings, class_label

# Example usage
data_directory = "/content/train_embeddings_150M"
dataset = EmbeddingsDataset(data_directory)

# Create a DataLoader for the dataset
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Example: Iterate through batches in the DataLoader
for batch_embeddings, batch_labels in dataloader:
    # Perform your operations on the embeddings and labels
    print("Batch embeddings shape:", batch_embeddings.shape)
    print("Batch labels:", batch_labels)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pathlib import Path
from sklearn.model_selection import train_test_split
import numpy as np

# Define the MLP binary classifier with dropout
class MLPClassifier(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, dropout_rate):
        super(MLPClassifier, self).__init__()
        layers = []
        prev_size = input_size
        for size in hidden_sizes:
            layers.append(nn.Linear(prev_size, size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))  # Add dropout layer
            prev_size = size
        self.hidden_layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(prev_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.hidden_layers(x)
        x = self.output_layer(x)
        x = self.sigmoid(x)
        return x


# Load and preprocess the dataset
data_directory = "/content/train_embeddings_150M"
dataset = EmbeddingsDataset(data_directory)  # Assuming you have defined the EmbeddingsDataset class

# Split the dataset into training and validation sets
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoader instances for training and validation
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

# Initialize the model, loss function, and optimizer
input_size = dataset[0][0].shape[0]  # Assuming the embeddings have a fixed size
hidden_sizes = [320, 160, 80, 40]  # Additional hidden layers
output_size = 1  # Binary classification
learning_rate = 0.003
dropout_rate = 0.2

model = MLPClassifier(input_size, hidden_sizes, output_size, dropout_rate)
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Early stopping parameters
patience = 3  # Number of epochs to wait without improvement
min_val_loss = np.inf
counter = 0

# Training loop
num_epochs = 6

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_embeddings, batch_labels in train_dataloader:
        optimizer.zero_grad()

        outputs = model(batch_embeddings.float())  # Convert embeddings to float
        loss = criterion(outputs.squeeze(), batch_labels.float())  # Squeeze to match dimensions
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}")

    # Validation loop
    model.eval()
    val_loss = 0
    correct_predictions = 0

    with torch.no_grad():
        for batch_embeddings, batch_labels in val_dataloader:
            outputs = model(batch_embeddings.float())
            loss = criterion(outputs.squeeze(), batch_labels.float())
            val_loss += loss.item()

            predicted_labels = (outputs >= 0.5).long()
            correct_predictions += (predicted_labels == batch_labels).sum().item()

    avg_val_loss = val_loss / len(val_dataloader)
    val_accuracy = correct_predictions / len(val_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    # Early stopping
    if avg_val_loss < min_val_loss:
        min_val_loss = avg_val_loss
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping!")
            break


# Benhcmarking

In [None]:
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score

In [None]:
model.eval()

# Load the test dataset
test_directory = "/content/test_embeddings_150M"
test_dataset = EmbeddingsDataset(test_directory)  # Assuming you have defined the EmbeddingsDataset class
test_dataloader = DataLoader(test_dataset, batch_size=32)  # Use an appropriate batch size

# Evaluate the model on the test dataset
true_labels = []
predicted_labels = []

with torch.no_grad():
    for batch_embeddings, batch_labels in test_dataloader:
        outputs = model(batch_embeddings.float())
        predicted_batch = (outputs >= 0.5).long()

        true_labels.extend(batch_labels.tolist())
        predicted_labels.extend(predicted_batch.tolist())

# Calculate evaluation metrics
conf_matrix = confusion_matrix(true_labels, predicted_labels)
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels)
recall = recall_score(true_labels, predicted_labels)

# Print evaluation metrics
print("Confusion Matrix:")
print(conf_matrix)
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)

In [None]:
save_path = "/content/solubility_model.pth"
torch.save(model.state_dict(), save_path)