# Joint Aligment Models

In [36]:
import os, re, string, unicodedata
import torch
import torch.nn as nn, torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoFeatureExtractor

## Data Loading

In [26]:
from datasets import load_dataset

In [27]:
def preprocess_sentence(examples):
    """
    Preprocess the sentence column (batch):
    1. Remove [UNK] tokens
    2. Remove punctuation
    3. Normalize text (lowercase, whitespace, unicode normalization)
    4. Tokenize with multilingual-e5-small tokenizer
    """
    texts = examples['sentence']
    
    # Process all texts
    processed_texts = []
    for text in texts:
        # Remove [UNK] tokens
        text = re.sub(r'\[UNK\]', '', text)
        # Remove punctuation
        text = text.translate(str.maketrans('', '', string.punctuation))
        # Unicode normalization (NFKD)
        text = unicodedata.normalize('NFKD', text)
        # Convert to lowercase
        text = text.lower()
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        processed_texts.append(text)
    
    return {'gt-transcription': processed_texts}

In [28]:
# Base directory for the dataset
base = "/mnt/h/"

# HuggingFace caches
os.environ["HF_DATASETS_CACHE"] = f"{base}/datasets"
os.environ["HF_HOME"] = base
os.environ["TRANSFORMERS_CACHE"] = f"{base}/models"

# Base directory for the dataset
orig_ds = load_dataset('ddamianos/hparl',
                       cache_dir=os.environ["HF_DATASETS_CACHE"])


# Remove unnecessary columns
for split in orig_ds:
    orig_ds[split] = orig_ds[split].map(
        preprocess_sentence,
        batched=True,
        batch_size=256
    ).remove_columns(
        [col for col in orig_ds[split].column_names if col not in ['audio', 'sentence']]
    )
    
# Split train and test
test_ds = orig_ds['test']
train_ds = orig_ds['train']

# Diplay some samples
train_ds.to_pandas().head()

Map: 100%|██████████| 8679/8679 [00:33<00:00, 259.27 examples/s]
Map: 100%|██████████| 76341/76341 [05:03<00:00, 251.30 examples/s]


Unnamed: 0,sentence,audio,gt-transcription
0,[UNK] και συναδελφοι θα προσπαθησω να μη γινω ...,"{'array': [-0.0051879883, -0.0065307617, 0.007...",και συναδελφοι θα προσπαθησω να μη γινω και εγ...
1,[UNK] οικονομολογος διοτι δεν ειμαι οικονομολο...,"{'array': [-0.1210022, -0.08911133, -0.0612792...",οικονομολογος διοτι δεν ειμαι οικονομολογος οπ...
2,θα αναφερω για τον προ [UNK],"{'array': [0.10360718, 0.089660645, 0.08279419...",θα αναφερω για τον προ
3,[UNK] γνωστα κοινα σε ολους και κοινης,"{'array': [-0.08016968, -0.0892334, -0.1252441...",γνωστα κοινα σε ολους και κοινης
4,αποδοχης η κυβερνηση λοιπον παρουσιαζει τον,"{'array': [0.0038146973, -0.0016174316, -0.004...",αποδοχης η κυβερνηση λοιπον παρουσιαζει τον


## Models

### Helper Functions

In [33]:
def mean_pooling(hidden_state, mask):
    '''
    Mean pooling with an attention mask.

    Computes a mask-aware mean over the temporal dimension of a sequence of 
    hidden states. Only positions where the mask is 1 contribute to the mean.

    Parameters
    ----------
    hidden_state : torch.Tensor
        Input tensor of shape ``[B, T, D]`` containing the sequence of hidden
        representations (e.g., token embeddings from a transformer).

    mask : torch.Tensor
        Attention mask of shape ``[B, T]`` with values in ``{0, 1}``.
        Positions with value ``1`` are included in the pooling operation,
        while positions with value ``0`` are ignored.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``[B, D]`` containing the mean-pooled representations
        for each example in the batch.
    '''
    if mask.dim() == 2:
        mask = mask.unsqueeze(-1)
    masked_hidden = hidden_state * mask
    sum_hidden = masked_hidden.sum(dim=1)
    lengths = mask.sum(dim=1).clamp(min=1e-9)
    pooled = sum_hidden / lengths
    return pooled

### CNN Aligner

