# Deploying Classification Models with Gradio & FastAPI

**DOST-ITDI AI Training Workshop**  
**Module 7: Classification Model Deployment**

---

## Learning Objectives
1. Load trained classification models
2. Create interactive web interfaces for predictions
3. Build REST APIs for classification
4. Deploy binary classification models
5. Handle imbalanced data predictions

## Use Case: Drug Activity Prediction

We'll deploy a model that predicts whether a compound is active or inactive against a biological target (BACE-1 enzyme).

**Applications:**
- Virtual screening
- Lead compound identification  
- Drug discovery pipelines

## Part 1: Train and Save Classification Model

In [None]:
# Import libraries
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import joblib
import warnings
warnings.filterwarnings('ignore')

print("Libraries imported successfully!")

In [None]:
# Load BACE dataset
url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv"
df = pd.read_csv(url)

print(f"Dataset shape: {df.shape}")
print(f"\nClass distribution:")
print(df['Class'].value_counts())
print(f"\nFirst few rows:")
df.head()

In [None]:
# Calculate molecular descriptors
from rdkit import Chem
from rdkit.Chem import Descriptors

def calculate_descriptors(smiles):
    """Calculate molecular descriptors from SMILES"""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        
        return {
            'MolWeight': Descriptors.MolWt(mol),
            'LogP': Descriptors.MolLogP(mol),
            'NumHDonors': Descriptors.NumHDonors(mol),
            'NumHAcceptors': Descriptors.NumHAcceptors(mol),
            'TPSA': Descriptors.TPSA(mol),
            'NumRotatableBonds': Descriptors.NumRotatableBonds(mol),
            'NumAromaticRings': Descriptors.NumAromaticRings(mol),
            'NumAliphaticRings': Descriptors.NumAliphaticRings(mol),
            'FractionCSP3': Descriptors.FractionCSP3(mol)
        }
    except:
        return None

# Calculate descriptors
descriptors_list = []
for smiles in df['mol']:
    desc = calculate_descriptors(smiles)
    descriptors_list.append(desc)

# Create descriptor DataFrame
descriptors_df = pd.DataFrame(descriptors_list)
descriptors_df = descriptors_df.dropna()

# Merge with target
df_clean = df.loc[descriptors_df.index].copy()
X = descriptors_df.values
y = df_clean['Class'].values

print(f"Features shape: {X.shape}")
print(f"Target shape: {y.shape}")
print(f"\nFeature names: {list(descriptors_df.columns)}")

In [None]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train Random Forest with class weights (handle imbalance)
model = RandomForestClassifier(
    n_estimators=100, 
    random_state=42, 
    class_weight='balanced',  # Handle imbalanced classes
    n_jobs=-1
)
model.fit(X_train_scaled, y_train)

# Evaluate
y_pred = model.predict(X_test_scaled)
y_prob = model.predict_proba(X_test_scaled)[:, 1]

print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=['Inactive', 'Active']))

print(f"\nROC-AUC Score: {roc_auc_score(y_test, y_prob):.3f}")

print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
print(cm)
print("\nModel trained successfully!")

In [None]:
# Save model and scaler
joblib.dump(model, 'bace_classifier.pkl')
joblib.dump(scaler, 'classifier_scaler.pkl')

# Save feature names
feature_names = list(descriptors_df.columns)
joblib.dump(feature_names, 'classifier_features.pkl')

print("Model, scaler, and feature names saved!")
print(f"  - bace_classifier.pkl")
print(f"  - classifier_scaler.pkl")
print(f"  - classifier_features.pkl")

## Part 2: Create Gradio Interface

In [None]:
import gradio as gr
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
from PIL import Image

# Load trained model
model = joblib.load('bace_classifier.pkl')
scaler = joblib.load('classifier_scaler.pkl')
feature_names = joblib.load('classifier_features.pkl')

