In [None]:
import torch


def svd_internal_dimensionality_reduction(tensor, num_components):
    """
    Performs SVD dimensionality reduction, but returns the full tensor instead of just the reduced components.
    """
    u, s, v = torch.svd(tensor)
    return torch.matmul(u[:, :num_components] * s[:num_components], v[:, :num_components].T)


In [None]:
A = torch.randn(10, 3)
B = torch.randn(3, 10)

C_low_rank = torch.matmul(A, B)
C_full_rank = torch.randn(10, 10)

In [None]:
C_full_rank_svd = svd_internal_dimensionality_reduction(C_full_rank, 3)
C_full_rank_svd.shape

In [None]:
torch.svd(C_full_rank_svd).S

In [None]:
torch.svd(C_low_rank).S

In [None]:
import torch

def random_projection_dim_reduction(tensor, target_dim):
    """
    Performs random projection dimensionality reduction according to the Johnson-Lindenstrauss lemma.
    Only reduces the inner dimensionality, does not affect the shape of the tensor
    """
    original_dtype = tensor.dtype
    original_shape = tensor.shape
    tensor = tensor.to(dtype=torch.float32)

    # generate a random matrix with entries drawn from a normal distribution
    random_matrix = torch.randn(tensor.shape[-1], target_dim, dtype=torch.float32, device=tensor.device)
    random_matrix /= torch.norm(random_matrix, dim=0, keepdim=True)

    # project the tensor onto the random matrix, shape should not change
    new_matrix = torch.matmul(tensor, random_matrix).to(dtype=original_dtype)
    assert new_matrix.shape == original_shape
    return new_matrix


In [None]:
A = torch.randn(100, 10)

B = random_projection_dim_reduction(A, 2)

In [None]:
B.shape

In [None]:
import torch
import matplotlib.pyplot as plt

@torch.no_grad()
def random_pruning(tensor, prune_ratio):
    """
    Performs random pruning dimensionality reduction.
    Only reduces the inner dimensionality, does not affect the shape of the tensor
    """
    random_pruning_mask = torch.rand_like(tensor) > prune_ratio
    tensor = tensor * random_pruning_mask
    return tensor

# Create a 2D tensor with random values
tensor = torch.rand((10, 10))

# Define a list of pruning ratios
prune_ratios = [0.1, 0.3, 0.5, 0.7, 0.9]

# Initialize a figure
fig, axs = plt.subplots(1, len(prune_ratios)+1, figsize=(20, 5))

# Plot the original tensor
axs[0].imshow(tensor.numpy(), cmap='viridis')
axs[0].set_title('Original Tensor')

# Apply pruning for each ratio and plot the resulting tensors
for i, prune_ratio in enumerate(prune_ratios):
    pruned_tensor = random_pruning(tensor.clone(), prune_ratio)
    axs[i+1].imshow(pruned_tensor.numpy(), cmap='viridis')
    axs[i+1].set_title(f'Pruned Tensor (ratio = {prune_ratio})')

# Display the plot
plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt

@torch.no_grad()
def magnitude_pruning(tensor, prune_ratio):
    """
    Performs magnitude pruning dimensionality reduction.
    Only reduces the inner dimensionality, does not affect the shape of the tensor
    """
    tensor_magnitude = torch.abs(tensor)
    threshold = torch.quantile(tensor_magnitude.flatten(), prune_ratio)

    mask = tensor_magnitude > threshold
    tensor = tensor * mask.to(dtype=tensor.dtype)
    return tensor

# Create a 2D tensor with random values
tensor = torch.rand((10, 10))

# Define a list of pruning ratios
prune_ratios = [0.1, 0.3, 0.5, 0.7, 0.9]

# Initialize a figure
fig, axs = plt.subplots(1, len(prune_ratios)+1, figsize=(20, 5))

# Plot the original tensor
axs[0].imshow(tensor.numpy(), cmap='viridis')
axs[0].set_title('Original Tensor')

# Apply pruning for each ratio and plot the resulting tensors
for i, prune_ratio in enumerate(prune_ratios):
    pruned_tensor = magnitude_pruning(tensor.clone(), prune_ratio)
    axs[i+1].imshow(pruned_tensor.numpy(), cmap='viridis')
    axs[i+1].set_title(f'Pruned Tensor (ratio = {prune_ratio})')

# Display the plot
plt.show()
