In [None]:
# lmfn.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# -------------------------
# Utility: sliding windows
# -------------------------
def sliding_windows(x, window_size, stride):
    """
    x: (B, T, D)
    returns: (B, num_windows, window_size, D)
    """
    B, T, D = x.shape
    if window_size > T:
        # pad at end
        pad = window_size - T
        x = F.pad(x, (0,0,0,pad))  # pad time dimension on right
        T = window_size
    # compute windows
    indices = []
    starts = list(range(0, T - window_size + 1, stride))
    if len(starts) == 0:
        starts = [0]
    windows = []
    for s in starts:
        windows.append(x[:, s:s+window_size, :].unsqueeze(1))  # (B,1,window_size,D)
    return torch.cat(windows, dim=1)  # (B, num_windows, window_size, D)


# -------------------------
# Small temporal encoder (shared MLP per timestep)
# -------------------------
class TimeStepEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    def forward(self, x):
        # x: (B, T, D_in) => output (B, T, hidden_dim)
        B, T, D = x.shape
        out = self.net(x.view(B * T, D)).view(B, T, -1)
        return out

# -------------------------
# Local fusion block
# -------------------------
class LocalFusionBlock(nn.Module):
    def __init__(self, hidden_dim, window_size, fused_dim):
        """
        hidden_dim: per-modality timestep representation dim
        window_size: number of timesteps in local window
        fused_dim: output dim of local fusion
        """
        super().__init__()
        # gating network: takes concatenated (or pooled) per-modality window info -> gates per modality
        # We'll implement a simple attention-like gating per modality using pooled window features.
        self.gate_mlp = nn.Sequential(
            nn.Linear(hidden_dim * window_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # scalar gate per modality (we'll apply one per modality separately)
        )
        # local fusion MLP takes concatenated gated modality-window vectors
        # input size: num_modalities * hidden_dim * window_size (after flatten/pooling choose)
        # But to keep size sane, we will **pool** across time inside window (mean), so each modality contributes hidden_dim
        self.local_fusion_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, fused_dim),  # we'll concatenate (pooled gated + pooled ungated) as residual-aware fusion
            nn.ReLU(),
            nn.Linear(fused_dim, fused_dim)
        )
    
    def forward(self, modality_window): 
        """
        modality_window: (B, window_size, hidden_dim) for a single modality
        returns: gated pooled vector (B, hidden_dim) and gate scalar (B,1)
        """
        B, W, H = modality_window.shape
        # 1) compute gate scalar per sample using flattened window
        flat = modality_window.view(B, W * H)
        gate_logits = self.gate_mlp(flat)  # (B,1)
        gate = torch.sigmoid(gate_logits)  # (B,1) in 0..1
        # 2) pooled vector (mean over window)
        pooled = modality_window.mean(dim=1)  # (B, H)
        gated_pooled = pooled * gate  # broadcast (B,H)
        # 3) provide both gated pooled and pooled (residual-aware)
        fused = self.local_fusion_mlp(torch.cat([gated_pooled, pooled], dim=1))  # (B, fused_dim)
        return fused, gate  # (B, F), (B,1)


