In [1]:
import torch
from transformers import AutoTokenizer, AutoModel

from src.transformer_wrapper import TransformerWrapper

from src.dataset_utils import prepare_df

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = prepare_df(dataset_path="data/TRAIN_RELEASE_3SEP2025/train_subtask1.csv")

**All Usage Examples**

In [None]:
tokenizer_path = "bert-base-uncased"
model_path = "bert-base-uncased"
texts = ["Hello world!", "Transformers are amazing for NLP."]

Single embeddings: torch.Size([2, 768])
Grouped embeddings: torch.Size([2, 8, 768])
Trainable grouped embeddings: torch.Size([2, 8, 768])


**1. Single Embedding (Mean / Max / CLS)**

In [None]:
wrapper = TransformerWrapper(tokenizer_path, model_path)

# Mean pooling
mean_embeds = wrapper.encode(texts, pooling="mean")
print(mean_embeds.shape)  # (2, 768)

# Max pooling
max_embeds = wrapper.encode(texts, pooling="max")
print(max_embeds.shape)  # (2, 768)

# CLS token
cls_embeds = wrapper.encode(texts, pooling="cls")
print(cls_embeds.shape)  # (2, 768)

torch.Size([2, 768])
torch.Size([2, 768])
torch.Size([2, 768])


**Explanation**:
- Each text is reduced to a single vector.
- Mean/max ignore padding tokens.
- CLS just takes the [CLS] token

**2. Grouped Embedding with Conv1D Pooling**

In [5]:
wrapper_conv = TransformerWrapper(
    tokenizer_path,
    model_path,
    use_conv_pooling=True,
    n_groups=4
)

grouped_embeds = wrapper_conv.encode_grouped(texts)
print(grouped_embeds.shape)  # (2, 4, 768)

torch.Size([2, 4, 768])


**Explanation**:

- Sequence split into 4 windows.
- Depthwise Conv1D pools each embedding dimension independently.
- Output: 4 embeddings per text, each summarizing a portion of the sequence.

**2a. With Overlapping Windows**

In [None]:
wrapper_conv_overlap = TransformerWrapper(
    tokenizer_path,
    model_path,
    use_conv_pooling=True,
    n_groups=4,
    overlap_pooling=True
)

grouped_embeds_overlap = wrapper_conv_overlap.encode_grouped(texts)
print(grouped_embeds_overlap.shape)  # (2, 4, 768)

torch.Size([2, 4, 768])


**Explanation**:

- Stride = kernel_size // 2 -> windows overlap by 50%.
- Adaptive pooling ensures output is exactly n_groups.

**3. Trainable Weighted Pooling**

In [7]:
wrapper_trainable = TransformerWrapper(
    tokenizer_path,
    model_path,
    use_trainable_pooling=True,
    n_groups=4
)

wrapper_trainable.set_training_mode("pooling_only")
grouped_trainable = wrapper_trainable.encode_grouped(texts, training=True)
print(grouped_trainable.shape)  # (2, 4, 768)

torch.Size([2, 4, 768])


**Explanation**:

- Learns token-level weights for each group.
- Transformer frozen, only pooling weights are trainable.
- Mask-aware: padding tokens contribute 0.

In [None]:
# -----------------------------
# Example: Depthwise convolution pooling
# -----------------------------
wrapper_conv = TransformerWrapper(
    tokenizer_path,
    model_path,
    use_conv_pooling=True,   # Enable depthwise Conv1D pooling
    n_groups=4,              # Reduce token dimension to 4 pooled embeddings
    overlap_pooling=False    # Non-overlapping windows
)

# Put in pooling-only training mode (transformer frozen)
wrapper_conv.set_training_mode("pooling_only")

# Encode texts using convolution pooling
grouped_conv = wrapper_conv.encode_grouped(texts, training=True)

print(grouped_conv.shape)  # (2, 4, 768)

torch.Size([2, 4, 768])
