In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Corrected DFSMN Layer
class DFSMNLayer(nn.Module):
    def __init__(self, input_dim, output_dim, left_context=8, right_context=2):
        super(DFSMNLayer, self).__init__()
        self.left_context = left_context
        self.right_context = right_context
        self.linear1 = nn.Linear(input_dim, output_dim)
        self.linear2 = nn.Linear(input_dim * (left_context + right_context + 1), output_dim)

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.size()
        
        # Context padding
        padded_x = F.pad(x, (0, 0, self.left_context, self.right_context))
        
        # Collect left and right context
        left_context = [padded_x[:, t: t + seq_len] for t in range(self.left_context)]
        right_context = [padded_x[:, t + self.left_context + 1: t + self.left_context + 1 + seq_len] for t in range(self.right_context)]
        
        # Combine contexts and the central frame
        context_combined = torch.cat(left_context + [x] + right_context, dim=-1)
        
        # Apply linear layers
        out1 = self.linear1(x)
        out2 = self.linear2(context_combined)
        
        # Combine outputs
        out = out1 + out2
        return out


# DFSMN Encoder
class DFSMNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(DFSMNEncoder, self).__init__()
        self.layers = nn.ModuleList([DFSMNLayer(input_dim if i == 0 else hidden_dim, hidden_dim) for i in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
class Conv1dPredictor(nn.Module):
    def __init__(self, input_dim, output_dim, context_size=4):
        super(Conv1dPredictor, self).__init__()
        self.context_size = context_size
        self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=output_dim, kernel_size=context_size, stride=1)

    def forward(self, y):
        y = y.float()  # Ensure input is of type float
        y = F.pad(y, (self.context_size - 1, 0))
        if y.dim() == 2:  # If y is 2D, add a batch dimension
            y = y.unsqueeze(0)
        y = y.transpose(1, 2)  # Change shape to [batch_size, input_dim, sequence_length]
        y = self.conv1d(y)
        y = y.transpose(1, 2)  # Change shape back to [batch_size, sequence_length, output_dim]
        return y


# Joint Network
class JointNetwork(nn.Module):
    def __init__(self, encoder_dim, predictor_dim, joint_dim):
        super(JointNetwork, self).__init__()
        self.linear_enc = nn.Linear(encoder_dim, joint_dim)
        self.linear_pred = nn.Linear(predictor_dim, joint_dim)
        self.linear_out = nn.Linear(joint_dim, encoder_dim)

    def forward(self, enc_out, pred_out):
        joint_out = torch.tanh(self.linear_enc(enc_out) + self.linear_pred(pred_out))
        joint_out = self.linear_out(joint_out)
        return joint_out

# Transducer Model
class Transducer(nn.Module):
    def __init__(self, encoder, predictor, joint_network):
        super(Transducer, self).__init__()
        self.encoder = encoder
        self.predictor = predictor
        self.joint_network = joint_network

    def forward(self, x, y):
        enc_out = self.encoder(x)
        pred_out = self.predictor(y)
        joint_out = self.joint_network(enc_out, pred_out)
        return joint_out

# Phone Synchronous Decoding (PSD) with Blank Skipping
def psd_decoding(transducer_model, x, blank_threshold=0.95):
    enc_out = transducer_model.encoder(x)
    y_in = torch.zeros((enc_out.size(0), transducer_model.predictor.context_size, enc_out.size(2)), device=x.device)
    pred_out = transducer_model.predictor(y_in)
    
    blank_deweight_value = 0.1
    w = []

    for t in range(enc_out.size(1)):
        joint_out = transducer_model.joint_network(enc_out[:, t, :], pred_out[:, -1, :])
        log_probs = F.log_softmax(joint_out, dim=-1)
        blank_prob = log_probs[:, 0].exp()
        if blank_prob < blank_threshold:
            pred_out = torch.cat([pred_out, joint_out.unsqueeze(1)], dim=1)
            w.append(torch.argmax(log_probs, dim=-1).item())
        else:
            log_probs[:, 0] -= blank_deweight_value

    return w

# Training Function
def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            outputs = model(x, y)
            loss = criterion(outputs.view(-1, outputs.size(-1)), y.view(-1))
            loss.backward()
            optimizer.step()

# Model Compression with SVD
def apply_svd_to_model(model, rank=50):
    for name, param in model.named_parameters():
        if 'weight' in name and param.dim() == 2:
            U, S, V = torch.svd(param)
            model.state_dict()[name].copy_(torch.mm(U[:, :rank], torch.mm(torch.diag(S[:rank]), V[:, :rank].t())))
    return model

# Evaluation Function
def evaluate_model(model, test_loader, criterion):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for x, y in test_loader:
            outputs = model(x, y)
            loss = criterion(outputs.view(-1, outputs.size(-1)), y.view(-1))
            total_loss += loss.item()
    return total_loss / len(test_loader)

# Example usage with random data (replace with actual data loaders)
input_dim = 40  # PNCC features dimension
hidden_dim = 400
num_layers = 8
joint_dim = 100

# Instantiate model components
encoder = DFSMNEncoder(input_dim, hidden_dim, num_layers)
predictor = Conv1dPredictor(hidden_dim, joint_dim)
joint_network = JointNetwork(hidden_dim, joint_dim, joint_dim)

# Instantiate the transducer model
transducer_model = Transducer(encoder, predictor, joint_network)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transducer_model.parameters(), lr=0.0005)

# Dummy data loaders for demonstration (replace with actual data)
train_loader = [(torch.rand(32, 100, input_dim), torch.randint(0, 10, (32, 100))) for _ in range(100)]
test_loader = [(torch.rand(32, 100, input_dim), torch.randint(0, 10, (32, 100))) for _ in range(20)]

# Train the model
train_model(transducer_model, train_loader, criterion, optimizer, num_epochs=10)

# Compress the transducer model
transducer_model = apply_svd_to_model(transducer_model)

# Fine-tune the compressed model
train_model(transducer_model, train_loader, criterion, optimizer, num_epochs=5)

# Evaluate the fine-tuned model
test_loss = evaluate_model(transducer_model, test_loader, criterion)
print(f'Test Loss: {test_loss}')


RuntimeError: Given groups=1, weight of size [100, 400, 4], expected input[1, 103, 32] to have 400 channels, but got 103 channels instead