In [25]:
import torch
from src.data.utils                 import setup_dataloader
from src.models.transformer_wrapper import TransformerWrapper

In [21]:
# DATA
TRAIN_PATH = "data/TRAIN_RELEASE_3SEP2025/train_subtask1.csv"

# TRANSFOMER
TOKENIZER_PATH = "bert-base-uncased"
MODEL_PATH     = "bert-base-uncased"

In [22]:
train_loader, train_dataset = setup_dataloader(
    csv_path=TRAIN_PATH,
    tokenizer_path=TOKENIZER_PATH,
    max_text_length=512,
    batch_size=4,
    shuffle=True
)

In [23]:
transfomer = TransformerWrapper(
    model_path=MODEL_PATH,
    device=None,
    use_attention_grouped_pooling=True,
    n_groups=4,
)

transfomer.set_training_mode(mode="pooling_only")

Initialized attention-based grouped pooling: 4 groups


In [26]:
for batch in train_loader:
    
    B, S, T = batch['input_ids'].shape
    mask = batch["seq_attention_mask"].bool()  # [B, S]
    
    # Flatten valid sequences
    inputs = {
        'input_ids': batch["input_ids"][mask],          # [N_valid, T]
        'attention_mask': batch["attention_mask"][mask]  # [N_valid, T]
    }
    
    # Get embeddings
    embeddings_flat = transfomer.encode_grouped(inputs)  # [N_valid, G, H]
    N_valid, G, H = embeddings_flat.shape
    
    # Reconstruct padded tensor for RNN
    embeddings = torch.zeros(B, S, G, H, device=embeddings_flat.device, dtype=embeddings_flat.dtype)
    embeddings[mask] = embeddings_flat  # Scatter back
    
    print(f"embeddings shape: {embeddings.shape}")  # [B, S, G, H]
    
    # Now ready for RNN
    # If your RNN expects [B, S, features], flatten G and H:
    # embeddings_rnn = embeddings.view(B, S, G * H)  # [B, S, G*H]
    
    break

embeddings shape: torch.Size([4, 19, 4, 768])


In [27]:
batch["input_ids"].shape

torch.Size([4, 19, 512])

In [None]:
inputs["input_ids"].shape, embeddings.shape

(torch.Size([47, 512]), torch.Size([47, 4, 768]))

torch.Size([47, 4, 768])

In [None]:
embeddings.shape

torch.Size([68, 4, 768])

In [None]:
batch_input_ids.shape

torch.Size([2, 7, 512])

In [None]:
embeddings.shape

torch.Size([11, 4, 768])