# Domain Separation Network (DSN) Pipeline
This notebook implements the DSN architecture for unsupervised domain adaptation in speech recognition.

In [5]:
import importlib.metadata
import editdistance, jiwer

print("jiwer version:", importlib.metadata.version("jiwer"))
print("editdistance version:", importlib.metadata.version("editdistance"))
print("editdistance available:", editdistance.eval("kitten", "sitting"))


jiwer version: 4.0.0
editdistance version: 0.8.1
editdistance available: 3


In [6]:
# ─────────────────────────────────────────────────────────────────────────────
# 1) Imports & device
# ─────────────────────────────────────────────────────────────────────────────
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

device = torch.device("cpu")
print("Using device:", device)



Using device: cpu


In [7]:
# ─────────────────────────────────────────────────────────────────────────────
# 2) Collate fn for variable-length utterances (pad to max-T per batch)
# ─────────────────────────────────────────────────────────────────────────────
def collate_fn(batch):
    """
    Handles variable-length feature sequences (and optional labels).

    batch:
      - source: list of (features[T,F], labels[T]) tuples
      - target: list of features[T,F]

    Returns:
      source -> (padded_feats[B,T,F], padded_labels[B,T], lengths[B])
      target -> (padded_feats[B,T,F], lengths[B])
    """
    if isinstance(batch[0], tuple):  # source (features, labels)
        feats, labels = zip(*batch)  # each feats: [T,F], labels: [T]
        lengths = torch.tensor([f.size(0) for f in feats], dtype=torch.long)
        padded_feats = pad_sequence(feats, batch_first=True)               # [B,T,F]
        padded_labels = pad_sequence(labels, batch_first=True, padding_value=-100)  # [B,T]
        return padded_feats, padded_labels, lengths
    else:  # target (features only)
        feats = batch
        lengths = torch.tensor([f.size(0) for f in feats], dtype=torch.long)
        padded_feats = pad_sequence(feats, batch_first=True)               # [B,T,F]
        return padded_feats, lengths


In [8]:
# ─────────────────────────────────────────────────────────────────────────────
# 3) Lazy .npy datasets (per-utterance files), optional labels
# ─────────────────────────────────────────────────────────────────────────────
class LazyNPYDataset(Dataset):
    """
    Lazy loading dataset for .npy feature files and optional labels.
    Assumes: one .npy per utterance. If labels provided, filenames match.
    """
    def __init__(self, feature_dir, label_dir=None, feature_dtype=np.float32):
        self.feature_dir = feature_dir
        self.label_dir = label_dir
        self.file_list = sorted([f for f in os.listdir(feature_dir) if f.endswith(".npy")])
        self.feature_dtype = feature_dtype

        if label_dir:
            self.label_list = sorted([f for f in os.listdir(label_dir) if f.endswith(".npy")])
            assert len(self.file_list) == len(self.label_list), \
                f"Mismatch: {len(self.file_list)} feature files vs {len(self.label_list)} label files"
            # Optional: verify same basenames
            for f_feat, f_lab in zip(self.file_list, self.label_list):
                assert os.path.splitext(f_feat)[0] == os.path.splitext(f_lab)[0], \
                    f"Filename mismatch: {f_feat} vs {f_lab}"

    def __len__(self):
        return len(self.file_list)

    def _load_npy(self, path, dtype=None):
        arr = np.load(path, allow_pickle=False)
        if dtype is not None and arr.dtype != dtype:
            arr = arr.astype(dtype, copy=False)
        return arr

    def __getitem__(self, idx):
        feat_path = os.path.join(self.feature_dir, self.file_list[idx])
        features = self._load_npy(feat_path, dtype=self.feature_dtype)     # [T,F]
        features = torch.as_tensor(features, dtype=torch.float32)

        if self.label_dir:
            label_path = os.path.join(self.label_dir, self.label_list[idx])
            labels = self._load_npy(label_path)                             # [T] (per-frame senone id)
            labels = torch.as_tensor(labels, dtype=torch.long)
            return features, labels
        else:
            return features


