# Novel Multi-Modal Transformer Architecture for Drug Mechanism of Action Prediction

## Abstract
This notebook presents a novel approach to drug mechanism of action (MoA) prediction using a **Multi-Modal Molecular Transformer (M3T)** architecture. Unlike traditional Graph Convolutional Networks, our approach combines:

1. **Molecular BERT-style Transformer** for SMILES sequence encoding
2. **Chemical Descriptor Fusion Network** for molecular properties
3. **Cross-Modal Attention Mechanism** for feature integration
4. **Contrastive Learning** for improved representation learning

## Key Innovations:
- ðŸš€ **Transformer-based molecular encoding** without graph structure assumptions
- ðŸš€ **Multi-modal fusion** of sequential and numerical molecular features
- ðŸš€ **Contrastive pre-training** for better molecular representations
- ðŸš€ **Attention visualization** for mechanism interpretability
- ðŸš€ **Statistical significance testing** against GCN baseline

## Research Hypothesis
We hypothesize that treating molecular SMILES as natural language sequences and applying transformer attention mechanisms will capture long-range molecular dependencies better than local graph convolutions, leading to improved MoA prediction accuracy and interpretability.

In [2]:
# Core libraries
import os
import sqlite3
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Deep learning frameworks
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

# Scientific computing
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer
from sklearn.metrics import (
    accuracy_score, f1_score, roc_auc_score, 
    precision_recall_curve, classification_report,
    multilabel_confusion_matrix
)
from scipy import stats

# Molecular informatics
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect

# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

