In [1]:
import os
import math
from pathlib import Path
from collections import OrderedDict

import torch
import torch.nn as nn
import numpy as np
import webbrowser
import graphviz
graphviz.set_jupyter_format('svg')
from lora_pytorch import LoRA
assert torch.cuda.is_available()
from torchview import draw_graph
from torchviz import make_dot
from graphviz import Digraph

from pointcept.engines.defaults import (
    default_argument_parser,
    default_config_parser,
    default_setup,
)
from pointcept.engines.test import TESTERS
from pointcept.engines.launch import launch
from pointcept.engines.test import TesterBase, SemSegTester

repo_root = Path("../..")


def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def create_spoofed_input(batch_size=2, num_points=1000, n_classes=5, num_features=6, device='cpu'):
    return {
        'coord': torch.rand(num_points * batch_size, num_features, device=device),
        'feat': torch.rand(num_points * batch_size, num_features, device=device),
        'grid_coord': torch.randint(0, 100, (num_points * batch_size, 3), device=device),
        'batch': torch.arange(batch_size, device=device).repeat_interleave(num_points),
        'offset': torch.tensor([num_points * i for i in range(1, batch_size + 1)], device=device),
        'condition': ['ScanNet'] * batch_size,
        'grid_size': torch.tensor([0.01], device=device),
        'segment': torch.randint(low=0, high=n_classes-1, size=(num_points * batch_size,), device=device)
    }


def patch_cfg(cfg: dict, repo_root: Path = repo_root) -> dict:
    cfg = cfg.copy()
    cfg["my_data_root"] = repo_root / cfg["my_data_root"]
    cfg["weight"] = repo_root / cfg["weight"]
    cfg["batch_size_test_per_gpu"] = 1
    return cfg


repo_root = Path("../..")
cfg_file = Path("../../test/custom-ppt-config.py"); assert cfg_file.exists

args = default_argument_parser().parse_args(args=["--config-file", f"{cfg_file}"])
cfg = default_config_parser(args.config_file, args.options); cfg = patch_cfg(cfg)

tester = TESTERS.build(dict(type=cfg.test.type, cfg=cfg))
model = tester.model

[2024-08-23 21:15:46,222 INFO test.py line 41 81188] => Loading config ...
[2024-08-23 21:15:46,223 INFO test.py line 48 81188] => Building model ...


proj_head shape says Linear(in_features=64, out_features=512, bias=True)


[2024-08-23 21:15:49,005 INFO test.py line 61 81188] Num params: 97447088
[2024-08-23 21:15:49,213 INFO test.py line 68 81188] Loading weight at: ../../models/PointTransformerV3/scannet-semseg-pt-v3m1-1-ppt-extreme/model/model_best.pth
[2024-08-23 21:15:49,859 INFO test.py line 84 81188] => Loaded weight '../../models/PointTransformerV3/scannet-semseg-pt-v3m1-1-ppt-extreme/model/model_best.pth' (epoch 94)
[2024-08-23 21:15:49,864 INFO test.py line 53 81188] => Building test dataset & dataloader ...
[2024-08-23 21:15:49,866 INFO scannet.py line 72 81188] Totally 0 x 1 samples in val set.


DITCHING CLASS EMBEDDING


# Visualise netron

In [3]:
torch.save(model, "model.pth")

Now install netron and open this file:

```bash
snap install netron
snap run netron
```

# LoRA

Compare implementations, need to understand wtf is going on in lora-pytorch with the custom MHA implementation and why this isn't necessary in the pytora + claude-generated solutions

maybe need to compare param counts and see the difference in their behaviour

### lora-pytorch implementation

In [6]:
lora_model = LoRA.from_module(model, rank=50)
print("bare model: ", count_trainable_parameters(model))
print("lora:", count_trainable_parameters(lora_model))
torch.save(model, "model_lora.pth")

bare model:  110759388
lora: 13312300


### pytora implementation

TODO

### custom implementation (claude, unchecked but it runs lol)

