In [None]:
import torch
from opacus import PrivacyEngine
from torch.utils.data import DataLoader

def train_securemed_dp(model, train_dataset, optimizer=None, criterion=None,
                        batch_size=32, num_epochs=3, epsilon_target=3.0, delta=1e-5,
                        max_grad_norm=1.0, lr=1e-5):
    """
    Fine-tune a SecureMed-LLM model with Differential Privacy (DP-SGD).
    
    Args:
        model (torch.nn.Module): The model to train.
        train_dataset (torch.utils.data.Dataset): Training dataset.
        optimizer (torch.optim.Optimizer, optional): Optimizer. Defaults to Adam.
        criterion (torch.nn.Module, optional): Loss function. Defaults to CrossEntropyLoss.
        batch_size (int): Batch size for training.
        num_epochs (int): Number of training epochs.
        epsilon_target (float): Target privacy budget ε.
        delta (float): Target δ for DP guarantee.
        max_grad_norm (float): Maximum gradient norm for DP clipping.
        lr (float): Learning rate.
        
    Returns:
        float: Achieved epsilon after training.
    """
    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    if criterion is None:
        criterion = torch.nn.CrossEntropyLoss()

    data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Setup Privacy Engine
    privacy_engine = PrivacyEngine(
        model,
        batch_size=batch_size,
        sample_size=len(train_dataset),
        alphas=[1, 10, 100],
        noise_multiplier=1.0,  # Adjust based on target epsilon
        max_grad_norm=max_grad_norm,
    )
    privacy_engine.attach(optimizer)
    
    model.train()
    for epoch in range(num_epochs):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
    # Get achieved epsilon
    epsilon, best_alpha = privacy_engine.get_privacy_spent(delta=delta)
    print(f"Achieved privacy: ε = {epsilon:.2f}, α = {best_alpha}")
    
    return epsilon
