In [1]:
import os
import math
import inspect
import logging
import typing as ty
from functools import partial
from pathlib import Path
from collections import OrderedDict

import torch
import torch.optim
import torch.nn as nn
import numpy as np
import webbrowser
import graphviz
import minlora
from minlora import (
    LoRAParametrization,
    add_lora,
    merge_lora,
    remove_lora
)
from minlora.utils import get_params_by_name, name_is_lora
from minlora.model import add_lora_by_name, apply_lora
from torch.optim import AdamW
from spconv.pytorch.conv import SubMConv3d
graphviz.set_jupyter_format('svg')
from lora_pytorch import LoRA
assert torch.cuda.is_available()
from torchview import draw_graph
from torchviz import make_dot
from graphviz import Digraph

from pointcept.engines.defaults import (
    default_argument_parser,
    default_config_parser,
    default_setup,
)
from pointcept.engines.test import TESTERS
from pointcept.engines.launch import launch
from pointcept.engines.test import TesterBase, SemSegTester
from pointcept.models.point_prompt_training import PDNorm

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

repo_root = Path("../..")


class WeightFreezer:
    """
    Utility class for conditional, invertible freezing/unfreezing of model 
    weights with state tracking
    """
    def __init__(self, model: nn.Module) -> None:
        self.model = model
        self.original_states = {}
        self._store_initial_states()
        
    def _store_initial_states(self) -> None:
        for name, param in self.model.named_parameters():
            self.original_states[name] = param.requires_grad

    def freeze_if(self, filter_fn: ty.Callable[[str, nn.Parameter], bool] | None) -> None:
        filter_fn = filter_fn or (lambda n, p: True)
        for name, param in self.model.named_parameters():
            if filter_fn(name, param):
                param.requires_grad = False
    
    def freeze_all(self) -> None:
        return self.freeze_if(filter_fn=None)

    def unfreeze_if(
        self,
        filter_fn: ty.Callable[[str, nn.Parameter], bool] | None, 
        hard: bool = False
    ) -> None:
        """
        Defaults to restoring to original state if the filter_fn returns True,
        meaning if the initial model had certain parameters frozen, these will 
        faithfully still be frozen. Setting hard=True overrides this and unfreezes
        irrespective of the initial state.
        """
        filter_fn = filter_fn or (lambda n, p: True)
        for name, param in self.model.named_parameters():
            if filter_fn(name, param):
                if hard:
                    param.requires_grad = True
                else:
                    param.requires_grad = self.original_states.get(name, True)

    def unfreeze_all(self, hard: bool = False) -> None:
        return self.unfreeze_if(filter_fn=None, hard=hard)

    def reset(self) -> None:
        for name, param in self.model.named_parameters():
            param.requires_grad = self.original_states.get(name, True)

    def print_frozen_status(self, print_unfrozen: bool = False) -> None:
        for name, param in self.model.named_parameters():
            state = "unfrozen" if param.requires_grad else "frozen"
            if state == "unfrozen" and not print_unfrozen:
                continue
            print(f"{name}: {state}")


def count_trainable_parameters(model):
    return dict(
        trainable=sum(p.numel() for p in model.parameters() if p.requires_grad),
        frozen=sum(p.numel() for p in model.parameters() if not p.requires_grad)
    )


def named_trainable_parameters(model):
    return dict(
        trainable=[n for n, p in model.named_parameters() if p.requires_grad],
        frozen=[n for n, p in model.named_parameters() if not p.requires_grad]
    )

def is_lora(name: str, value: nn.Parameter) -> bool:
    return name_is_lora(name)


def filter_named_params(
    model: nn.Module,
    filter_fn: ty.Callable[[str, nn.Parameter], bool] | None
) -> tuple[str, nn.Parameter]:
    """
    generator which returns (parameter_name, weight tensor)
    for all tensors whose names match the filter function
    """
    for n, p in model.named_parameters():
        if filter_fn is None or filter_fn(n, p):
            yield n, p


get_named_lora_params = partial(filter_named_params, filter_fn=is_lora)
get_named_non_lora_params = partial(filter_named_params, filter_fn=(lambda x: not is_lora(x)))