# Device configuration
device = torch.device("mps" if torch.backends.mps.is_available() else 
                     "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

ModuleNotFoundError: No module named 'rdkit'

## 1. Enhanced Data Acquisition and Preprocessing

We extend the original ChEMBL query to include additional molecular and target information for multi-modal learning.

In [None]:
# Connect to ChEMBL database
db_path = os.path.expanduser("~/Downloads/chembl_35/chembl_35_sqlite/chembl_35.db")
conn = sqlite3.connect(db_path)

# Enhanced query with additional molecular and target information
enhanced_query = """
SELECT DISTINCT
    dm.molregno,
    dm.mechanism_of_action,
    dm.action_type,
    cs.canonical_smiles,
    cs.standard_inchi_key,
    cp.full_mwt,
    cp.alogp,
    cp.hba,
    cp.hbd,
    cp.psa,
    cp.rtb,
    cp.ro3_pass,
    cp.num_ro5_violations,
    td.target_type,
    td.organism
FROM drug_mechanism dm
JOIN compound_structures cs ON dm.molregno = cs.molregno
JOIN compound_properties cp ON dm.molregno = cp.molregno
LEFT JOIN target_dictionary td ON dm.tid = td.tid
WHERE cs.canonical_smiles IS NOT NULL 
    AND dm.mechanism_of_action IS NOT NULL
    AND LENGTH(cs.canonical_smiles) BETWEEN 10 AND 200
LIMIT 15000;
"""

print("Executing enhanced ChEMBL query...")
df_raw = pd.read_sql(enhanced_query, conn)
print(f"Retrieved {len(df_raw)} drug-mechanism pairs")
print(f"Unique mechanisms: {df_raw['mechanism_of_action'].nunique()}")
print(f"Unique molecules: {df_raw['molregno'].nunique()}")

# Display sample data
df_raw.head()

In [None]:
# Data preprocessing and quality control
def preprocess_molecular_data(df):
    """Enhanced preprocessing with molecular validation and feature engineering."""
    
    print("Starting molecular data preprocessing...")
    
    # Remove duplicates and invalid SMILES
    df_clean = df.drop_duplicates(subset=['molregno', 'mechanism_of_action'])
    
    # Validate SMILES strings
    valid_smiles = []
    for smiles in tqdm(df_clean['canonical_smiles'], desc="Validating SMILES"):
        mol = Chem.MolFromSmiles(smiles)
        valid_smiles.append(mol is not None)
    
    df_clean = df_clean[valid_smiles].reset_index(drop=True)
    print(f"Valid molecules after SMILES validation: {len(df_clean)}")
    
    # Filter mechanisms with sufficient samples (min 5 examples)
    mechanism_counts = df_clean['mechanism_of_action'].value_counts()
    frequent_mechanisms = mechanism_counts[mechanism_counts >= 5].index
    df_filtered = df_clean[df_clean['mechanism_of_action'].isin(frequent_mechanisms)]
    
    print(f"Mechanisms with â‰¥5 examples: {len(frequent_mechanisms)}")
    print(f"Final dataset size: {len(df_filtered)}")
    
    return df_filtered

# Apply preprocessing
df_processed = preprocess_molecular_data(df_raw)

# Display mechanism distribution
plt.figure(figsize=(12, 6))
mechanism_counts = df_processed['mechanism_of_action'].value_counts().head(20)
plt.barh(range(len(mechanism_counts)), mechanism_counts.values)
plt.yticks(range(len(mechanism_counts)), mechanism_counts.index, fontsize=8)
plt.xlabel('Number of Compounds')
plt.title('Top 20 Most Frequent Mechanisms of Action')
plt.tight_layout()
plt.show()

print(f"\nDataset Statistics:")
print(f"Total samples: {len(df_processed)}")
print(f"Unique mechanisms: {df_processed['mechanism_of_action'].nunique()}")
print(f"Average SMILES length: {df_processed['canonical_smiles'].str.len().mean():.1f}")

## 2. Multi-Modal Feature Engineering

We create three types of molecular representations:
1. **Sequential**: Tokenized SMILES for transformer input
2. **Numerical**: Chemical descriptors and properties
3. **Structural**: Morgan fingerprints for contrastive learning

In [None]:
class MolecularFeatureExtractor:
    """Extract multi-modal molecular features for transformer input."""
    
    def __init__(self, max_length=128):
        self.max_length = max_length
        self.scaler = StandardScaler()
        
    def extract_chemical_descriptors(self, smiles_list):
        """Extract comprehensive chemical descriptors using RDKit."""
        
        descriptors = []
        descriptor_names = [
            'MolWt', 'LogP', 'NumHDonors', 'NumHAcceptors', 'TPSA',
            'NumRotatableBonds', 'NumAromaticRings', 'NumSaturatedRings',
            'NumAliphaticRings', 'RingCount', 'FractionCsp3',
            'NumHeteroatoms', 'BertzCT', 'BalabanJ', 'Ipc'
        ]
        
        for smiles in tqdm(smiles_list, desc="Extracting descriptors"):
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                descriptors.append([0] * len(descriptor_names))
                continue
                
            desc_values = [
                Descriptors.MolWt(mol),
                Descriptors.MolLogP(mol),
                Descriptors.NumHDonors(mol),
                Descriptors.NumHAcceptors(mol),
                Descriptors.TPSA(mol),
                Descriptors.NumRotatableBonds(mol),
                Descriptors.NumAromaticRings(mol),
                Descriptors.NumSaturatedRings(mol),
                Descriptors.NumAliphaticRings(mol),
                Descriptors.RingCount(mol),
                Descriptors.FractionCsp3(mol),
                Descriptors.NumHeteroatoms(mol),
                Descriptors.BertzCT(mol),
                Descriptors.BalabanJ(mol),
                Descriptors.Ipc(mol)
            ]
            descriptors.append(desc_values)
            
        return np.array(descriptors), descriptor_names
    
    def extract_morgan_fingerprints(self, smiles_list, radius=2, n_bits=1024):
        """Extract Morgan fingerprints for structural similarity."""
        
        fingerprints = []
        for smiles in tqdm(smiles_list, desc="Extracting fingerprints"):
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                fingerprints.append(np.zeros(n_bits))
                continue
                
            fp = GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
            fingerprints.append(np.array(fp))
            
        return np.array(fingerprints)
    
    def tokenize_smiles(self, smiles_list):
        """Tokenize SMILES strings for transformer input."""
        
        # Create character-level vocabulary from SMILES
        all_chars = set(''.join(smiles_list))
        vocab = ['<PAD>', '<UNK>', '<START>', '<END>'] + sorted(list(all_chars))
        char_to_idx = {char: idx for idx, char in enumerate(vocab)}
        
        tokenized = []
        for smiles in smiles_list:
            tokens = [char_to_idx.get(char, char_to_idx['<UNK>']) for char in smiles]
            # Add start/end tokens and pad/truncate
            tokens = [char_to_idx['<START>']] + tokens + [char_to_idx['<END>']]
            
            if len(tokens) > self.max_length:
                tokens = tokens[:self.max_length]
            else:
                tokens.extend([char_to_idx['<PAD>']] * (self.max_length - len(tokens)))
                
            tokenized.append(tokens)
            
        return np.array(tokenized), vocab, char_to_idx

# Extract all molecular features
feature_extractor = MolecularFeatureExtractor(max_length=128)

print("Extracting multi-modal molecular features...")

# 1. Chemical descriptors
descriptors, descriptor_names = feature_extractor.extract_chemical_descriptors(
    df_processed['canonical_smiles'].tolist()
)

# 2. Morgan fingerprints
fingerprints = feature_extractor.extract_morgan_fingerprints(
    df_processed['canonical_smiles'].tolist()
)

# 3. Tokenized SMILES
tokenized_smiles, vocab, char_to_idx = feature_extractor.tokenize_smiles(
    df_processed['canonical_smiles'].tolist()
)

print(f"Chemical descriptors shape: {descriptors.shape}")
print(f"Morgan fingerprints shape: {fingerprints.shape}")
print(f"Tokenized SMILES shape: {tokenized_smiles.shape}")
print(f"Vocabulary size: {len(vocab)}")

# Normalize chemical descriptors
descriptors_normalized = feature_extractor.scaler.fit_transform(descriptors)

# Prepare labels for multi-label classification
mlb = MultiLabelBinarizer()
# Convert single mechanisms to list format for MultiLabelBinarizer
mechanism_lists = [[mech] for mech in df_processed['mechanism_of_action']]
labels = mlb.fit_transform(mechanism_lists)

print(f"\nLabel encoding:")
print(f"Number of unique mechanisms: {len(mlb.classes_)}")
print(f"Labels shape: {labels.shape}")
print(f"Label sparsity: {(labels == 0).sum() / labels.size:.3f}")

NameError: name 'StandardScaler' is not defined

## 3. Multi-Modal Molecular Transformer (M3T) Architecture

Our novel architecture combines:
- **SMILES Transformer Encoder**: Processes tokenized molecular sequences
- **Chemical Descriptor Network**: Handles numerical molecular properties
- **Cross-Modal Attention**: Fuses sequential and numerical representations
- **Contrastive Learning Head**: Improves molecular representation quality

In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding for transformer input."""
    
    def __init__(self, d_model, max_len=512):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-np.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class SMILESTransformerEncoder(nn.Module):
    """Transformer encoder for SMILES sequences."""
    
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, 
                 dim_feedforward=1024, max_len=128, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=num_layers
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask=None):
        # src shape: (batch_size, seq_len)
        src = self.embedding(src) * np.sqrt(self.d_model)
        src = self.pos_encoder(src.transpose(0, 1)).transpose(0, 1)
        src = self.dropout(src)
        
        # Create padding mask
        if src_mask is None:
            src_mask = (src.sum(dim=-1) == 0)  # Padding positions
            
        output = self.transformer_encoder(src, src_key_padding_mask=src_mask)
        
        # Global average pooling (excluding padding)
        mask_expanded = (~src_mask).unsqueeze(-1).float()
        output_masked = output * mask_expanded
        output_pooled = output_masked.sum(dim=1) / mask_expanded.sum(dim=1)
        
        return output_pooled, output  # pooled representation and full sequence

class ChemicalDescriptorNetwork(nn.Module):
    """Neural network for processing chemical descriptors."""
    
    def __init__(self, input_dim, hidden_dims=[512, 256, 128], dropout=0.1):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
            
        self.network = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.network(x)

class CrossModalAttention(nn.Module):
    """Cross-modal attention mechanism for fusing different representations."""
    
    def __init__(self, seq_dim, desc_dim, hidden_dim=256, num_heads=8):
        super().__init__()
        
        self.seq_proj = nn.Linear(seq_dim, hidden_dim)
        self.desc_proj = nn.Linear(desc_dim, hidden_dim)
        
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            batch_first=True
        )
        
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
    def forward(self, seq_features, desc_features):
        # Project to same dimension
        seq_proj = self.seq_proj(seq_features)  # (batch, hidden_dim)
        desc_proj = self.desc_proj(desc_features)  # (batch, hidden_dim)
        
        # Add sequence dimension for attention
        seq_proj = seq_proj.unsqueeze(1)  # (batch, 1, hidden_dim)
        desc_proj = desc_proj.unsqueeze(1)  # (batch, 1, hidden_dim)
        
        # Cross attention: seq attends to desc
        attn_output, attn_weights = self.multihead_attn(
            query=seq_proj,
            key=desc_proj,
            value=desc_proj
        )
        
        # Residual connection and normalization
        output = self.norm1(seq_proj + attn_output)
        
        # Feed-forward network
        ffn_output = self.ffn(output)
        output = self.norm2(output + ffn_output)
        
        return output.squeeze(1), attn_weights  # Remove sequence dimension

In [None]:
class MultiModalMolecularTransformer(nn.Module):
    """Main M3T model combining all components."""
    
    def __init__(self, vocab_size, num_descriptors, num_classes, 
                 d_model=256, nhead=8, num_layers=6, dropout=0.1):
        super().__init__()
        
        # Component networks
        self.smiles_encoder = SMILESTransformerEncoder(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dropout=dropout
        )
        
        self.descriptor_network = ChemicalDescriptorNetwork(
            input_dim=num_descriptors,
            hidden_dims=[512, 256, 128],
            dropout=dropout
        )
        
        self.cross_modal_attention = CrossModalAttention(
            seq_dim=d_model,
            desc_dim=128,
            hidden_dim=256,
            num_heads=nhead
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
        
        # Contrastive learning head
        self.contrastive_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        
    def forward(self, smiles_tokens, descriptors, return_attention=False):
        # Encode SMILES sequences
        seq_features, seq_full = self.smiles_encoder(smiles_tokens)
        
        # Process chemical descriptors
        desc_features = self.descriptor_network(descriptors)
        
        # Cross-modal fusion
        fused_features, attention_weights = self.cross_modal_attention(
            seq_features, desc_features
        )
        
        # Classification
        logits = self.classifier(fused_features)
        
        # Contrastive representation
        contrastive_repr = self.contrastive_head(fused_features)
        
        if return_attention:
            return logits, contrastive_repr, attention_weights
        else:
            return logits, contrastive_repr

# Initialize model
model = MultiModalMolecularTransformer(
    vocab_size=len(vocab),
    num_descriptors=descriptors_normalized.shape[1],
    num_classes=labels.shape[1],
    d_model=256,
    nhead=8,
    num_layers=6,
    dropout=0.1
).to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nM3T Model Architecture:")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024**2:.1f} MB")

# Test forward pass
with torch.no_grad():
    sample_smiles = torch.tensor(tokenized_smiles[:4]).to(device)
    sample_desc = torch.tensor(descriptors_normalized[:4]).float().to(device)
    
    logits, contrastive = model(sample_smiles, sample_desc)
    print(f"\nForward pass test:")
    print(f"Input SMILES shape: {sample_smiles.shape}")
    print(f"Input descriptors shape: {sample_desc.shape}")
    print(f"Output logits shape: {logits.shape}")
    print(f"Contrastive representation shape: {contrastive.shape}")

## 4. Dataset Preparation and Stratified Splitting

We implement stratified splitting to ensure balanced representation of mechanisms across train/validation/test sets.

In [None]:
class MolecularDataset(Dataset):
    """PyTorch dataset for multi-modal molecular data."""
    
    def __init__(self, smiles_tokens, descriptors, fingerprints, labels):
        self.smiles_tokens = torch.tensor(smiles_tokens, dtype=torch.long)
        self.descriptors = torch.tensor(descriptors, dtype=torch.float32)
        self.fingerprints = torch.tensor(fingerprints, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
        
    def __len__(self):
        return len(self.smiles_tokens)
    
    def __getitem__(self, idx):
        return {
            'smiles_tokens': self.smiles_tokens[idx],
            'descriptors': self.descriptors[idx],
            'fingerprints': self.fingerprints[idx],
            'labels': self.labels[idx]
        }

# Stratified splitting for multi-label data
def stratified_multilabel_split(X, y, test_size=0.2, val_size=0.1, random_state=42):
    """Stratified split for multi-label classification."""
    
    # Create a single label for stratification (most frequent mechanism)
    stratify_labels = y.argmax(axis=1)
    
    # First split: train+val vs test
    X_temp, X_test, y_temp, y_test, strat_temp, strat_test = train_test_split(
        X, y, stratify_labels, test_size=test_size, 
        stratify=stratify_labels, random_state=random_state
    )
    
    # Second split: train vs val
    val_size_adjusted = val_size / (1 - test_size)
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_adjusted,
        stratify=strat_temp, random_state=random_state
    )
    
    return X_train, X_val, X_test, y_train, y_val, y_test

# Prepare data for splitting
X_combined = np.arange(len(tokenized_smiles))  # Just indices

# Perform stratified split
idx_train, idx_val, idx_test, y_train, y_val, y_test = stratified_multilabel_split(
    X_combined, labels, test_size=0.2, val_size=0.1, random_state=RANDOM_SEED
)

# Create datasets
train_dataset = MolecularDataset(
    tokenized_smiles[idx_train],
    descriptors_normalized[idx_train],
    fingerprints[idx_train],
    y_train
)

val_dataset = MolecularDataset(
    tokenized_smiles[idx_val],
    descriptors_normalized[idx_val],
    fingerprints[idx_val],
    y_val
)

test_dataset = MolecularDataset(
    tokenized_smiles[idx_test],
    descriptors_normalized[idx_test],
    fingerprints[idx_test],
    y_test
)

# Create data loaders
batch_size = 32

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, 
    num_workers=0, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False,
    num_workers=0, pin_memory=True
)

test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False,
    num_workers=0, pin_memory=True
)

print(f"Dataset splits:")
print(f"Training: {len(train_dataset)} samples")
print(f"Validation: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")
print(f"\nBatch configuration:")
print(f"Batch size: {batch_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Verify label distribution
train_label_dist = y_train.sum(axis=0)
val_label_dist = y_val.sum(axis=0)
test_label_dist = y_test.sum(axis=0)

print(f"\nLabel distribution verification:")
print(f"Train labels per class (meanÂ±std): {train_label_dist.mean():.1f}Â±{train_label_dist.std():.1f}")
print(f"Val labels per class (meanÂ±std): {val_label_dist.mean():.1f}Â±{val_label_dist.std():.1f}")
print(f"Test labels per class (meanÂ±std): {test_label_dist.mean():.1f}Â±{test_label_dist.std():.1f}")

## 5. Training Setup with Contrastive Learning

We implement a hybrid loss function combining:
1. **Multi-label classification loss** (Binary Cross Entropy)
2. **Contrastive learning loss** (InfoNCE) for better molecular representations
3. **Regularization terms** for model stability

In [None]:
class ContrastiveLoss(nn.Module):
    """InfoNCE contrastive loss for molecular representation learning."""
    
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, representations, fingerprints):
        """Compute contrastive loss between learned representations and fingerprints."""
        
        batch_size = representations.size(0)
        
        # Normalize representations
        repr_norm = F.normalize(representations, dim=1)
        fp_norm = F.normalize(fingerprints, dim=1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(repr_norm, fp_norm.T) / self.temperature
        
        # Create labels (diagonal should be positive pairs)
        labels = torch.arange(batch_size).to(representations.device)
        
        # Compute InfoNCE loss
        loss = F.cross_entropy(similarity_matrix, labels)
        
        return loss

class HybridLoss(nn.Module):
    """Combined loss function for multi-task learning."""
    
    def __init__(self, classification_weight=1.0, contrastive_weight=0.1, 
                 temperature=0.1):
        super().__init__()
        
        self.classification_weight = classification_weight
        self.contrastive_weight = contrastive_weight
        
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.contrastive_loss = ContrastiveLoss(temperature)
        
    def forward(self, logits, labels, representations, fingerprints):
        # Classification loss
        cls_loss = self.bce_loss(logits, labels)
        
        # Contrastive loss
        cont_loss = self.contrastive_loss(representations, fingerprints)
        
        # Combined loss
        total_loss = (self.classification_weight * cls_loss + 
                     self.contrastive_weight * cont_loss)
        
        return total_loss, cls_loss, cont_loss

# Initialize loss function and optimizer
criterion = HybridLoss(
    classification_weight=1.0,
    contrastive_weight=0.1,
    temperature=0.1
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-5,
    betas=(0.9, 0.999)
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

print("Training setup completed:")
print(f"Loss function: Hybrid (BCE + InfoNCE)")
print(f"Optimizer: AdamW (lr=1e-4, weight_decay=1e-5)")
print(f"Scheduler: ReduceLROnPlateau")
print(f"Device: {device}")