<a href="https://colab.research.google.com/github/AkramBenamar/DomainAwareEmbedder/blob/master/DomainsAwareEmbedder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#DomainsAwareEmbeder

##Model

###Embedder

####PositionalEncoding

In [1]:
import unittest
import torch
import math
from torch import nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512) -> None:
        super().__init__()
        self.pe = self._generate_encoding(d_model, max_len)

    def _generate_encoding(self, d_model: int, max_len: int) -> torch.Tensor:
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # (1, max_len, d_model)

    def forward(self, seq_len: int) -> torch.Tensor:
        return self.pe[:, :seq_len]

class TestPositionalEncoding(unittest.TestCase):

    def test_shape(self):
        """Test output shape"""
        d_model = 16
        max_len = 100
        pe = PositionalEncoding(d_model, max_len)
        output = pe(50)
        self.assertEqual(output.shape, (1, 50, d_model))

    def test_values_repeatability(self):
        """Test same output for same inputs"""
        d_model = 32
        max_len = 60
        pe = PositionalEncoding(d_model, max_len)
        output1 = pe(10)
        output2 = pe(10)
        self.assertTrue(torch.allclose(output1, output2, atol=1e-6))

    def test_no_nan(self):
        """Test qnot NaN"""
        pe = PositionalEncoding(64, 128)
        output = pe(64)
        self.assertFalse(torch.isnan(output).any())

    def test_known_value(self):
        """Test values"""
        d_model = 4
        max_len = 1
        pe = PositionalEncoding(d_model, max_len)
        output = pe(1)[0, 0]  # shape: (d_model,)
        expected = torch.tensor([
            math.sin(0 / (10000 ** (0 / d_model))),  # sin(0) = 0
            math.cos(0 / (10000 ** (0 / d_model))),  # cos(0) = 1
            math.sin(0 / (10000 ** (2 / d_model))),  # sin(0) = 0
            math.cos(0 / (10000 ** (2 / d_model)))   # cos(0) = 1
        ])
        self.assertTrue(torch.allclose(output, expected, atol=1e-5))


unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestPositionalEncoding))


....
----------------------------------------------------------------------
Ran 4 tests in 0.192s

OK


<unittest.runner.TextTestResult run=4 errors=0 failures=0>

####DomainEmbedder