In [9]:
# ─────────────────────────────────────────────────────────────────────────────
# 4) Model components
# ─────────────────────────────────────────────────────────────────────────────
def _mlp_stack(in_dim, hidden_dim, num_layers):
    layers = []
    for _ in range(num_layers):
        layers += [nn.Linear(in_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU()]
        in_dim = hidden_dim
    return nn.Sequential(*layers), in_dim

def build_private_encoder(input_dim=1320, hidden_dim=512, num_layers=4):
    seq, _ = _mlp_stack(input_dim, hidden_dim, num_layers)
    return seq  # outputs [N, hidden_dim]

def build_shared_encoder(input_dim=1320, hidden_dim=1024, num_layers=6):
    seq, out_dim = _mlp_stack(input_dim, hidden_dim, num_layers)
    assert out_dim == hidden_dim
    return seq  # outputs [N, hidden_dim]

def build_decoder(bottleneck_dim=1024, output_dim=1320, num_layers=3):
    layers = []
    in_dim = bottleneck_dim
    for _ in range(num_layers):
        layers += [nn.Linear(in_dim, in_dim), nn.BatchNorm1d(in_dim), nn.ReLU()]
    layers += [nn.Linear(in_dim, output_dim)]
    return nn.Sequential(*layers)

def build_classifier(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        nn.Linear(1024, output_dim)
    )

def build_domain_classifier(input_dim, hidden_dim=256):
    return nn.Sequential(
        nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(),
        nn.Linear(hidden_dim, 2)  # 0=source, 1=target
    )


In [10]:
# ─────────────────────────────────────────────────────────────────────────────
# 5) Losses
# ─────────────────────────────────────────────────────────────────────────────
def diff_loss(private_feat, shared_feat):
    """
    L_diff: encourage orthogonality between private and shared spaces.
    private_feat, shared_feat: [N, Dp], [N, Ds]
    """
    # Use normalized Frobenius norm of cross-covariance
    batch = private_feat.size(0)
    return torch.norm(private_feat.T @ shared_feat, p='fro') / max(batch, 1)

def recon_loss(original, reconstructed):
    """L_recon: reconstruct input from (shared + private)."""
    return nn.functional.mse_loss(reconstructed, original, reduction="mean")


In [11]:
# ─────────────────────────────────────────────────────────────────────────────
# 6) DSN wrapper
# ─────────────────────────────────────────────────────────────────────────────
class DSN(nn.Module):
    def __init__(self, input_dim=1320, num_senones=3080, shared_dim=1024):
        super().__init__()
        # private encoders same output dim as shared encoder
        self.private_s = build_private_encoder(input_dim, shared_dim, num_layers=4)
        self.private_t = build_private_encoder(input_dim, shared_dim, num_layers=4)
        self.shared = build_shared_encoder(input_dim, shared_dim, num_layers=6)
        self.decoder = build_decoder(shared_dim, input_dim, num_layers=3)

        self.senone_classifier = build_classifier(shared_dim, num_senones)
        self.domain_classifier = build_domain_classifier(shared_dim)

    def forward(self, x, mode='source'):
        """
        x: [N, F]
        mode: 'source' or 'target'
        """
        private = self.private_s(x) if mode == 'source' else self.private_t(x)
        shared = self.shared(x)
        recon = self.decoder(shared + private)

        senone_logits = self.senone_classifier(shared) if mode == 'source' else None
        domain_logits = self.domain_classifier(shared)

        return {
            "private": private,
            "shared": shared,
            "recon": recon,
            "senone_logits": senone_logits,
            "domain_logits": domain_logits
        }


In [12]:
# ─────────────────────────────────────────────────────────────────────────────
# 7) DataLoaders & training
# ─────────────────────────────────────────────────────────────────────────────
def make_loader(dataset, batch_size=32, shuffle=True):
    return DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle,
        pin_memory=True, collate_fn=collate_fn
    )

def get_dataloaders(path1, path2, sanskrit_train, sanskrit_test, batch_size=32):
    src_train = LazyNPYDataset(path1, path2)     # (x,y)
    tgt_train = LazyNPYDataset(sanskrit_train)   # x only
    tgt_test  = LazyNPYDataset(sanskrit_test)    # x only

    return (
        make_loader(src_train, batch_size=batch_size, shuffle=True),
        make_loader(tgt_train, batch_size=batch_size, shuffle=True),
        make_loader(tgt_test,  batch_size=batch_size, shuffle=False),
    )

def flatten_batch_time(x_bt_f):
    """[B,T,F] -> [B*T, F]"""
    return x_bt_f.view(-1, x_bt_f.size(-1))

def flatten_labels(y_bt):
    """[B,T] -> [B*T]"""
    return y_bt.reshape(-1)

def train_dsn(model, src_loader, tgt_loader, num_epochs=20,
              alpha=0.25, beta=0.075, gamma=0.1, lr=0.01):
    """
    Total loss = L_cls + alpha*L_domain + beta*L_diff + gamma*L_recon
    """
    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    ce = nn.CrossEntropyLoss(ignore_index=-100)  # ignore padded labels

    for epoch in range(1, num_epochs + 1):
        model.train()
        running = 0.0
        steps = 0

        # zip will stop at the shorter loader (OK for UDA)
        for (src_x, src_y, _), (tgt_x, _) in zip(src_loader, tgt_loader):
            src_x = src_x.to(device, non_blocking=True)  # [B,T,F]
            src_y = src_y.to(device, non_blocking=True)  # [B,T]
            tgt_x = tgt_x.to(device, non_blocking=True)  # [B,T,F]

            # flatten to per-frame
            src_xf = flatten_batch_time(src_x)  # [N,F]
            tgt_xf = flatten_batch_time(tgt_x)  # [N,F]
            src_yf = flatten_labels(src_y)      # [N]

            # forward
            out_s = model(src_xf, mode='source')
            out_t = model(tgt_xf, mode='target')

            # losses
            l_cls  = ce(out_s["senone_logits"], src_yf)
            dom_s  = torch.zeros(src_xf.size(0), dtype=torch.long, device=device)
            dom_t  = torch.ones (tgt_xf.size(0), dtype=torch.long, device=device)
            l_dom  = ce(out_s["domain_logits"], dom_s) + ce(out_t["domain_logits"], dom_t)
            l_diff = diff_loss(out_s["private"], out_s["shared"]) + diff_loss(out_t["private"], out_t["shared"])
            l_rec  = recon_loss(src_xf, out_s["recon"]) + recon_loss(tgt_xf, out_t["recon"])

            loss = l_cls + alpha*l_dom + beta*l_diff + gamma*l_rec

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running += loss.item()
            steps += 1

        print(f"Epoch {epoch:02d}/{num_epochs} | Avg loss: {running/max(steps,1):.4f}")


