In [None]:
%pip install proteinshake

# ProteinShake Point Cloud Implementation
Adapts SequenceLSTM model to work with ProteinShake's testing framework

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import sys
from typing import List, Tuple, Optional
from proteinshake.tasks import EnzymeClassTask, DummyModel

def print_immediate(msg):
    print(msg)
    sys.stdout.flush()

def collate_point_clouds(batch):
    point_clouds = []
    labels = []

    for point_cloud, metadata in batch:
        point_clouds.append(point_cloud)
        ec_class = metadata['protein']['EC']
        label = int(ec_class.split('.')[0]) - 1
        labels.append(label)

    max_points = max(pc.shape[0] for pc in point_clouds)
    padded_clouds = []
    masks = []

    for pc in point_clouds:
        pad_length = max_points - pc.shape[0]
        if pad_length > 0:
            padding = torch.zeros(pad_length, pc.shape[1])
            padded_pc = torch.cat([pc, padding], dim=0)
        else:
            padded_pc = pc

        mask = torch.ones(pc.shape[0])
        if pad_length > 0:
            mask = torch.cat([mask, torch.zeros(pad_length)])
        padded_clouds.append(padded_pc)
        masks.append(mask)

    batch_clouds = torch.stack(padded_clouds)
    batch_masks = torch.stack(masks)
    batch_labels = torch.tensor(labels, dtype=torch.long)
    return batch_clouds, batch_masks, batch_labels

## Model Architecture

