In [1]:
import torch
import torch.nn as nn

class MultipleLinearTransform(nn.Module):
    def __init__(self, input_dim, num_transforms):
        super(MultipleLinearTransform, self).__init__()
        # 定义线性层
        self.linear = nn.Linear(input_dim, input_dim)
        # 定义变换的次数
        self.num_transforms = num_transforms

    def forward(self, x):
        # 初始化累加结果
        total_result = torch.zeros_like(x)
        current_input = x
        for _ in range(self.num_transforms):
            # 进行线性变换
            transformed = self.linear(current_input)
            # 累加结果
            total_result += transformed
            # 更新当前输入
            current_input = transformed
        return total_result


class DualNetwork(nn.Module):
    def __init__(self, features_dim1, features_dim2, inflect_dim, num_transforms):
        super(DualNetwork, self).__init__()
        # 定义两个单层网络
        self.model1 = nn.Linear(features_dim1, inflect_dim)  # 第一个网络
        self.model2 = nn.Linear(features_dim2, inflect_dim)  # 第二个网络
        self.transform = MultipleLinearTransform(inflect_dim, num_transforms)

    def forward(self, x1, x2):
        # 前向传播
        output1 = self.model1(x1)
        output2 = self.model2(x2)

        # 使用 MultipleLinearTransform 进行多次线性变换并求和
        sum_result1 = self.transform(output1)
        sum_result2 = self.transform(output2)

        # 标准化输出
        standardized_output1 = self.standardize(sum_result1)
        standardized_output2 = self.standardize(sum_result2)

        # 计算标准化后的点积
        dot_product = torch.matmul(standardized_output1, standardized_output2.T)
        return dot_product,sum_result1,sum_result2

    def standardize(self, tensor):
        # 标准化函数
        mean = tensor.mean(dim=0, keepdim=True)
        std = tensor.std(dim=0, keepdim=True)
        std = torch.clamp(std, min=1e-8)  # 避免除零错误
        return (tensor - mean) / std