# Model Layers Guide

This notebook explores the four core layers that sit on top of `FeatureEmbedding`:
**FM**, **DNN**, **CIN**, and **Multi-Head Self-Attention**.
Each layer consumes a different view of the embedding output and captures a different type of feature interaction.

## 0. Setup — Create a Synthetic Schema & Embeddings

We'll use a small synthetic schema so this notebook runs instantly without downloading data.

In [17]:
import torch
import pandas as pd

from deepfm.data.schema import FieldSchema, DatasetSchema, FeatureType
from deepfm.models.layers.embedding import FeatureEmbedding

# Define a small schema: 6 fields, mix of vocab sizes and embed dims
fields = {
    "user_id":    FieldSchema("user_id",    FeatureType.SPARSE, vocabulary_size=100, embedding_dim=16),
    "item_id":    FieldSchema("item_id",    FeatureType.SPARSE, vocabulary_size=200, embedding_dim=16),
    "gender":     FieldSchema("gender",     FeatureType.SPARSE, vocabulary_size=3,   embedding_dim=4),
    "age":        FieldSchema("age",        FeatureType.SPARSE, vocabulary_size=8,   embedding_dim=4),
    "occupation": FieldSchema("occupation", FeatureType.SPARSE, vocabulary_size=22,  embedding_dim=8),
    "city":       FieldSchema("city",       FeatureType.SPARSE, vocabulary_size=50,  embedding_dim=8),
}
schema = DatasetSchema(fields=fields, label_field="label")
FM_DIM = 16

# Build embedding layer and get the three outputs
emb = FeatureEmbedding(schema, fm_embed_dim=FM_DIM)
emb.eval()

batch = {
    "user_id":    torch.randint(1, 100, (8,)),
    "item_id":    torch.randint(1, 200, (8,)),
    "gender":     torch.randint(1, 3,   (8,)),
    "age":        torch.randint(1, 8,   (8,)),
    "occupation": torch.randint(1, 22,  (8,)),
    "city":       torch.randint(1, 50,  (8,)),
}

with torch.no_grad():
    first_order, field_embeddings, flat_embeddings = emb(batch)

B = field_embeddings.size(0)
F = field_embeddings.size(1)
D = field_embeddings.size(2)

print(f"Batch size B={B}, Fields F={F}, FM dim D={D}")
print(f"first_order:      {first_order.shape}")
print(f"field_embeddings: {field_embeddings.shape}")
print(f"flat_embeddings:  {flat_embeddings.shape}")

Batch size B=8, Fields F=6, FM dim D=16
first_order:      torch.Size([8, 1])
field_embeddings: torch.Size([8, 6, 16])
flat_embeddings:  torch.Size([8, 56])


---
## 1. FM Interaction Layer

The FM layer computes **second-order feature interactions** efficiently in O(F*D) using the identity:

```
sum_{i<j} <v_i, v_j> = 0.5 * ( (sum_i v_i)^2 - sum_i (v_i^2) )
```

This avoids the O(F^2 * D) cost of explicit pairwise dot products.

In [2]:
from deepfm.models.layers.fm import FMInteraction

fm = FMInteraction()
print(fm)
print(f"\nParameters: {sum(p.numel() for p in fm.parameters())} (none — FM is parameter-free!)")

FMInteraction()

Parameters: 0 (none — FM is parameter-free!)


In [3]:
with torch.no_grad():
    fm_output = fm(field_embeddings)

print(f"Input:  field_embeddings {field_embeddings.shape}  — (B, F, D)")
print(f"Output: fm_output        {fm_output.shape}         — (B, 1)")
print(f"\nFM interaction values (one scalar per sample):")
print(fm_output.squeeze())

Input:  field_embeddings torch.Size([8, 6, 16])  — (B, F, D)
Output: fm_output        torch.Size([8, 1])         — (B, 1)

FM interaction values (one scalar per sample):
tensor([-0.2719, -0.4070, -0.3825, -0.7346, -0.6332,  0.6506, -0.5905,  0.0814])


### Verifying the Math

Let's confirm the efficient formula matches the explicit pairwise computation.

In [4]:
# Explicit O(F^2) pairwise computation for verification
with torch.no_grad():
    explicit = torch.zeros(B, 1)
    for i in range(F):
        for j in range(i + 1, F):
            dot = (field_embeddings[:, i] * field_embeddings[:, j]).sum(dim=1, keepdim=True)
            explicit += dot

print("Efficient vs Explicit (should match):")
print(f"  Max difference: {(fm_output - explicit).abs().max().item():.2e}")
print(f"  All close:      {torch.allclose(fm_output, explicit, atol=1e-5)}")

Efficient vs Explicit (should match):
  Max difference: 1.79e-07
  All close:      True


---
## 2. DNN Layer