In [34]:
class CnnAligner(nn.Module):
    '''
    CNN Aligner with multi-kernel convolutions and temporal downsampling.

    input:  speech_embs [B, T, speech_dim]
    output: aligned_embs [B, T', text_dim]
            pooled_mask  [B, T']
    '''

    def __init__(self,
                 speech_dim: int = 384,
                 text_dim: int = 384,
                 hidden_dim: int = 256,
                 kernel_sizes=(3, 5, 7),
                 num_layers: int = 2,
                 pool_stride: int = 2,
                 dropout: float = 0.1,
                 **kwargs):

        super().__init__()

        speech_dim = int(speech_dim)
        text_dim = int(text_dim)
        hidden_dim = int(hidden_dim)
        kernel_sizes = [int(k) for k in kernel_sizes]

        self.kernel_sizes = kernel_sizes
        self.pool_stride = pool_stride

        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        in_channels = speech_dim

        for _ in range(num_layers):
            layer_convs = nn.ModuleList([
                nn.Conv1d(
                    in_channels=in_channels,
                    out_channels=hidden_dim,
                    kernel_size=k,
                    padding=k // 2
                )
                for k in kernel_sizes
            ])
            self.convs.append(layer_convs)
            self.norms.append(nn.LayerNorm(hidden_dim * len(kernel_sizes)))
            in_channels = hidden_dim * len(kernel_sizes)

        self.pool = nn.MaxPool1d(
            kernel_size=pool_stride,
            stride=pool_stride
        )

        self.proj = nn.Conv1d(in_channels, text_dim, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.GELU()

        self.speech_dim = speech_dim
        self.text_dim = text_dim

    def forward(self, speech_embs, attn_mask):
        """
        speech_embs: [B, T, speech_dim]
        attn_mask:   [B, T]
        """

        # [B, C, T]
        x = speech_embs.transpose(1, 2)
        mask = attn_mask

        for convs, norm in zip(self.convs, self.norms):
            # Multi-kernel conv
            feats = [self.act(conv(x)) for conv in convs]
            x = torch.cat(feats, dim=1)   # [B, C', T]
            x = self.dropout(x)

            # Pool in time
            x = self.pool(x)
            mask = self._pool_mask(mask)

            # LayerNorm over channels
            x = x.transpose(1, 2)          # [B, T', C']
            x = norm(x)
            x = x.transpose(1, 2)          # [B, C', T']

        x = self.proj(x)                  # [B, text_dim, T']
        x = x.transpose(1, 2)             # [B, T', text_dim]

        return x, mask

    def _pool_mask(self, mask):
        """
        Downsample attention mask consistently with temporal pooling.
        """
        if mask is None:
            return None

        mask = mask.unsqueeze(1).float()  # [B, 1, T]
        mask = F.max_pool1d(
            mask,
            kernel_size=self.pool_stride,
            stride=self.pool_stride,
        )
        return mask.squeeze(1).long() # [B, T']

### Full Aligner Model

In [38]:
class AlignmentModel(nn.Module):
    """
    Model for learning a joint representation between speech segments and text.

    Initializes a frozen Whisper-tiny encoder (speech) and a frozen
    multilingual-e5-small encoder (text). Neither backbone is trainable.
    """

    def __init__(self,
                 hidden_dim=256,
                 kernel_sizes=(3, 5, 7), 
                 num_layers=2,
                 pool_stride=2, 
                 dropout=0.1,
                 init_tau=0.07):
        super().__init__()

        # ---- Speech encoder (Whisper-tiny) ----
        self.speech_encoder = AutoModel.from_pretrained("openai/whisper-tiny")
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
        for param in self.speech_encoder.parameters():
            param.requires_grad = False

        # ---- Text encoder (multilingual-e5-small) ----
        self.text_encoder = AutoModel.from_pretrained("intfloat/multilingual-e5-small")
        self.tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-small")
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        # ---- Aligner ----
        self.aligner = CnnAligner(
            speech_dim=self.speech_encoder.config.d_model,
            text_dim=self.text_encoder.config.hidden_size,
            hidden_dim=hidden_dim,
            kernel_sizes=kernel_sizes,
            num_layers=num_layers,
            pool_stride=pool_stride,
            dropout=dropout)

        # ---- Learnable temperature ----
        self.log_tau = nn.Parameter(
            torch.log(torch.tensor(init_tau, dtype=torch.float32))
        )

    def preprocess(self, audio, text):
        """
        Preprocess raw audio and text into model-ready tensors.

        Parameters
        ----------
        audio : dict
            Batched audio dictionary with keys:
              - 'array'         : list of 1-D waveform arrays, one per sample
              - 'sampling_rate' : int, sampling rate shared by all samples
        text : list[str]
            List of input strings, one per sample.

        Returns
        -------
        input_features : Tensor [B, n_mels, T]
            Mel-spectrogram features for the speech encoder.
        tok : dict[str, Tensor]
            Tokenized text inputs (input_ids, attention_mask, etc.).
        """
        device = next(self.parameters()).device

        features = self.feature_extractor(
            audio['array'],
            sampling_rate=audio['sampling_rate'],
            return_tensors='pt',
            padding='max_length',
        )
        input_features = features.input_features.to(device)

        tok = self.tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
        )
        tok = {k: v.to(device) for k, v in tok.items()}

        return input_features, tok

    def forward(self, audio, text):
        """
        Parameters
        ----------
        audio : dict
            Batched audio dictionary with keys:
              - 'array'         : list of 1-D waveform arrays, one per sample
              - 'sampling_rate' : int, sampling rate shared by all samples
        text : list[str]
            List of input strings, one per sample (same length as audio['array']).

        Returns
        -------
        dict with keys:
          - 'loss'       : scalar, symmetric InfoNCE contrastive loss
          - 'logits'     : Tensor [B, B], cosine-similarity logits scaled by τ
          - 'speech_emb' : Tensor [B, D], L2-normalised speech embeddings
          - 'text_emb'   : Tensor [B, D], L2-normalised text embeddings
        """
        device = next(self.parameters()).device
        input_features, tok = self.preprocess(audio, text)

        # ---- Speech branch ----
        speech_hidden = self.speech_encoder.encoder(
            input_features
        ).last_hidden_state  # [B, T', speech_dim]

        speech_mask = torch.ones(
            speech_hidden.shape[:2], dtype=torch.long, device=device
        )

        aligned_speech, aligned_mask = self.aligner(speech_hidden, speech_mask)
        speech_emb = mean_pooling(aligned_speech, aligned_mask)  # [B, D]

        # ---- Text branch ----
        text_hidden = self.text_encoder(**tok).last_hidden_state  # [B, T, D]
        text_emb = mean_pooling(text_hidden, tok['attention_mask'])  # [B, D]

        # ---- Logits & loss ----
        speech_emb = F.normalize(speech_emb, dim=-1)
        text_emb = F.normalize(text_emb, dim=-1)

        tau = torch.exp(self.log_tau)
        logits = (speech_emb @ text_emb.T) / tau  # [B, B]

        labels = torch.arange(logits.size(0), device=device)
        loss = 0.5 * (
            F.cross_entropy(logits, labels)
            + F.cross_entropy(logits.T, labels)
        )

        return {
            'loss': loss,
            'logits': logits,
            'speech_emb': speech_emb,
            'text_emb': text_emb,
        }


