In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from scipy.optimize import linear_sum_assignment

# Simple Model Definition
class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.fc(x)

# Pruning Function: pruning based on magnitude
def prune_magnitude(model, prune_n=0, prune_m=0, sparsity_ratio=0.2):
    W = model.fc.weight.data
    W_metric = torch.abs(W)
    
    if prune_n != 0:
        W_mask = (torch.zeros_like(W) == 1)
        for ii in range(W_metric.shape[1]):
            if ii % prune_m == 0:
                tmp = W_metric[:, ii:(ii + prune_m)].float()
                _, indices = torch.topk(tmp, prune_n, dim=1, largest=False)
                W_mask.scatter_(1, indices, True)
    else:
        # Apply pruning based on the sparsity ratio
        thresh = torch.sort(W_metric.flatten())[0][int(W.numel() * sparsity_ratio)].cpu()
        W_mask = (W_metric <= thresh)

    W[W_mask] = 0

# Gumbel Sinkhorn (used to generate learnable permutation matrix)
def gumbel_sinkhorn(log_alpha, n_iters, tau, epsilon=1e-8):
    # Sample Gumbel noise
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(log_alpha) + epsilon) + epsilon)
    # Add noise to the logits
    noisy_logits = (log_alpha + 1 * gumbel_noise) / tau
    # Initialize the permutation matrix
    S = torch.exp(noisy_logits)
    for _ in range(n_iters):
        # Row normalization
        S = S / (S.sum(dim=1, keepdim=True) + epsilon)
        # Column normalization
        S = S / (S.sum(dim=0, keepdim=True) + epsilon)
    return S
# def gumbel_sinkhorn(log_alpha, n_iters, tau, noise_factor=1.0, epsilon=1e-8):
#     # Sample Gumbel noise
#     gumbel_noise = -torch.log(-torch.log(torch.rand_like(log_alpha) + epsilon) + epsilon)
#     # Add noise to the logits
#     noisy_logits = (log_alpha + noise_factor * gumbel_noise) / tau
#     # Initialize the permutation matrix
#     S = torch.exp(noisy_logits)
#     for _ in range(n_iters):
#         # Row normalization
#         S = S / (S.sum(dim=1, keepdim=True) + epsilon)
#         # Column normalization
#         S = S / (S.sum(dim=0, keepdim=True) + epsilon)
#     return S
# Learnable Permutation Matrix: defined as a soft parameter to be trained
class PermutationLayer(nn.Module):
    def __init__(self, input_dim, sinkhorn_iterations=10, tau=0.1, epsilon=1e-8):
        super(PermutationLayer, self).__init__()
        self.sinkhorn_iterations = sinkhorn_iterations
        self.tau = tau
        self.epsilon = epsilon

    def forward(self, W):
        # Step 1: Apply Gumbel Sinkhorn to get S_soft (soft permutation)
        S_soft = gumbel_sinkhorn(W, self.sinkhorn_iterations, self.tau, epsilon=self.epsilon)
        S_soft = torch.clamp(S_soft, min=1e-8, max=1 - 1e-8)
        
        # Step 2: Compute the hard permutation matrix P_hard
        with torch.no_grad():
            S_cpu = S_soft.detach().cpu().numpy()
            if not np.isfinite(S_cpu).all():
                print("Invalid values in S_cpu")
                print(S_cpu)
                exit()
            row_ind, col_ind = linear_sum_assignment(-S_cpu)
            P_hard = torch.zeros_like(S_soft)
            P_hard[row_ind, col_ind] = 1.0

        # Modify P_hard to allow gradient flow
        P_hard = (P_hard - S_soft).detach() + S_soft

        P_hard = P_hard.to(W.dtype)
        
        # Permute the columns of W using P_hard
        W_perm = torch.matmul(W, P_hard)
        return W_perm

# Initialize model, loss function, and optimizer
input_dim = 10
output_dim = 5
model = SimpleModel(input_dim, output_dim).cuda()
permutation_layer = PermutationLayer(input_dim).cuda()

# Use SGD optimizer
optimizer = optim.SGD(list(model.parameters()) + list(permutation_layer.parameters()), lr=0.01)

# Generate random input and target data
input_data = torch.randn(8, input_dim).cuda()  # batch size = 8
target_data = torch.randn(8, output_dim).cuda()

# Training Loop
num_epochs = 100
prune_n = 2
prune_m = 2
sparsity_ratio = 0.2

for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Forward pass before pruning and permutation
    output_before_pruning = model(input_data)

    # Apply pruning to the model
    prune_magnitude(model, prune_n=prune_n, prune_m=prune_m, sparsity_ratio=sparsity_ratio)

    # Forward pass after pruning (without permutation)
    output_after_pruning = model(input_data)

    # Apply permutation to the pruned weights
    pruned_weights_with_permutation = permutation_layer(model.fc.weight.data)

    # Set the model's pruned weights with permutation
    model.fc.weight.data = pruned_weights_with_permutation

    # Forward pass after pruning and permutation
    output_after_permutation = model(input_data)

    # Loss: difference between outputs before and after pruning and permutation
    loss = torch.mean(torch.abs(output_after_permutation - output_before_pruning))

    # Backpropagation
    loss.backward()
    optimizer.step()

    # Print loss every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.6f}')

# Final learned permutation matrix
print("Learned permutation matrix P:")
print(permutation_layer.permutation_matrix)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (5x10 and 5x10)