之前，我们尝试了将所有的资产投影到一个较低维度的空间上，得到了每个资产的低维嵌入向量表示。

对于这个嵌入向量，一开始我的思路是把他作为标签输入模型，让模型可以学习到这个资产的现实属性，比如这个资产是商品还是金融，是逆周期还是顺周期

比如 股指期货就属于 (金融期货，顺周期...等等其他属性) 大豆就属于 (商品期货，逆周期...等等其他属性) 

但一个更好的思路是，在得到了每个资产的嵌入向量之后，我们可以反解出得到某一种纯因子所需要的资产组合，比如我们可以用有色金属和贵金属，构造出纯的周期因子（举个例子）

这样做至少有3个好处：

1. 这种因子资产的收益率是真实的，并非不可交易的虚拟资产（当然交易成本略高一些）

2. 这种因子资产是由多种资产组合而成，本身就有一定的多空对冲，天然的风险较小

3. 这种因子资产的特征更为集中，不同资产的噪声会互相抵消，从而让主要的特征更突出，更不容易过拟合。

这种范式在A股和期货其实是等同的，因为是纯机器学习模型，没有人为先验输入，在这个范式看来，期货市场就是一个规模小一点的、资产特征更鲜明的证券市场，把53个期货换成5000个股票，把整个模型的规模横向扩大。（还有更大的算力和显存）

前面我们已经将筛选出的53中期货品种嵌入到了10个特征维度上，得到了一个（53，10）的特征嵌入矩阵，表示每个资产的因子成分

现在要做的是，反解出一个维度为（10，53）的资产组合矩阵，来定义10个资产组合，让这10个资产组合恰好对应这10中纯因子，即I10的单位矩阵

求解一个矩阵 B（10*53） ，使得 B A = I 10

当然，这个问题和前面的组合资产一样，B是有多解的，我们也是需要找到最小范数解　B = (At * A)^-1 * At

In [1]:
import os
os.chdir('d:/future/Index_Future_Prediction')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tqdm
import optuna
import math

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import lr_scheduler, Adam, AdamW
from torch.utils.data import TensorDataset, DataLoader

from utils import *
from modules import *

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
embedding = AssetsEmbedding(num_base_assets = 53, embedding_dim = 10)
embedding.load_state_dict(torch.load('params/assets_embedding.params'))
embedding_matrix = embedding.embedding.weight
embedding_matrix.shape

torch.Size([53, 10])

直接应用pinv就可以得到纯因子资产组合的权重（暂时把资产到因子的映射称为成分，因子到资产的组合称为权重）

把reverse_matrix 和 embedding_matrix 乘起来验证一下，主对角线元素确实都是1，其他位置的元素应该是0，但是因为计算误差有一些很小的值，可以忽略，构建资产组合的时候会自动筛选掉，这里不会有太大的影响

In [3]:
reverse_matrix = torch.linalg.pinv(embedding_matrix)
result = reverse_matrix @ embedding_matrix
print(result)