In [53]:
import lightning as L
from torch.utils.data import Dataset, DataLoader


class AlignmentLitModule(L.LightningModule):
    """Lightning wrapper around AlignmentModel."""

    def __init__(self, lr=1e-4, weight_decay=0.0, **model_kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model = AlignmentModel(**model_kwargs)

    def forward(self, audio, text):
        return self.model(audio, text)

    def _shared_step(self, batch, stage):
        audio = {
            'array': [a['array'] for a in batch['audio']],
            'sampling_rate': batch['audio'][0]['sampling_rate'],
        }
        text = batch['sentence']
        out = self.model(audio, text)
        self.log(f'{stage}_loss', out['loss'], prog_bar=True, batch_size=len(text))
        return out

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, 'train')['loss']

    def validation_step(self, batch, batch_idx):
        self._shared_step(batch, 'val')

    def configure_optimizers(self):
        # Only train aligner + temperature (backbones are frozen)
        params = list(self.model.aligner.parameters()) + [self.model.log_tau]
        return torch.optim.AdamW(params, lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)


In [54]:
class SpeechTextDataModule(L.LightningDataModule):
    """Lightning DataModule wrapping a HuggingFace DatasetDict with 'audio' and 'sentence' columns."""

    def __init__(self, train_ds, val_ds, batch_size=8, num_workers=4):
        super().__init__()
        self.train_ds = train_ds
        self.val_ds = val_ds
        self.batch_size = batch_size
        self.num_workers = num_workers

    @staticmethod
    def _collate(batch):
        """Keep raw dicts — AlignmentModel.preprocess handles conversion."""
        return {
            'audio': [item['audio'] for item in batch],
            'sentence': [item['sentence'] for item in batch],
        }

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=self._collate,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self._collate,
            pin_memory=True,
        )


In [None]:
# ---- Instantiate modules ----
lit_model = AlignmentLitModule(lr=1e-4)

dm = SpeechTextDataModule(
    train_ds=train_ds,
    val_ds=test_ds,
    batch_size=8,
    num_workers=4,
)

# ---- Trainer ----
trainer = L.Trainer(
    max_epochs=10,
    accelerator='auto',
    precision='16-mixed',
    log_every_n_steps=10,
    val_check_interval=0.25,  # validate 4 times per epoch
)

trainer.fit(lit_model, datamodule=dm)
