In [None]:
import os
import pandas as pd
import numpy as np
from Bio import SeqIO
import torch
import esm

# Set random seeds for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Check if GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Read FASTA files and assign labels
def read_fasta_and_label(file_path, label):
    sequences = []
    labels = []
    for record in SeqIO.parse(file_path, "fasta"):
        sequences.append(str(record.seq))
        labels.append(label)
    return pd.DataFrame({'sequence': sequences, 'label': labels})

# Load positive and negative sequences
positive_df = read_fasta_and_label('/content/sample_data/clustered_positive_0.4.fasta', 1)
negative_df = read_fasta_and_label('/content/sample_data/clustered_negative_0.4.fasta', 0)

# Combine and shuffle the data
data = pd.concat([positive_df, negative_df]).sample(frac=1, random_state=random_seed).reset_index(drop=True)
print(f'Number of positive samples: {len(positive_df)}')
print(f'Number of negative samples: {len(negative_df)}')

# Load ESM model
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
model = model.to(device)  # Move the model to the GPU
batch_converter = alphabet.get_batch_converter()

# Function to extract ESM features
def extract_esm_features(sequences):
    data = [("seq"+str(i), seq) for i, seq in enumerate(sequences)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)  # Move the tokens to the GPU
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[6])
    token_representations = results["representations"][6]
    return token_representations.mean(1).cpu().numpy()  # Get the mean representation for each sequence

# Extract ESM features for all sequences in the dataset
data['esm_features'] = data['sequence'].apply(lambda x: extract_esm_features([x])[0])

# Convert ESM features to a format suitable for saving
X = np.vstack(data['esm_features'].values)
y = data['label'].values

# Save features and labels to files
np.save('esm_features.npy', X)
np.save('labels.npy', y)

print("ESM features and labels have been saved locally.")
