In [1]:
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder

max_seq_len = 1000
batch_size = 32
device = "cuda:1" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [None]:
df = pd.read_csv("../data/taxonomy_data_10000.csv")
lens = df['Sequence'].apply(len)
df = df[lens < max_seq_len]
print(df.shape)

In [None]:
def list_encoder(s):
    return [int(i) for i in s.replace("[", "").replace("]", "").split(", ")]

# Example list of taxonomy IDs for multiple sequences
taxonomy_ids_list = [list_encoder(tax_id_str) for tax_id_str in df['tree trace'].values]

set_list = [set(t) for t in taxonomy_ids_list]
union_set = set().union(*set_list)
num_taxonomy_ids = len(union_set)
print("num_taxonomy_ids: ", len(union_set))

# Flatten the list of taxonomy IDs and get unique taxonomy IDs
all_taxonomy_ids = set([tax_id for sublist in taxonomy_ids_list for tax_id in sublist])
# Create a mapping from taxonomy ID to index
taxonomy_id_to_idx = {tax_id: idx + 1 for idx, tax_id in enumerate(all_taxonomy_ids)}
taxonomy_idx_to_id = {idx + 1: tax_id for idx, tax_id in enumerate(all_taxonomy_ids)}

# Apply the mapping to each list of taxonomy IDs
mapped_taxonomy_ids_list = [[taxonomy_id_to_idx[tax_id] for tax_id in tax_ids] + [0 for _ in range(max_tax_len - len(tax_ids))] for tax_ids in taxonomy_ids_list]

# Initialize LabelEncoder
le = LabelEncoder()

# Fit and transform the labels
numeric_labels = le.fit_transform(df['Organism (ID)'].values)
print("Numeric labels:", numeric_labels)

union_set = set(df['Organism (ID)'].values)
print("num_organism_ids: ", len(union_set))
num_organism_ids = len(union_set)

# all_organism_ids = set(df['Organism (ID)'].values)
# # Create a mapping from taxonomy ID to index
# organism_id_to_idx = {org_id: idx for idx, org_id in enumerate(all_organism_ids)}
# organism_idx_to_id = {idx: org_id for idx, org_id in enumerate(all_organism_ids)}
# 
# # Apply the mapping to each list of taxonomy IDs
# mapped_organism_ids = [organism_id_to_idx[org_id] for org_id in df['Organism (ID)'].values]


def encode_taxonomy(taxonomy):
    return [1 if _ in taxonomy else 0 for _ in range(num_taxonomy_ids)]

In [None]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, taxonomy_ids, taxes):
        self.sequences = sequences
        self.taxonomy_ids = taxonomy_ids
        self.taxes = taxes

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        taxonomy = self.taxonomy_ids[idx]
        tax = self.taxes[idx]

        # Encode sequence and taxonomy (example encoding for demonstration)
        sequence_encoded = torch.tensor(encode_sequence(sequence), dtype=torch.long)
        taxonomy_encoded = torch.tensor(encode_taxonomy(taxonomy), dtype=torch.long)
        tax_encoded = torch.tensor(tax, dtype=torch.long)

        return sequence_encoded, taxonomy_encoded, tax_encoded

In [None]:
# Split the dataset into training, validation, and test sets
train_sequences, test_sequences, train_taxonomy, test_taxonomy, train_tax, test_tax = train_test_split(df['Sequence'].values, mapped_taxonomy_ids_list, numeric_labels, test_size=0.2, random_state=42)
train_sequences, val_sequences, train_taxonomy, val_taxonomy, train_tax, val_tax = train_test_split(train_sequences, train_taxonomy, train_tax, test_size=0.25, random_state=42)

# Create Dataset objects
train_dataset = ProteinDataset(train_sequences, train_taxonomy, train_tax)
val_dataset = ProteinDataset(val_sequences, val_taxonomy, val_tax)
test_dataset = ProteinDataset(test_sequences, test_taxonomy, test_tax)

# Create DataLoader objects for batching
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Example usage of the DataLoader
for sequences_batch, taxonomy_batch, tax_batch in train_loader:
    print(f"Sequences batch shape: {sequences_batch.shape}")
    print(f"Taxonomy batch shape: {taxonomy_batch.shape}")
    print(f"Organism batch shape: {tax_batch.shape}")  
    print(tax_batch)
    break  # Print one batch and exit

