In [1]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import random
from PIL import Image

import sys
sys.path.insert(0, "../mlflow-scripts")
from model import get_depthpro_model, LightningModel
from dataloader import get_dataloaders, Urban100Dataset, collate_fn



In [2]:
# model_path = "best_model.pth"  
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#model = torch.load(model_path, map_location=device, weights_only=False)
# _ = model.eval()  


model = LightningModel(get_depthpro_model(32))
checkpoint = torch.load("../model_weights_10.pth", map_location="cpu")
state_dict = checkpoint # ["state_dict"]  # Lightning 会把实际模型参数放在这个 key 下
model.load_state_dict(state_dict)
model.eval()  

Some weights of DepthProForDepthEstimation were not initialized from the model checkpoint at geetu040/DepthPro and are newly initialized: ['depth_pro.encoder.feature_projection.projections.4.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.3.2.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.4.3.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.5.4.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.0.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.1.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.2.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.3.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.4.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.5.weight', 'fusion_stage.layers.4.deconv.weight', 'fusion_stage.layers.5.projection.bias', 'fusion_stage.layers.5.projection.weight', 'fusion_stage.layers.5.residual_layer1.convolution1.bias', 'fusion_stage.layers.5.residual_l

LightningModel(
  (model): DepthProForSuperResolution(
    (depthpro_for_depth_estimation): DepthProForDepthEstimation(
      (depth_pro): DepthProModel(
        (encoder): DepthProEncoder(
          (patch_encoder): DepthProViT(
            (embeddings): DepthProViTEmbeddings(
              (patch_embeddings): DepthProViTPatchEmbeddings(
                (projection): Conv2d(3, 1024, kernel_size=(4, 4), stride=(4, 4))
              )
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (encoder): DepthProViTEncoder(
              (layer): ModuleList(
                (0-3): 4 x DepthProViTLayer(
                  (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
                  (attention): DepthProViTSdpaAttention(
                    (attention): DepthProViTSdpaSelfAttention(
                      (query): Linear(in_features=1024, out_features=1024, bias=True)
                      (key): Linear(in_features=1024, out_features=1024, bias=True

In [3]:
test_dataset = Urban100Dataset()
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0
)


lr_images_path: ['../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_009_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_049_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_078_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_087_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_080_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_095_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_038_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_092_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_048_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_008_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_093_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_039_SRF_4_LR.png', '../dataset/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_094_SRF_4_LR.p

In [4]:
def offline_eval(batch):
    with torch.no_grad():
        lr, hr = batch  # adjust depending on dataset output
        lr = lr.to("cpu")
        hr = hr.to("cpu")
        sr = model(lr)
        sr = F.interpolate(sr, size=hr.shape[2:])
        return model.mse(sr, hr), model.psnr(sr, hr), model.ssim(sr, hr), model.snr(sr, hr)


mse_list, psnr_list, ssim_list, snr_list = [], [], [], []
count = 0
for batch in test_loader:
    print(f"processing img {count} out of {test_dataset.__len__()}")
    mse, psnr, ssim, snr = offline_eval(batch)
    count += 1
    mse_list.append(mse)
    psnr_list.append(psnr)
    ssim_list.append(ssim)
    snr_list.append(snr)

print("mse:", np.mean(mse_list))
print("psnr:", np.mean(psnr_list))
print("ssim:", np.mean(ssim_list))
print("snr:", np.mean(snr_list))

processing img 0 out of 100
processing img 1 out of 100
processing img 2 out of 100
processing img 3 out of 100
processing img 4 out of 100
processing img 5 out of 100
processing img 6 out of 100
processing img 7 out of 100
processing img 8 out of 100
processing img 9 out of 100
processing img 10 out of 100
processing img 11 out of 100
processing img 12 out of 100
processing img 13 out of 100
processing img 14 out of 100
processing img 15 out of 100
processing img 16 out of 100
processing img 17 out of 100
processing img 18 out of 100
processing img 19 out of 100
processing img 20 out of 100
processing img 21 out of 100
processing img 22 out of 100
processing img 23 out of 100
processing img 24 out of 100
processing img 25 out of 100
processing img 26 out of 100
processing img 27 out of 100
processing img 28 out of 100
processing img 29 out of 100
processing img 30 out of 100
processing img 31 out of 100
processing img 32 out of 100
processing img 33 out of 100
processing img 34 out of