The DNN is a standard MLP that processes the **flat concatenated embeddings**.
It captures arbitrary higher-order interactions through non-linear transformations.

Stack: `Linear → (BatchNorm) → Activation → Dropout` repeated per hidden layer.

In [5]:
from deepfm.models.layers.dnn import DNN

input_dim = flat_embeddings.shape[1]
dnn = DNN(
    input_dim=input_dim,
    hidden_units=[128, 64, 32],
    activation="relu",
    dropout=0.1,
    use_batch_norm=True,
)
print(dnn)
print(f"\nOutput dim: {dnn.output_dim}")

DNN(
  (mlp): Sequential(
    (0): Linear(in_features=56, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.1, inplace=False)
    (8): Linear(in_features=64, out_features=32, bias=True)
    (9): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.1, inplace=False)
  )
)

Output dim: 32


In [6]:
with torch.no_grad():
    dnn.eval()
    dnn_output = dnn(flat_embeddings)

print(f"Input:  flat_embeddings {flat_embeddings.shape}  — (B, total_dim)")
print(f"Output: dnn_output      {dnn_output.shape}       — (B, last_hidden)")

Input:  flat_embeddings torch.Size([8, 56])  — (B, total_dim)
Output: dnn_output      torch.Size([8, 32])       — (B, last_hidden)


### Layer-by-Layer Dimension Flow

In [7]:
rows = []
for i, module in enumerate(dnn.mlp):
    name = module.__class__.__name__
    params = sum(p.numel() for p in module.parameters())
    if hasattr(module, "in_features"):
        detail = f"{module.in_features} -> {module.out_features}"
    elif hasattr(module, "num_features"):
        detail = f"features={module.num_features}"
    elif hasattr(module, "p"):
        detail = f"p={module.p}"
    else:
        detail = "-"
    rows.append({"idx": i, "layer": name, "detail": detail, "params": params})

df = pd.DataFrame(rows)
print(f"Total DNN parameters: {sum(p.numel() for p in dnn.parameters()):,}\n")
df

Total DNN parameters: 18,080



Unnamed: 0,idx,layer,detail,params
0,0,Linear,56 -> 128,7296
1,1,BatchNorm1d,features=128,256
2,2,ReLU,-,0
3,3,Dropout,p=0.1,0
4,4,Linear,128 -> 64,8256
5,5,BatchNorm1d,features=64,128
6,6,ReLU,-,0
7,7,Dropout,p=0.1,0
8,8,Linear,64 -> 32,2080
9,9,BatchNorm1d,features=32,64


### Comparing Activation Functions

In [8]:
rows = []
for act_name in ["relu", "leaky_relu", "gelu", "tanh"]:
    d = DNN(input_dim, [64, 32], activation=act_name, dropout=0.0, use_batch_norm=False)
    d.eval()
    with torch.no_grad():
        out = d(flat_embeddings)
    rows.append({
        "activation": act_name,
        "output_mean": f"{out.mean().item():.4f}",
        "output_std": f"{out.std().item():.4f}",
        "pct_zero": f"{(out == 0).float().mean().item():.1%}",
    })

pd.DataFrame(rows)

Unnamed: 0,activation,output_mean,output_std,pct_zero
0,relu,0.0416,0.0552,48.0%
1,leaky_relu,0.0463,0.0625,0.0%
2,gelu,0.0162,0.037,0.0%
3,tanh,-0.0045,0.1195,0.0%


---
## 3. CIN Layer (Compressed Interaction Network)

CIN is the key innovation in **xDeepFM**. It captures **explicit, vector-wise** higher-order interactions
(unlike FM which is scalar-wise, and DNN which is implicit).

Each CIN layer:
1. Computes an outer product between the current hidden state and the original input
2. Compresses with Conv1d (kernel_size=1) — like a learned weighted sum of interaction maps
3. Optionally splits: half feeds forward, half goes to the output pool

In [9]:
from deepfm.models.layers.cin import CIN

cin = CIN(num_fields=F, embed_dim=D, layer_sizes=[128, 128], split_half=True)
print(cin)
print(f"\nOutput dim: {cin.output_dim}")

CIN(
  (conv_layers): ModuleList(
    (0): Conv1d(36, 128, kernel_size=(1,), stride=(1,))
    (1): Conv1d(384, 128, kernel_size=(1,), stride=(1,))
  )
)

Output dim: 192


In [10]:
with torch.no_grad():
    cin_output = cin(field_embeddings)

print(f"Input:  field_embeddings {field_embeddings.shape}  — (B, F, D)")
print(f"Output: cin_output       {cin_output.shape}       — (B, output_dim)")

Input:  field_embeddings torch.Size([8, 6, 16])  — (B, F, D)
Output: cin_output       torch.Size([8, 192])       — (B, output_dim)


