In [None]:
!pip install timm torch torchvision



In [None]:
import torch
import torch.nn as nn
import math
import timm

class ConvLoRA(nn.Module):
    """
    LoRA implemented for Conv2d layers.
    It wraps a specific target Conv2d layer, freezes it, and adds the A-B paths.
    """
    def __init__(self, target_conv: nn.Conv2d, r: int = 8, alpha: int = 16):
        super().__init__()
        self.target_conv = target_conv
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r

        # Freeze the original target layer
        self.target_conv.weight.requires_grad = False
        if self.target_conv.bias is not None:
            self.target_conv.bias.requires_grad = False

        # Extract attributes from target
        in_channels = target_conv.in_channels
        out_channels = target_conv.out_channels
        kernel_size = target_conv.kernel_size
        stride = target_conv.stride
        padding = target_conv.padding
        groups = target_conv.groups

        # --- LoRA Layers ---
        # Matrix A: Down-projection.
        # We match kernel size/stride/padding to maintain spatial consistency.
        self.lora_A = nn.Conv2d(
            in_channels=in_channels,
            out_channels=r,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=False
        )

        # Matrix B: Up-projection.
        # Using 1x1 to act as a linear mixing layer back to dimension d.
        self.lora_B = nn.Conv2d(
            in_channels=r,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )

        self.reset_parameters()

    def reset_parameters(self):
        # Initialize A with Kaiming, B with Zeros (so LoRA starts as identity)
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

    def forward(self, x):
        # Original frozen path
        original_out = self.target_conv(x)

        # LoRA path
        lora_out = self.lora_B(self.lora_A(x)) * self.scaling

        return original_out + lora_out

In [None]:
def apply_lora_to_fastvit(model, r=8, alpha=16, target_blocks=[0, 1, 2, 3]):
    """
    Iterates through the FastViT stages and injects LoRA into Conv2d layers.
    """

    # Helper to recursively swap layers
    def replace_conv_with_lora(module):
        for name, child in module.named_children():
            if isinstance(child, nn.Conv2d):
                # Heuristic: LoRA the 1x1 projection layers (pointwise)
                # For FastViT efficiency, we target layers with kernel_size=1 or larger filters.
                if child.kernel_size == (1, 1) or child.kernel_size == 1:
                    print(f"  -> Patching Layer: {name} | Shape: {child.weight.shape}")
                    setattr(module, name, ConvLoRA(child, r=r, alpha=alpha))
            else:
                # Recursively search deeper
                replace_conv_with_lora(child)

    print(f"Integrate LoRA (r={r}) into FastViT...")

    stages = model.stages # Access the main 4 stages

    for i, stage in enumerate(stages):
        if i in target_blocks:
            print(f"\nProcessing Stage {i}...")
            replace_conv_with_lora(stage)

    return model

In [None]:
# 1. Load Pre-trained FastViT
# We use 'fastvit_t8' as a lightweight example.
model = timm.create_model('fastvit_t8', pretrained=True)

# 2. Check Parameters before LoRA
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Original Params: {total_params:,} | Trainable: {trainable_params:,}")

# 3. Apply LoRA
# We patch all 4 blocks (0, 1, 2, 3)
model = apply_lora_to_fastvit(model, r=16, alpha=16)

# 4. Check Parameters after LoRA
total_params_lora = sum(p.numel() for p in model.parameters())
trainable_params_lora = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\n" + "="*30)
print(f"LoRA Adapted Model Stats:")
print(f"Total Params: {total_params_lora:,}")
print(f"Trainable Params: {trainable_params_lora:,}")
print(f"Percentage Trainable: {(trainable_params_lora/total_params_lora)*100:.2f}%")
print("="*30)

# 5. Dummy Forward Pass to ensure shapes are correct
dummy_input = torch.randn(1, 3, 256, 256)
output = model(dummy_input)
print(f"\nForward pass successful. Output shape: {output.shape}")

model.safetensors:   0%|          | 0.00/16.3M [00:00<?, ?B/s]

Original Params: 4,026,232 | Trainable: 4,026,232
Integrate LoRA (r=16) into FastViT...

Processing Stage 0...
  -> Patching Layer: conv | Shape: torch.Size([48, 1, 1, 1])
  -> Patching Layer: fc1 | Shape: torch.Size([144, 48, 1, 1])
  -> Patching Layer: fc2 | Shape: torch.Size([48, 144, 1, 1])
  -> Patching Layer: conv | Shape: torch.Size([48, 1, 1, 1])
  -> Patching Layer: fc1 | Shape: torch.Size([144, 48, 1, 1])
  -> Patching Layer: fc2 | Shape: torch.Size([48, 144, 1, 1])

Processing Stage 1...
  -> Patching Layer: conv | Shape: torch.Size([96, 96, 1, 1])
  -> Patching Layer: conv | Shape: torch.Size([96, 1, 1, 1])
  -> Patching Layer: fc1 | Shape: torch.Size([288, 96, 1, 1])
  -> Patching Layer: fc2 | Shape: torch.Size([96, 288, 1, 1])
  -> Patching Layer: conv | Shape: torch.Size([96, 1, 1, 1])
  -> Patching Layer: fc1 | Shape: torch.Size([288, 96, 1, 1])
  -> Patching Layer: fc2 | Shape: torch.Size([96, 288, 1, 1])

Processing Stage 2...
  -> Patching Layer: conv | Shape: torch.