In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from model_evalue_plot import double_ellipse_draw,double_rectangle_draw,circle_draw,rec_draw,cross_draw,lack_rec_draw,ring_draw
import torchvision
from tqdm import tqdm
import pickle
import numpy as np
import pandas as pd
from scipy import interpolate
from sklearn.metrics import r2_score
from matplotlib import pyplot as plt

In [None]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

In [2]:
class resnet50(nn.Module):
    def __init__(self, output_size):
        super(resnet50, self).__init__()
        self.resnet = torchvision.models.resnet50(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
        nn.Linear(num_ftrs, 1024),
        nn.Dropout(0.5),
        nn.Linear(1024, output_size),
        nn.Sigmoid()
)
    def forward(self, x,w):
        x = self.resnet(x)
        return x

In [3]:
def preprocess_evaluate(pred_data):
    pred_data = pickle.load(open(pred_data, "rb"))
    label = pred_data["input_data"]
    pred_data = pred_data["pred"]
    types = []
    args = []
    for i in tqdm(range(len(pred_data))):
        try:
            pattern = r'-?\d+\.?\d*'
            type = pred_data[i].split(",")[0].strip()
            split_data = re.findall(pattern, pred_data[i])
            split_data = [float(i) for i in split_data]
            if type == "cross":
                W1, L1, W2, L2, offset, phi, Px, Py = split_data
                types.append(type)
                args.append([W1, L1, W2, L2, offset, Px, Py, phi])
            elif type == "rec":
                W,L,phi,Px,Py = split_data
                types.append(type)
                args.append([L, W, Px, Py, phi])
            elif type == "ellipse":
                a, b, phi, Px, Py = split_data
                types.append(type)
                args.append([a, b, Px, Py, phi])
            elif type == "double_rec":
                W1, L1, W2, L2, Px, Py, phi = split_data
                types.append(type)
                args.append([W1, L1, W2, L2, Px, Py, phi])
            elif type == "double_ellipse":
                a, b, theta, Px, Py, phi = split_data
                types.append(type)
                args.append([a, b, theta, Px, Py, phi])
            elif type == "lack_rec":
                W, L, alpha, beta, gamma, Px, Py, phi = split_data
                types.append(type)
                args.append([L, W, alpha, beta, gamma, Px, Py, phi])
            elif type == "ring":
                R, r, theta, phi, Px, Py = split_data
                types.append(type)
                args.append([R, r, theta, phi, Px, Py])
            else:
                types.append("Type Error")
                args.append("Type Error")
        except:
            types.append("pred Error")
            args.append(["pred Error"])

    return types, args, label

In [4]:
def spectrum_interpolation(x, y, x_new):
    f = interpolate.interp1d(x, y, kind='linear', axis=1)
    y_new_all = f(x_new)
    return y_new_all

In [6]:
def model_predict(img):
    model = resnet50(100).to(DEVICE)
    resnet_path = r'D:\codes\data_autmention_model_final\runs\2024082301\model\2024082301_best.pth'
    model.load_state_dict(torch.load(resnet_path)['model_state_dict'])
    model.eval()
    light_source = np.linspace(400,800,100)
    img = img.to(DEVICE)
# 每500个数据一个batch
    for i in tqdm(range(0, len(img), 500)):
        wavePred= model(img[i:i+500], 0).cpu().detach().numpy()
        if i == 0:
            wavePred_all = wavePred
        else:
            wavePred_all = np.concatenate((wavePred_all, wavePred), axis=0)
    return  wavePred_all
    # cross_wave_pred = pd.DataFrame(wavePred_all, columns=light_source)
    # cross_wave_pred.to_csv(r"data/eval_data_VIT_7_epoch_9.pkl.csv", index=False)


In [7]:
def R2score_plot(y, y_pred):
    # 创建图表
    fig, ax = plt.subplots(figsize=(7, 6))

    # 计算 R2 得分
    r2_visible = r2_score(y, y_pred)

    # 绘制散点图
    ax.scatter(y, y_pred, color='blue')
    ax.set_title('可见光谱')
    ax.set_xlabel('真实值')
    ax.set_ylabel('预测值')
    ax.text(0.05, 0.95, f'R2 得分: {r2_visible:.3f}', transform=ax.transAxes, fontsize=12, verticalalignment='top')
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])

    # 显示图像
    plt.show()

