In [None]:
import torch
from src.models.tokenizer_wrapper   import TokenizerWrapper
from src.models.transformer_wrapper import TransformerWrapper

In [None]:
# Example texts
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Transformers provide state-of-the-art embeddings for NLP tasks."
]

In [None]:
tokenizer_path = "bert-base-uncased"
model_path = "bert-base-uncased"
tokenizer = TokenizerWrapper(path=tokenizer_path, max_len=512)

inputs = tokenizer(texts=texts)

{'input_ids': tensor([[  101,  1996,  4248,  ...,     0,     0,     0],
        [  101, 19081,  3073,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [None]:


# -----------------------------
# 1. Standard single embeddings
# -----------------------------
wrapper = TransformerWrapper(tokenizer_path, model_path)
wrapper.set_training_mode("inference")  # no gradients

# Mean pooling
emb_mean = wrapper.encode(inputs, pooling="mean")
print("Mean pooling shape:", emb_mean.shape)  # (2, hidden_size)

# Max pooling
emb_max = wrapper.encode(inputs, pooling="max")
print("Max pooling shape:", emb_max.shape)

# CLS token
emb_cls = wrapper.encode(inputs, pooling="cls")
print("CLS token shape:", emb_cls.shape)

# -----------------------------
# 2. Learnable attention pooling
# -----------------------------
wrapper_attn = TransformerWrapper(tokenizer_path, model_path, use_learnable_single_pooling=True)
wrapper_attn.set_training_mode("pooling_only")  # train pooling only
emb_attn = wrapper_attn.encode(inputs, pooling="learnable_attention")
print("Learnable attention shape:", emb_attn.shape)

# -----------------------------
# 3. Grouped embeddings (split-based)
# -----------------------------
wrapper_grouped = TransformerWrapper(tokenizer_path, model_path, n_groups=4)
wrapper_grouped.set_training_mode("inference")
emb_groups = wrapper_grouped.encode_grouped(inputs, pooling="mean")
print("Grouped embeddings (split) shape:", emb_groups.shape)  # (2, 4, hidden_size)

# -----------------------------
# 4. Grouped embeddings (Conv1D + adaptive pooling)
# -----------------------------
wrapper_conv = TransformerWrapper(tokenizer_path, model_path, 
                                  use_conv_grouped_pooling=True, 
                                  n_groups=4, conv_kernel_size=3)
wrapper_conv.set_training_mode("pooling_only")
emb_conv = wrapper_conv.encode_grouped(inputs)
print("Grouped embeddings (Conv1D) shape:", emb_conv.shape)

# -----------------------------
# 5. Use in an LSTM pipeline
# -----------------------------
# Suppose you want to feed grouped embeddings to an LSTM
lstm_input = emb_groups  # shape (B, n_groups, H)
lstm = torch.nn.LSTM(input_size=lstm_input.size(2), hidden_size=128, batch_first=True)
lstm_out, (h_n, c_n) = lstm(lstm_input)
print("LSTM output shape:", lstm_out.shape)  # (B, n_groups, hidden_size=128)


RuntimeError: Invalid device string: 'bert-base-uncased'