## how to load network weights from selfsup to supervised model

1) Load the supervised model 
2) copy the weights to the supervised model

In [None]:
#| default_exp networks/selfsup_utils

In [None]:
#| export 
import torch 
from loguru import logger 

In [15]:
import torch 
import fastcore.all as fc

from mmengine.config import Config
from voxdet.utils import locate_cls
from voxdet.networks.monai_retina3d import retina_detector

In [40]:
cfg = Config.fromfile("../../configs/lidc/exp_2_convnextv2.py")

In [41]:
model = retina_detector(cfg)

Loading backbone from MedCT


In [42]:
model.network.feature_extractor.body

ConvNextV2Model3d(
  (embeddings): ConvNextV2Embeddings3d(
    (patch_embeddings): Conv3d(2, 40, kernel_size=(2, 4, 4), stride=(2, 4, 4))
    (layernorm): ConvNextV2LayerNorm3d()
  )
  (encoder): ConvNextV2Encoder3d(
    (stages): ModuleList(
      (0): ConvNextV2Stage3d(
        (downsampling_layer): Identity()
        (layers): Sequential(
          (0): ConvNextV2Layer3d(
            (dwconv): Conv3d(40, 40, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=40)
            (layernorm): ConvNextV2LayerNorm3d()
            (pwconv1): Linear(in_features=40, out_features=160, bias=True)
            (act): GELUActivation()
            (grn): ConvNextV2GRN3d()
            (pwconv2): Linear(in_features=160, out_features=40, bias=True)
            (drop_path): Identity()
          )
          (1): ConvNextV2Layer3d(
            (dwconv): Conv3d(40, 40, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=40)
            (layernorm): ConvNextV2LayerNorm3d()
   

> Load the model weights from selfsup 

In [14]:
weights = torch.load("../../resources/selfsup/exp1_epoch=744-step=168370-val_rloss=1.235.ckpt")
weights.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters', 'cfg'])

In [17]:
keys = fc.L(weights["state_dict"].keys())
keys

(#73) ['beddings.mask_token','beddings.patch_embeddings.weight','beddings.patch_embeddings.bias','beddings.layernorm.weight','beddings.layernorm.bias','ncoder.stages.0.layers.0.dwconv.weight','ncoder.stages.0.layers.0.dwconv.bias','ncoder.stages.0.layers.0.layernorm.weight','ncoder.stages.0.layers.0.layernorm.bias','ncoder.stages.0.layers.0.pwconv1.weight'...]

In [23]:
keys[1].split(".", 1)[1]

'patch_embeddings.weight'

In [34]:
new_dict = {}
for key, value in weights["state_dict"].items():
    base, key = key.split(".", 1)
    if base == "beddings": new_dict["embeddings."+key] = value
    if base == "ncoder": new_dict["encoder."+key] = value
    if base == "ayernorm": new_dict["layernorm."+key] = value
    if base == "coder": new_dict["encoder."+key] = value

In [38]:
torch.save(new_dict, "../../resources/selfsup/exp1_cleaned_epoch=744-step=168370-val_rloss=1.235.ckpt")

In [44]:
model_dict = model.network.feature_extractor.body.state_dict()
network_keys = fc.L(model_dict.keys())
network_keys

(#70) ['embeddings.patch_embeddings.weight','embeddings.patch_embeddings.bias','embeddings.layernorm.weight','embeddings.layernorm.bias','encoder.stages.0.layers.0.dwconv.weight','encoder.stages.0.layers.0.dwconv.bias','encoder.stages.0.layers.0.layernorm.weight','encoder.stages.0.layers.0.layernorm.bias','encoder.stages.0.layers.0.pwconv1.weight','encoder.stages.0.layers.0.pwconv1.bias'...]

In [30]:
network_keys[-5:]

(#5) ['encoder.stages.1.layers.2.grn.bias','encoder.stages.1.layers.2.pwconv2.weight','encoder.stages.1.layers.2.pwconv2.bias','layernorm.weight','layernorm.bias']

In [53]:
for k in network_keys:
    m = new_dict[k]
    n = model_dict[k]
    if m.shape == n.shape:continue 
    print(k, m.shape, n.shape)
    break

embeddings.patch_embeddings.weight torch.Size([40, 2, 4, 8, 8]) torch.Size([40, 2, 2, 4, 4])


In [55]:
patch_pos_embed = torch.nn.functional.interpolate(
            m, size=n.shape[2:], mode="trilinear", align_corners=False
        )
patch_pos_embed.shape

torch.Size([40, 2, 2, 4, 4])

In [58]:
final_weights = {}
for k in network_keys:
    m = new_dict[k]
    n = model_dict[k]
    if m.shape == n.shape:
        final_weights[k] = m
        continue
    new_m = torch.nn.functional.interpolate(
            m, size=n.shape[2:], mode="trilinear", align_corners=False
        )
    final_weights[k] = new_m

In [60]:
model.network.feature_extractor.body.load_state_dict(final_weights)
print("Loaded")

Loaded


```
 ┃ Name                   ┃ Type                             ┃ Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ model                  │ ConvNextV2ForMaskedImageModeling │  699 K │
│ 1 │ model.model            │ ConvNextV2Model3d                │  367 K │
│ 2 │ model.model.embeddings │ ConvNextV2Embeddings3d           │ 20.6 K │
│ 3 │ model.model.encoder    │ ConvNextV2Encoder3d              │  347 K │
│ 4 │ model.model.layernorm  │ LayerNorm                        │    160 │
│ 5 │ model.decoder          │ Sequential                       │  331 K │
│ 6 │ model.decoder.0        │ Conv3d                           │  331 K │
│ 7 │ model.decoder.1        │ PixelShuffle3d                   │      0 │
```

In [65]:
#| export 
def load_from_selfup_retina(seflsup_weight_loc, model):
    weights = torch.load(seflsup_weight_loc)
    if "state_dict" in weights.keys(): weights = weights["state_dict"]
    
    model_dict = model.network.feature_extractor.body.state_dict()
    network_keys = model_dict.keys()
    
    final_weights = {}
    for k in network_keys:
        if k not in weights.keys():
            logger.warn(f"{k} not in weights")
        if k not in model_dict.keys():
            logger.warn(f"{k} not in model_dict")
        m = weights[k]
        n = model_dict[k]
        logger.info(f"mapping weights {k}-{m.shape}")
        if m.shape == n.shape:
            final_weights[k] = m
            continue
        logger.info(f"resizing weights {k} from {m.shape} to {n.shape}")
        new_m = torch.nn.functional.interpolate(
                m, size=n.shape[2:], mode="trilinear", align_corners=False
            )
        final_weights[k] = new_m
    return final_weights

In [66]:
#| hide
import nbdev; nbdev.nbdev_export()