In [None]:
class SimpleAttentionClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_taxonomy_ids, num_attention_layers=3, dropout_rate=0.1):
        super(SimpleAttentionClassifier, self).__init__()
        
        # Embedding layer for sequences
        self.sequence_embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Stack of attention layers with normalization, dropout, and skip connections
        self.attention_layers = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=4, batch_first=True) for _ in range(num_attention_layers)
        ])
        
        # Layer normalization for each attention layer
        self.norm_layers = nn.ModuleList([nn.LayerNorm(embedding_dim) for _ in range(num_attention_layers)])
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout_rate)
        
        # Fully connected layer for predicting taxonomy
        self.fc = nn.Linear(embedding_dim, num_taxonomy_ids)

    def forward(self, sequences):
        # Embed the input sequences
        embedded_seq = self.sequence_embedding(sequences)  # (batch_size, seq_len, embed_dim)
        
        # Pass through multiple attention layers with skip connections, layer normalization, and dropout
        for attention_layer, norm_layer in zip(self.attention_layers, self.norm_layers):
            # Attention mechanism (self-attention here)
            attn_output, _ = attention_layer(embedded_seq, embedded_seq, embedded_seq)
            
            # Add skip connection: output + input
            attn_output = attn_output + embedded_seq  # Skip connection (Residual connection)
            
            # Apply normalization
            attn_output = norm_layer(attn_output)
            
            # Apply dropout
            attn_output = self.dropout(attn_output)
            
            # Update input for the next attention layer
            embedded_seq = attn_output
        
        # Mean pooling across the sequence length dimension
        attn_output = attn_output.mean(dim=1)  # (batch_size, embed_dim)
        
        # Pass through a fully connected layer to predict taxonomy IDs
        output = self.fc(attn_output)  # (batch_size, num_taxonomy_ids)
        
        return output