tensor([[ 1.0000e+00,  9.2385e-09, -5.5037e-08,  1.8948e-07, -1.2968e-07,
         -1.6432e-07, -1.4985e-07, -4.6812e-08, -9.0374e-08,  1.0174e-07],
        [ 4.1399e-08,  1.0000e+00, -7.8192e-08,  1.7928e-07, -2.2836e-07,
          1.1250e-07, -2.3063e-07,  1.7747e-07, -1.2770e-07, -6.1568e-08],
        [ 7.5361e-09, -1.6685e-07,  1.0000e+00, -5.8917e-08,  1.8342e-07,
         -4.3520e-07,  6.8001e-08,  8.4062e-08,  4.3654e-07,  2.3909e-07],
        [ 1.6990e-07,  1.9003e-08,  6.5449e-08,  1.0000e+00, -3.7454e-09,
          2.8470e-07,  1.4713e-07, -2.4187e-07, -2.0271e-07, -2.0462e-07],
        [-7.2681e-08, -4.8685e-08,  1.2193e-07,  1.8203e-08,  1.0000e+00,
         -1.0288e-07,  6.2561e-08, -1.7167e-07,  5.6759e-08, -1.9859e-07],
        [-1.1601e-07,  4.2707e-08, -2.3216e-07,  9.3580e-08, -2.3373e-07,
          1.0000e+00, -2.5321e-07, -1.1867e-07, -1.3324e-07,  1.3054e-07],
        [-1.6203e-07, -2.0793e-07,  1.7137e-08,  2.0561e-07,  1.1871e-08,
         -1.7457e-07,  1.0000e+0

这样，我们就得到了一个纯因子的资产组合矩阵；

例如这里的因子A，我们没有对因子A进行定义，因此也无从知道因子A本质上是什么，这是一个黑箱，但我们可以从因子A的组成，来大致猜测因子A可能是什么资产

In [4]:
assets_names = [
    # 股指期货
    '上证50', '沪深300', '中证500',
    # 国债期货
    '2年国债', '5年国债', '10年国债', '30年国债',
    # 黑色金属
    '铁矿石', '焦煤', '螺纹钢', '热轧卷板', '不锈钢', '硅铁', '锰硅',
    # 有色金属
    '沪铜', '沪铝', '沪锌', '沪镍',
    # 贵金属
    '黄金', '白银',
    # 能源化工
    '燃油', '低硫燃料油', '沥青', 'LPG', 'PTA', '乙二醇', '短纤', 
    '塑料', '聚丙烯', 'PVC', '苯乙烯', '甲醇', '尿素', '橡胶',
    # 农产品
    '豆一', '豆二', '豆粕', '菜粕', '豆油', '菜油', '棕榈油', '花生',
    '玉米', '玉米淀粉', '棉花', '白糖', '红枣', '苹果', '纸浆', 
    '鸡蛋', '生猪',
    # 建材
    '玻璃', '纯碱'
]

In [5]:
factor_weights_table = pd.DataFrame(reverse_matrix.T.detach().numpy(), index = assets_names)
factor_weights_table

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
上证50,-0.038973,0.027604,0.082788,-0.00633,-0.014438,0.023319,-0.043292,0.017257,0.028938,-0.019075
沪深300,-0.039625,0.04123,0.089986,0.008605,-0.016247,0.036675,-0.061822,0.029305,0.016257,-0.023709
中证500,-0.021256,0.050249,0.076292,0.008628,-0.030373,0.046459,-0.091532,0.0519,0.024304,-0.033552
2年国债,0.002744,-0.006155,-0.001078,-0.007063,0.008967,-0.000528,0.013049,0.010399,-0.016529,-0.000532
5年国债,-0.000972,-0.001183,0.002451,0.000805,-0.001019,-0.004438,0.00227,0.002091,0.000858,0.004283
10年国债,0.000377,-0.005568,0.000841,-0.00489,-9.6e-05,-0.008532,0.009507,0.003773,-0.006315,0.006048
30年国债,0.000285,-0.004539,-0.00187,-0.006711,0.004771,-0.006064,0.011014,0.011683,-0.01808,0.005411
铁矿石,-0.035613,-0.001491,-0.063429,-0.039211,-0.02356,0.042069,0.012628,0.069581,0.004404,-0.033339
焦煤,-0.029242,0.099465,-0.072996,-0.005888,0.04644,0.023217,-0.004966,0.027133,0.019449,-0.009757
螺纹钢,-0.025187,0.024535,0.000392,-0.025029,0.005133,0.030057,-0.000776,0.02455,-0.005247,-0.045744


这个矩阵非常大，对每一个因子，我们尝试阅读其多的前5和空的前5来理解这个因子的含义

In [6]:
for i in range(10):
    factor = factor_weights_table[i].copy().sort_values()
    print(f'因子{i}')
    print('正权重成分',factor.tail(10).index.values)
    print('负权重成分', factor.head(10).index.values)

因子0
正权重成分 ['豆油' '低硫燃料油' '玉米' '菜油' '苹果' '豆一' '豆粕' '菜粕' '红枣' '豆二']
负权重成分 ['白银' '锰硅' '沪铜' '纯碱' '黄金' '苯乙烯' '沪深300' '上证50' '塑料' '沪铝']
因子1
正权重成分 ['纯碱' '上证50' 'LPG' '豆一' '低硫燃料油' '沪深300' '红枣' '中证500' '生猪' '焦煤']
负权重成分 ['沪镍' '菜粕' '甲醇' '玉米' '不锈钢' '橡胶' '尿素' '豆二' '白银' '乙二醇']
因子2
正权重成分 ['棉花' '豆油' '燃油' '低硫燃料油' '菜油' '棕榈油' '沥青' '中证500' '上证50' '沪深300']
负权重成分 ['红枣' '焦煤' '铁矿石' '纯碱' '尿素' '白银' '玻璃' '豆粕' '沪铜' '橡胶']
因子3
正权重成分 ['沪铝' '锰硅' '塑料' '棉花' '硅铁' '生猪' '白糖' '红枣' '玻璃' '橡胶']
负权重成分 ['尿素' '纯碱' '菜油' '豆油' '玉米' '棕榈油' '黄金' '菜粕' '沥青' '铁矿石']
因子4
正权重成分 ['锰硅' '纯碱' '沪镍' '塑料' '纸浆' '生猪' '焦煤' '沪锌' '硅铁' '橡胶']
负权重成分 ['红枣' '黄金' '菜粕' '白银' '苹果' '燃油' '沥青' '棉花' '豆粕' '沪铜']
因子5
正权重成分 ['铁矿石' '棉花' '中证500' '豆二' '白糖' '玻璃' '沪镍' '鸡蛋' '豆粕' '菜粕']
负权重成分 ['黄金' '生猪' '乙二醇' '沥青' '低硫燃料油' '苯乙烯' '纸浆' 'LPG' '10年国债' '甲醇']
因子6
正权重成分 ['聚丙烯' 'LPG' '乙二醇' '沥青' '短纤' '棉花' 'PTA' '苯乙烯' '低硫燃料油' '燃油']
负权重成分 ['中证500' '沪深300' '沪镍' '上证50' '纯碱' '硅铁' '花生' '锰硅' '白银' '沪铝']
因子7
正权重成分 ['螺纹钢' '焦煤' '沪深300' '苹果' '菜油' '豆油' '中证500' '铁矿石' '棕榈油' '橡胶']
负权重成分 ['生猪' '锰硅' '沪镍' 

这里还是比较混乱的，我丢给AI去找规律，能得出如下的几个比较有特点的信息

因子0: 经济衰退/滞胀对冲因子
正权重 (做多): 豆油, 低硫燃料油, 玉米, 菜油, 苹果, 豆一, 豆粕, 菜粕, 红枣, 豆二

负权重 (做空): 白银, 锰硅, 沪铜, 纯碱, 黄金, 苯乙烯, 沪深300, 上证50, 塑料, 沪铝

解读: 这是一个非常典型的宏观对冲因子。它几乎做多了所有核心农产品和部分基础能源，同时做空了股票指数、工业金属和化工品。这个组合的核心逻辑是押注于生活必需品（食品、燃料）的韧性将强于依赖于经济增长的金融和工业资产。该因子在滞胀（经济停滞、通胀高企）或经济衰退时期会有非常出色的防御表现。

因子1: 中国经济顺周期因子
正权重 (做多): 纯碱, 上证50, LPG, 豆一, 低硫燃料油, 沪深300, 红枣, 中证500, 生猪, 焦煤

负权重 (做空): 沪镍, 菜粕, 甲醇, 玉米, 不锈钢, 橡胶, 尿素, 豆二, 白银, 乙二醇

解读: 该因子的灵魂是做多中国A股三大指数。多头中的其他品种，如焦煤、纯碱（基建地产上游）和生猪（消费），都与国内经济活动密切相关。因此，这是一个与中国宏观经济景气度高度绑定的顺周期因子。当市场预期国内经济向好、政策发力时，该因子预计会上涨。

因子6: 石化产业链 vs. 金属/矿业
正权重 (做多): 聚丙烯, LPG, 乙二醇, 沥青, 短纤, 棉花, PTA, 苯乙烯, 低硫燃料油, 燃油

负权重 (做空): 中证500, 沪深300, 沪镍, 上证50, 纯碱, 硅铁, 花生, 锰硅, 白银, 沪铝

解读: 这是一个非常干净的跨行业对冲因子。它几乎做多了完整的石油化工产业链（从LPG到塑料、化纤），同时做空股票指数和金属/矿业综合体。这可能是在押注消费需求（驱动石化产品）将比工业投资需求（驱动金属）更强劲，或者仅仅是一个在硬资产内部的相对价值策略。

因子8: 工业材料 vs. 粮食/饲料
正权重 (做多): 沪镍, 橡胶, LPG, 沥青, 豆二, 乙二醇, 纸浆, 玻璃, 甲醇, 红枣

负权重 (做空): 生猪, 鸡蛋, 玉米淀粉, 豆粕, 玉米, 白糖, 燃油, 菜粕, PTA, 锰硅

解读: 该因子构建了一个清晰的对立面：做多工业材料，同时做空“粮食+饲料”复合体。它押注的是工业通胀将超过农业通胀。在工业复苏、需求旺盛，而全球粮食丰收、食品价格稳定的宏观背景下，该因子会有最佳表现。

接下来，我们构建因子的收益曲线

In [7]:
assets_list = [
    # 股指期货
    'IH.CFX', 'IF.CFX', 'IC.CFX',
    # 国债期货
    'TS.CFX', 'TF.CFX', 'T.CFX', 'TL1.CFX',
    # 黑色金属产业链
    'I.DCE', 'JM.DCE', 'RB.SHF', 'HC.SHF', 'SS.SHF', 'SF.ZCE', 'SM.ZCE',
    # 有色金属
    'CU.SHF', 'AL.SHF', 'ZN.SHF', 'NI.SHF',
    # 贵金属
    'AU.SHF', 'AG.SHF',
    # 能源化工
    'FU.SHF', 'LU.INE', 'BU.SHF', 'PG.DCE', 'TA.ZCE', 'EG.DCE', 'PF.ZCE', 
    'L.DCE', 'PP.DCE', 'V.DCE', 'EB.DCE', 'MA.ZCE', 'UR.ZCE', 'RU.SHF',
    # 农产品
    'A.DCE', 'B.DCE', 'M.DCE', 'RM.ZCE', 'Y.DCE', 'OI.ZCE', 'P.DCE', 'PK.ZCE',
    'C.DCE', 'CS.DCE', 'CF.ZCE', 'SR.ZCE', 'CJ.ZCE', 'AP.ZCE', 'SP.SHF', 
    'JD.DCE', 'LH.DCE',
    # 建材
    'FG.ZCE', 'SA.ZCE'
]

feature_columns = ['inday_chg_open','inday_chg_high','inday_chg_low','inday_chg_close','inday_chg_amplitude', 'ma_10','ma_26','ma_45','ma_90','ma_vol',]
label_columns = ['label_return','down_prob','middle_prob','up_prob']

In [8]:
seq_len = 120
batch_size = 64

In [9]:
# 训练数据
start_date = 20220901
end_date = 20250901

feature = []
label = []
current_assets = []

data = pd.read_csv(f'data/C.DCE.csv')
data = data[data['trade_date'] > start_date].copy()
data = data[data['trade_date'] < end_date].copy()
full_length = len(data)

for asset_code in assets_list:
    data = pd.read_csv(f'data/{asset_code}.csv')
    data = data[data['trade_date'] > start_date]
    data = data[data['trade_date'] < end_date]
    if len(data) == full_length:
        current_assets.append(asset_code)
        feature.append(torch.tensor(data[feature_columns].values, dtype = torch.float32))
        label.append(torch.tensor(data[label_columns].values, dtype = torch.float32))

feature = torch.stack(feature, dim = 1)
label = torch.stack(label, dim = 1)
feature = feature.unfold(dimension = 0, size = seq_len, step = 1).transpose(2,3)
label = label[seq_len-1:]

In [10]:
reshape_feature = feature.permute(0,2,1,3).flatten(0,1)
expand_reverse = reverse_matrix.unsqueeze(0).unsqueeze(0).repeat(606, 120, 1, 1).flatten(0,1)
factor_feature = torch.bmm(expand_reverse, reshape_feature).reshape(606, 120, 10, 10).permute(0,2,1,3)
factor_feature.shape

torch.Size([606, 10, 120, 10])

In [11]:
expand_reverse = reverse_matrix.unsqueeze(0).repeat(606, 1, 1)
factor_label = torch.bmm(expand_reverse, label)
factor_label.shape

torch.Size([606, 10, 4])

In [12]:
factor_weights = expand_reverse

In [13]:
train_feature, test_feature = factor_feature[:400], factor_feature[400:]
train_weights, test_weights = factor_weights[:400], factor_weights[400:]
train_label, test_label = factor_label[:400], factor_label[400:]

In [14]:
train_data = TensorDataset(train_feature, train_weights, train_label)
train_loader = DataLoader(dataset = train_data, batch_size = batch_size)
test_data = TensorDataset(test_feature, test_weights, test_label)
test_loader = DataLoader(dataset = test_data, batch_size = batch_size)

In [15]:
class PanelTransformerBackbone(nn.Module):
    def __init__(self, dim_patch_feature, dim_projection, dim_temporal_embedding, dim_assets_embedding, num_bass_assets, num_head, num_layer, dropout):
        super().__init__()
        dim_encoder_input = dim_projection +  dim_temporal_embedding + dim_assets_embedding
        self.projection = nn.Linear(dim_patch_feature, dim_projection)
        self.assets_embedding = AssetsEmbedding(num_base_assets = num_bass_assets, embedding_dim = dim_assets_embedding, target_ratio = 0.2, freeze = True)
        self.temporal_embedding = TemporalEmbedding(dim_embedding = dim_temporal_embedding)
        self.panel_encoder = MultiLayerPanelEncoder(num_layer = num_layer, d_model = dim_encoder_input, num_head = num_head, num_ffn_hidden = dim_encoder_input * 2, dropout = dropout)

    def forward(self, x, weights):
        x = self.projection(x)
        x = self.temporal_embedding(x)
        x = self.assets_embedding(x, weights)
        x = self.panel_encoder(x)
        return x

model = PanelTransformerBackbone(dim_patch_feature = 120, dim_projection = 128, dim_temporal_embedding = 6, dim_assets_embedding = 10, num_bass_assets = 53, num_head = 8, num_layer = 3 , dropout = 0.5)
model.assets_embedding.load_state_dict(torch.load('params/assets_embedding.params'))

<All keys matched successfully>

In [16]:
class RankPanelTransformer(nn.Module):
    """Panel Time Series Transformer"""
    def __init__(self, dim_raw_feature, patch_size, stride, mask_expand_size, seq_len, dim_projection, dim_temporal_embedding, dim_assets_embedding, num_bass_assets, num_head, num_layer, dropout):
        super().__init__()
        # 模型参数
        self.device = 'cuda:0'
        self.input_size = dim_raw_feature
        self.patch_size = patch_size
        self.stride = stride
        self.mask_expand_size = mask_expand_size
        self.num_patch = int(np.floor((seq_len - patch_size) / stride) + 1)

        self.dim_projection = dim_projection

        dim_encoder_input = dim_projection +  dim_temporal_embedding + dim_assets_embedding

        # 前置层
        self.patch = TimeSeriesPatcher(patch_size, stride)

        # 编码层
        self.encoder = PanelTransformerBackbone(dim_patch_feature = dim_raw_feature * patch_size,
                                                  dim_projection = dim_projection,
                                                  dim_temporal_embedding = dim_temporal_embedding,
                                                  dim_assets_embedding = dim_assets_embedding,
                                                  num_bass_assets = num_bass_assets,
                                                  num_head = num_head,
                                                  num_layer = num_layer,
                                                  dropout = dropout)

        # 线性输出
        self.decoder = nn.Sequential(
            nn.Flatten(start_dim = -2),
            # nn.Linear(self.num_patch * dim_encoder_input, self.num_patch *dim_encoder_input),
            nn.Dropout(dropout),
            nn.Linear(self.num_patch * dim_encoder_input, 1),
            nn.Flatten(start_dim = -2),
        )

        # # 注意力式输出
        # self.decoder = nn.Sequential(
        #     nn.Flatten(start_dim = -2),
        #     # nn.Dropout(dropout),
        #     nn.Linear(dim_encoder_input*self.num_patch, dim_encoder_input),
        #     nn.TransformerEncoder(encoder_layer=nn.TransformerEncoderLayer(d_model = dim_encoder_input, nhead = 4, dim_feedforward=dim_encoder_input, dropout = dropout, batch_first=True, norm_first=True),num_layers = 3),
        #     # nn.Dropout(dropout),
        #     nn.Linear(dim_encoder_input, 1),
        #     nn.Flatten(start_dim = -2),
        # )
    
    def forward(self, x, weights):
        x_patched = self.patch(x)
        enc_out = self.encoder(x_patched, weights)
        output = self.decoder(enc_out)
        return output
    
    # # 重写 train 方法，来固定模型只训练输出层
    # def train(self, mode = True):
    #     super().train(mode)
    #     if mode:
    #         self.encoder.eval()

In [17]:
def accuracy(pred, real, top):
    """
    计算前10名资产的识别准确率
    """
    # 获取预测和真实的top资产的索引
    _, pred_top10_indices = torch.topk(pred, top, dim=1)
    _, real_top10_indices = torch.topk(real, top, dim=1)

    batch_accs = []
    # 对批次中的每个样本进行计算
    for i in range(pred.shape[0]):
        pred_set = set(pred_top10_indices[i].tolist())
        real_set = set(real_top10_indices[i].tolist())
        
        # 计算交集中的元素数量
        intersection_size = len(pred_set.intersection(real_set))
        
        # 计算该样本的准确率
        acc = intersection_size / top
        batch_accs.append(acc)
        
    # 返回批次的平均准确率
    return np.mean(batch_accs)

In [18]:
dim_raw_feature = 10
patch_size = 10
stride = 5
mask_expand_size = 3
seq_len = 120
dim_projection = 102
dim_temporal_embedding = 16
dim_assets_embedding = 10
num_bass_assets = 53
num_head = 8
num_layer = 5
dropout = 0.3

batch_size = 32
learning_rate = 1e-3
weight_decay = 0
gamma = 1

test_size = 128

epochs = 200

top = 2

In [19]:
loss_fn = WeightedRankLoss(alpha=0.5, p=1, q=1)

model = RankPanelTransformer(dim_raw_feature,patch_size,stride,mask_expand_size,seq_len,dim_projection,dim_temporal_embedding,dim_assets_embedding,num_bass_assets,num_head,num_layer,dropout).to('cuda:0')
model.encoder.load_state_dict(torch.load('params/panel_tf_backbone.params'))
model.encoder.assets_embedding.load_state_dict(torch.load('params/assets_embedding.params'))

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

In [20]:
num_assets = feature.shape[1]
print(f'baseline accuracy: {top/10:.2%}')

test_losses, test_accs = [], [] 
model.eval()
with torch.no_grad():
    for batch_x, batch_weights, batch_y in test_loader:
        batch_x, batch_weights, batch_y = batch_x.to('cuda:0'), batch_weights.to('cuda:0'), batch_y.to('cuda:0')
        pred = model(batch_x, batch_weights)
        batch_y = batch_y[...,0]# y 需要处理，只选第一列
        loss = loss_fn(pred, batch_y)
        acc = accuracy(pred, batch_y, top)
        test_losses.append(loss.item())
        test_accs.append(acc)
base_loss = np.mean(test_losses)
base_accs = np.mean(test_accs)
print(f'模型未得到训练前的基准损失：{base_loss:.2f}, 基准准确率：{base_accs:.2%}')

baseline accuracy: 20.00%
模型未得到训练前的基准损失：0.72, 基准准确率：24.83%


In [21]:
def epoch():
    train_losses, train_accs = [], []
    model.train()
    for batch_x, batch_weights, batch_y in train_loader:
        batch_x, batch_weights, batch_y =  batch_x.to('cuda:0'), batch_weights.to('cuda:0'), batch_y.to('cuda:0')
        optimizer.zero_grad()
        pred = model(batch_x, batch_weights)
        batch_y = batch_y[...,0]# y 需要处理，只选第一列
        loss = loss_fn(pred, batch_y)
        acc = accuracy(pred, batch_y, top)
        train_losses.append(loss.item())
        train_accs.append(acc)
        loss.backward()
        optimizer.step()

    test_losses, test_accs = [], []
    model.eval()
    with torch.no_grad():
        for batch_x, batch_weights, batch_y in test_loader:
            batch_x, batch_weights, batch_y = batch_x.to('cuda:0'), batch_weights.to('cuda:0'), batch_y.to('cuda:0')
            pred = model(batch_x, batch_weights)
            batch_y = batch_y[...,0]# y 需要处理，只选第一列
            loss = loss_fn(pred, batch_y)
            acc = accuracy(pred, batch_y, top)
            test_losses.append(loss.item())
            test_accs.append(acc)

    return np.mean(train_losses), np.mean(test_losses), np.mean(train_accs), np.mean(test_accs)


def train(epochs = 30):
    train_losses, test_losses = [], []
    train_accs, test_accs = [], []
    for i in tqdm.tqdm(range(epochs)):
        train_loss, test_loss, train_acc, test_acc = epoch()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        train_accs.append(train_acc)
        test_accs.append(test_acc)
        scheduler.step()

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(range(epochs), train_losses, label='Train Loss')
    plt.plot(range(epochs), test_losses, label='Test Loss')
    plt.title('Loss Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)


    plt.subplot(1, 2, 2)
    plt.plot(range(epochs), train_accs, label='Train Accuracy')
    plt.plot(range(epochs), test_accs, label='Test Accuracy')
    plt.title(f'Top-{top} Accuracy Curve') # 更新图表标题
    plt.xlabel('Epoch')
    plt.ylabel(f'Top-{top} Accuracy') # 更新Y轴标签
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    return np.mean(test_losses[-5:])

final_loss = train(epochs)
print(final_loss)

  0%|          | 0/200 [00:00<?, ?it/s]


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.