def count_lora_parameters(model):
    """use minlora directly"""
    return sum(p.numel() for p in minlora.get_lora_params(model))


def count_lora_params_manual(model):
    """just looking at weight tensor names manually as a cross check"""
    return sum(p.numel() for n, p in get_named_lora_params(model))


def assert_lora_trainable(model):
    for param in minlora.get_lora_params(model):
        assert param.requires_grad


def configure_adamw_lora(
    model,
    weight_decay: float = 0.05,
    learning_rate: float = 0.005,
    betas: tuple[float, float] = (0.9, 0.999),
    device_type: str = "cuda"
) -> torch.optim.AdamW:
    """
    Create an AdamW optimiser which targets only LoRA parameters during
    gradient descent
    """
    # apply weight decay to all lora params
    optim_groups = [
        {"params": list(minlora.get_lora_params(model)) , "weight_decay": weight_decay},
        # could also add biases for fine-tuning,
        # {"params": minlora.get_bias_params(model), "weight_decay": 0.0}, # bias params don't get weight decay
    ]

    def parameter_count(optim_groups):
        n = sum(p.numel() for group in optim_groups for p in group["params"])
        if n < 1e6:
            return f"{n/1e3:.1f}k"
        else:
            return f"{n/1e6:.1f}M"

    logger.info(f"Optimizing {parameter_count(optim_groups)} parameters")

    # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
    use_fused = (device_type == "cuda") and ("fused" in inspect.signature(torch.optim.AdamW).parameters)
    logger.info(f"Using fused AdamW: {use_fused}")
    extra_args = dict(fused=True) if use_fused else dict()
    return torch.optim.AdamW(
        optim_groups,
        lr=learning_rate,
        betas=betas,
        **extra_args
    )


def total_optimized_params(optimizer: torch.optim.Optimizer) -> int:
    tot = 0
    for param_group in optimizer.param_groups:
        for param in param_group["params"]:
            tot += param.numel()
    return tot

    
def create_spoofed_input(batch_size=2, num_points=1000, n_classes=5, num_features=6, device='cpu'):
    return {
        'coord': torch.rand(num_points * batch_size, num_features, device=device),
        'feat': torch.rand(num_points * batch_size, num_features, device=device),
        'grid_coord': torch.randint(0, 100, (num_points * batch_size, 3), device=device),
        'batch': torch.arange(batch_size, device=device).repeat_interleave(num_points),
        'offset': torch.tensor([num_points * i for i in range(1, batch_size + 1)], device=device),
        'condition': ['ScanNet'] * batch_size,
        'grid_size': torch.tensor([0.01], device=device),
        'segment': torch.randint(low=0, high=n_classes-1, size=(num_points * batch_size,), device=device)
    }


def patch_cfg(cfg: dict, repo_root: Path = repo_root) -> dict:
    cfg = cfg.copy()
    cfg["my_data_root"] = repo_root / cfg["my_data_root"]
    cfg["weight"] = repo_root / cfg["weight"]
    cfg["batch_size_test_per_gpu"] = 1
    return cfg


repo_root = Path("../..")
cfg_file = Path("../../test/custom-ppt-config.py"); assert cfg_file.exists
device = "cuda"

args = default_argument_parser().parse_args(args=["--config-file", f"{cfg_file}"])
cfg = default_config_parser(args.config_file, args.options); cfg = patch_cfg(cfg)

tester = TESTERS.build(dict(type=cfg.test.type, cfg=cfg))
model = tester.model
model.to(device)

# make this once at start, otherwise i gotta make it a singleton to 
# avoid subsequent runs redefining the "initial state"
wf = WeightFreezer(model) 
print("loaded")

