In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
import numpy as np

# Two-tower model

Labels Used in the Two-Tower Approach
Retrieval Labels:

Type: Binary (0 or 1)
Description: These labels indicate whether a user interacted with an ad. A value of 1 means the user interacted (e.g., clicked or purchased the ad), while 0 indicates no interaction. These labels are typically used during the retrieval stage of training, helping the model learn to rank ads based on their relevance to users.
Ranking Labels:

Type: Continuous (often in the range of [0, 1])
Description: These labels provide a more nuanced measure of the likelihood of interaction. They can represent probabilities or expected values, such as click-through rates (CTR). The ranking labels help fine-tune the model in the ranking stage, allowing it to predict how likely a user is to interact with a given ad more accurately.
Advantages of the DeepFM Model
Combination of Linear and Deep Learning Models:

DeepFM effectively combines Factorization Machines (FM) for capturing feature interactions and Deep Neural Networks (DNN) for modeling complex patterns. This allows it to leverage both the explicit feature interactions (captured by FM) and the deeper latent patterns (captured by DNN) in the data.
Handling Sparse Data:

The FM component is particularly useful in scenarios with high-dimensional sparse data, as it efficiently captures interactions between features without needing extensive amounts of data.
Scalability:

DeepFM can handle large-scale datasets efficiently. The architecture is designed to scale well with increasing data size, making it suitable for real-world applications.
Flexibility:

The model can easily incorporate different types of features, including categorical and numerical, by embedding them appropriately. This flexibility allows it to be used across various domains, such as click prediction and recommendation systems.
Interpretability:

The FM component provides some level of interpretability, as it directly models feature interactions. This can be helpful for understanding which features contribute to predictions.
End-to-End Learning:

DeepFM supports end-to-end learning, meaning you can train it directly from raw feature inputs to predictions without needing to separate feature engineering into distinct steps.

In [2]:
class TwoTowerModel(nn.Module):
    def __init__(self, user_feature_dim, ad_feature_dim, embedding_dim):
        super(TwoTowerModel, self).__init__()
        # User embedding tower
        self.user_embedding = nn.Sequential(
            nn.Linear(user_feature_dim, 128), nn.ReLU(), nn.Linear(128, embedding_dim)
        )
        # Ad embedding tower
        self.ad_embedding = nn.Sequential(
            nn.Linear(ad_feature_dim, 128), nn.ReLU(), nn.Linear(128, embedding_dim)
        )

    def forward(self, user_features, ad_features):
        user_emb = self.user_embedding(user_features)
        ad_emb = self.ad_embedding(ad_features)

        # Normalize for cosine similarity
        user_emb = F.normalize(user_emb, p=2, dim=1)
        ad_emb = F.normalize(ad_emb, p=2, dim=1)

        # Calculate similarity scores
        similarity = torch.sum(user_emb * ad_emb, dim=1)

        return similarity, user_emb, ad_emb