In [6]:
import unittest
import torch
class DomainAwareEmbedder(nn.Module):
    def __init__(
        self,
        num_domains: int,
        d_model: int,
        d_embed: int,
        n_heads: int = 4,
        max_seq_len: int = 512
    ) -> None:
        super().__init__()
        self.d_model = d_model

        self.domain_proj_layer = nn.Linear(num_domains, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len)
        self.query_proj = nn.Linear(d_embed, d_model)
        self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)
        self.layer_norm = nn.LayerNorm(d_model)

    def project_domains(self, m: torch.Tensor) -> torch.Tensor:
        """Project modality descriptors to embedding space."""
        return self.domain_proj_layer(m.float())

    def combine_domain_and_position(self, domain_proj: torch.Tensor, seq_len: int, device=None) -> torch.Tensor:
        """Add positional encoding to projected modality embeddings."""
        pos_enc = self.pos_encoder(seq_len).to(device or domain_proj.device)
        return domain_proj + pos_enc

    def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, d_embed)
            m: (batch_size, seq_len, num_modalities)
        Returns:
            Tensor of shape (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape

        domain_proj = self.project_domains(m)
        DC = self.combine_domain_and_position(domain_proj, seq_len)

        query = self.query_proj(x)
        attended, _ = self.attention(query, DC, DC)

        return self.layer_norm(attended)


import unittest
import torch

class TestDomainAwareEmbedder(unittest.TestCase):

    def setUp(self):
        self.batch_size = 2
        self.seq_len = 10
        self.d_embed = 32
        self.d_model = 64
        self.num_domains = 5

        self.model = DomainAwareEmbedder(
            num_domains=self.num_domains,
            d_model=self.d_model,
            d_embed=self.d_embed,
            n_heads=4,
            max_seq_len=100
        )

        self.x = torch.randn(self.batch_size, self.seq_len, self.d_embed)
        self.m = torch.randn(self.batch_size, self.seq_len, self.num_domains)

    def test_project_domains(self):
        projected = self.model.project_domains(self.m)
        self.assertEqual(projected.shape, (self.batch_size, self.seq_len, self.d_model))
        self.assertFalse(torch.isnan(projected).any())

    def test_combine_domain_and_position(self):
        domain_proj = self.model.project_domains(self.m)
        combined = self.model.combine_domain_and_position(domain_proj, self.seq_len)
        self.assertEqual(combined.shape, (self.batch_size, self.seq_len, self.d_model))
        self.assertFalse(torch.isnan(combined).any())

    def test_forward_output_shape(self):
        output = self.model(self.x, self.m)
        self.assertEqual(output.shape, (self.batch_size, self.seq_len, self.d_model))

    def test_forward_repeatability(self):
        output1 = self.model(self.x, self.m)
        output2 = self.model(self.x, self.m)
        self.assertTrue(torch.allclose(output1, output2, atol=1e-5))

    def test_forward_no_nan(self):
        output = self.model(self.x, self.m)
        self.assertFalse(torch.isnan(output).any())

def test_combine_domain_and_position_adds_encoding(self):
    domain_proj = self.model.project_domains(self.m)
    combined = self.model.combine_domain_and_position(domain_proj, self.seq_len)
    pos_enc = self.model.pos_encoder(self.seq_len).to(domain_proj.device)
    diff = combined - domain_proj
    self.assertTrue(torch.allclose(diff, pos_enc.expand_as(domain_proj), atol=1e-6))

unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestDomainAwareEmbedder))

.....
----------------------------------------------------------------------
Ran 5 tests in 0.022s

OK


<unittest.runner.TextTestResult run=5 errors=0 failures=0>

###Encoder

####TransformerEncoder

In [8]:
import torch
import torch.nn as nn
import unittest

class TransformerEncoder(nn.Module):
    """Transformer-based encoder that replaces RNN/GRU for sequence modeling."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        n_heads: int = 4,
        n_layers: int = 2,
        dropout: float = 0.1
    ) -> None:
        super().__init__()

        self.input_proj = nn.Linear(input_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_layers
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
        Returns:
            Output tensor of shape (batch_size, hidden_dim), last token representation
        """
        x = self.input_proj(x)
        out = self.transformer_encoder(x)
        return out[:, -1, :]




class TestTransformerEncoder(unittest.TestCase):

    def setUp(self):
        self.batch_size = 4
        self.seq_len = 10
        self.input_dim = 32
        self.hidden_dim = 64
        self.encoder = TransformerEncoder(
            input_dim=self.input_dim,
            hidden_dim=self.hidden_dim,
            n_heads=4,
            n_layers=2
        )

    def test_output_shape(self):
        x = torch.randn(self.batch_size, self.seq_len, self.input_dim)
        out = self.encoder(x)
        self.assertEqual(out.shape, (self.batch_size, self.hidden_dim))

    def test_projection_works(self):
        x = torch.randn(self.batch_size, self.seq_len, self.input_dim)
        projected = self.encoder.input_proj(x)
        self.assertEqual(projected.shape, (self.batch_size, self.seq_len, self.hidden_dim))

    def test_determinism(self):
        torch.manual_seed(42)
        self.encoder.eval()
        x = torch.randn(self.batch_size, self.seq_len, self.input_dim)
        out1 = self.encoder(x)
        torch.manual_seed(42)
        out2 = self.encoder(x)
        self.assertTrue(torch.allclose(out1, out2, atol=1e-6))


unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestTransformerEncoder))

...
----------------------------------------------------------------------
Ran 3 tests in 0.042s

OK


<unittest.runner.TextTestResult run=3 errors=0 failures=0>

##Data