### Understanding split_half

With `split_half=True`, each intermediate layer splits its feature maps:
- One half feeds into the **next** CIN layer (for deeper interactions)
- The other half goes directly to the **output pool** (for shallower interactions)

This gives the model multi-granularity: the output contains both 2nd-order and higher-order interactions.

In [11]:
rows = []
for i, (conv, ds, ns) in enumerate(zip(cin.conv_layers, cin.direct_sizes, cin.next_sizes)):
    rows.append({
        "layer": i,
        "conv_in_channels": conv.in_channels,
        "conv_out_channels": conv.out_channels,
        "to_output_pool": ds,
        "to_next_layer": ns,
        "interaction_order": i + 2,  # layer 0 = 2nd order, layer 1 = 3rd order, etc.
    })

print(f"With split_half=True, output_dim = {cin.output_dim} = {' + '.join(str(d) for d in cin.direct_sizes)}")
pd.DataFrame(rows)

With split_half=True, output_dim = 192 = 64 + 128


Unnamed: 0,layer,conv_in_channels,conv_out_channels,to_output_pool,to_next_layer,interaction_order
0,0,36,128,64,64,2
1,1,384,128,128,128,3


### Effect of split_half on Output Dimension

In [None]:
for split in [True, False]:
    c = CIN(num_fields=F, embed_dim=D, layer_sizes=[128, 128], split_half=split)
    with torch.no_grad():
        out = c(field_embeddings)
    params = sum(p.numel() for p in c.parameters())
    print(f"split_half={str(split):5s}  output_dim={c.output_dim:4d}  output_shape={out.shape}  params={params:,}")

---
## 4. Multi-Head Self-Attention Layer

The attention layer refines field embeddings by learning **which field pairs matter most**.
Used in **AttentionDeepFM** to replace or augment FM's uniform pairwise interactions.

Standard transformer-style: Q/K/V projections → scaled dot-product → multi-head → residual + LayerNorm.

In [12]:
from deepfm.models.layers.attention import MultiHeadSelfAttention

attn = MultiHeadSelfAttention(
    embed_dim=D,
    num_heads=4,
    attention_dim=64,
    num_layers=2,
    use_residual=True,
)
print(attn)
print(f"\nTotal parameters: {sum(p.numel() for p in attn.parameters()):,}")