In [None]:
class SequenceClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, max_seq_len):
        super(SequenceClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)  # +1 for padding
        self.conv1 = nn.Conv1d(embedding_dim, 128, kernel_size=5, padding=2)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.fc1 = nn.Linear(128 * (max_seq_len // 2), 512)  # Output size depends on conv and pooling layers
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.embedding(x).permute(0, 2, 1)  # (batch_size, embedding_dim, seq_len)
        x = self.conv1(x)
        x = self.pool(torch.relu(x))  # Max pooling
        x = x.view(x.size(0), -1)  # Flatten the output for the fully connected layer
        x = torch.relu(self.fc1(x))
        output = self.fc2(x)  # No softmax here, because we'll use CrossEntropyLoss, which applies it internally
        return output


In [None]:
class FNNClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, max_seq_len):
        super(FNNClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)  # +1 for padding
        self.fc1 = nn.Linear(embedding_dim * max_seq_len, 512)  # Output size depends on conv and pooling layers
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = x.view(x.size(0), -1)  # Flatten the output for the fully connected layer
        x = torch.relu(self.fc1(x))
        output = self.fc2(x)  # No softmax here, because we'll use CrossEntropyLoss, which applies it internally
        return output


In [None]:
# Hyperparameters
vocab_size = len(vocab) + 1
embedding_dim = 64

# Instantiate the model
# model_cls = SequenceClassifier(vocab_size, embedding_dim, num_taxonomy_ids, max_seq_len).to(device)
# model_cls = SimpleAttentionClassifier(vocab_size=vocab_size, embedding_dim=embedding_dim, num_taxonomy_ids=num_taxonomy_ids).to(device)
model_cls = FNNClassifier(vocab_size, embedding_dim, num_taxonomy_ids, max_seq_len).to(device)
criterion_cls = nn.CrossEntropyLoss()  # Multi-label classification requires BCEWithLogitsLoss
optimizer_cls = optim.Adam(model_cls.parameters(), lr=0.001)

In [None]:
# Hyperparameters
vocab_size = len(vocab) + 1  # +1 for padding
embedding_dim = 64

# Instantiate the model
model = SimpleAttentionClassifier(vocab_size=vocab_size, embedding_dim=embedding_dim, num_taxonomy_ids=num_taxonomy_ids).to(device)
criterion = nn.BCEWithLogitsLoss()  # Multi-label classification requires BCEWithLogitsLoss
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
print("classifier_model:", sum(p.numel() for p in model_cls.parameters()) / 1e6, 'M parameters')
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

In [None]:
num_epochs = 10
Method = "Transformer" # "CNN", "Transformer"

In [None]:
def evaluate(model, test_loader):
    model.to(device)
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    
    all_index = []
    all_taxes = []
    all_preds = []
    all_labels = []

    with torch.no_grad():  # Disable gradient computation during evaluation
        for batch_idx, (sequences, taxonomy_ids, taxes) in enumerate(test_loader):
            sequences = sequences.to(device)
            taxonomy_ids = taxonomy_ids.to(device)
            taxes = taxes.to(device)
            
            if Method == "Transformer":
                outputs = model(sequences)
                # Calculate loss
                loss = criterion(outputs, taxonomy_ids.float())
                preds = torch.sigmoid(outputs) > 0.6  # Binary predictions
                
                all_preds.append(preds.cpu())
                all_labels.append(taxonomy_ids.cpu())
                all_taxes.append(taxes.cpu())
            else:
                outputs = model_cls(sequences)
                loss = criterion_cls(outputs, taxes)
                
                outputs = nn.Softmax(dim=0)(outputs)
                index = torch.argmax(outputs, dim=1)
                
                all_index.append(index.cpu())
                all_taxes.append(taxes.cpu())
            
            running_loss += loss.item()

    # Concatenate all batches into single tensors
    if Method == "Transformer":
        all_preds = torch.cat(all_preds).int()
        all_labels = torch.cat(all_labels)
        all_taxes = torch.cat(all_taxes)
        print(all_preds[0])
        print(all_labels[0])
        print(all_taxes[0])
        print(le.inverse_transform(all_taxes[0]))
        for p, l in zip(all_preds, all_labels):
            print(*[taxonomy_idx_to_id[i] for i, p_ in enumerate(p) if p_])
            print(*[taxonomy_idx_to_id[i] for i, l_ in enumerate(l) if l_])
            break
    else:
        all_index = torch.cat(all_index)
        all_taxes = torch.cat(all_taxes)
        print(all_index)
        print(all_taxes)
        # for i, t in zip(all_index, all_taxes):
        #     print(t, i)
        #     break
    
    if Method =="Transformer":
        accuracy = np.mean([accuracy_score(p, l) for p, l in zip(all_preds, all_labels)])    
        f1 = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    else:
        accuracy = accuracy_score(all_taxes, all_index)
        f1 = f1_score(all_taxes, all_index, average='macro')
        
    avg_loss = running_loss / len(test_loader)
    return avg_loss, accuracy, f1

In [None]:
train_losses = []
val_losses = []
val_accuracies = []
val_f1_scores = []
tresholds = []

for epoch in range(num_epochs):
    model_cls.train()  # Set model to training mode
    running_loss = 0.0

    for batch_idx, (sequences, taxonomy_ids, taxes) in tqdm(enumerate(train_loader)):
        sequences = sequences.to(device)
        taxonomy_ids = taxonomy_ids.to(device)
        taxes = taxes.to(device)

        if Method == "Transformer":
            outputs = model(sequences)
            loss = criterion(outputs, taxonomy_ids.float())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            outputs = model_cls(sequences)
            loss = criterion_cls(outputs, taxes)

            # Backpropagation: Zero the gradients, compute the backward pass, and update weights
            optimizer_cls.zero_grad()
            loss.backward()
            optimizer_cls.step()

        # Track the loss
        running_loss += loss.item()

    # Print loss for this epoch
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {epoch_loss:.4f}")
    
    # Evaluate the model on the test set
    val_loss, val_accuracy, val_f1 = evaluate(model, val_loader)
    print(f"val Loss: {val_loss:.4f}, val Accuracy: {val_accuracy:.4f}, val F1 Score: {val_f1:.4f}")
    

    # Store validation loss, accuracy, and F1 score
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    val_f1_scores.append(val_f1)
    # tresholds.append(t)

In [None]:
import matplotlib.pyplot as plt

# Plot losses and accuracy after training
plt.figure(figsize=(12, 5))

# Plot training and validation loss
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss per Epoch')
plt.legend()

# Plot validation accuracy and F1 score
plt.subplot(1, 2, 2)
plt.plot(val_accuracies, label='Validation Accuracy')
plt.plot(val_f1_scores, label='Validation F1 Score')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.title('Accuracy and F1 Score per Epoch')
plt.legend()

# Show plots
plt.tight_layout()
plt.show()