In [None]:
class PointCloudToSequenceAdapter(nn.Module):
    def __init__(self, point_cloud_dim: int = 4, sequence_embed_dim: int = 128):
        super().__init__()
        self.point_conv1 = nn.Conv1d(point_cloud_dim, 64, 1)
        self.point_conv2 = nn.Conv1d(64, 128, 1)
        self.point_conv3 = nn.Conv1d(128, sequence_embed_dim, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(sequence_embed_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, point_cloud, mask=None):
        batch_size, num_points, point_dim = point_cloud.shape
        x = point_cloud.transpose(1, 2)
        x = F.relu(self.bn1(self.point_conv1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.point_conv2(x)))
        x = self.dropout(x)
        x = F.relu(self.bn3(self.point_conv3(x)))
        x = x.transpose(1, 2)
        if mask is not None:
            mask = mask.unsqueeze(-1)
            x = x * mask
        return x

class SequenceLSTMBaseline(nn.Module):
    def __init__(self, num_classes: int = 7, hidden_dim: int = 128, max_length: int = 500, 
                 input_embed_dim: int = 128, num_layers: int = 2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.3 if num_layers > 1 else 0
        )
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
        total_params = sum(p.numel() for p in self.parameters())
        print_immediate(f"SequenceLSTM initialized with {total_params:,} parameters")

    def forward(self, embedded_sequences, mask=None):
        lstm_out, _ = self.lstm(embedded_sequences)
        if mask is not None:
            mask_expanded = mask.unsqueeze(-1)
            lstm_out = lstm_out * mask_expanded
        attn_mask = ~mask.bool() if mask is not None else None
        attended, _ = self.attention(lstm_out, lstm_out, lstm_out, key_padding_mask=attn_mask)
        if mask is not None:
            mask_expanded = mask.unsqueeze(-1)
            attended = attended * mask_expanded
            pooled = attended.sum(dim=1) / mask.sum(dim=1, keepdim=True).clamp(min=1)
        else:
            pooled = attended.mean(dim=1)
        logits = self.classifier(pooled)
        return logits

## ProteinShake Integration

In [None]:
class ProteinShakePointLSTMModel(DummyModel):
    def __init__(self, task, device='cuda' if torch.cuda.is_available() else 'cpu'):
        super().__init__(task)
        self.device = device
        sample_data, _ = task.train[0]
        point_cloud_shape = sample_data.shape
        print_immediate(f"Point cloud shape: {point_cloud_shape}")
        print_immediate(f"Using device: {device}")
        self.point_adapter = PointCloudToSequenceAdapter(
            point_cloud_dim=point_cloud_shape[-1],
            sequence_embed_dim=64
        )
        self.sequence_model = SequenceLSTMBaseline(
            num_classes=task.num_classes,
            hidden_dim=64,
            max_length=1000,
            input_embed_dim=64,
            num_layers=1
        )
        self.point_adapter.to(device)
        self.sequence_model.to(device)
        self.optimizer = torch.optim.Adam(
            list(self.point_adapter.parameters()) + list(self.sequence_model.parameters()),
            lr=0.001,
            weight_decay=1e-5
        )
        self.criterion = nn.CrossEntropyLoss()
        total_params = (sum(p.numel() for p in self.point_adapter.parameters()) +
                       sum(p.numel() for p in self.sequence_model.parameters()))
        print_immediate(f"Total model parameters: {total_params:,}")

    def train_step(self, batch):
        self.point_adapter.train()
        self.sequence_model.train()
        point_clouds, masks, labels = batch
        point_clouds = point_clouds.to(self.device)
        masks = masks.to(self.device)
        labels = labels.to(self.device)
        self.optimizer.zero_grad()
        sequence_embeddings = self.point_adapter(point_clouds, masks)
        logits = self.sequence_model(sequence_embeddings, masks)
        loss = self.criterion(logits, labels)
        loss.backward()
        self.optimizer.step()
        return {
            'loss': loss.item(),
            'accuracy': (logits.argmax(dim=1) == labels).float().mean().item()
        }

    def test_step(self, test_data):
        self.point_adapter.eval()
        self.sequence_model.eval()
        predictions = []
        with torch.no_grad():
            test_loader = DataLoader(test_data, batch_size=8, shuffle=False, collate_fn=collate_point_clouds)
            for batch in test_loader:
                point_clouds, masks, labels = batch
                point_clouds = point_clouds.to(self.device)
                masks = masks.to(self.device)
                sequence_embeddings = self.point_adapter(point_clouds, masks)
                logits = self.sequence_model(sequence_embeddings, masks)
                preds = logits.argmax(dim=1)
                predictions.extend(preds.cpu().numpy())
        return np.array(predictions)

## Testing Function

In [None]:
def test_proteinshake_point_model():
    print_immediate("Loading ProteinShake EnzymeClassTask with point cloud representation...")
    task = EnzymeClassTask().to_point().torch()
    print_immediate("Task loaded successfully!")
    print_immediate(f"Number of classes: {task.num_classes}")
    print_immediate(f"Train set size: {len(task.train)}")
    print_immediate(f"Test set size: {len(task.test)}")
    print_immediate("Creating model...")
    model = ProteinShakePointLSTMModel(task)
    print_immediate("Starting training...")
    num_epochs = 10
    batch_size = 8
    train_loader = DataLoader(task.train, batch_size=batch_size, shuffle=True, collate_fn=collate_point_clouds)
    for epoch in range(num_epochs):
        print_immediate(f"\nEpoch {epoch+1}/{num_epochs}")
        epoch_losses = []
        epoch_accs = []
        for batch_idx, batch in enumerate(train_loader):
            metrics = model.train_step(batch)
            epoch_losses.append(metrics['loss'])
            epoch_accs.append(metrics['accuracy'])
        avg_loss = np.mean(epoch_losses)
        avg_acc = np.mean(epoch_accs)
        print_immediate(f"Epoch {epoch+1} Summary - Loss: {avg_loss:.4f} | Accuracy: {avg_acc:.4f}")
    print_immediate("\nTraining completed! Starting evaluation...")
    prediction = model.test_step(task.test)
    metrics = task.evaluate(task.test_targets, prediction)
    print_immediate("\nFinal Evaluation Results:")
    for metric_name, metric_value in metrics.items():
        print_immediate(f"  {metric_name}: {metric_value}")
    return metrics

if __name__ == "__main__":
    test_proteinshake_point_model()