<a href="https://colab.research.google.com/github/AkramBenamar/DomainAwareEmbedder/blob/master/DomainsAwareEmbedder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#DomainsAwareEmbeder

##Model

###Embedder

####PositionalEncoding

In [1]:
import unittest
import torch
import math
from torch import nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512) -> None:
        super().__init__()
        self.pe = self._generate_encoding(d_model, max_len)

    def _generate_encoding(self, d_model: int, max_len: int) -> torch.Tensor:
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # (1, max_len, d_model)

    def forward(self, seq_len: int) -> torch.Tensor:
        return self.pe[:, :seq_len]

class TestPositionalEncoding(unittest.TestCase):

    def test_shape(self):
        """Test output shape"""
        d_model = 16
        max_len = 100
        pe = PositionalEncoding(d_model, max_len)
        output = pe(50)
        self.assertEqual(output.shape, (1, 50, d_model))

    def test_values_repeatability(self):
        """Test same output for same inputs"""
        d_model = 32
        max_len = 60
        pe = PositionalEncoding(d_model, max_len)
        output1 = pe(10)
        output2 = pe(10)
        self.assertTrue(torch.allclose(output1, output2, atol=1e-6))

    def test_no_nan(self):
        """Test qnot NaN"""
        pe = PositionalEncoding(64, 128)
        output = pe(64)
        self.assertFalse(torch.isnan(output).any())

    def test_known_value(self):
        """Test values"""
        d_model = 4
        max_len = 1
        pe = PositionalEncoding(d_model, max_len)
        output = pe(1)[0, 0]  # shape: (d_model,)
        expected = torch.tensor([
            math.sin(0 / (10000 ** (0 / d_model))),  # sin(0) = 0
            math.cos(0 / (10000 ** (0 / d_model))),  # cos(0) = 1
            math.sin(0 / (10000 ** (2 / d_model))),  # sin(0) = 0
            math.cos(0 / (10000 ** (2 / d_model)))   # cos(0) = 1
        ])
        self.assertTrue(torch.allclose(output, expected, atol=1e-5))


unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestPositionalEncoding))


....
----------------------------------------------------------------------
Ran 4 tests in 0.192s

OK


<unittest.runner.TextTestResult run=4 errors=0 failures=0>

###Encoder

##Data

##UnitTests

##Repo

In [2]:
!ssh-keygen -t rsa -b 4096 -C "ha_ben_amar@esi.dz" -f /root/.ssh/id_rsa -N ""
!cat /root/.ssh/id_rsa.pub


Generating public/private rsa key pair.
Created directory '/root/.ssh'.
Your identification has been saved in /root/.ssh/id_rsa
Your public key has been saved in /root/.ssh/id_rsa.pub
The key fingerprint is:
SHA256:+U2vCM3WFYmB4gen3mumWvw5j0xUoRBCd91AjlaK/CQ ha_ben_amar@esi.dz
The key's randomart image is:
+---[RSA 4096]----+
|     .o +..+*o   |
|       +o+o*.+.. |
|       .E=* + o  |
|        o*..   . |
|       .Soo . .  |
|       ..=.+ o   |
|        + *.o .  |
|       . *=+ .   |
|      ...+*oo    |
+----[SHA256]-----+
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDTlDiQa0g05Q7N0M+D4od/UXw6c5pIanq3ZHURFlblS4+47AQmiptCkWxzVNllfKtfqg8BaF33Na7OzeJ8M7oidpeSCoST/DnNC8CNzOm4MSQzcvzO47vmgNCu2UydfH3Rpt6T02fkWL6kETRqEmHAQKqdKaXJM5oaPoR0OoR3pSEcROuxZHHy6gGnRRf6QST93Ci3+vXYNsW4vHVTz79gkQncfwgWrislIUO+V2sm3IpJMdqpHXL6ZbTcLAAXCSZugn2mHDpLQRZ4dNtdkNOgKB246aOYByiWWFmkrSfQKtc6y6qgKbvDfrGRmDm6k+7dt+DttFd/YPSuA0/SyJtok1FNqZxw+FpzbyspGLMRhuWKcjpET7gBFwkmN0YsdrtS37ZZABRlBzcX0V/ebk5voyJBPwwjikqgnChGu8ygWl494vEF

In [4]:
!git config --global user.email "ha_ben_amar@esi.dz"
!git config --global user.name "AkramBenamar"
!git clone git@github.com:AkramBenamar/DomainAwareEmbedder.git


Cloning into 'DomainAwareEmbedder'...
Host key verification failed.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights
and the repository exists.
