In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

## basic elm

In [2]:
class ELM(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        # default_args
        self.args = {
            'input_dim': 32,
            'output_dim': 24,
            'hidden_dim': 8,
            'activation_fun': 'sigmoid'
        }

        self.args.update(kwargs) # 更新默认参数

        self.activation_fun = getattr(F,self.args["activation_fun"])

        self.fc1 = nn.Linear(self.args['input_dim'], self.args['hidden_dim'])
        self.fc2 = nn.Linear(self.args['hidden_dim'], self.args['output_dim'], bias=False) # ELM输出层不需要偏置

    def forward(self, x):
        with torch.no_grad():   # 不需要反向传播
            x = self.activation_fun(self.fc1(x))
            return self.fc2(x)
        
    def fit(self, data, ground_truth):
        # 给定N条data和对应ground_truth，更新fc2的参数
        
        with torch.no_grad():   # 不需要反向传播
            hidden_mat = self.activation_fun(self.fc1(data))
            beta = torch.matmul(torch.pinverse(hidden_mat), ground_truth)  # 计算H的广义逆，求beta
            self.fc2.weight = nn.Parameter(beta)    # 更新fc2的参数为beta



## 处理svc数据

In [3]:
def read_SVC_data(file_path):
    df = pd.read_excel(file_path, header=None)
    data = torch.tensor(df.iloc[2:].values.astype(float), dtype=torch.float32)
    target = torch.tensor(df.iloc[1].values.astype(float), dtype=torch.float32).unsqueeze(1)
    return data.transpose(0,1), target

In [4]:
data, target = read_SVC_data("C:\\Users\\29147\\Desktop\\近期使用\\测试 的副本.xls")

In [5]:
data.shape

torch.Size([256, 973])

In [6]:
model = ELM(input_dim=973, hidden_dim=1300, output_dim=1)

In [7]:
model.fit(data, target)



In [10]:
model(data[0].unsqueeze(0))



RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1300 and 1x1300)

In [11]:
data[0].unsqueeze(0)

tensor([[12.4750, 12.0750, 12.1650, 12.3750, 12.1650, 12.3000, 11.8750, 11.9850,
         12.0200, 12.1000, 11.8550, 11.9450, 12.0900, 11.9300, 12.0900, 11.8600,
         12.1200, 11.8250, 11.7200, 11.7850, 11.9150, 11.7700, 11.8550, 11.9400,
         11.8600, 11.8300, 11.7600, 11.7100, 11.7250, 11.6950, 11.6900, 11.6750,
         11.7150, 11.7350, 11.6550, 11.7250, 11.6850, 11.7500, 11.6700, 11.7000,
         11.7150, 11.7500, 11.6900, 11.6600, 11.6800, 11.6950, 11.7050, 11.7400,
         11.7350, 11.7300, 11.7450, 11.7050, 11.7050, 11.7450, 11.7050, 11.7000,
         11.7500, 11.7350, 11.7550, 11.7150, 11.7050, 11.7050, 11.7200, 11.6950,
         11.6950, 11.6900, 11.6950, 11.7000, 11.6850, 11.6500, 11.6750, 11.6550,
         11.6500, 11.6350, 11.6400, 11.6450, 11.6600, 11.6550, 11.6350, 11.6200,
         11.6150, 11.6250, 11.6250, 11.6200, 11.6250, 11.6000, 11.6050, 11.5900,
         11.5950, 11.6000, 11.5900, 11.5900, 11.6100, 11.6100, 11.6200, 11.6050,
         11.5900, 11.5950, 1