In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#### Private Encoder (For Source and Target Domain)

In [None]:
class PrivateEncoder(nn.Module):
    def __init__(self, input_dim=1320, hidden_dim=512, num_layers=4):
        super(PrivateEncoder, self).__init__()
        layers = []
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)

#### Shared Encoder

In [None]:
class SharedEncoder(nn.Module):
    def __init__(self, input_dim=1320, hidden_dim=1024, num_layers=6):
        super(SharedEncoder, self).__init__()
        layers = []
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)

#### Senone Classifier

In [None]:
class SenoneClassifier(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=1024, output_dim=3080):
        super(SenoneClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.classifier(x)

#### Domain Classifier

In [None]:
class DomainClassifier(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=256):
        super(DomainClassifier, self).__init__()
        self.domain_classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # source or target
        )

    def forward(self, x):
        return self.domain_classifier(x)

#### Shared Decoder

In [None]:
class SharedDecoder(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=1024, output_dim=1320):
        super(SharedDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.decoder(x)

#### DSN Model Wrapper

In [None]:
class DSN(nn.Module):
    def __init__(self):
        super(DSN, self).__init__()
        self.shared_encoder = SharedEncoder()
        self.private_encoder_source = PrivateEncoder()
        self.private_encoder_target = PrivateEncoder()
        self.senone_classifier = SenoneClassifier()
        self.domain_classifier = DomainClassifier()
        self.shared_decoder = SharedDecoder()

    def forward(self, x, domain='source', mode='train'):
        private_encoder = self.private_encoder_source if domain == 'source' else self.private_encoder_target
        private_feat = private_encoder(x)
        shared_feat = self.shared_encoder(x)

        if mode == 'train':
            recon = self.shared_decoder(shared_feat)
            senone_out = self.senone_classifier(shared_feat)
            domain_out = self.domain_classifier(shared_feat)
            return private_feat, shared_feat, recon, senone_out, domain_out
        elif mode == 'inference':
            senone_out = self.senone_classifier(shared_feat)
            return senone_out

### Loss Function and Training Setup

In [None]:
beta = 0.25  # reconstruction loss
gamma = 0.075  # domain classification loss
delta = 0.1  # difference loss

criterion_recon = nn.MSELoss()
criterion_domain = nn.CrossEntropyLoss()
criterion_senone = nn.CrossEntropyLoss()

def difference_loss(private, shared):
    return torch.mean(torch.sum((F.normalize(private, dim=1) * F.normalize(shared, dim=1))**2, dim=1))

In [None]:
# Instantiate model
model = DSN()

source_feats.npy: shape (15000, 1320)

source_labels.npy: shape (15000,) — senone labels

target_feats.npy: shape (2837, 1320) — target domain has no labels



In [None]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

# Load Kaldi-extracted features and labels
source_feats = torch.tensor(np.load("source_feats.npy"), dtype=torch.float32)     # (15000, 1320)
source_labels = torch.tensor(np.load("source_labels.npy"), dtype=torch.long)      # (15000,)
target_feats = torch.tensor(np.load("target_feats.npy"), dtype=torch.float32)     # (2837, 1320)

# Wrap into Datasets
source_dataset = TensorDataset(source_feats, source_labels)
# Target domain has dummy labels just to allow zipping
target_dataset = TensorDataset(target_feats, torch.zeros(len(target_feats)))

# Build DataLoaders
batch_size = 128
source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

## TRAINING LOOP

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DSN().to(device)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20000, gamma=0.95)

# Dummy loop
for epoch in range(20):
    for (source_x, source_y), (target_x, _) in zip(source_loader, target_loader):
        source_x, source_y = source_x.to(device), source_y.to(device)
        target_x = target_x.to(device) 

        # Forward for source
        priv_src, shared_src, recon_src, senone_out, domain_out_src = model(source_x, domain='source', mode='train')
        loss_senone = criterion_senone(senone_out, source_y)
        loss_domain_src = criterion_domain(domain_out_src, torch.zeros(source_x.size(0), dtype=torch.long, device=device))
        loss_recon_src = criterion_recon(recon_src, source_x)

        # Forward for target
        priv_tgt, shared_tgt, recon_tgt, _, domain_out_tgt = model(target_x, domain='target', mode='train')
        loss_domain_tgt = criterion_domain(domain_out_tgt, torch.ones(target_x.size(0), dtype=torch.long, device=device))
        loss_recon_tgt = criterion_recon(recon_tgt, target_x)

        # Difference loss
        loss_diff_src = difference_loss(priv_src, shared_src)
        loss_diff_tgt = difference_loss(priv_tgt, shared_tgt)

        # Total loss
        loss = loss_senone \
               + beta * (loss_recon_src + loss_recon_tgt) \
               + gamma * (loss_domain_src + loss_domain_tgt) \
               + delta * (loss_diff_src + loss_diff_tgt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()

# After your training loop finishes (after scheduler.step() outside the loop)
torch.save(model.state_dict(), "dsn_model.pth")
print("Model saved to dsn_model.pth")

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

# Load test features
test_feats = torch.tensor(np.load("test_feats.npy"), dtype=torch.float32)  # shape: (558, 1320)

# Wrap into a dataset (no labels needed for inference)
test_dataset = TensorDataset(test_feats)

# Create DataLoader
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
# Load the saved model weights before inference
model.load_state_dict(torch.load("dsn_model.pth", map_location=device))

In [None]:
model.eval()
all_preds = []

with torch.no_grad():
    for (test_x,) in test_loader:
        test_x = test_x.to(device)
        logits = model(test_x, mode='inference')
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds.cpu())

all_preds = torch.cat(all_preds)
np.save("test_predictions.npy", all_preds.numpy())
print("Inference predictions saved to test_predictions.npy")