In [12]:
def create_model(
    rna_vocab_size: int = 5,
    protein_vocab_size: int = 21,
    **kwargs
) -> RNAProteinInteractionModel:
    """
    Factory function to create a model instance.

    Args:
        rna_vocab_size: Size of RNA vocabulary
        protein_vocab_size: Size of protein vocabulary
        **kwargs: Additional arguments passed to the model

    Returns:
        Initialized model
    """
    return RNAProteinInteractionModel(
        rna_vocab_size=rna_vocab_size,
        protein_vocab_size=protein_vocab_size,
        **kwargs
    )



model = create_model()
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Example input (batch_size=2, sequence lengths can vary with padding)
batch_size = 2
rna_seq_len = 100
protein_seq_len = 200

# Random sequences (in practice, these would be encoded from actual sequences)
rna_seq = torch.randint(1, 5, (batch_size, rna_seq_len))  # 1-4 for A,C,G,U
protein_seq = torch.randint(1, 21, (batch_size, protein_seq_len))  # 1-20 for amino acids

# Forward pass
model.eval()
with torch.no_grad():
    logits = model(rna_seq, protein_seq)
    probabilities = model.predict_proba(rna_seq, protein_seq)
    predictions = model.predict(rna_seq, protein_seq)

print(f"\nLogits shape: {logits.shape}")
print(f"Logits: {logits.squeeze()}")
print(f"\nProbabilities: {probabilities.squeeze()}")
print(f"Predictions: {predictions.squeeze()}")


Model created with 2,764,545 parameters

Logits shape: torch.Size([2, 1])
Logits: tensor([-0.7329, -1.0068])

Probabilities: tensor([0.3246, 0.2676])
Predictions: tensor([0, 0])
