In [5]:
import torch
import torch.nn as nn

LSTM_NEUROES = 13*13*13

class LSTM(nn.Module):
    def __init__(self, input_size=LSTM_NEUROES, hidden_size=128, num_layers=2, output_size=LSTM_NEUROES):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.input_fc = nn.Linear(input_size + output_size, input_size)  # 映射到 input_size

    def forward(self, x, h_0, c_0):
        print("输入:", x.shape)
        out, (h_n, c_n) = self.lstm(x, (h_0, c_0))
        out = out[:, -1, :]  # 取LSTM最后一个时间步的输出
        out = self.fc(out)
        print("输出:", out.shape)
        return out, h_n, c_n

# 初始化模型
model = LSTM()
batch_size = 1

# 初始化隐藏层状态
h_0 = torch.zeros(model.num_layers, batch_size, model.hidden_size)
c_0 = torch.zeros(model.num_layers, batch_size, model.hidden_size)

# 第一次输入
input_1 = torch.randn(batch_size, 1, LSTM_NEUROES)  # (batch_size, seq_len, input_size)
output, h_n, c_n = model(input_1, h_0, c_0)
print("第一次输出:", output)
print("h_n:", h_n.shape)
print("c_n:", c_n.shape)

# 第二次输入，使用第一次的输出作为一部分输入
input_2 = torch.randn(batch_size, 1, LSTM_NEUROES)  # 另外的输入
new_input = torch.cat((input_2, output.unsqueeze(1)), dim=2)  # 合并输出作为输入的一部分
new_input = model.input_fc(new_input)  # 映射到原始input_size

output, h_n, c_n = model(new_input, h_n, c_n)
print("第一次输出:", output)
print("h_n:", h_n.shape)
print("c_n:", c_n.shape)
# 多次迭代输入进行修正
num_iterations = 5
for i in range(num_iterations):
    input_next = torch.randn(batch_size, 1, LSTM_NEUROES)  # 另外的输入
    new_input = torch.cat((input_next, output.unsqueeze(1)), dim=2)  # 合并输出作为输入的一部分
    new_input = model.input_fc(new_input)  # 映射到原始input_size

    output, h_n, c_n = model(new_input, h_n, c_n)
    print(f"第{i+3}次输出:", output)


输入: torch.Size([1, 1, 2197])
输出: torch.Size([1, 2197])
第一次输出: tensor([[-0.0111,  0.0593, -0.0235,  ...,  0.0436, -0.0385, -0.1238]],
       grad_fn=<AddmmBackward0>)
h_n: torch.Size([2, 1, 128])
c_n: torch.Size([2, 1, 128])
输入: torch.Size([1, 1, 2197])
输出: torch.Size([1, 2197])
第一次输出: tensor([[ 0.0033,  0.0623, -0.0255,  ...,  0.0506, -0.0307, -0.1387]],
       grad_fn=<AddmmBackward0>)
h_n: torch.Size([2, 1, 128])
c_n: torch.Size([2, 1, 128])
输入: torch.Size([1, 1, 2197])
输出: torch.Size([1, 2197])
第3次输出: tensor([[-0.0026,  0.0404, -0.0338,  ...,  0.0430, -0.0272, -0.1315]],
       grad_fn=<AddmmBackward0>)
输入: torch.Size([1, 1, 2197])
输出: torch.Size([1, 2197])
第4次输出: tensor([[ 0.0039,  0.0343, -0.0273,  ...,  0.0455, -0.0167, -0.1245]],
       grad_fn=<AddmmBackward0>)
输入: torch.Size([1, 1, 2197])
输出: torch.Size([1, 2197])
第5次输出: tensor([[ 0.0062,  0.0417, -0.0219,  ...,  0.0495, -0.0286, -0.1408]],
       grad_fn=<AddmmBackward0>)
输入: torch.Size([1, 1, 2197])
输出: torch.Size([1, 2197])

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim

LSTM_NEUROES = 13 * 13 * 13

class LSTM(nn.Module):
    def __init__(self, input_size=LSTM_NEUROES, hidden_size=128, num_layers=2, output_size=LSTM_NEUROES):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.input_fc = nn.Linear(input_size + output_size, input_size)  # 映射到 input_size

    def forward(self, input_t, prev_output, h_0, c_0):
        if prev_output is not None:
            input_t = torch.cat((input_t, prev_output), dim=1)
            input_t = self.input_fc(input_t)  # 映射到原始input_size

        input_t = input_t.unsqueeze(1)  # 调整形状为 (batch_size, seq_len, input_size)
        out, (h_0, c_0) = self.lstm(input_t, (h_0, c_0))
        out = self.fc(out[:, -1, :])
        return out, h_0, c_0

# 定义数据集
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, seq_len, input_size):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.input_size = input_size
        self.data = torch.randn(num_samples, seq_len, input_size)
        self.targets = torch.randn(num_samples, input_size)
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# 创建数据集和数据加载器
num_samples = 1000
seq_len = 5  # 假设序列长度为5
dataset = CustomDataset(num_samples, seq_len, LSTM_NEUROES)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

# 初始化模型、损失函数和优化器
model = LSTM()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 20
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0
    for inputs, targets in data_loader:
        optimizer.zero_grad()

        batch_size = inputs.size(0)
        h_0 = torch.zeros(model.num_layers, batch_size, model.hidden_size).to(inputs.device)
        c_0 = torch.zeros(model.num_layers, batch_size, model.hidden_size).to(inputs.device)
        prev_output = None

        for t in range(seq_len):
            input_t = inputs[:, t, :]
            print("input_t:", input_t.shape)
            output, h_0, c_0 = model(input_t, prev_output, h_0, c_0)
            print("out: {}, h_0: {}, c_0: {}".format(output.shape, h_0.shape, c_0.shape))
            prev_output = output

        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(data_loader)}')

# 保存模型
torch.save(model.state_dict(), 'lstm_model.pth')


input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]), h_0: torch.Size([2, 1, 128]), c_0: torch.Size([2, 1, 128])
input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]), h_0: torch.Size([2, 1, 128]), c_0: torch.Size([2, 1, 128])
input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]), h_0: torch.Size([2, 1, 128]), c_0: torch.Size([2, 1, 128])
input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]), h_0: torch.Size([2, 1, 128]), c_0: torch.Size([2, 1, 128])
input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]), h_0: torch.Size([2, 1, 128]), c_0: torch.Size([2, 1, 128])
input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]), h_0: torch.Size([2, 1, 128]), c_0: torch.Size([2, 1, 128])
input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]), h_0: torch.Size([2, 1, 128]), c_0: torch.Size([2, 1, 128])
input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]), h_0: torch.Size([2, 1, 128]), c_0: torch.Size([2, 1, 128])
input_t: torch.Size([1, 2197])
out: torch.Size([1, 2197]

KeyboardInterrupt: 