In [6]:
import torch
from torch import nn

import shutil

from compressai.zoo import load_state_dict, models

In [7]:
def load_checkpoint(arch: str, checkpoint_path: str, strict=True) -> nn.Module:
    state_dict = load_state_dict(
        torch.load(checkpoint_path, map_location=torch.device("cpu"))["state_dict"]
    )
    return models[arch].from_state_dict(state_dict, strict)

In [8]:
def save_checkpoint(state, is_best, filename):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, filename[:-8] + "_best" + filename[-8:])

In [9]:
pretrain_ckpt = "./ckpt/cnn_025.pth.tar"
save_ckpt = "./ckpt/lora_cnn_025.pth.tar"

In [18]:
model = load_checkpoint("cnn", pretrain_ckpt, False)
lora_model = load_checkpoint("lora_cnn", pretrain_ckpt, False)

param_dict = {n: p for n, p in model.named_parameters()}
lora_dict = lora_model.state_dict()

for n, p in lora_dict.items():
    if n.find(".conv.weight") != -1:
        lora_dict[n] = param_dict[n.replace(".conv.weight", ".weight")]
    if n.find(".conv.bias") != -1:
        lora_dict[n] = param_dict[n.replace(".conv.bias", ".bias")]

for n in param_dict.keys():
    print(n)

lora_dict = {n: p for n, p in lora_dict.items() if n.find("lora") == -1}
for n in lora_dict.keys():
    print(n)

g_a.0.weight
g_a.0.bias
g_a.1.beta
g_a.1.gamma
g_a.2.weight
g_a.2.bias
g_a.3.beta
g_a.3.gamma
g_a.4.conv_a.0.conv.0.weight
g_a.4.conv_a.0.conv.0.bias
g_a.4.conv_a.0.conv.2.weight
g_a.4.conv_a.0.conv.2.bias
g_a.4.conv_a.0.conv.4.weight
g_a.4.conv_a.0.conv.4.bias
g_a.4.conv_a.1.conv.0.weight
g_a.4.conv_a.1.conv.0.bias
g_a.4.conv_a.1.conv.2.weight
g_a.4.conv_a.1.conv.2.bias
g_a.4.conv_a.1.conv.4.weight
g_a.4.conv_a.1.conv.4.bias
g_a.4.conv_a.2.conv.0.weight
g_a.4.conv_a.2.conv.0.bias
g_a.4.conv_a.2.conv.2.weight
g_a.4.conv_a.2.conv.2.bias
g_a.4.conv_a.2.conv.4.weight
g_a.4.conv_a.2.conv.4.bias
g_a.4.conv_b.0.attn.relative_position_bias_table
g_a.4.conv_b.0.attn.qkv.weight
g_a.4.conv_b.0.attn.qkv.bias
g_a.4.conv_b.0.attn.proj.weight
g_a.4.conv_b.0.attn.proj.bias
g_a.4.conv_b.1.conv.0.weight
g_a.4.conv_b.1.conv.0.bias
g_a.4.conv_b.1.conv.2.weight
g_a.4.conv_b.1.conv.2.bias
g_a.4.conv_b.1.conv.4.weight
g_a.4.conv_b.1.conv.4.bias
g_a.4.conv_b.2.conv.0.weight
g_a.4.conv_b.2.conv.0.bias
g_a.4.c

In [20]:

save_checkpoint({"state_dict":lora_dict}, False, save_ckpt)