In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Define input and output dimensions
in_features = 64
out_features = 32

# Create a linear layer
linear = nn.Linear(in_features, out_features)

# Get the current weights
weights = linear.weight.data

# Create a mask for sparse initialization (90% sparsity)
mask = torch.rand_like(weights) > 0.9  # Keep only 10% of the weights
sparse_weights = weights * mask

# Update the weights
linear.weight.data = sparse_weights

# Convert weights to numpy for visualization
weight_matrix = sparse_weights.numpy()

# Create a figure with a larger size
plt.figure(figsize=(10, 6))

# Create heatmap
sns.heatmap(weight_matrix, 
            cmap='RdBu_r',  # Red-Blue diverging colormap
            center=0,       # Center the colormap at 0
            cbar_kws={'label': 'Weight Value'},
            xticklabels=10, # Show every 10th tick
            yticklabels=10)

plt.title('Sparse Weight Matrix (90% Sparsity)')
plt.xlabel('Input Features')
plt.ylabel('Output Features')

plt.show()

# Print sparsity statistics
total_weights = weights.numel()
zero_weights = torch.sum(sparse_weights == 0).item()
sparsity_percentage = (zero_weights / total_weights) * 100

print(f"Total weights: {total_weights}")
print(f"Zero weights: {zero_weights}")
print(f"Sparsity percentage: {sparsity_percentage:.2f}%")