In [None]:
type, args, label = preprocess_evaluate(r"D:\codes\transformer_for_metasurface\eval_data_VIT_05_epoch_6.pkl")

pred_tensor = []
labels = []
for i in tqdm(range(len(type))):
    if type[i] == "Type Error" or args[i][0] == "pred Error":
        continue
    else:
        labels.append(label[i])
        if type[i] == "cross":
            img_tensor = cross_draw(args[i])
        elif type[i] == "rec":
            img_tensor = rec_draw(args[i])
        elif type[i] == "ellipse":
            img_tensor = circle_draw(args[i])
        elif type[i] == "double_rec":
            img_tensor = double_rectangle_draw(args[i])
        elif type[i] == "double_ellipse":
            img_tensor = double_ellipse_draw(args[i])
        elif type[i] == "lack_rec":
            img_tensor = lack_rec_draw(args[i])
        elif type[i] == "ring":
            img_tensor = ring_draw(args[i])
        pred_tensor.append(img_tensor)

img = torch.stack(pred_tensor)

In [None]:
wavePred_all = model_predict(img)

In [None]:
wavePred_all

In [10]:
light_source = np.linspace(400,800,100)
light_source_new = np.linspace(400,800,500)
preds = spectrum_interpolation(light_source, wavePred_all, light_source_new)

In [None]:
preds

In [None]:
labels[0]

In [None]:
choice = np.random.choice(len(labels), 1)[0]
plt.plot(light_source_new, preds[choice], label="pred")
plt.plot(light_source_new, labels[choice], label="label")
plt.legend()
plt.show()
type[choice],args[choice]

In [None]:
choice #6438,119,4845,7098。363

In [None]:
preds[0]

In [None]:
r2_score(labels, preds)

In [None]:
mse = np.mean((labels - preds) ** 2, axis=1)
mae = np.mean(np.abs(labels - preds), axis=1)
np.mean(mse), np.mean(mae)

In [None]:
# 1. 平均光谱对比图
mean_label = np.mean(labels, axis=0)
mean_pred = np.mean(preds, axis=0)

plt.figure(figsize=(10, 6))
plt.plot(mean_label, label='Mean True Spectrum')
plt.plot(mean_pred, '--', label='Mean Predicted Spectrum')
plt.xlabel('Spectral Points')
plt.ylabel('Intensity')
plt.legend()
plt.title('Comparison of Mean True and Predicted Spectra')
plt.show()

In [None]:
# 2. 误差分布图（直方图）
errors = labels - preds
plt.figure(figsize=(10, 6))
plt.hist(errors.flatten(), bins=50, alpha=0.7, color='g')
plt.xlabel('Error (True - Pred)')
plt.ylabel('Frequency')
plt.title('Distribution of Prediction Errors')
plt.show()

In [None]:
mean_error = np.mean(errors, axis=0)
std_error = np.std(errors, axis=0)

plt.figure(figsize=(10, 6))
plt.plot(mean_error, label='Mean Error')
plt.fill_between(range(len(mean_error)), mean_error - std_error, mean_error + std_error, alpha=0.3, color='r', label='Error Std Dev')
plt.xlabel('Spectral Points')
plt.ylabel('Error (True - Pred)')
plt.legend()
plt.title('Mean and Standard Deviation of Errors across Spectral Points')
plt.show()

In [None]:
# 4. 误差的箱线图
plt.figure(figsize=(10, 6))
plt.boxplot(errors.T, showfliers=False)  # 转置使每个光谱点成为一个箱线图
plt.xlabel('Spectral Points')
plt.ylabel('Error (True - Pred)')
plt.title('Box Plot of Errors at Each Spectral Point')
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.imshow(errors, aspect='auto', cmap='viridis', interpolation='none')
plt.colorbar(label='Error (True - Pred)')
plt.xlabel('Spectral Points')
plt.ylabel('Spectra')
plt.title('Heatmap of Prediction Errors')
plt.show()