[2024-09-06 17:53:23,934 INFO test.py line 41 131730] => Loading config ...
[2024-09-06 17:53:23,935 INFO test.py line 48 131730] => Building model ...
[2024-09-06 17:53:26,779 INFO test.py line 61 131730] Num params: 97447088
[2024-09-06 17:53:27,016 INFO test.py line 68 131730] Loading weight at: ../../models/PointTransformerV3/scannet-semseg-pt-v3m1-1-ppt-extreme/model/model_best.pth
[2024-09-06 17:53:27,580 INFO test.py line 84 131730] => Loaded weight '../../models/PointTransformerV3/scannet-semseg-pt-v3m1-1-ppt-extreme/model/model_best.pth' (epoch 94)
[2024-09-06 17:53:27,584 INFO test.py line 53 131730] => Building test dataset & dataloader ...
[2024-09-06 17:53:27,586 INFO scannet.py line 72 131730] Totally 0 x 1 samples in val set.


loaded


# Inject new normalisation layers for new datasets

In [6]:
def expand_ppt_model_conditions(
    model: nn.Module,
    new_conditions: list[str],
    condition_mapping: dict[str, str] = None
) -> nn.Module:
    """
    Expands a trained PPT model to handle new conditions (datasets). The appropriate 
    normalisation layers are either copied from the trained norm layers corresponding 
    to existing datasets or are initialised randomly (as specified by condition_mapping).
    
    Args:
    - model: The trained PPT model
    - new_conditions: List of new condition names to add
    - condition_mapping: dict mapping new conditions to existing ones for weight initialisation
    
    Returns:
    - Updated model with expanded normalisation layers
    """
    if condition_mapping is None:
        condition_mapping = {}
    for condition in new_conditions:
        if condition not in condition_mapping:
            condition_mapping[condition] = None

    original_conditions = model.conditions
    model.conditions = tuple(list(original_conditions) + new_conditions)

    def expand_pdnorm(pdnorm):
        if isinstance(pdnorm, PDNorm) and pdnorm.decouple:
            first_norm = pdnorm.norm[0]
            if isinstance(first_norm, nn.BatchNorm1d):
                new_norm_func = lambda: type(first_norm)(
                    first_norm.num_features,
                    eps=first_norm.eps,
                    momentum=first_norm.momentum,
                    affine=first_norm.affine,
                    track_running_stats=first_norm.track_running_stats
                )
            elif isinstance(first_norm, nn.LayerNorm):
                new_norm_func = lambda: type(first_norm)(
                    first_norm.normalized_shape,
                    eps=first_norm.eps,
                    elementwise_affine=first_norm.elementwise_affine
                )
            else:
                raise ValueError(f"Unsupported normalization type: {type(first_norm)}")

            new_norms = [new_norm_func() for _ in new_conditions]
            
            for i, condition in enumerate(new_conditions):
                if condition_mapping[condition] in original_conditions:
                    source_idx = original_conditions.index(condition_mapping[condition])
                    new_norms[i].weight.data.copy_(pdnorm.norm[source_idx].weight.data)
                    new_norms[i].bias.data.copy_(pdnorm.norm[source_idx].bias.data)
                    if isinstance(new_norms[i], nn.BatchNorm1d):
                        new_norms[i].running_mean.copy_(pdnorm.norm[source_idx].running_mean)
                        new_norms[i].running_var.copy_(pdnorm.norm[source_idx].running_var)
                else:
                    # Initialize with random values
                    nn.init.normal_(new_norms[i].weight, mean=1.0, std=0.02)
                    nn.init.zeros_(new_norms[i].bias)
            
            pdnorm.norm.extend(new_norms)
            pdnorm.conditions = model.conditions
        return pdnorm

    def update_norm_layers(module):
        for name, child in module.named_children():
            if isinstance(child, PDNorm):
                setattr(module, name, expand_pdnorm(child))
            else:
                update_norm_layers(child)

    update_norm_layers(model)

    old_embed = model.embedding_table
    new_embed = nn.Embedding(len(model.conditions), old_embed.embedding_dim)
    nn.init.normal_(new_embed.weight, mean=0.0, std=0.02)
    new_embed.weight.data[:len(original_conditions)] = old_embed.weight.data
    
    for i, condition in enumerate(new_conditions):
        new_idx = len(original_conditions) + i
        if condition_mapping.get(condition) in original_conditions:
            source_idx = original_conditions.index(condition_mapping[condition])
            new_embed.weight.data[new_idx] = old_embed.weight.data[source_idx]
    
    model.embedding_table = new_embed

    return model

