In [None]:
import os
import sys
import cv2
import numpy as np
import torch
import pickle
from matplotlib import pyplot as plt


### Model

In [None]:
from omegaconf import OmegaConf
from scripts.inference import load_model_from_config
from ldm.util import instantiate_from_config

In [None]:
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

In [None]:
config = OmegaConf.load("configs/nusc.yaml")
model = load_model_from_config(config, "checkpoints/model.ckpt")

In [None]:
model

### Checkpoint

In [None]:
import torch

In [None]:
model1 = torch.load("checkpoints/model.ckpt", map_location="cpu")['state_dict']
model2 = torch.load("models/Paint-by-Example/2024-03-21T21-08-02_nusc/checkpoints/last.ckpt", map_location="cpu")['state_dict']

In [None]:
model1['learnable_vector'].to(torch.int32)

In [None]:
for k in model1.keys():
    if 'cond_stage_model' in k:
        print(k)

In [None]:
for k in model1.keys():
    if k not in model2.keys():
        print(f"{k} not in model2")
    elif not torch.equal(model1[k], model2[k]):
        print(f"{k} is not equal")