In [3]:
class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.scale = 0.01
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        return (x @ self.lora_A.T @ self.lora_B.T) * self.scale

class AdaptiveLoRAWrapper(nn.Module):
    def __init__(self, base_layer, rank=4):
        super().__init__()
        self.base_layer = base_layer
        if hasattr(base_layer, 'weight'):
            weight = base_layer.weight
            in_features, out_features = weight.shape[1], weight.shape[0]
        elif hasattr(base_layer, 'in_features') and hasattr(base_layer, 'out_features'):
            in_features, out_features = base_layer.in_features, base_layer.out_features
        else:
            raise ValueError(f"Unable to determine in_features and out_features for {type(base_layer)}")
        self.lora = LoRALayer(in_features, out_features, rank)

    def forward(self, x):
        return self.base_layer(x) + self.lora(x)

def get_in_out_features(layer):
    if hasattr(layer, 'in_features') and hasattr(layer, 'out_features'):
        return layer.in_features, layer.out_features
    elif hasattr(layer, 'weight'):
        return layer.weight.shape[1], layer.weight.shape[0]
    else:
        raise ValueError(f"Unable to determine in_features and out_features for {type(layer)}")

class LoRAQKV(nn.Module):
    def __init__(self, qkv_layer, rank=4):
        super().__init__()
        self.qkv_layer = qkv_layer
        in_features, out_features = get_in_out_features(qkv_layer)
        self.lora = LoRALayer(in_features, out_features, rank)

    def forward(self, x):
        return self.qkv_layer(x) + self.lora(x)

def apply_lora_to_ptv3(model, rank=4):
    for name, module in model.named_modules():
        if isinstance(module, SerializedAttention):
            module.qkv = LoRAQKV(module.qkv, rank)
            module.proj = AdaptiveLoRAWrapper(module.proj, rank)
        elif isinstance(module, MLP):
            module.fc1 = AdaptiveLoRAWrapper(module.fc1, rank)
            module.fc2 = AdaptiveLoRAWrapper(module.fc2, rank)

def apply_lora_to_ppt(model, rank=4):
    # Apply LoRA to PT-v3 backbone
    apply_lora_to_ptv3(model.backbone, rank)
    
    # Apply LoRA to the projection head
    model.proj_head = AdaptiveLoRAWrapper(model.proj_head, rank)

    def freeze_non_lora_params(model):
        for name, param in model.named_parameters():
            if 'lora' not in name:
                param.requires_grad = False

    freeze_non_lora_params(model)
    return model

# Usage:
# ppt_model = PointPromptTraining(...)
# ppt_model_with_lora = apply_lora_to_ppt(ppt_model)

In [8]:
ppt_model_with_lora = apply_lora_to_ppt(model) 

In [7]:
from pointcept.models.point_transformer_v3 import SerializedAttention, MLP

In [10]:
count_trainable_parameters(ppt_model_with_lora)

453888

In [9]:
ppt_model_with_lora

PointPromptTraining(
  (backbone): PointTransformerV3(
    (embedding): Embedding(
      (stem): PointSequential(
        (conv): SubMConv3d(6, 48, kernel_size=[5, 5, 5], stride=[1, 1, 1], padding=[1, 1, 1], dilation=[1, 1, 1], output_padding=[0, 0, 0], bias=False, algo=ConvAlgo.Native)
        (norm): PDNorm(
          (norm): ModuleList(
            (0-2): 3 x BatchNorm1d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          )
        )
        (act): GELU(approximate='none')
      )
    )
    (enc): PointSequential(
      (enc0): PointSequential(
        (block0): Block(
          (cpe): PointSequential(
            (0): SubMConv3d(48, 48, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1], output_padding=[0, 0, 0], algo=ConvAlgo.MaskImplicitGemm)
            (1): Linear(in_features=48, out_features=48, bias=True)
            (2): PDNorm(
              (norm): ModuleList(
                (0-2): 3 x LayerNorm((48,), eps=0.001, elem