In [16]:
import os 
import torch
import torch.nn as nn

## Functin Definition

In [74]:
# ================ 1) Low-Rank Decomposition of Conv1d(kernel_size=1) ================
def low_rank_conv1d(conv_layer, rank):
    """
    Perform low-rank decomposition on Conv1d(kernel_size=1):
    Original weight.shape = (out_channels, in_channels, 1).
    After decomposition, replace with: Sequential(Conv1d(in_channels, rank, 1), ReLU, Conv1d(rank, out_channels, 1)).
    """
    W = conv_layer.weight.data  # (out_channels, in_channels, 1)
    out_channels, in_channels, kernel_size = W.shape
    
    if kernel_size != 1:
        # If not a 1×1 conv, skip or raise an error
        raise ValueError("Currently only support Conv1d with kernel_size=1.")

    # 1) Reshape weight to (out_channels, in_channels)
    W_2d = W.view(out_channels, in_channels)

    # 2) Perform SVD decomposition
    U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)

    # 3) Truncate to rank
    U = U[:, :rank]        # (out_channels, rank)
    S = S[:rank]           # (rank,)
    Vh = Vh[:rank, :]      # (rank, in_channels)

    # 4) Create two smaller conv layers
    conv1 = nn.Conv1d(in_channels, rank, kernel_size=1, bias=False)
    conv2 = nn.Conv1d(rank, out_channels, kernel_size=1, bias=False)

    # 5) Assign weights
    conv1.weight.data = Vh.unsqueeze(2)            # (rank, in_channels, 1)
    conv2.weight.data = (U * S).unsqueeze(2)       # (out_channels, rank, 1)

    # 6) Replace the original layer with Sequential (optionally add activation function)
    return nn.Sequential(conv1, nn.ReLU(), conv2)


# ================ 2) Low-Rank Decomposition of SpectralConv (DenseTensor) ================
def low_rank_spectral_conv(spectral_layer, rank):
    """
    Approximate the weight of SpectralConv (DenseTensor(shape=[C_out, C_in, Nx, Ny])).
    This demonstrates an approach to performing truncated SVD on [C_out, C_in] dimensions (for reference).
    You need to modify this according to the actual DenseTensor API:
      - How to retrieve the underlying torch.Tensor?
      - How to write back to DenseTensor?
    """
    # -- Assume spectral_layer.weight is a DenseTensor --
    # (1) Retrieve the actual torch.Tensor
    #    (The following names are purely examples, adjust based on the actual implementation of DenseTensor)
    if not hasattr(spectral_layer.weight, "to_tensor"):
        print(f"[Warning] DenseTensor has no 'to_tensor()' method. Skipped.")
        return
    
    W_torch = spectral_layer.weight.to_tensor()  # => Shape [C_out, C_in, Nx, Ny], type: torch.Tensor
    
    C_out, C_in, Nx, Ny = W_torch.shape

    # (2) Perform 2D SVD only on channel dimensions, treating Nx, Ny as batch/extra dimensions
    #     Simplest approach: reshape -> (C_out*C_in, Nx*Ny)
    #     Perform a 2D SVD -> Truncate to rank
    W_2d = W_torch.view(C_out*C_in, Nx*Ny)  # => (C_out*C_in, Nx*Ny)

    U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)
    # U.shape: (C_out*C_in, C_out*C_in)
    # S.shape: (min(C_out*C_in, Nx*Ny),)
    # Vh.shape: (min(C_out*C_in, Nx*Ny), Nx*Ny)

    # Truncate to rank, ensuring it does not exceed the minimum dimension of U, Vh
    max_rank = min(rank, U.shape[1], Vh.shape[0])  
    U = U[:, :max_rank]
    S = S[:max_rank]
    Vh = Vh[:max_rank, :]

    # Approximate: W_2d_low = U * S * Vh
    #   (C_out*C_in, max_rank) x (max_rank,) x (max_rank, Nx*Ny)
    # First multiply (U * S), then multiply Vh
    U_S = U * S.unsqueeze(0)  # Broadcast
    W_2d_low = U_S @ Vh       # => (C_out*C_in, Nx*Ny)

    # Reshape back to original shape
    W_low = W_2d_low.view(C_out, C_in, Nx, Ny)

    # (3) Write back to DenseTensor
    # Assume DenseTensor has a from_tensor() or set_tensor() method
    if hasattr(spectral_layer.weight, "from_tensor"):
        spectral_layer.weight.from_tensor(W_low)
    else:
        print("[Warning] No method to write back to DenseTensor. Skipped.")