In [9]:
import torch
import torch.nn as nn
from pointcept.models.point_prompt_training import PDNorm

def test_ppt_model_expansion(model, new_conditions=["NewDataset1", "NewDataset2"], device="cuda"):
    """
    Test function to verify the correctness of PDNorm expansion in a PPT model.
    
    Args:
    - model: The original PPT model
    - new_conditions: List of new conditions to add (default: ["NewDataset1", "NewDataset2"])
    - device: The device to run the test on (default: "cuda")
    
    Returns:
    - None, but raises AssertionError if any test fails
    """
    # Ensure the model is on the specified device
    model = model.to(device)
    
    # Setup
    original_conditions = model.conditions
    condition_mapping = {
        "NewDataset1": "ScanNet",  # Copy from ScanNet
        "NewDataset2": None  # Random initialization
    }
    
    # Store original embedding weights
    original_embedding_weights = model.embedding_table.weight.clone()
    
    # Expand the model
    expanded_model = expand_ppt_model_conditions(model, new_conditions, condition_mapping)
    expanded_model = expanded_model.to(device)
    
    # Helper function to check if tensors are close
    def tensors_close(a, b, rtol=1e-5, atol=1e-8):
        return torch.allclose(a.to(device), b.to(device), rtol=rtol, atol=atol)
    
    # Test embedding table
    assert expanded_model.embedding_table.weight.shape[0] == len(original_conditions) + len(new_conditions), "Embedding table size mismatch"
    assert tensors_close(expanded_model.embedding_table.weight[:len(original_conditions)], original_embedding_weights), "Original embeddings changed"
    assert tensors_close(
        expanded_model.embedding_table.weight[len(original_conditions)], 
        original_embedding_weights[original_conditions.index("ScanNet")]
    ), "NewDataset1 embedding not copied correctly"
    
    def check_pdnorm_layers(module, prefix=''):
        for name, child in module.named_children():
            full_name = f"{prefix}.{name}" if prefix else name
            if isinstance(child, PDNorm):
                assert len(child.norm) == len(original_conditions) + len(new_conditions), f"PDNorm {full_name} size mismatch"
                
                # Get the corresponding PDNorm from the original model
                original_pdnorm = model
                for part in full_name.split('.'):
                    original_pdnorm = getattr(original_pdnorm, part)
                
                # Check parameters of original conditions
                for i, condition in enumerate(original_conditions):
                    assert tensors_close(child.norm[i].weight, original_pdnorm.norm[i].weight), f"Weight mismatch for {condition} in {full_name}"
                    assert tensors_close(child.norm[i].bias, original_pdnorm.norm[i].bias), f"Bias mismatch for {condition} in {full_name}"
                    assert child.norm[i].eps == original_pdnorm.norm[i].eps, f"Eps mismatch for {condition} in {full_name}"
                
                # Check parameters of new conditions
                scannet_idx = original_conditions.index("ScanNet")
                new_dataset1_idx = len(original_conditions)
                new_dataset2_idx = len(original_conditions) + 1
                
                # NewDataset1 should be copied from ScanNet
                if not tensors_close(child.norm[new_dataset1_idx].weight, child.norm[scannet_idx].weight):
                    print(f"NewDataset1 weight: {child.norm[new_dataset1_idx].weight}")
                    print(f"ScanNet weight: {child.norm[scannet_idx].weight}")
                    raise AssertionError(f"NewDataset1 weight not copied correctly in {full_name}")
                
                if not tensors_close(child.norm[new_dataset1_idx].bias, child.norm[scannet_idx].bias):
                    print(f"NewDataset1 bias: {child.norm[new_dataset1_idx].bias}")
                    print(f"ScanNet bias: {child.norm[scannet_idx].bias}")
                    raise AssertionError(f"NewDataset1 bias not copied correctly in {full_name}")
                
                # NewDataset2 should be randomly initialized
                assert not tensors_close(child.norm[new_dataset2_idx].weight, child.norm[scannet_idx].weight, rtol=1e-3, atol=1e-3), f"NewDataset2 weight should not match ScanNet in {full_name}"
                
                # Check that NewDataset2 is properly initialized
                assert torch.allclose(child.norm[new_dataset2_idx].weight.mean(), torch.tensor(1.0, device=device), rtol=1e-1), f"NewDataset2 weight not properly initialized in {full_name}"
                assert torch.allclose(child.norm[new_dataset2_idx].bias.mean(), torch.tensor(0.0, device=device), rtol=1e-1), f"NewDataset2 bias not properly initialized in {full_name}"
                
                # Check eps and other parameters
                for i in range(len(child.norm)):
                    assert child.norm[i].eps == child.norm[0].eps, f"Eps mismatch in {full_name} for layer {i}"
                    if isinstance(child.norm[i], nn.BatchNorm1d):
                        assert child.norm[i].momentum == child.norm[0].momentum, f"Momentum mismatch in {full_name} for layer {i}"
                        assert child.norm[i].affine == child.norm[0].affine, f"Affine mismatch in {full_name} for layer {i}"
                        assert child.norm[i].track_running_stats == child.norm[0].track_running_stats, f"Track_running_stats mismatch in {full_name} for layer {i}"
                    elif isinstance(child.norm[i], nn.LayerNorm):
                        assert child.norm[i].elementwise_affine == child.norm[0].elementwise_affine, f"Elementwise_affine mismatch in {full_name} for layer {i}"
            else:
                check_pdnorm_layers(child, full_name)
    
    # Run the recursive check
    check_pdnorm_layers(expanded_model)
    
    print("All tests passed successfully!")