MultiHeadSelfAttention(
  (layers): ModuleList(
    (0-1): 2 x _AttentionBlock(
      (W_q): Linear(in_features=16, out_features=64, bias=True)
      (W_k): Linear(in_features=16, out_features=64, bias=True)
      (W_v): Linear(in_features=16, out_features=64, bias=True)
      (W_out): Linear(in_features=64, out_features=16, bias=True)
      (layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
    )
  )
)

Total parameters: 8,672


In [13]:
with torch.no_grad():
    attn.eval()
    attn_output = attn(field_embeddings)

print(f"Input:  field_embeddings {field_embeddings.shape}  — (B, F, D)")
print(f"Output: attn_output      {attn_output.shape}      — (B, F, D)  (same shape!)")
print(f"\nAttention preserves the (B, F, D) shape — it refines embeddings, not reduces them.")

Input:  field_embeddings torch.Size([8, 6, 16])  — (B, F, D)
Output: attn_output      torch.Size([8, 6, 16])      — (B, F, D)  (same shape!)

Attention preserves the (B, F, D) shape — it refines embeddings, not reduces them.


### Visualizing Attention Weights

Let's extract the attention weights from the first layer to see which fields attend to each other.

In [14]:
import math

# Manually extract attention weights from the first layer, first head
block = attn.layers[0]
with torch.no_grad():
    Q = block.W_q(field_embeddings)  # (B, F, attn_dim)
    K = block.W_k(field_embeddings)
    head_dim = block.head_dim
    num_heads = block.num_heads

    # Reshape to multi-head
    Q = Q.view(B, F, num_heads, head_dim).transpose(1, 2)  # (B, H, F, hd)
    K = K.view(B, F, num_heads, head_dim).transpose(1, 2)

    # Attention weights for head 0, sample 0
    scores = torch.matmul(Q[0, 0], K[0, 0].T) / math.sqrt(head_dim)  # (F, F)
    weights = torch.softmax(scores, dim=-1)

field_names = list(schema.fields.keys())
print("Attention weights (head 0, sample 0):")
print(f"{'':15s}", "  ".join(f"{n:>10s}" for n in field_names))
for i, name in enumerate(field_names):
    row = "  ".join(f"{weights[i, j].item():10.3f}" for j in range(F))
    print(f"{name:15s} {row}")

Attention weights (head 0, sample 0):
                   user_id     item_id      gender         age  occupation        city
user_id              0.167       0.168       0.163       0.169       0.164       0.170
item_id              0.167       0.168       0.162       0.168       0.163       0.172
gender               0.166       0.169       0.164       0.165       0.164       0.172
age                  0.168       0.167       0.162       0.169       0.163       0.172
occupation           0.168       0.167       0.164       0.169       0.163       0.168
city                 0.167       0.167       0.164       0.168       0.163       0.171


### Residual Connection Effect

With `use_residual=True`, the output is `LayerNorm(attention_out + input)`. This means
the model can learn to pass through the original embeddings when attention isn't helpful.

In [15]:
with torch.no_grad():
    attn_res = MultiHeadSelfAttention(D, num_heads=4, attention_dim=64, num_layers=1, use_residual=True)
    attn_no_res = MultiHeadSelfAttention(D, num_heads=4, attention_dim=64, num_layers=1, use_residual=False)
    attn_res.eval()
    attn_no_res.eval()

    out_res = attn_res(field_embeddings)
    out_no_res = attn_no_res(field_embeddings)

    # How much does the output differ from input?
    diff_res = (out_res - field_embeddings).norm() / field_embeddings.norm()
    diff_no_res = (out_no_res - field_embeddings).norm() / field_embeddings.norm()

print(f"Relative change from input:")
print(f"  With residual:    {diff_res.item():.4f}")
print(f"  Without residual: {diff_no_res.item():.4f}")
print(f"\nResidual connections keep outputs closer to the input (easier to train).")

Relative change from input:
  With residual:    4.3467
  Without residual: 1.2193

Residual connections keep outputs closer to the input (easier to train).


---
## 5. Comparing All Layers

In [16]:
rows = [
    {
        "layer": "FMInteraction",
        "input": "field_embeddings (B,F,D)",
        "output_shape": str(tuple(fm_output.shape)),
        "interaction_type": "2nd-order, explicit",
        "params": sum(p.numel() for p in fm.parameters()),
        "used_in": "DeepFM, AttentionDeepFM",
    },
    {
        "layer": "DNN",
        "input": "flat_embeddings (B,total_dim)",
        "output_shape": str(tuple(dnn_output.shape)),
        "interaction_type": "higher-order, implicit",
        "params": sum(p.numel() for p in dnn.parameters()),
        "used_in": "DeepFM, xDeepFM, AttentionDeepFM",
    },
    {
        "layer": "CIN",
        "input": "field_embeddings (B,F,D)",
        "output_shape": str(tuple(cin_output.shape)),
        "interaction_type": "higher-order, explicit (vector-wise)",
        "params": sum(p.numel() for p in cin.parameters()),
        "used_in": "xDeepFM",
    },
    {
        "layer": "MultiHeadSelfAttention",
        "input": "field_embeddings (B,F,D)",
        "output_shape": str(tuple(attn_output.shape)),
        "interaction_type": "adaptive pairwise weighting",
        "params": sum(p.numel() for p in attn.parameters()),
        "used_in": "AttentionDeepFM",
    },
]

pd.DataFrame(rows)

Unnamed: 0,layer,input,output_shape,interaction_type,params,used_in
0,FMInteraction,"field_embeddings (B,F,D)","(8, 1)","2nd-order, explicit",0,"DeepFM, AttentionDeepFM"
1,DNN,"flat_embeddings (B,total_dim)","(8, 32)","higher-order, implicit",18080,"DeepFM, xDeepFM, AttentionDeepFM"
2,CIN,"field_embeddings (B,F,D)","(8, 192)","higher-order, explicit (vector-wise)",54016,xDeepFM
3,MultiHeadSelfAttention,"field_embeddings (B,F,D)","(8, 6, 16)",adaptive pairwise weighting,8672,AttentionDeepFM


## 6. How Models Compose These Layers

```
FeatureEmbedding(batch)
    |
    +-- first_order (B,1)
    +-- field_embeddings (B,F,D)
    +-- flat_embeddings (B,total_dim)


DeepFM:                                  xDeepFM:                               AttentionDeepFM:
  logit = first_order                      logit = first_order                    logit = first_order
        + FM(field_emb)                          + Linear(CIN(field_emb))               + FM(field_emb)
        + Linear(DNN(flat_emb))                  + Linear(DNN(flat_emb))                + Linear(DNN(cat(
                                                                                            Attn(field_emb).flatten(),
                                                                                            flat_emb)))


Key differences:
  - DeepFM:          FM (2nd order explicit) + DNN (higher order implicit)
  - xDeepFM:         CIN (higher order explicit, vector-wise) + DNN (higher order implicit)
  - AttentionDeepFM: FM + Attention-refined embeddings fed to DNN (learned interaction importance)
```

All three share the same `FeatureEmbedding` — the only difference is how they process its outputs.