# ================ 3) Iterate Through Model and Apply Low-Rank Decomposition ================
def apply_low_rank_decomposition(model, rank=16):
    """
    1) Apply low-rank decomposition to all Conv1d(kernel_size=1) layers (replacing with two smaller Conv1d layers)
    2) Apply low-rank decomposition to DenseTensor in SpectralConv
    """
    for name, module in model.named_modules():
        # --- (a) Conv1d ---
        if isinstance(module, nn.Conv1d):
            # Check kernel_size
            if module.kernel_size == (1,):
                print(f"[LowRank] Replacing 1x1 Conv1d at: {name}")
                # Find parent module
                parent_name = ".".join(name.split(".")[:-1])
                child_name = name.split(".")[-1]

                # Construct low-rank replacement layer
                new_layer = low_rank_conv1d(module, rank)
                # Replace new_layer in parent
                if parent_name:
                    parent_module = dict(model.named_modules())[parent_name]
                    setattr(parent_module, child_name, new_layer)
                else:
                    setattr(model, name, new_layer)

        # --- (b) SpectralConv (DenseTensor) ---
        # Need to determine based on actual class name.
        # Example: if isinstance(module, SpectralConvClass): ...
        # Here, we simply check if "SpectralConv" is in the class name
        elif "SpectralConv" in type(module).__name__:
            # DenseTensor is typically stored as module.weight: DenseTensor(...)
            # Verify first
            if hasattr(module, "weight"):
                print(f"[LowRank] Approximating DenseTensor at: {name} shape={module.weight.shape}")
                low_rank_spectral_conv(module, rank)
            else:
                print(f"[Warning] Found SpectralConv but no 'weight' attribute: {name}")

In [75]:
def get_model_size(file_path):
    size_in_bytes = os.path.getsize(file_path)
    size_in_mb = size_in_bytes / (1024 ** 2)  # to MB
    return size_in_mb
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Decomposition

In [67]:
base_dir = 'e:/UCLCS/UCL Project/Neural Operator/COMP0031-Model-Compression-on-Neural-Operator'
model_dir = os.path.join(base_dir, 'models', 'darcy_small.pth')

In [68]:
model = torch.load(model_dir)
original_size = get_model_size(model_dir)
original_params = count_parameters(model)
model.eval()

  model = torch.load(model_dir)


FNO(
  (positional_embedding): GridEmbeddingND()
  (fno_blocks): FNOBlocks(
    (convs): ModuleList(
      (0-3): 4 x SpectralConv(
        (weight): DenseTensor(shape=torch.Size([32, 32, 16, 9]), rank=None)
      )
    )
    (fno_skips): ModuleList(
      (0-3): 4 x Flattened1dConv(
        (conv): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
      )
    )
    (channel_mlp): ModuleList(
      (0-3): 4 x ChannelMLP(
        (fcs): ModuleList(
          (0): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
          (1): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
        )
      )
    )
    (channel_mlp_skips): ModuleList(
      (0-3): 4 x SoftGating()
    )
  )
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (1): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
    )
  )
  (projection): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
      (1): Conv1d(64, 1, kernel_size=(

In [71]:
apply_low_rank_decomposition(model, rank=8)

[LowRank] Approximating DenseTensor at: fno_blocks.convs.0 shape=torch.Size([32, 32, 16, 9])
[LowRank] Approximating DenseTensor at: fno_blocks.convs.1 shape=torch.Size([32, 32, 16, 9])
[LowRank] Approximating DenseTensor at: fno_blocks.convs.2 shape=torch.Size([32, 32, 16, 9])
[LowRank] Approximating DenseTensor at: fno_blocks.convs.3 shape=torch.Size([32, 32, 16, 9])
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.fno_skips.0.conv.0
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.fno_skips.0.conv.2
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.fno_skips.1.conv.0
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.fno_skips.1.conv.2
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.fno_skips.2.conv.0
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.fno_skips.2.conv.2
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.fno_skips.3.conv.0
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.fno_skips.3.conv.2
[LowRank] Replacing 1x1 Conv1d at: fno_blocks.channel_mlp.0.fcs.0.0
[LowRank] Replacing 1x1 Conv1d at: fno_b

In [76]:
low_rank_params = count_parameters(model)
print(f"Original model parameter number: {original_params}")
print(f"Low-Rank model parameter number: {low_rank_params}")
print(f"parameter compression rate: {(1 - low_rank_params / original_params) * 100:.2f}%")

Original model parameter number: 602977
Low-Rank model parameter number: 591076
parameter compression rate: 1.97%


In [73]:
model.eval()

FNO(
  (positional_embedding): GridEmbeddingND()
  (fno_blocks): FNOBlocks(
    (convs): ModuleList(
      (0-3): 4 x SpectralConv(
        (weight): DenseTensor(shape=torch.Size([32, 32, 16, 9]), rank=None)
      )
    )
    (fno_skips): ModuleList(
      (0-3): 4 x Flattened1dConv(
        (conv): Sequential(
          (0): Sequential(
            (0): Conv1d(32, 8, kernel_size=(1,), stride=(1,), bias=False)
            (1): ReLU()
            (2): Conv1d(8, 1, kernel_size=(1,), stride=(1,), bias=False)
          )
          (1): ReLU()
          (2): Sequential(
            (0): Conv1d(1, 8, kernel_size=(1,), stride=(1,), bias=False)
            (1): ReLU()
            (2): Conv1d(8, 32, kernel_size=(1,), stride=(1,), bias=False)
          )
        )
      )
    )
    (channel_mlp): ModuleList(
      (0-3): 4 x ChannelMLP(
        (fcs): ModuleList(
          (0): Sequential(
            (0): Sequential(
              (0): Conv1d(32, 8, kernel_size=(1,), stride=(1,), bias=False)
  