# Usage:
# test_ppt_model_expansion(your_model, device="cuda")  # or "cpu" if preferred

In [10]:
test_ppt_model_expansion(model)

All tests passed successfully!


In [11]:
# Assume 'trained_model' is your PPT model trained on the original datasets

# Define new conditions and mapping
new_conditions = ["NewDataset1", "NewDataset2"]
condition_mapping = {
    "NewDataset1": "ScanNet",  # Initialize NewDataset1 with ScanNet's parameters
    "NewDataset2": None  # Initialize NewDataset2 with default initialization
}

# Expand the model
expanded_model = expand_ppt_model_conditions(model, new_conditions, condition_mapping)

In [12]:
expanded_model

PointPromptTraining(
  (backbone): PointTransformerV3(
    (embedding): Embedding(
      (stem): PointSequential(
        (conv): SubMConv3d(6, 48, kernel_size=[5, 5, 5], stride=[1, 1, 1], padding=[1, 1, 1], dilation=[1, 1, 1], output_padding=[0, 0, 0], bias=False, algo=ConvAlgo.Native)
        (norm): PDNorm(
          (norm): ModuleList(
            (0-8): 9 x BatchNorm1d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          )
        )
        (act): GELU(approximate='none')
      )
    )
    (enc): PointSequential(
      (enc0): PointSequential(
        (block0): Block(
          (cpe): PointSequential(
            (0): SubMConv3d(48, 48, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1], output_padding=[0, 0, 0], algo=ConvAlgo.MaskImplicitGemm)
            (1): Linear(in_features=48, out_features=48, bias=True)
            (2): PDNorm(
              (norm): ModuleList(
                (0-8): 9 x LayerNorm((48,), eps=0.001, elem

# Visualise netron

In [3]:
#torch.save(model, "model.pth")

Now install netron and open this file:

```bash
snap install netron
snap run netron
```

# LoRA

### minlora implementation

Quick test run to see things look reasonable

In [2]:
# lora adapter hyperparameters
lora_hparams = dict(
    lora_dropout_p = 0.0,
    rank=10,
    lora_alpha = 64
)

# optimizer hyperparameters
weight_decay = 0.05
learning_rate = 0.005
beta1, beta2 = 0.9, 0.999#0.95
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

lora_config = {
    torch.nn.Embedding: {
        "weight": partial(LoRAParametrization.from_embedding, **lora_hparams),
    },
    torch.nn.Linear: {
        "weight": partial(LoRAParametrization.from_linear, **lora_hparams),
    },
    SubMConv3d: {
        "weight": partial(LoRAParametrization.from_sparseconv3d, **lora_hparams),
    }
}

print("# params before LoRA:", count_trainable_parameters(model))
wf.print_frozen_status()
print("freezing all weights")
wf.freeze_all()
print("# params after freezing:", count_trainable_parameters(model))
print("applying LoRA adapters")
minlora.add_lora(model, lora_config=lora_config)

lora_trainable_params = count_trainable_parameters(model)["trainable"]
print("# params after LoRA:", lora_trainable_params)

# create AdamW optimizer (for LoRA weights only)
optimizer = configure_adamw_lora(
    model,
    weight_decay,
    learning_rate,
    (beta1, beta2),
    device_type
)

print("performing cross checks")
# check all lora parameters trainable
assert_lora_trainable(model)
# check manual lora parameter counting against minlora to check that it matches
assert count_lora_parameters(model) == count_lora_params_manual(model)
# cross check with lora params with gradient enabled
assert total_optimized_params(optimizer) == lora_trainable_params

print("restoring initial model state")
# remove adapters and unfreeze weights to original state
minlora.remove_lora(model)
wf.unfreeze_all()

print("# trainable params after removing lora:", count_trainable_parameters(model))
wf.print_frozen_status()

# if use_lora:
#     optimizer = configure_optimizers_lora(model, weight_decay, learning_rate, (beta1, beta2), device_type)
# else:
#     optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
# if init_from == 'resume':
#     optimizer.load_state_dict(checkpoint['optimizer'])


INFO:__main__:Optimizing 3.3M parameters
INFO:__main__:Using fused AdamW: True


# params before LoRA: {'trainable': 97447088, 'frozen': 1}
logit_scale: frozen
freezing all weights
# params after freezing: {'trainable': 0, 'frozen': 97447089}
applying LoRA adapters
# params after LoRA: 3314890
performing cross checks
restoring initial model state
# trainable params after removing lora: {'trainable': 97447088, 'frozen': 1}
logit_scale: frozen


## Test fwd/backward passes with LoRA

Now freeze + apply LoRA again and test a bit on dummy data:

In [3]:
wf.freeze_all()
minlora.add_lora(model, lora_config=lora_config)

In [6]:
def inspect_lora_gradients(model, x, num_steps=5):
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    def check_grads():
        a_no_grad, b_no_grad = [], []
        a_with_grad, b_with_grad = 0, 0
        total_a, total_b = 0, 0
        trainable_params_with_grad = 0
        frozen_params = 0
        total_params = 0

        for name, param in model.named_parameters():
            total_params += param.numel()
            if not param.requires_grad:
                frozen_params += param.numel()
            elif param.grad is not None and torch.any(param.grad != 0):
                trainable_params_with_grad += param.numel()

            if 'lora_A' in name:
                total_a += 1
                if param.grad is None or torch.all(param.grad == 0):
                    a_no_grad.append(name)
                else:
                    a_with_grad += 1
            elif 'lora_B' in name:
                total_b += 1
                if param.grad is None or torch.all(param.grad == 0):
                    b_no_grad.append(name)
                else:
                    b_with_grad += 1

        return (a_with_grad, b_with_grad, total_a, total_b, a_no_grad, b_no_grad, 
                trainable_params_with_grad, frozen_params, total_params)

    # Initial forward and backward pass
    y = model(x)
    loss = y["loss"].sum()
    loss.backward()
    
    results = check_grads()
    (
        a_grad,
        b_grad,
        total_a,
        total_b,
        a_no_grad,
        b_no_grad,
        trainable_grad,
        frozen,
        total
    ) = results

    print("*** First Pass ***")
    print(f"Initial gradients: A: {a_grad}/{total_a}, B: {b_grad}/{total_b}")
    print(f"Trainable parameters with gradients: {trainable_grad:,}")
    print(f"Frozen parameters: {frozen:,}")
    print(f"Total parameters: {total:,}")
    if a_no_grad:
        print(f"Total A matrices without gradients: {len(a_no_grad)}")
    if b_no_grad:
        print(f"Total B matrices without gradients: {len(b_no_grad)}")

    # Perform several optimization steps
    for i in range(num_steps):
        optimizer.step()
        optimizer.zero_grad()
        
        y = model(x)
        loss = y["loss"].sum()
        loss.backward()
        
        results = check_grads()
        a_grad, b_grad, total_a, total_b, a_no_grad, b_no_grad, trainable_grad, frozen, total = results

        print(f"\nGradients after step {i+1}:")
        print(f"A: {a_grad}/{total_a}, B: {b_grad}/{total_b}")
        print(f"Trainable parameters with gradients: {trainable_grad:,}")
        print(f"Frozen parameters: {frozen:,}")
        print(f"Total parameters: {total:,}")
        if a_no_grad:
            print(f"A matrices without gradients: {a_no_grad}")
        if b_no_grad:
            print(f"B matrices without gradients: {b_no_grad}")
            
X = create_spoofed_input(device="cuda", batch_size=16)
inspect_lora_gradients(model, X)

*** First Pass ***
Initial gradients: A: 0/195, B: 194/195
Trainable parameters with gradients: 808,320
Frozen parameters: 97,447,089
Total parameters: 100,761,979
Total A matrices without gradients: 195
Total B matrices without gradients: 1

Gradients after step 1:
A: 194/195, B: 194/195
Trainable parameters with gradients: 3,312,300
Frozen parameters: 97,447,089
Total parameters: 100,761,979
A matrices without gradients: ['embedding_table.parametrizations.weight.0.lora_A']
B matrices without gradients: ['embedding_table.parametrizations.weight.0.lora_B']

Gradients after step 2:
A: 194/195, B: 194/195
Trainable parameters with gradients: 3,312,300
Frozen parameters: 97,447,089
Total parameters: 100,761,979
A matrices without gradients: ['embedding_table.parametrizations.weight.0.lora_A']
B matrices without gradients: ['embedding_table.parametrizations.weight.0.lora_B']

Gradients after step 3:
A: 194/195, B: 194/195
Trainable parameters with gradients: 3,312,300
Frozen parameters: 97

In [4]:
def showlora(model):
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d, nn.MultiheadAttention)):
            print(f"Module {name}:")
            if hasattr(module, 'parametrizations'):
                for param_name, param in module.parametrizations.items():
                    print(f"  - {param_name} LoRA parameters:")
                    for lora_name, lora_param in param.named_parameters():
                        print(f"    - {lora_name}: device = {lora_param.device}")
            elif isinstance(module, nn.MultiheadAttention):
                if hasattr(module.out_proj, 'parametrizations'):
                    for param_name, param in module.out_proj.parametrizations.items():
                        print(f"  - out_proj.{param_name} LoRA parameters:")
                        for lora_name, lora_param in param.named_parameters():
                            print(f"    - {lora_name}: device = {lora_param.device}")

showlora(model)

Module backbone.enc.enc0.block0.cpe.1:
  - weight LoRA parameters:
    - original: device = cuda:0
    - 0.lora_A: device = cuda:0
    - 0.lora_B: device = cuda:0
Module backbone.enc.enc0.block0.attn.qkv:
  - weight LoRA parameters:
    - original: device = cuda:0
    - 0.lora_A: device = cuda:0
    - 0.lora_B: device = cuda:0
Module backbone.enc.enc0.block0.attn.proj:
  - weight LoRA parameters:
    - original: device = cuda:0
    - 0.lora_A: device = cuda:0
    - 0.lora_B: device = cuda:0
Module backbone.enc.enc0.block0.mlp.0.fc1:
  - weight LoRA parameters:
    - original: device = cuda:0
    - 0.lora_A: device = cuda:0
    - 0.lora_B: device = cuda:0
Module backbone.enc.enc0.block0.mlp.0.fc2:
  - weight LoRA parameters:
    - original: device = cuda:0
    - 0.lora_A: device = cuda:0
    - 0.lora_B: device = cuda:0
Module backbone.enc.enc0.block1.cpe.1:
  - weight LoRA parameters:
    - original: device = cuda:0
    - 0.lora_A: device = cuda:0
    - 0.lora_B: device = cuda:0
Module 

In [8]:
#torch.save(model, "model_minlora.pth")

RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See:
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training

No biggie, just can't visualise this using netron as a consequence