<a href="https://colab.research.google.com/github/AkramBenamar/DomainAwareEmbedder/blob/master/DomainsAwareEmbedder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#DomainsAwareEmbeder

##Model

###Embedder

####PositionalEncoding

In [3]:
import unittest
import torch
import math
from torch import nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512) -> None:
        super().__init__()
        self.pe = self._generate_encoding(d_model, max_len)

    def _generate_encoding(self, d_model: int, max_len: int) -> torch.Tensor:
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # (1, max_len, d_model)

    def forward(self, seq_len: int) -> torch.Tensor:
        return self.pe[:, :seq_len]

class TestPositionalEncoding(unittest.TestCase):

    def test_shape(self):
        """Test output shape"""
        d_model = 16
        max_len = 100
        pe = PositionalEncoding(d_model, max_len)
        output = pe(50)
        self.assertEqual(output.shape, (1, 50, d_model))

    def test_values_repeatability(self):
        """Test same output for same inputs"""
        d_model = 32
        max_len = 60
        pe = PositionalEncoding(d_model, max_len)
        output1 = pe(10)
        output2 = pe(10)
        self.assertTrue(torch.allclose(output1, output2, atol=1e-6))

    def test_no_nan(self):
        """Test qnot NaN"""
        pe = PositionalEncoding(64, 128)
        output = pe(64)
        self.assertFalse(torch.isnan(output).any())

    def test_known_value(self):
        """Test values"""
        d_model = 4
        max_len = 1
        pe = PositionalEncoding(d_model, max_len)
        output = pe(1)[0, 0]  # shape: (d_model,)
        expected = torch.tensor([
            math.sin(0 / (10000 ** (0 / d_model))),  # sin(0) = 0
            math.cos(0 / (10000 ** (0 / d_model))),  # cos(0) = 1
            math.sin(0 / (10000 ** (2 / d_model))),  # sin(0) = 0
            math.cos(0 / (10000 ** (2 / d_model)))   # cos(0) = 1
        ])
        self.assertTrue(torch.allclose(output, expected, atol=1e-5))


unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestPositionalEncoding))


....
----------------------------------------------------------------------
Ran 4 tests in 0.013s

OK


<unittest.runner.TextTestResult run=4 errors=0 failures=0>

####DomainEmbedder

In [8]:
import unittest
import torch

class DomainAwareEmbedder(nn.Module):
    def __init__(
        self,
        num_domains: int,
        d_model: int,
        d_embed: int,
        n_heads: int = 4,
        max_seq_len: int = 512
    ) -> None:
        super().__init__()
        self.d_model = d_model

        self.domain_proj_layer = nn.Linear(num_domains, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len)
        self.query_proj = nn.Linear(d_embed, d_model)
        self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)
        self.layer_norm = nn.LayerNorm(d_model)

    def project_domains(self, m: torch.Tensor) -> torch.Tensor:
        """Project modality descriptors to embedding space."""
        return self.domain_proj_layer(m.float())

    def combine_domain_and_position(self, domain_proj: torch.Tensor, seq_len: int, device=None) -> torch.Tensor:
        """Add positional encoding to projected modality embeddings."""
        pos_enc = self.pos_encoder(seq_len).to(device or domain_proj.device)
        return domain_proj + pos_enc

    def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, d_embed)
            m: (batch_size, seq_len, num_domains)
        Returns:
            Tensor of shape (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape

        domain_proj = self.project_domains(m)
        DC = self.combine_domain_and_position(domain_proj, seq_len)

        query = self.query_proj(x)
        attended, _ = self.attention(query, DC, DC)

        return self.layer_norm(attended)



class TestDomainAwareEmbedder(unittest.TestCase):

    def setUp(self):
        self.batch_size = 2
        self.seq_len = 10
        self.d_embed = 32
        self.d_model = 64
        self.num_domains = 5

        self.model = DomainAwareEmbedder(
            num_domains=self.num_domains,
            d_model=self.d_model,
            d_embed=self.d_embed,
            n_heads=4,
            max_seq_len=100
        )

        self.x = torch.randn(self.batch_size, self.seq_len, self.d_embed)
        self.m = torch.randn(self.batch_size, self.seq_len, self.num_domains)

    def test_project_domains(self):
        projected = self.model.project_domains(self.m)
        self.assertEqual(projected.shape, (self.batch_size, self.seq_len, self.d_model))
        self.assertFalse(torch.isnan(projected).any())

    def test_combine_domain_and_position(self):
        domain_proj = self.model.project_domains(self.m)
        combined = self.model.combine_domain_and_position(domain_proj, self.seq_len)
        self.assertEqual(combined.shape, (self.batch_size, self.seq_len, self.d_model))
        self.assertFalse(torch.isnan(combined).any())

    def test_forward_output_shape(self):
        output = self.model(self.x, self.m)
        self.assertEqual(output.shape, (self.batch_size, self.seq_len, self.d_model))

    def test_forward_repeatability(self):
        output1 = self.model(self.x, self.m)
        output2 = self.model(self.x, self.m)
        self.assertTrue(torch.allclose(output1, output2, atol=1e-5))

    def test_forward_no_nan(self):
        output = self.model(self.x, self.m)
        self.assertFalse(torch.isnan(output).any())

    def test_combine_domain_and_position_adds_encoding(self):
        domain_proj = self.model.project_domains(self.m)
        combined = self.model.combine_domain_and_position(domain_proj, self.seq_len)
        pos_enc = self.model.pos_encoder(self.seq_len).to(domain_proj.device)
        diff = combined - domain_proj
        self.assertTrue(torch.allclose(diff, pos_enc.expand_as(domain_proj), atol=1e-6))
    def test_positional_encoding_applied_per_position(self):
        domain_proj = self.model.project_domains(self.m)
        combined = self.model.combine_domain_and_position(domain_proj, self.seq_len)
        pos_enc = self.model.pos_encoder(self.seq_len).to(domain_proj.device)  # shape: (1, seq_len, d_model)

        for b in range(self.batch_size):
            for i in range(self.seq_len):
                expected = domain_proj[b, i, :] + pos_enc[0, i, :]
                actual = combined[b, i, :]
                self.assertTrue(torch.allclose(actual, expected, atol=1e-6),
                                msg=f"Incorrect positional encoding at batch {b}, position {i}")


unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestDomainAwareEmbedder))