In [3]:
def train_retrieval_stage(model, data_loader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    criterion = nn.BCEWithLogitsLoss()  # Move criterion inside the function

    for batch_idx, (user_features, ad_features, retrieval_labels, _) in enumerate(
        data_loader
    ):
        user_features = user_features.to(device)
        ad_features = ad_features.to(device)
        retrieval_labels = retrieval_labels.to(
            device
        ).squeeze()  # Remove extra dimensions

        # Get similarity scores
        similarity, _, _ = model(user_features, ad_features)

        # Calculate loss
        loss = criterion(similarity, retrieval_labels.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(
                f"Retrieval Stage - Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}"
            )

    return total_loss / len(data_loader)

# Deep Factorization Machine Model

In [None]:
# Factorization Machine Layer
class FactorizationMachineLayer(nn.Module):
    def __init__(self, feature_dim, embedding_dim):
        super(FactorizationMachineLayer, self).__init__()
        self.embeddings = nn.Embedding(feature_dim, embedding_dim)

    def forward(self, x):
        embed_x = self.embeddings(x)
        sum_square = torch.sum(embed_x, dim=1) ** 2
        square_sum = torch.sum(embed_x**2, dim=1)
        second_order = 0.5 * (sum_square - square_sum).unsqueeze(
            1
        )  # Ensure correct dimension
        return second_order


class DeepFM(nn.Module):
    def __init__(self, feature_dim, embedding_dim, hidden_dims):
        super(DeepFM, self).__init__()
        self.fm_layer = FactorizationMachineLayer(feature_dim, embedding_dim)
        self.embeddings = nn.Embedding(feature_dim, embedding_dim)

        # Deep neural network layers
        dnn_layers = []
        input_dim = feature_dim * embedding_dim  # Number of inputs to DNN
        for hidden_dim in hidden_dims:
            dnn_layers.append(nn.Linear(input_dim, hidden_dim))
            dnn_layers.append(nn.ReLU())
            input_dim = hidden_dim
        self.dnn = nn.Sequential(*dnn_layers)

        # Output layer
        self.output = nn.Linear(input_dim + 1, 1)  # +1 for FM output
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        fm_out = self.fm_layer(x.long())
        embed_x = self.embeddings(x).view(x.size(0), -1)  # Flatten embeddings
        dnn_out = self.dnn(embed_x)
        combined = torch.cat([fm_out, dnn_out], dim=1)
        output = self.output(combined)
        return output  # Sigmoid for binary classification

In [None]:
def train_ranking_stage(model, data_loader, optimizer, criterion, device, epoch):
    model.train()
    total_loss = 0

    for batch_idx, (features, labels) in enumerate(data_loader):
        # Convert features to float and normalize
        features = features.to(device).float()
        # Ensure features are in reasonable range
        features = (features - features.mean(dim=0, keepdim=True)) / (
            features.std(dim=0, keepdim=True) + 1e-7
        )

        labels = labels.to(device).float()

        # Forward pass
        outputs = model(features)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Gradient clipping to prevent explosions
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(
                f"Ranking Stage - Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}"
            )

    return total_loss / len(data_loader)


# Modified data preparation for ranking stage
def prepare_ranking_data(user_features, ad_features, ranking_labels, batch_size):
    # Combine features
    combined_features = torch.cat((user_features, ad_features), dim=1)

    # Normalize features
    combined_features = (
        combined_features - combined_features.mean(dim=0, keepdim=True)
    ) / (combined_features.std(dim=0, keepdim=True) + 1e-7)

    # Create dataset and dataloader
    ranking_dataset = TensorDataset(combined_features, ranking_labels)
    ranking_loader = DataLoader(ranking_dataset, batch_size=batch_size, shuffle=True)

    return ranking_loader

# Helper functions

In [6]:
def retrieve_top_k(model, data_loader, device, k=100):
    """
    Retrieve top k candidates from retrieval model.
    Returns both retrieval and ranking labels for the retrieved items.
    """
    model.eval()
    all_user_features = []
    all_ad_features = []
    all_similarities = []
    all_retrieval_labels = []
    all_ranking_labels = []

    with torch.no_grad():
        for user_features, ad_features, retrieval_labels, ranking_labels in data_loader:
            user_features = user_features.to(device)
            ad_features = ad_features.to(device)

            similarity, _, _ = model(user_features, ad_features)

            # Move everything to CPU for concatenation
            all_user_features.append(user_features.cpu())
            all_ad_features.append(ad_features.cpu())
            all_similarities.append(similarity.cpu())
            all_retrieval_labels.append(retrieval_labels)
            all_ranking_labels.append(ranking_labels)

    # Concatenate all batches on CPU
    all_user_features = torch.cat(all_user_features, dim=0)
    all_ad_features = torch.cat(all_ad_features, dim=0)
    all_similarities = torch.cat(all_similarities, dim=0)
    all_retrieval_labels = torch.cat(all_retrieval_labels, dim=0)
    all_ranking_labels = torch.cat(all_ranking_labels, dim=0)

    # Get top k indices on CPU
    top_k_indices = torch.topk(
        all_similarities, k=min(k, len(all_similarities))
    ).indices

    # Gather top k items
    retrieved_user_features = all_user_features[top_k_indices]
    retrieved_ad_features = all_ad_features[top_k_indices]
    retrieved_ranking_labels = all_ranking_labels[top_k_indices]

    # Move results back to the specified device
    retrieved_user_features = retrieved_user_features.to(device)
    retrieved_ad_features = retrieved_ad_features.to(device)
    retrieved_ranking_labels = retrieved_ranking_labels.to(device)

    return retrieved_user_features, retrieved_ad_features, retrieved_ranking_labels

# Data setup

User Features
For each user, the data might include:

Age: Integer value representing age (e.g., 25, 40).
Gender: Categorical (encoded as integer, e.g., 0 for male, 1 for female).
Country: Categorical, converted to numerical IDs (e.g., 1 for USA, 2 for Canada).
Following/Followers: Number of people the user follows and their followers (e.g., 150 follows, 300 followers).

Ad Features
Each ad might include:

Ad Category: Categorical feature describing the ad type (e.g., 0 for electronics, 1 for clothing).
Price: Numeric feature for ad item price.
Brand Name: Numeric ID representing brand (e.g., 12 for Nike).
Conversion Rate: Historical interaction rate, a float (e.g., 0.05 representing 5% conversion rate).
Likes and Shares: Historical likes and shares count for the ad.

Labels
Binary label indicating if a user interacted with the ad:

1 for positive interaction (e.g., click or purchase).
0 for no interaction.

In [7]:
# Hyperparameters
batch_size = 64
num_samples = 10000
user_feature_dim = 5
ad_feature_dim = 5
embedding_dim = 64
num_epochs = 2
k = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate synthetic data with two types of labels
user_features = torch.randn(num_samples, user_feature_dim)
ad_features = torch.randn(num_samples, ad_feature_dim)

# Retrieval labels: Binary labels indicating if the ad is relevant (1) or not (0)
retrieval_labels = torch.randint(0, 2, (num_samples, 1)).float()

# Ranking labels: Continuous values indicating the likelihood of interaction (e.g., CTR)
# These could be more fine-grained than retrieval labels
ranking_labels = torch.rand(num_samples, 1)  # Values between 0 and 1

# Create dataset and dataloader with both types of labels
dataset = TensorDataset(user_features, ad_features, retrieval_labels, ranking_labels)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [8]:
# Initialize models
two_tower_model = TwoTowerModel(user_feature_dim, ad_feature_dim, embedding_dim).to(
    device
)
deepfm_model = DeepFM(
    feature_dim=user_feature_dim + ad_feature_dim,
    embedding_dim=16,
    hidden_dims=[128, 64, 32],
).to(device)

# Optimizers and losses
retrieval_optimizer = torch.optim.Adam(two_tower_model.parameters(), lr=0.001)
ranking_optimizer = torch.optim.Adam(deepfm_model.parameters(), lr=0.001)
retrieval_criterion = nn.BCEWithLogitsLoss  # Binary cross entropy for retrieval
ranking_criterion = nn.BCEWithLogitsLoss()  # Mean squared error for ranking

In [9]:
# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    # Train retrieval model - note we're not passing criterion anymore
    retrieval_loss = train_retrieval_stage(
        two_tower_model, data_loader, retrieval_optimizer, device, epoch
    )

    # Retrieve top k items and get corresponding ranking labels
    retrieved_user_features, retrieved_ad_features, retrieved_ranking_labels = (
        retrieve_top_k(two_tower_model, data_loader, device, k)
    )

    # Prepare data for ranking stage
    if retrieved_user_features is not None:
        ranking_features = torch.cat(
            (retrieved_user_features, retrieved_ad_features), dim=1
        )
        ranking_dataset = TensorDataset(ranking_features, retrieved_ranking_labels)
        ranking_loader = DataLoader(
            ranking_dataset, batch_size=batch_size, shuffle=True
        )

        # Train ranking model with ranking labels
        ranking_loss = train_ranking_stage(
            deepfm_model,
            ranking_loader,
            ranking_optimizer,
            ranking_criterion,
            device,
            epoch,
        )

        print(
            f"Epoch {epoch+1} - Retrieval Loss: {retrieval_loss:.4f}, Ranking Loss: {ranking_loss:.4f}"
        )


Epoch 1/2
Retrieval Stage - Epoch 0, Batch 0, Loss: 0.6784
Retrieval Stage - Epoch 0, Batch 100, Loss: 0.6775


RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long