In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ClickPredictor(nn.Module):
    """
    Click Predictor Module: Predicts the probability of a user clicking on a candidate news article.
    Uses a dot product between the user representation and the candidate news representation.
    """
    def __init__(self):
        super(ClickPredictor, self).__init__()
        # No additional parameters are needed; the click probability
        # is computed using the dot product followed by a sigmoid activation.
    
    def forward(self, user_repr, candidate_news_repr):
        """
        Forward pass for click prediction.
        
        Args:
            user_repr (Tensor): User representation tensor of shape (batch_size, embedding_dim).
            candidate_news_repr (Tensor): Candidate news representation tensor of shape (batch_size, embedding_dim).
    
        Returns:
            click_prob (Tensor): Click probabilities tensor of shape (batch_size, 1), with values in [0, 1].
        """
        # Ensure the embedding dimensions match
        assert user_repr.size(1) == candidate_news_repr.size(1), "Embedding dimensions must match!"
        
        # Step 1: Compute the dot product between user and candidate news representations
        # This measures the similarity between user interests and news content
        click_score = torch.sum(user_repr * candidate_news_repr, dim=1, keepdim=True)  # Shape: (batch_size, 1)
    
        # Step 2: Apply a sigmoid activation function to get probabilities in [0, 1]
        click_prob = torch.sigmoid(click_score)  # Shape: (batch_size, 1)
    
        # Step 3: Return the click probabilities
        return click_prob

# Example usage
if __name__ == "__main__":
    # Simulate example inputs
    batch_size = 5      # Number of user-news pairs in a batch
    embedding_dim = 300 # Embedding dimension matching the NewsEncoder and UserEncoder outputs
    
    # Random user and news embeddings
    user_repr = torch.rand(batch_size, embedding_dim)          # Shape: (batch_size, embedding_dim)
    candidate_news_repr = torch.rand(batch_size, embedding_dim) # Shape: (batch_size, embedding_dim)
    
    # Instantiate the Click Predictor
    click_predictor = ClickPredictor()
    
    # Forward pass: Compute click probabilities
    click_probs = click_predictor(user_repr, candidate_news_repr)
    
    # Print the results
    print("User Representations Shape:", user_repr.shape)
    print("Candidate News Representations Shape:", candidate_news_repr.shape)
    print("Click Probabilities Shape:", click_probs.shape)
    print("Click Probabilities:\n", click_probs)


# News Recommendation System Documentation

This notebook explains the key components of a News Recommendation System, focusing on user representation, candidate news representation, and click prediction.

---

## 1. User Representation (`user_repr`)

### **Source**  
Generated by the `UserEncoder` module.

### **Computation**  
- The `UserEncoder` takes as input the embeddings of news articles that the user has interacted with.
- It uses the following mechanisms to produce a user representation:
  - **Multi-head additive attention** (`MultiHeadAdditiveAttention`): Aggregates information from multiple interaction embeddings.
  - **Additive attention** (`UserAdditiveAttention`): Focuses on the most relevant interactions to refine the user representation.

### **Shape**  
- Output tensor shape: `(batch_size, embedding_dim)`

---

## 2. Candidate News Representation (`candidate_news_repr`)

### **Source**  
Generated by the `NewsEncoder` module.

### **Computation**  
- The `NewsEncoder` processes tokenized and embedded titles of candidate news articles.
- It uses the following components:
  - **Word embeddings**: Converts tokenized words into dense vector representations.
  - **Self-attention** (`MultiHeadSelfAttention`): Captures contextual relationships between words.
  - **Additive attention** (`AdditiveWordAttention`): Aggregates word-level information into a single news representation.

### **Shape**  
- Output tensor shape: `(batch_size, embedding_dim)`

---

## 3. Click Predictor (`ClickPredictor`)

### **Function**  
Predicts the probability that a user will click on a candidate news article.

### **Inputs**  
- `user_repr`: User representation tensor from `UserEncoder`.
- `candidate_news_repr`: News representation tensor from `NewsEncoder`.

### **Computation**  
1. Computes the **dot product** between `user_repr` and `candidate_news_repr` to measure similarity between user interests and news content.
2. Applies a **sigmoid activation function** to map the scores to probabilities in the range `[0, 1]`.

### **Output**  
- `click_prob`: A tensor of shape `(batch_size, 1)` containing click probabilities.

