In [2]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import random
from PIL import Image

import sys
sys.path.insert(0, "../mlflow-scripts")
from model import get_depthpro_model, LightningModel
from dataloader import get_dataloaders, Urban100Dataset, collate_fn

In [3]:
# model_path = "best_model.pth"  
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#model = torch.load(model_path, map_location=device, weights_only=False)
# _ = model.eval()  


model = LightningModel(get_depthpro_model(32))
checkpoint = torch.load("../model_weights_10.pth", map_location="cpu")
state_dict = checkpoint # ["state_dict"]  # Lightning 会把实际模型参数放在这个 key 下
model.load_state_dict(state_dict)
model.eval()  

Some weights of DepthProForDepthEstimation were not initialized from the model checkpoint at geetu040/DepthPro and are newly initialized: ['depth_pro.encoder.feature_projection.projections.4.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.3.2.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.4.3.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.5.4.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.0.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.1.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.2.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.3.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.4.weight', 'depth_pro.encoder.feature_upsample.upsample_blocks.6.5.weight', 'fusion_stage.layers.4.deconv.weight', 'fusion_stage.layers.5.projection.bias', 'fusion_stage.layers.5.projection.weight', 'fusion_stage.layers.5.residual_layer1.convolution1.bias', 'fusion_stage.layers.5.residual_l

LightningModel(
  (model): DepthProForSuperResolution(
    (depthpro_for_depth_estimation): DepthProForDepthEstimation(
      (depth_pro): DepthProModel(
        (encoder): DepthProEncoder(
          (patch_encoder): DepthProViT(
            (embeddings): DepthProViTEmbeddings(
              (patch_embeddings): DepthProViTPatchEmbeddings(
                (projection): Conv2d(3, 1024, kernel_size=(4, 4), stride=(4, 4))
              )
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (encoder): DepthProViTEncoder(
              (layer): ModuleList(
                (0-3): 4 x DepthProViTLayer(
                  (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
                  (attention): DepthProViTSdpaAttention(
                    (attention): DepthProViTSdpaSelfAttention(
                      (query): Linear(in_features=1024, out_features=1024, bias=True)
                      (key): Linear(in_features=1024, out_features=1024, bias=True

In [4]:
test_dataset = Urban100Dataset()
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0
)


lr_images_path: ['/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_001_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_002_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_003_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_004_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_005_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_006_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_007_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_008_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_009_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_010_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100/img_011_SRF_4_LR.png', '/mnt/object/urban100/Urban 100/X4 Ur

In [7]:
def offline_eval(batch):
    with torch.no_grad():
        lr, hr = batch  # adjust depending on dataset output
        lr = lr.to("cpu")
        hr = hr.to("cpu")
        sr = model(lr)
        sr = F.interpolate(sr, size=hr.shape[2:])
        return model.mse(sr, hr), model.psnr(sr, hr), model.ssim(sr, hr), model.snr(sr, hr)


for batch in test_loader:
    mse, psnr, ssim, snr = offline_eval(batch)
    print("test_mse_loss", mse)
    print("test_psnr", psnr)
    print("test_ssim", ssim)
    print("test_snr", snr)
    print('\n')
    

test_mse_loss tensor(0.0173)
test_psnr tensor(23.6338)
test_ssim tensor(0.6772)
test_snr tensor(13.0907)