def predict_activity(smiles):
    """
    Predict drug activity from SMILES notation
    
    Args:
        smiles (str): SMILES notation
    
    Returns:
        tuple: (prediction_text, probability_plot, molecule_image, descriptors_text)
    """
    try:
        # Calculate descriptors
        descriptors = calculate_descriptors(smiles)
        
        if descriptors is None:
            return "Invalid SMILES notation", None, None, "Error: Could not parse SMILES"
        
        # Prepare features
        features = np.array([[descriptors[feat] for feat in feature_names]])
        features_scaled = scaler.transform(features)
        
        # Predict
        prediction = model.predict(features_scaled)[0]
        probabilities = model.predict_proba(features_scaled)[0]
        
        # Generate molecule image
        mol = Chem.MolFromSmiles(smiles)
        mol_img = Draw.MolToImage(mol, size=(300, 300))
        
        # Create probability plot
        fig, ax = plt.subplots(figsize=(6, 4))
        classes = ['Inactive (0)', 'Active (1)']
        colors = ['#FF6B6B', '#4ECDC4']
        bars = ax.bar(classes, probabilities, color=colors, alpha=0.7)
        ax.set_ylabel('Probability', fontsize=12)
        ax.set_title('Prediction Confidence', fontsize=14, fontweight='bold')
        ax.set_ylim([0, 1])
        ax.grid(axis='y', alpha=0.3)
        
        # Add probability labels on bars
        for bar, prob in zip(bars, probabilities):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{prob:.1%}', ha='center', va='bottom', fontsize=11, fontweight='bold')
        
        plt.tight_layout()
        
        # Format prediction text
        if prediction == 1:
            pred_text = f"## Prediction: ACTIVE\n\n"
            pred_text += f"**Confidence:** {probabilities[1]:.1%}\n\n"
            pred_text += "This compound is predicted to be **active** against BACE-1."
        else:
            pred_text = f"## Prediction: INACTIVE\n\n"
            pred_text += f"**Confidence:** {probabilities[0]:.1%}\n\n"
            pred_text += "This compound is predicted to be **inactive** against BACE-1."
        
        # Format descriptors
        desc_text = "**Molecular Descriptors:**\n\n"
        for name, value in descriptors.items():
            desc_text += f"- {name}: {value:.2f}\n"
        
        # Drug-likeness check (Lipinski's Rule of Five)
        desc_text += "\n**Drug-Likeness (Lipinski's Rule of 5):**\n\n"
        violations = 0
        
        if descriptors['MolWeight'] > 500:
            desc_text += "- MW > 500: VIOLATION\n"
            violations += 1
        else:
            desc_text += "- MW ≤ 500: PASS\n"
            
        if descriptors['LogP'] > 5:
            desc_text += "- LogP > 5: VIOLATION\n"
            violations += 1
        else:
            desc_text += "- LogP ≤ 5: PASS\n"
            
        if descriptors['NumHDonors'] > 5:
            desc_text += "- H-Donors > 5: VIOLATION\n"
            violations += 1
        else:
            desc_text += "- H-Donors ≤ 5: PASS\n"
            
        if descriptors['NumHAcceptors'] > 10:
            desc_text += "- H-Acceptors > 10: VIOLATION\n"
            violations += 1
        else:
            desc_text += "- H-Acceptors ≤ 10: PASS\n"
        
        if violations == 0:
            desc_text += "\n**Result:** Drug-like (0 violations)"
        else:
            desc_text += f"\n**Result:** {violations} violation(s)"
        
        return pred_text, fig, mol_img, desc_text
        
    except Exception as e:
        return f"Error: {str(e)}", None, None, "Could not calculate descriptors"

# Example SMILES (known active and inactive compounds)
examples = [
    ["O=S(=O)(Nc1cccc(-c2cnc3ccccc3n2)c1)c1cccs1"],  # Complex sulfonamide
    ["CC(=O)Oc1ccccc1C(=O)O"],  # Aspirin
    ["CCO"],  # Ethanol (likely inactive)
    ["CC(C)Cc1ccc(C(C)C(=O)O)cc1"],  # Ibuprofen
    ["CN1C=NC2=C1C(=O)N(C(=O)N2C)C"],  # Caffeine
]

# Create Gradio interface
iface = gr.Interface(
    fn=predict_activity,
    inputs=gr.Textbox(
        label="Enter SMILES Notation",
        placeholder="e.g., CC(=O)Oc1ccccc1C(=O)O",
        lines=2
    ),
    outputs=[
        gr.Markdown(label="Prediction"),
        gr.Plot(label="Probability Distribution"),
        gr.Image(label="Molecule Structure", type="pil"),
        gr.Markdown(label="Molecular Properties")
    ],
    title="Drug Activity Predictor (BACE-1)",
    description="""Predict whether a compound is active or inactive against BACE-1 enzyme.  
    **BACE-1** is a key enzyme in Alzheimer's disease.  
    This model uses Random Forest with balanced class weights.""",
    examples=examples,
    theme="soft",
    allow_flagging="never"
)

# Launch interface
print("Launching Gradio interface...")
iface.launch(share=False, inbrowser=True)

## Part 3: Create FastAPI REST API

In [None]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, List
import uvicorn

# Create FastAPI app
app = FastAPI(
    title="BACE-1 Activity Prediction API",
    description="REST API for predicting drug activity against BACE-1 enzyme",
    version="1.0.0"
)

# Request models
class SinglePredictionRequest(BaseModel):
    smiles: str
    
    class Config:
        schema_extra = {
            "example": {
                "smiles": "CC(=O)Oc1ccccc1C(=O)O"
            }
        }