.......
----------------------------------------------------------------------
Ran 7 tests in 0.028s

OK


<unittest.runner.TextTestResult run=7 errors=0 failures=0>

###Encoder

####TransformerEncoder

In [9]:
import torch
import torch.nn as nn
import unittest

class TransformerEncoder(nn.Module):
    """Transformer-based encoder"""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        n_heads: int = 4,
        n_layers: int = 2,
        dropout: float = 0.1
    ) -> None:
        super().__init__()

        self.input_proj = nn.Linear(input_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_layers
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
        Returns:
            Output tensor of shape (batch_size, hidden_dim), last token representation
        """
        x = self.input_proj(x)
        out = self.transformer_encoder(x)
        return out[:, -1, :]




class TestTransformerEncoder(unittest.TestCase):

    def setUp(self):
        self.batch_size = 4
        self.seq_len = 10
        self.input_dim = 32
        self.hidden_dim = 64
        self.encoder = TransformerEncoder(
            input_dim=self.input_dim,
            hidden_dim=self.hidden_dim,
            n_heads=4,
            n_layers=2
        )

    def test_output_shape(self):
        x = torch.randn(self.batch_size, self.seq_len, self.input_dim)
        out = self.encoder(x)
        self.assertEqual(out.shape, (self.batch_size, self.hidden_dim))

    def test_projection_works(self):
        x = torch.randn(self.batch_size, self.seq_len, self.input_dim)
        projected = self.encoder.input_proj(x)
        self.assertEqual(projected.shape, (self.batch_size, self.seq_len, self.hidden_dim))

    def test_determinism(self):
        torch.manual_seed(42)
        self.encoder.eval()
        x = torch.randn(self.batch_size, self.seq_len, self.input_dim)
        out1 = self.encoder(x)
        torch.manual_seed(42)
        out2 = self.encoder(x)
        self.assertTrue(torch.allclose(out1, out2, atol=1e-6))


unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestTransformerEncoder))

...
----------------------------------------------------------------------
Ran 3 tests in 0.066s

OK


<unittest.runner.TextTestResult run=3 errors=0 failures=0>

####DomainAwareTransformerEncoder

In [10]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [13]:
import torch
import torch.nn as nn
import unittest
import pandas as pd
from torchinfo import summary


class DomainAwareTransformerEncoder(nn.Module):
    """
    Encoder integrates domain-aware embeddings with a TransformerEncoder.
    """

    def __init__(
        self,
        num_domains: int,
        input_dim: int,
        hidden_dim: int,
        domain_embed_dim: int,
        max_seq_len: int = 512,
        n_heads: int = 4,
        n_layers: int = 2,
        dropout: float = 0.1,
        num_classes: int = 31,
        vocab_size: int = 119547,
        inject_domain_bias: bool = True
    ) -> None:
        super().__init__()
        self.inject_domain_bias = inject_domain_bias
        if self.inject_domain_bias == True:
            self.domain_embedder = DomainAwareEmbedder(
                num_domains=num_domains,
                d_model=hidden_dim,
                d_embed=input_dim,
                n_heads=n_heads,
                max_seq_len=max_seq_len
            )

        self.encoder = TransformerEncoder(
            input_dim=hidden_dim,
            hidden_dim=hidden_dim,
            n_heads=n_heads,
            n_layers=n_layers,
            dropout=dropout
        )
        self.embedding = nn.Embedding(vocab_size, input_dim)#Vocab size
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x: torch.LongTensor, m: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input features of shape (batch_size, seq_len, input_dim)
            m: Domain descriptors of shape (batch_size, seq_len, num_domains)
        Returns:
            Tensor of shape (batch_size, hidden_dim), last token representation
        """

        x = self.embedding(x)
        x_proj = self.input_proj(x)  # (B, S, hidden_dim)
        if self.inject_domain_bias == True:
            domain_context = self.domain_embedder(x, m)  # (B, S, hidden_dim)
            enriched = x_proj + domain_context  # Inject domain bias
        else:
            enriched = x_proj
        enc= self.encoder(enriched)
        return self.classifier(enc)







class TestDomainAwareTransformerEncoder(unittest.TestCase):

    def setUp(self):
        self.batch_size = 2
        self.seq_len = 10
        self.input_dim = 32
        self.hidden_dim = 64
        self.domain_embed_dim = 32
        self.num_domains = 5
        self.num_classes=31

        self.model = DomainAwareTransformerEncoder(
            num_domains=self.num_domains,
            input_dim=self.input_dim,
            hidden_dim=self.hidden_dim,
            domain_embed_dim=self.domain_embed_dim,
            max_seq_len=50
        )

        self.x = torch.randint(0, 35222, (self.batch_size, self.seq_len), dtype=torch.long)
        self.m = torch.randn(self.batch_size, self.seq_len, self.num_domains)

    def test_output_shape(self):
        out = self.model(self.x, self.m)

        self.assertEqual(out.shape, (self.batch_size, self.num_classes))

    def test_embedder_parameters_are_trainable(self):
        embedder_params = list(self.model.domain_embedder.parameters())
        self.assertTrue(any(p.requires_grad for p in embedder_params))
        self.assertTrue(any(p.numel() > 0 for p in embedder_params))

    def test_deterministic_output(self):
        torch.manual_seed(42)
        out1 = self.model(self.x, self.m)
        torch.manual_seed(42)
        out2 = self.model(self.x, self.m)
        self.assertTrue(torch.allclose(out1, out2, atol=1e-5))

    def test_total_trainable_params(self):
        total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        self.assertGreater(total_params, 0)

    def test_visualize_parameters(self):
        model = DomainAwareTransformerEncoder(
        num_domains=5,
        input_dim=32,
        hidden_dim=64,
        domain_embed_dim=32,
        max_seq_len=50
        )


        batch_size = 2
        seq_len = 10
        input_dim = 32
        num_domains = 5

        summary(
            model,
            input_data=(torch.randint(0, 30522, (self.batch_size, self.seq_len), dtype=torch.long),
                        torch.randn(batch_size, seq_len, num_domains)),
            col_names=["input_size", "output_size", "num_params", "trainable"],
            depth=4,
            verbose=1
        )



unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestDomainAwareTransformerEncoder))

.....
----------------------------------------------------------------------
Ran 5 tests in 0.397s

OK


Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Trainable
DomainAwareTransformerEncoder                      [2, 10]                   [2, 31]                   --                        True
├─Embedding: 1-1                                   [2, 10]                   [2, 10, 32]               3,825,504                 True
├─Linear: 1-2                                      [2, 10, 32]               [2, 10, 64]               2,112                     True
├─DomainAwareEmbedder: 1-3                         [2, 10, 32]               [2, 10, 64]               --                        True
│    └─Linear: 2-1                                 [2, 10, 5]                [2, 10, 64]               384                       True
│    └─PositionalEncoding: 2-2                     --                        [1, 10, 64]               --                        --
│    └─Linear: 2-3                                 [2, 10, 

<unittest.runner.TextTestResult run=5 errors=0 failures=0>

##Data

In [16]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


###DataViz

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


class MultilingualDatasetAnalyzer:
    """
    Analyze a multilingual dataset containing reviews, languages, and product categories.
    """

    def __init__(self, df: pd.DataFrame, text_column: str = "review_body",
                 lang_column: str = "language", category_column: str = "product_category") -> None:

        self.df = df.copy()
        self.text_column = text_column
        self.lang_column = lang_column
        self.category_column = category_column

        self._validate_columns()

    def _validate_columns(self) -> None:
        """Ensure required columns are present in the dataset."""
        required = {self.text_column, self.lang_column, self.category_column}
        missing = required - set(self.df.columns)
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

    def compute_language_distribution(self) -> pd.DataFrame:
        """Compute count and percentage of reviews per language."""
        lang_stats = self.df[self.lang_column].value_counts().reset_index()
        lang_stats.columns = ["Language", "Count"]
        lang_stats["Percentage"] = 100 * lang_stats["Count"] / len(self.df)
        return lang_stats

    def unique_categories_per_language(self) -> pd.DataFrame:
        """List unique product categories for each language."""
        return (
            self.df.groupby(self.lang_column)[self.category_column]
            .apply(lambda x: sorted(x.dropna().unique().tolist()))
            .reset_index(name="Unique Categories")
        )

    def sample_review_per_language(self, random_state: int = 42) -> pd.DataFrame:
        """Sample one example review per language."""
        return (
            self.df.groupby(self.lang_column)[self.text_column]
            .apply(lambda x: x.dropna().sample(1, random_state=random_state).values[0]
                   if not x.dropna().empty else "N/A")
            .reset_index(name="Example Review")
        )

    def plot_language_distribution(self) -> None:
        """Visualize the number of reviews per language."""
        stats = self.compute_language_distribution()
        plt.figure(figsize=(10, 6))
        sns.barplot(data=stats, x="Language", y="Count", palette="mako")
        plt.title("Number of Samples per Language")
        plt.ylabel("Review Count")
        plt.xlabel("Language")
        plt.tight_layout()
        plt.show()

    def summarize(self) -> None:
        """Print summary statistics and samples."""
        print("Dataset Summary")
        print(f"Total samples: {len(self.df)}")
        print(f"Languages ({self.df[self.lang_column].nunique()}): {sorted(self.df[self.lang_column].unique())}")

        print("Language Distribution")
        print(self.compute_language_distribution().to_string(index=False))

        print("Unique Categories per Language")
        print(self.unique_categories_per_language().to_string(index=False))

        print("Example Review per Language")
        print(self.sample_review_per_language().to_string(index=False))


df = pd.read_csv("/content/drive/MyDrive/Efrei_Datasets/train.csv")
analyzer = MultilingualDatasetAnalyzer(df, text_column="review_body",
                                        lang_column="language",
                                        category_column="product_category")
analyzer.summarize()
# analyzer.plot_language_distribution()

Dataset Summary
Total samples: 1200000
Languages (6): ['de', 'en', 'es', 'fr', 'ja', 'zh']
Language Distribution
Language  Count  Percentage
      de 200000   16.666667
      en 200000   16.666667
      es 200000   16.666667
      fr 200000   16.666667
      ja 200000   16.666667
      zh 200000   16.666667
Unique Categories per Language
language                                                                                                                                                                                                                                                                                                                                                                        Unique Categories
      de [apparel, automotive, baby_product, beauty, book, camera, digital_ebook_purchase, digital_video_download, drugstore, electronics, furniture, grocery, home, home_improvement, industrial_supplies, jewelry, kitchen, lawn_and_garden, luggage, musical_instruments, offic

###Dataset

In [14]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from typing import List, Dict, Tuple
from collections import defaultdict
import random


class MultilingualDomainDataset(Dataset):
    LANGUAGES = ['en', 'fr', 'de']
    DOMAIN_MAP = {
        'en': [1, 0, 0],
        'fr': [0, 1, 0],
        'de': [0, 0, 1]
    }

    def __init__(self, df: pd.DataFrame, tokenizer_name: str, label2id: Dict[str, int],
                 text_column: str = "review_body", lang_column: str = "language",
                 label_column: str = "product_category", max_length: int = 128):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.label2id = label2id
        self.text_column = text_column
        self.lang_column = lang_column
        self.label_column = label_column
        self.max_length = max_length
        self.vocab_size = self.tokenizer.vocab_size

        # desired languages
        self.samples = [
            (row[text_column], row[lang_column], row[label_column])
            for _, row in df.iterrows()
            if row[lang_column] in self.LANGUAGES and pd.notna(row[label_column])
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        text, lang, label = self.samples[idx]
        encoded = self.tokenizer(text, padding="max_length", truncation=True,
                                 max_length=self.max_length, return_tensors="pt")

        seq_len = encoded["input_ids"].shape[1]
        domain_vector = torch.tensor(self.DOMAIN_MAP[lang], dtype=torch.float)
        domain_matrix = domain_vector.repeat(seq_len, 1)

        return {
            "input_ids": encoded["input_ids"].squeeze(0),           # (seq_len)
            "attention_mask": encoded["attention_mask"].squeeze(0), # (seq_len)
            "domain_embedding": domain_matrix,                      # (seq_len, 3)
            "language": lang,
            "label": torch.tensor(self.label2id[label], dtype=torch.long)

         }
    def get_vocab_size(self):
        return self.vocab_size


def multilingual_batch_sampler(dataset: MultilingualDomainDataset, batch_size: int) -> List[List[int]]:
    indices_by_lang = defaultdict(list)
    for idx, (_, lang, _) in enumerate(dataset.samples):
        indices_by_lang[lang].append(idx)

    for lang in indices_by_lang:
        random.shuffle(indices_by_lang[lang])

    all_batches = []
    for lang, indices in indices_by_lang.items():
        for i in range(0, len(indices), batch_size):
            batch = indices[i:i + batch_size]
            if len(batch) == batch_size:
                all_batches.append(batch)

    random.shuffle(all_batches)
    return all_batches


def multilingual_collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "domain_embedding": torch.stack([item["domain_embedding"] for item in batch]),
        "labels": torch.stack([item["label"] for item in batch]),
        "language": batch[0]["language"]
    }


def create_label_mapping(*dfs: List[pd.DataFrame], label_column: str = "product_category") -> Dict[str, int]:
    labels = set()
    for df in dfs:
        labels.update(df[label_column].dropna().unique())
    return {label: idx for idx, label in enumerate(sorted(labels))}


def build_dataloaders(train_path: str, val_path: str, test_path: str,
                      tokenizer_name: str, batch_size: int = 16, max_length: int = 128):

    df_train = pd.read_csv(train_path)
    df_val = pd.read_csv(val_path)
    df_test = pd.read_csv(test_path)


    label2id = create_label_mapping(df_train, df_val, df_test)


    train_dataset = MultilingualDomainDataset(df_train, tokenizer_name, label2id, max_length=max_length)
    val_dataset = MultilingualDomainDataset(df_val, tokenizer_name, label2id, max_length=max_length)
    test_dataset = MultilingualDomainDataset(df_test, tokenizer_name, label2id, max_length=max_length)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_sampler=multilingual_batch_sampler(train_dataset, batch_size),
        collate_fn=multilingual_collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_sampler=multilingual_batch_sampler(val_dataset, batch_size),
        collate_fn=multilingual_collate_fn
    )

    test_loader = DataLoader(
        test_dataset,
        batch_sampler=multilingual_batch_sampler(test_dataset, batch_size),
        collate_fn=multilingual_collate_fn
    )
    vocab_size = train_dataset.get_vocab_size()
    print(f"Vocab size: {vocab_size}")

    return train_loader, val_loader, test_loader, label2id,vocab_size


In [17]:
train_loader, val_loader, test_loader, label2id,vocab_size = build_dataloaders(
    train_path="/content/drive/MyDrive/Efrei_Datasets/train.csv",
    val_path="/content/drive/MyDrive/Efrei_Datasets/validation.csv",
    test_path="/content/drive/MyDrive/Efrei_Datasets/test.csv",
    tokenizer_name="bert-base-multilingual-cased",
    batch_size=16,
    max_length=128
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

Vocab size: 119547


In [28]:
batch = next(iter(train_loader))
print(batch["input_ids"])        # (B, seq_len)
print(batch["domain_embedding"].shape) # (B, seq_len, 3)
print(batch["labels"])                 # (B,)
print(batch["language"])               # string: en / fr / de


tensor([[  101, 37282, 61470,  ...,     0,     0,     0],
        [  101, 12944, 93151,  ...,     0,     0,     0],
        [  101,   146, 11850,  ...,     0,     0,     0],
        ...,
        [  101,   146, 10392,  ...,   119, 54690,   102],
        [  101, 12936, 91415,  ...,     0,     0,     0],
        [  101, 10747, 11897,  ...,     0,     0,     0]])
torch.Size([16, 128, 3])
tensor([12, 20,  0,  1, 26, 13,  8, 20, 21, 12, 13, 16,  1,  3, 21, 13])
en


##Trainer