In [17]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torch.nn.functional as F
import math
import os
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import copy

In [18]:
# Constants
IMG_SIZE = 64
BATCH_SIZE = 128
T = 300  # Total diffusion steps

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [19]:
# Model definition (same as your original code)
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

In [20]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [21]:
class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

In [22]:
# Scheduler functions
def linear_beta_scheduler(timesteps, start=0.0001, end=0.008):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# Pre-compute parameters
linear_betas = linear_beta_scheduler(T)
linear_alphas = 1. - linear_betas
linear_alphas_cumprod = torch.cumprod(linear_alphas, axis=0)
linear_sqrt_alphas_cumprod = torch.sqrt(linear_alphas_cumprod)
linear_sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - linear_alphas_cumprod)
linear_sqrt_recip_alphas = torch.sqrt(1.0 / linear_alphas)
linear_alphas_cumprod_prev = F.pad(linear_alphas_cumprod[:-1], (1, 0), value=1.0)
linear_posterior_variance = linear_betas * (1. - linear_alphas_cumprod_prev) / (1. - linear_alphas_cumprod)

In [23]:
# Image processing functions
def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),  # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch if needed
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    return reverse_transforms(image)

@torch.no_grad()
def sample_timestep(model, x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    """
    betas_t = get_index_from_list(linear_betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        linear_sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(linear_sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(linear_posterior_variance, t, x.shape)

    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def generate_image(model, seed=None):
    """
    Generate one image from random noise using the trained Unet
    """
    if seed is not None:
        torch.manual_seed(seed)

    # Start with random noise
    img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)

    # Sample through timesteps
    for i in range(0, T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(model, img, t)

    return img

In [24]:
# Pruning functions
def iterative_magnitude_pruning(model, sparsity):
    """
    Apply iterative magnitude pruning (IMP) to the model.

    Args:
        model: The PyTorch model to prune
        sparsity: Target sparsity level (0.0 to 1.0)

    Returns:
        Pruned model
    """
    pruned_model = copy.deepcopy(model)

    # Get all weight parameters
    weights = []
    for name, param in pruned_model.named_parameters():
        if 'weight' in name and param.dim() > 1:  # Only prune weights in Conv and Linear layers
            weights.append((name, param))

    # Calculate the number of weights to prune
    total_weights = sum(w[1].numel() for w in weights)
    num_to_prune = int(total_weights * sparsity)

    # Flatten all weights and get the magnitude threshold
    all_weights = torch.cat([w[1].data.view(-1).abs() for w in weights])
    threshold = torch.sort(all_weights)[0][num_to_prune]

    # Apply pruning
    zero_weight_count = 0
    for name, param in weights:
        mask = param.data.abs() > threshold
        param.data.mul_(mask)
        zero_weight_count += (~mask).sum().item()

    print(f"IMP Pruning: Set {zero_weight_count} weights to zero ({zero_weight_count/total_weights:.2%} sparsity)")
    return pruned_model

In [25]:
def hilbert_schmidt_pruning(model, sparsity):
    """
    Apply Hilbert-Schmidt pruning to the model.

    Args:
        model: The PyTorch model to prune
        sparsity: Target sparsity level (0.0 to 1.0)

    Returns:
        Pruned model
    """
    pruned_model = copy.deepcopy(model)

    # Get all weight parameters that can be pruned
    weights = []
    for name, param in pruned_model.named_parameters():
        if 'weight' in name and param.dim() > 1:  # Only prune weights in Conv and Linear layers
            weights.append((name, param))

    # Calculate the number of weights to prune
    total_weights = sum(w[1].numel() for w in weights)
    num_to_prune = int(total_weights * sparsity)

    # Calculate Hilbert-Schmidt norm (Frobenius norm) for each weight tensor
    hs_norms = {}
    for name, param in weights:
        # For each individual weight, calculate its contribution to the Frobenius norm
        param_norms = param.data.pow(2)
        hs_norms[name] = param_norms

    # Flatten all norm contributions
    all_norm_contribs = torch.cat([norm.view(-1) for norm in hs_norms.values()])
    threshold = torch.sort(all_norm_contribs)[0][num_to_prune]

    # Apply pruning based on contribution to Hilbert-Schmidt norm
    zero_weight_count = 0
    for name, param in weights:
        mask = hs_norms[name] > threshold
        param.data.mul_(mask)
        zero_weight_count += (~mask).sum().item()

    print(f"HS Pruning: Set {zero_weight_count} weights to zero ({zero_weight_count/total_weights:.2%} sparsity)")
    return pruned_model

In [26]:
def count_parameters(model):
    """Count total and non-zero parameters in the model"""
    total_params = 0
    nonzero_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            nparam = param.numel()
            total_params += nparam
            nonzero_params += (param != 0).sum().item()

    return total_params, nonzero_params

In [27]:
def load_model(model, path):
    """Load model weights from path"""
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    model.eval()
    return model

In [28]:
def generate_and_save_images(model, output_dir, num_images=5, prefix=""):
    """Generate multiple images and save them"""
    os.makedirs(output_dir, exist_ok=True)

    all_images = []
    for i in range(num_images):
        # Use a consistent seed for each position to ensure fair comparison
        img = generate_image(model, seed=i)
        pil_img = show_tensor_image(img.detach().cpu())
        all_images.append(pil_img)
        pil_img.save(os.path.join(output_dir, f"{prefix}_sample_{i}.png"))

    # Create a grid of images
    fig, axes = plt.subplots(1, num_images, figsize=(num_images*4, 4))
    for i, (ax, img) in enumerate(zip(axes, all_images)):
        ax.imshow(np.array(img))
        ax.set_title(f"Sample {i}")
        ax.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{prefix}_grid.png"))
    plt.close()
    return all_images

In [29]:
def main():
    # Create the model
    model = SimpleUnet()

    # Load the original model
    original_model = load_model(model, "model.pth")

    # Calculate and print statistics
    total_params, nonzero_params = count_parameters(original_model)
    print(f"Original model: {total_params} total parameters, {nonzero_params} non-zero ({nonzero_params/total_params:.2%})")

    # Create output directory
    output_dir = "pruning_results"
    os.makedirs(output_dir, exist_ok=True)

    # Generate images with the original model
    print("Generating images with original model...")
    original_images = generate_and_save_images(original_model, output_dir, num_images=5, prefix="original")

    # Apply different pruning methods at different sparsity levels
    sparsity_levels = [0.25, 0.50, 0.75, 0.90]
    pruning_methods = {
        "imp": iterative_magnitude_pruning,
        "hs": hilbert_schmidt_pruning
    }

    # Store all images for final grid
    all_results = {"original": original_images}

    for method_name, method_func in pruning_methods.items():
        method_results = {}
        for sparsity in sparsity_levels:
            print(f"Applying {method_name} pruning with {sparsity:.0%} sparsity...")
            pruned_model = method_func(original_model, sparsity)

            # Count parameters
            total, nonzero = count_parameters(pruned_model)
            actual_sparsity = 1 - nonzero/total
            print(f"  Achieved sparsity: {actual_sparsity:.2%}")

            # Generate images with the pruned model
            print(f"Generating images with {method_name} pruned model ({sparsity:.0%} sparsity)...")
            prefix = f"{method_name}_{int(sparsity*100)}"
            pruned_images = generate_and_save_images(pruned_model, output_dir, num_images=5, prefix=prefix)
            method_results[f"{int(sparsity*100)}%"] = pruned_images

        all_results[method_name] = method_results

    # Create a comprehensive comparison grid
    num_methods = len(pruning_methods)
    num_sparsity = len(sparsity_levels)
    num_samples = 5

    # Fix: Change num_methods + 1 to num_methods + 2 to account for header row and original model row
    fig, axes = plt.subplots(num_methods + 2, num_sparsity + 1,
                             figsize=((num_sparsity + 1) * 4, (num_methods + 2) * 4))

    # Set up the row and column headers
    axes[0, 0].axis('off')
    for i, sparsity in enumerate(sparsity_levels):
        axes[0, i+1].set_title(f"{int(sparsity*100)}% Sparsity")
        axes[0, i+1].axis('off')

    # Original model (first row)
    axes[1, 0].set_title("Original")
    axes[1, 0].axis('off')
    for j in range(num_sparsity):
        img = original_images[j % num_samples]
        axes[1, j+1].imshow(np.array(img))
        axes[1, j+1].axis('off')

    # Pruned models
    for i, method_name in enumerate(pruning_methods.keys()):
        axes[i+2, 0].set_title(method_name.upper())
        axes[i+2, 0].axis('off')

        for j, sparsity in enumerate(sparsity_levels):
            sparsity_key = f"{int(sparsity*100)}%"
            img = all_results[method_name][sparsity_key][j % num_samples]
            axes[i+2, j+1].imshow(np.array(img))
            axes[i+2, j+1].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "comprehensive_comparison.png"))
    plt.close()

    print(f"All results saved to {output_dir}/")

if __name__ == "__main__":
    main()

Original model: 62438883 total parameters, 62438883 non-zero (100.00%)
Generating images with original model...
Applying imp pruning with 25% sparsity...
IMP Pruning: Set 15603940 weights to zero (25.00% sparsity)
  Achieved sparsity: 24.99%
Generating images with imp pruned model (25% sparsity)...
Applying imp pruning with 50% sparsity...
IMP Pruning: Set 31207875 weights to zero (50.00% sparsity)
  Achieved sparsity: 49.98%
Generating images with imp pruned model (50% sparsity)...
Applying imp pruning with 75% sparsity...
IMP Pruning: Set 46811809 weights to zero (75.00% sparsity)
  Achieved sparsity: 74.97%
Generating images with imp pruned model (75% sparsity)...
Applying imp pruning with 90% sparsity...
IMP Pruning: Set 56174172 weights to zero (90.00% sparsity)
  Achieved sparsity: 89.97%
Generating images with imp pruned model (90% sparsity)...
Applying hs pruning with 25% sparsity...
HS Pruning: Set 15603940 weights to zero (25.00% sparsity)
  Achieved sparsity: 24.99%
Generati