class BatchPredictionRequest(BaseModel):
    smiles_list: List[str]
    
    class Config:
        schema_extra = {
            "example": {
                "smiles_list": [
                    "CC(=O)Oc1ccccc1C(=O)O",
                    "CCO",
                    "c1ccccc1"
                ]
            }
        }

# Response models
class PredictionResponse(BaseModel):
    prediction: int
    prediction_label: str
    probability_inactive: float
    probability_active: float
    confidence: float
    descriptors: Dict[str, float]
    drug_like: bool
    lipinski_violations: int

class BatchPredictionResponse(BaseModel):
    predictions: List[PredictionResponse]
    total_active: int
    total_inactive: int

@app.get("/")
def read_root():
    return {
        "message": "BACE-1 Activity Prediction API",
        "endpoints": {
            "/predict": "POST - Predict single compound",
            "/predict/batch": "POST - Predict multiple compounds",
            "/health": "GET - Check API health"
        }
    }

@app.get("/health")
def health_check():
    return {"status": "healthy", "model_loaded": True}

def check_lipinski(descriptors):
    """Check Lipinski's Rule of Five"""
    violations = 0
    if descriptors['MolWeight'] > 500:
        violations += 1
    if descriptors['LogP'] > 5:
        violations += 1
    if descriptors['NumHDonors'] > 5:
        violations += 1
    if descriptors['NumHAcceptors'] > 10:
        violations += 1
    return violations

@app.post("/predict", response_model=PredictionResponse)
def predict(request: SinglePredictionRequest):
    """
    Predict activity for a single compound
    """
    try:
        # Calculate descriptors
        descriptors = calculate_descriptors(request.smiles)
        
        if descriptors is None:
            raise HTTPException(status_code=400, detail="Invalid SMILES notation")
        
        # Prepare features
        features = np.array([[descriptors[feat] for feat in feature_names]])
        features_scaled = scaler.transform(features)
        
        # Predict
        prediction = int(model.predict(features_scaled)[0])
        probabilities = model.predict_proba(features_scaled)[0]
        
        # Check drug-likeness
        violations = check_lipinski(descriptors)
        
        return PredictionResponse(
            prediction=prediction,
            prediction_label="Active" if prediction == 1 else "Inactive",
            probability_inactive=float(probabilities[0]),
            probability_active=float(probabilities[1]),
            confidence=float(max(probabilities)),
            descriptors=descriptors,
            drug_like=(violations == 0),
            lipinski_violations=violations
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict/batch", response_model=BatchPredictionResponse)
def predict_batch(request: BatchPredictionRequest):
    """
    Predict activity for multiple compounds
    """
    predictions = []
    active_count = 0
    inactive_count = 0
    
    for smiles in request.smiles_list:
        try:
            result = predict(SinglePredictionRequest(smiles=smiles))
            predictions.append(result)
            
            if result.prediction == 1:
                active_count += 1
            else:
                inactive_count += 1
                
        except HTTPException:
            # Skip invalid SMILES
            continue
    
    return BatchPredictionResponse(
        predictions=predictions,
        total_active=active_count,
        total_inactive=inactive_count
    )

print("FastAPI app created!")
print("\nTo run the API:")
print("  uvicorn app:app --reload --port 8001")
print("\nAPI will be available at: http://localhost:8001")
print("Interactive docs: http://localhost:8001/docs")

## Summary

### What We Learned:

1. **Classification Model Deployment**
   - Handled imbalanced classes with class weights
   - Predicted binary outcomes (active/inactive)
   - Calculated prediction probabilities

2. **Enhanced Gradio Interface**
   - Probability visualizations
   - Molecule structure display
   - Drug-likeness assessment
   - Lipinski's Rule of Five checking

3. **Advanced FastAPI Features**
   - Single and batch predictions
   - Detailed response models
   - Automatic validation
   - Error handling

### Deployment Considerations:

1. **Model Performance**
   - Monitor accuracy over time
   - Retrain periodically with new data
   - Version your models

2. **API Performance**
   - Add caching for common queries
   - Implement rate limiting
   - Use async endpoints for better concurrency

3. **Security**
   - Add authentication (API keys, OAuth)
   - Input validation and sanitization
   - HTTPS in production

### Real-World Applications:

- Virtual screening of compound libraries
- Lead optimization workflows
- Integration with LIMS systems
- High-throughput prediction pipelines

---

**Congratulations!** You can now deploy both regression and classification models!

---

**Next Steps:**
1. Deploy to cloud platforms (AWS, GCP, Azure)
2. Add monitoring and logging
3. Create CI/CD pipelines
4. Scale with Docker and Kubernetes