# -------------------------
# LMFN main model
# -------------------------
class LMFN(nn.Module):
    def __init__(self, input_dims, timestep_hidden=128, window_size=5, stride=1, fused_dim=128, agg_channels=128, out_dim=2):
        """
        input_dims: list of input dims for each modality (per timestep)
        timestep_hidden: encoder output dim for each timestep per modality
        window_size: temporal window length for local fusion
        stride: sliding stride
        fused_dim: local fused vector dim per window & pair (or per window)
        agg_channels: channels for temporal aggregator (1D conv)
        out_dim: output dimension (classes or 1)
        """
        super().__init__()
        self.num_modalities = len(input_dims)
        self.window_size = window_size
        self.stride = stride
        # timestep encoders for each modality
        self.encoders = nn.ModuleList([TimeStepEncoder(d, timestep_hidden) for d in input_dims])
        # local fusion blocks: one per modality (they produce fused vectors per modality-window)
        # We'll later combine modality fused vectors by concatenation and a combiner MLP to produce single fused vector per window
        self.local_blocks = nn.ModuleList([LocalFusionBlock(timestep_hidden, window_size, fused_dim) for _ in input_dims])
        # combiner that merges per-modality local fused outputs into a single window-level fused vector
        self.window_combiner = nn.Sequential(
            nn.Linear(fused_dim * self.num_modalities, fused_dim * 2),
            nn.ReLU(),
            nn.Linear(fused_dim * 2, fused_dim)
        )
        # temporal aggregator: 1D conv over windows (num_windows dimension)
        self.agg_conv = nn.Conv1d(in_channels=fused_dim, out_channels=agg_channels, kernel_size=3, padding=1)
        self.agg_pool = nn.AdaptiveAvgPool1d(1)
        # classifier
        self.classifier = nn.Sequential(
            nn.Linear(agg_channels, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, out_dim)
        )
    
    def forward(self, *modal_seq):
        """
        modal_seq: each tensor (B, T, D_i)
        returns: logits (B, out_dim) and an aux dict
        """
        B = modal_seq[0].shape[0]
        # 1) encode each modality per timestep -> (B, T, H)
        encoded = [enc(x) for enc, x in zip(self.encoders, modal_seq)]
        # 2) build sliding windows per modality: (B, num_windows, window_size, H)
        windows_per_mod = [sliding_windows(e, self.window_size, self.stride) for e in encoded]
        num_windows = windows_per_mod[0].shape[1]
        # 3) for each window index, compute local fusion across modalities
        window_fused_list = []
        gates = []
        for w_idx in range(num_windows):
            per_mod_fused = []
            per_mod_gates = []
            for m in range(self.num_modalities):
                mod_win = windows_per_mod[m][:, w_idx, :, :]  # (B, window_size, H)
                fused_vec, gate = self.local_blocks[m](mod_win)  # (B, F), (B,1)
                per_mod_fused.append(fused_vec)
                per_mod_gates.append(gate)
            # concatenate per-modality fused vectors -> (B, F * num_modalities)
            concat = torch.cat(per_mod_fused, dim=1)
            window_fused = self.window_combiner(concat)  # (B, fused_dim)
            window_fused_list.append(window_fused.unsqueeze(1))  # keep window dim
            gates.append(torch.cat(per_mod_gates, dim=1).unsqueeze(1))  # (B,1,num_modalities)
        # stack window fused: (B, num_windows, fused_dim)
        fused_windows = torch.cat(window_fused_list, dim=1)
        # 4) aggregate over windows using 1D conv/pooling
        # conv expects (B, C, L) where C == fused_dim, L == num_windows
        x = fused_windows.permute(0, 2, 1)  # (B, fused_dim, num_windows)
        x = self.agg_conv(x)
        x = F.relu(x)
        x = self.agg_pool(x).squeeze(-1)  # (B, agg_channels)
        logits = self.classifier(x)  # (B, out_dim)
        aux = {
            "encoded": encoded,
            "fused_windows": fused_windows,
            "gates": torch.cat(gates, dim=1) if len(gates) > 0 else None  # (B, num_windows, num_modalities)
        }
        return logits, aux

# -------------------------
# Synthetic data & small training loop
# -------------------------
def synthetic_sequence_data(num_samples=400, T=30, dims=[16, 8, 12], num_classes=2):
    torch.manual_seed(0)
    Xs = [torch.randn(num_samples, T, d) for d in dims]
    # make label depend on sum of mean pooled first modality over later half of sequence
    half = T // 2
    key_signal = Xs[0][:, half:, :].mean(dim=(1,2))
    y = (key_signal + 0.1 * torch.randn(num_samples) > 0).long()
    return Xs, y

def train_example():
    input_dims = [16, 8, 12]
    model = LMFN(input_dims, timestep_hidden=64, window_size=5, stride=2, fused_dim=64, agg_channels=64, out_dim=2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    Xs, y = synthetic_sequence_data(num_samples=600, T=30, dims=input_dims, num_classes=2)
    dataset = TensorDataset(Xs[0], Xs[1], Xs[2], y)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(8):
        total_loss = 0.0
        correct = 0
        total = 0
        for xb0, xb1, xb2, lbl in loader:
            xb0 = xb0.to(device); xb1 = xb1.to(device); xb2 = xb2.to(device); lbl = lbl.to(device)
            opt.zero_grad()
            logits, aux = model(xb0, xb1, xb2)
            loss = loss_fn(logits, lbl)
            loss.backward()
            opt.step()
            total_loss += loss.item() * xb0.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == lbl).sum().item()
            total += xb0.size(0)
        print(f"Epoch {epoch+1}: loss={total_loss/total:.4f} acc={correct/total:.4f}")
    # sample inference
    model.eval()
    with torch.no_grad():
        sample = [Xs[i][:4].to(device) for i in range(len(Xs))]
        logits, aux = model(*sample)
        print("Logits sample:", logits)
        print("Fused windows shape:", aux["fused_windows"].shape)
        print("Gates shape:", aux["gates"].shape)

if __name__ == "__main__":
    train_example()