In [5]:

import os
import math
import argparse
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
# from model_train import CNN_1, CNN_2, CNN_3, CNN_4
from model_train2 import CNN, CNNwithSEBlock, CNN3D, CNNwithSEBlock3D, UNet, UNetwithSEBlock, UNetwithSelfattention, UNet3D, UNetwithSEBlock3D, UNetwithSelfattention3D
from DataSet import MaxMinNormalizeGlobalPerChannel,MyDataSet, dataset_2

random.seed(26)
np.random.seed(26)
torch.manual_seed(26)
torch.cuda.manual_seed(26)
torch.cuda.manual_seed_all(26) 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"  # 或者 ":4096:8"

img_dir = 'Gauss_S1.00_NL0.30_B0.50/Gauss_S1.00_NL0.30_B0.50'




model_dict = {
    'MSE': (CNNwithSEBlock(), '/home/linux/3.3lab/outcomes/Old_MSE_2/CNNwithSEBlock.pth'),
    'Weighted MSE': (CNNwithSEBlock(), '/home/linux/3.3lab/outcomes/SSIM_comparison/CNNwithSEBlock.pth'),
    'SSIM_0.5': (CNNwithSEBlock(), '/home/linux/3.3lab/outcomes/SSIM_test01_0.5x/CNNwithSEBlock.pth'),
    'SSIM_1': (CNNwithSEBlock(), '/home/linux/3.3lab/outcomes/SSIM_test01_1x/CNNwithSEBlock.pth'),
    'SSIM_5': (CNNwithSEBlock(), '/home/linux/3.3lab/outcomes/SSIM_test01_5x/CNNwithSEBlock.pth'),
    'SSIM_only':(CNNwithSEBlock(),'/home/linux/3.3lab/outcomes/SSIM_only_1x/CNNwithSEBlock.pth')
}


In [6]:

# 定义训练以及预测时的预处理方法
data_transform = {
    "without_jet": transforms.Compose([MaxMinNormalizeGlobalPerChannel()]),
    "jet": transforms.Compose([MaxMinNormalizeGlobalPerChannel()])}

# 实例化训练数据集
data_set = MyDataSet(img_dir=img_dir,
                    group_size=10000,
                    size_in = 10000,
                    splition = True,
                    split_shuffle = False,
                    transform=data_transform["without_jet"])
train_dataset = dataset_2(data_set.train_X, data_set.train_Y)
val_dataset = dataset_2(data_set.val_X, data_set.val_Y)
test_dataset = dataset_2(data_set.test_X, data_set.test_Y)

del train_dataset
del test_dataset

test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=0)

Y_min = 0.0
Y_max = 20.6504

def weight_MSE(predicted_img, true_img):
    weight = true_img / np.sum(true_img, axis=(1, 2), keepdims=True)
    return np.sum(weight * (predicted_img - true_img) ** 2)