In [13]:
# ─────────────────────────────────────────────────────────────────────────────
# 8) Example usage (set your actual folders)
#    As you asked earlier: keep feature folder as path1 etc.
# ─────────────────────────────────────────────────────────────────────────────
# CHANGE THESE to your real paths:
path1 = r"C:\npy_feats\hindi"          # Hindi features (.npy per utterance)
path2 = r"C:\senone_labels"     # Matching senone labels (.npy per utterance)
sanskrit_train = r"C:\npy_feats\sanskrit_train" # Sanskrit train features (.npy per utt)
sanskrit_test  = r"C:\npy_feats\sanskrit_test"  # Sanskrit test features (.npy per utt)

batch_size = 16
src_loader, tgt_loader, tgt_test_loader = get_dataloaders(
    path1, path2, sanskrit_train, sanskrit_test, batch_size=batch_size
)

dsn_model = DSN(input_dim=1320, num_senones=3080).to(device)
train_dsn(dsn_model, src_loader, tgt_loader, num_epochs=20)

Epoch 01/20 | Avg loss: 8.0693
Epoch 02/20 | Avg loss: 3.2824
Epoch 03/20 | Avg loss: 2.9245
Epoch 04/20 | Avg loss: 2.7586
Epoch 05/20 | Avg loss: 2.5964
Epoch 06/20 | Avg loss: 2.4429
Epoch 07/20 | Avg loss: 2.3908
Epoch 08/20 | Avg loss: 2.2921
Epoch 09/20 | Avg loss: 2.2524
Epoch 10/20 | Avg loss: 2.1740
Epoch 11/20 | Avg loss: 2.1843
Epoch 12/20 | Avg loss: 2.1048
Epoch 13/20 | Avg loss: 2.0864
Epoch 14/20 | Avg loss: 2.0719
Epoch 15/20 | Avg loss: 2.0752
Epoch 16/20 | Avg loss: 2.0412
Epoch 17/20 | Avg loss: 1.9732
Epoch 18/20 | Avg loss: 1.9297
Epoch 19/20 | Avg loss: 1.9885
Epoch 20/20 | Avg loss: 1.8976


In [None]:

# ─────────────────────────────────────────────────────────────────────────────
# 9) Model evaluation: inference on test set + WER & CER
# ─────────────────────────────────────────────────────────────────────────────
import os
from jiwer import wer
import editdistance

# Path to transcript file (update this)
transcript_file = r"C:\Users\prasu\Desktop\filtered_transcripts.txt"

# Load reference transcripts
references = {}
with open(transcript_file, "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split("|")
        if len(parts) == 2:
            filename, transcript = parts
            npy_filename = filename.replace(".m4a", ".npy")
            references[npy_filename] = transcript.lower()

# Run inference on test set
dsn_model.eval()
predictions = {}
with torch.no_grad():
    for feats, lengths, fnames in tgt_test_loader:
        feats = feats.to(device)
        xf = feats.view(-1, feats.size(-1))
        out = dsn_model(xf, mode='target')
        if out["senone_logits"] is None:
            continue

        pred_ids = out["senone_logits"].argmax(dim=-1).cpu().numpy().tolist()
        pred_str = " ".join(map(str, pred_ids))   # convert to space-separated string

        for fname in fnames:
            predictions[fname] = pred_str


# Align predictions and references
y_true, y_pred = [], []
for fname, ref_text in references.items():
    if fname in predictions:
        y_true.append(ref_text)
        y_pred.append(predictions[fname])

# Compute WER and CER
wer_score = wer(y_true, y_pred)

def cer(r, h):
    return editdistance.eval(r, h) / len(r) if len(r) > 0 else 0

cer_scores = [cer(r, h) for r, h in zip(y_true, y_pred)]
cer_score = sum(cer_scores) / len(cer_scores) if cer_scores else 0

print(f"WER: {wer_score:.3f}")
print(f"CER: {cer_score:.3f}")


WER: 0.000
CER: 0.000
