# 03 â€“ Model Prototyping

Train centralized prototypes of the multimodal risk model to validate architecture choices before moving to federated training.


In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset

from federated_health_risk.models.multimodal_model import MultimodalRiskNet

# TODO: replace with real feature tensors from feature store or synthetic generator
num_samples = 1024
vitals = torch.rand(num_samples, 8)
air = torch.rand(num_samples, 6)
text = torch.rand(num_samples, 16)
y = (torch.rand(num_samples) > 0.7).float()

dataset = TensorDataset(vitals, air, text, y)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = MultimodalRiskNet(vitals_dim=8, air_dim=6, text_dim=16)
model


MultimodalRiskNet(
  (vitals_branch): Sequential(
    (0): Linear(in_features=8, out_features=128, bias=True)
    (1): ReLU()
    (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (air_branch): Sequential(
    (0): Linear(in_features=6, out_features=64, bias=True)
    (1): ReLU()
    (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (text_branch): Sequential(
    (0): Linear(in_features=16, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
  )
  (head): Sequential(
    (0): Linear(in_features=320, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=128, out_features=1, bias=True)
    (4): Sigmoid()
  )
)

## Checklist
- Implement training loop with optimizer, scheduler, and MLflow logging.
- Run hyperparameter sweeps (dropout, hidden size, class weights).
- Export best model weights to `models/latest_model.pt` for the inference API.
