In [None]:
# NLP Quest

## Summary
    1. Trasformer
    2. Data Normalization
    3. MLP (Classification)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

class TwitterSpamDetector(nn.Module):
    def __init__(self, num_tokens, emb_size, num_heads, hidden_size, num_layers, dropout_prob, num_following, num_followers, num_actions):
        super().__init__()
        
        # Embedding layer
        self.embedding = nn.Embedding(num_tokens, emb_size)
        
        # Transformer encoder
        self.encoder_layer = nn.TransformerEncoderLayer(emb_size, num_heads, hidden_size, dropout_prob)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers)
        
        # MLP layers
        self.mlp = nn.Sequential(
            nn.Linear(emb_size + num_following + num_followers + num_actions + 1, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        )

        self.num_following = num_following
        self.num_followers = num_followers
        self.num_actions = num_actions
        
    def forward(self, x, following, followers, actions, is_retweet):
        # Encode input sequence with Transformer
        x = self.embedding(x)
        x = self.transformer_encoder(x)

        # Concatenate normalized input values
        following_norm = following.float()
        followers_norm = followers.float()
        actions_norm = actions.float()
        is_retweet = is_retweet.float()
        input_vec = torch.cat([x.mean(dim=1), following_norm.unsqueeze(1), followers_norm.unsqueeze(1), actions_norm.unsqueeze(1), is_retweet.unsqueeze(1)], dim=1)

        # Pass through MLP layers
        y = self.mlp(input_vec)

        return y