class ModelResults:
    def __init__(self):
        self.energy_residual_2to5 = []
        self.energy_residual_5to10 = []
        self.energy_residual_10toinfty = []
        self.barycenter_shift = [[] for _ in range(3)]
        
    def compute_barycenter_shift(self, true_img, predict_img):
        barycenter_true_X = [0, 0, 0]
        barycenter_true_Y = [0, 0, 0]
        barycenter_predict_X = [0, 0, 0]
        barycenter_predict_Y = [0, 0, 0]
        true_energy = [0, 0, 0]
        
        for i in range(56):
            for j in range(56):
                if true_img[i][j] > 1e-1:
                    energy_residual = (true_img[i][j] - predict_img[i][j]) / true_img[i][j]
                    if true_img[i][j] > 2 and true_img[i][j] <= 5:
                        self.energy_residual_2to5.append(energy_residual)
                        self.update_barycenter(i, j, true_img, predict_img, barycenter_true_X, barycenter_true_Y,
                                               barycenter_predict_X, barycenter_predict_Y, true_energy, 0)
                    if true_img[i][j] > 5 and true_img[i][j] <= 10:
                        self.energy_residual_5to10.append(energy_residual)
                        self.update_barycenter(i, j, true_img, predict_img, barycenter_true_X, barycenter_true_Y,
                                               barycenter_predict_X, barycenter_predict_Y, true_energy, 1)
                    elif true_img[i][j] > 10:
                        self.energy_residual_10toinfty.append(energy_residual)
                        self.update_barycenter(i, j, true_img, predict_img, barycenter_true_X, barycenter_true_Y,
                                               barycenter_predict_X, barycenter_predict_Y, true_energy, 2)
        self.calculate_barycenter_shift(barycenter_true_X, barycenter_true_Y, barycenter_predict_X, barycenter_predict_Y,
                                        true_energy)

    def update_barycenter(self, i, j, true_img, predict_img, barycenter_true_X, barycenter_true_Y,
                          barycenter_predict_X, barycenter_predict_Y, true_energy, idx):
        barycenter_true_X[idx] += i * true_img[i][j]
        barycenter_true_Y[idx] += j * true_img[i][j]
        barycenter_predict_X[idx] += i * predict_img[i][j]
        barycenter_predict_Y[idx] += j * predict_img[i][j]
        true_energy[idx] += true_img[i][j]

    def calculate_barycenter_shift(self, barycenter_true_X, barycenter_true_Y, barycenter_predict_X,
                                barycenter_predict_Y, true_energy):
        for k in range(3):
            # 避免除零错误，确保除数不为零
            if true_energy[k] != 0:
                barycenter_True_X = barycenter_true_X[k] / true_energy[k]
                barycenter_True_Y = barycenter_true_Y[k] / true_energy[k]
                barycenter_Predict_X = barycenter_predict_X[k] / true_energy[k]
                barycenter_Predict_Y = barycenter_predict_Y[k] / true_energy[k]

                self.barycenter_shift[k].append(np.sqrt((barycenter_True_X - barycenter_Predict_X) ** 2 +
                                                    (barycenter_True_Y - barycenter_Predict_Y) ** 2))

    def print_results(self):
        print(f"\tEnergy residual 2-5: {np.mean(self.energy_residual_2to5)}")
        print(f"\tEnergy residual 5-10: {np.mean(self.energy_residual_5to10)}")
        print(f"\tEnergy residual 10-infty: {np.mean(self.energy_residual_10toinfty)}")
        print(f"\tBarycenter shift 2-5: {np.mean(self.barycenter_shift[0])}")
        print(f"\tBarycenter shift 5-10: {np.mean(self.barycenter_shift[1])}")
        print(f"\tBarycenter shift 10-infty: {np.mean(self.barycenter_shift[2])}")


def process_model(model_name, model, model_path, test_dataloader, Y_max, Y_min, data_set):
    model = model.to("cuda")
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    weight_MSE_list = []
    predicted_images = []
    model_results = ModelResults()
    
    with torch.no_grad():
        for i, (X_test, Y_test) in enumerate(test_dataloader):
            outputs = model(X_test.to("cuda"))
            weight_MSE_list.append(weight_MSE(outputs.cpu().detach().numpy(), Y_test.numpy()) / len(Y_test))
            predicted_images.append(outputs.cpu().detach().numpy())
        
        print(f"Weighted MSE of {model_name}: {np.mean(weight_MSE_list)}")
    
    predicted_images = np.concatenate(predicted_images, axis=0)
    predicted_images = predicted_images * (Y_max - Y_min) + Y_min
    truth_list = data_set.val_Y.numpy()
    predict_list = predicted_images
    
    
    for true_img, predict_img in zip(truth_list, predict_list):
        true_img = true_img[0] * (Y_max - Y_min) + Y_min
        predict_img = predict_img[0]
        model_results.compute_barycenter_shift(true_img, predict_img)

    model_results.print_results()
    


# 调用
for model_name, (model, model_path) in model_dict.items():
    process_model(model_name, model, model_path, test_dataloader, Y_max, Y_min, data_set)

transformation is not None


  model.load_state_dict(torch.load(model_path))


Weighted MSE of MSE: 0.02185518853366375
	Energy residual 2-5: 0.014413216151297092
	Energy residual 5-10: 0.01966192200779915
	Energy residual 10-infty: 0.0445161797106266
	Barycenter shift 2-5: 0.5772069692611694
	Barycenter shift 5-10: 0.8389326333999634
	Barycenter shift 10-infty: 2.6635375022888184
Weighted MSE of Weighted MSE: 0.015594887547194958
	Energy residual 2-5: -0.00048354602768085897
	Energy residual 5-10: 0.01433281134814024
	Energy residual 10-infty: 0.03573371469974518
	Barycenter shift 2-5: 0.1826164871454239
	Barycenter shift 5-10: 0.6390345692634583
	Barycenter shift 10-infty: 2.254971742630005
Weighted MSE of SSIM_0.5: 0.01761464588344097
	Energy residual 2-5: 0.009683451615273952
	Energy residual 5-10: 0.016395552083849907
	Energy residual 10-infty: 0.03519632667303085
	Barycenter shift 2-5: 0.40698885917663574
	Barycenter shift 5-10: 0.6977148652076721
	Barycenter shift 10-infty: 2.3959054946899414
Weighted MSE of SSIM_1: 0.016260061413049698
	Energy residual 2-