In [1]:
import torch
import torch.nn as nn

class ConvLSTMCell(nn.Module):
    def __init__(self, input_len, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()

        self.input_len = input_len
        self.fc = nn.Linear(input_len, input_dim * kernel_size[0] * kernel_size[1] * kernel_size[2])
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding_size = (kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)
        if any(k % 2 == 0 for k in kernel_size):
            raise ValueError("Only support odd kernel size")
        self.bias = bias
        self.conv = nn.Conv3d(self.input_dim + self.hidden_dim, 
                              4 * self.hidden_dim,  # 4* 是因为后面输出时要切4片
                              self.kernel_size, 
                              padding=self.padding_size, 
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        
        # 经过全连接层
        input_tensor = input_tensor.view(-1, self.input_len)
        input_fc = self.fc(input_tensor)
        
        # 变形
        batch_size = input_tensor.size(0)
        depth, height, width = h_cur.size(2), h_cur.size(3), h_cur.size(4)
        input_fc = input_fc.view(batch_size, self.input_dim, depth, height, width)
        
        combined = torch.cat((input_fc, h_cur), dim=1)
        combined_conv = self.conv(combined)
        cc_f, cc_i, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        f = torch.sigmoid(cc_f)
        i = torch.sigmoid(cc_i)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        print("f: {}, i: {}, o: {}, g: {}, c_cur: {}".format(f.shape, i.shape, o.shape, g.shape, c_cur.shape))
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        depth, height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, depth, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, depth, height, width, device=self.conv.weight.device))

class ConvLSTM(nn.Module):
    def __init__(self, input_len, input_dim, hidden_dim, kernel_size, num_layers, image_size, bias=False):
        super(ConvLSTM, self).__init__()

        self.input_len = input_len
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.image_size = image_size  # 添加 image_size
        self.bias = bias

        cell_list = []
        for i in range(0, num_layers):
            cur_input_dim = input_dim if i == 0 else self.hidden_dim[i-1]
            cell_list.append(ConvLSTMCell(input_len=self.input_len,
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size,
                                          bias=self.bias))
        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state):
        cur_layer_input = input_tensor
        new_hidden_state = []

        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            print(f"Layer {layer_idx} - Hidden state h: {h.shape}, c: {c.shape}")
            h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input, cur_state=[h, c])
            print(f"Layer {layer_idx} - Output h: {h.shape}, c: {c.shape}")
            cur_layer_input = h
            new_hidden_state.append([h, c])

        return cur_layer_input, new_hidden_state

    def init_hidden(self, batch_size):
        return [cell.init_hidden(batch_size, self.image_size) for cell in self.cell_list]

# 测试代码
input_len = 3000  # 输入长度
input_dim = 6
hidden_dim = [8]
kernel_size = (15, 15, 15)
num_layers = 1
image_size = (15, 15, 15)  # 这里假设卷积核大小为3x3x3

model = ConvLSTM(input_len=input_len,
                 input_dim=input_dim,
                 hidden_dim=hidden_dim,
                 kernel_size=kernel_size,
                 num_layers=num_layers,
                 image_size=image_size,  # 传递 image_size
                 bias=True)

# 单个时间步的一维数据
single_time_step_data = torch.randn(1, input_len)
hidden_state = model.init_hidden(batch_size=1)

# 模拟处理时间序列数据
seq_len = 10
for t in range(seq_len):
    output, hidden_state = model(single_time_step_data, hidden_state)
    print(f"Output at time step {t}: {output.shape}")
    print(f"Hidden state at time step {t}: {hidden_state[0][0].shape}")

# 打印最后一个时间步的输出
print(f"Final Output: {output.shape}")

# 打印最后一个时间步的隐藏状态
for layer_idx, (h, c) in enumerate(hidden_state):
    print(f"Layer {layer_idx} - Hidden state h: {h.shape}, c: {c.shape}")


Layer 0 - Hidden state h: torch.Size([1, 8, 15, 15, 15]), c: torch.Size([1, 8, 15, 15, 15])
f: torch.Size([1, 8, 15, 15, 15]), i: torch.Size([1, 8, 15, 15, 15]), o: torch.Size([1, 8, 15, 15, 15]), g: torch.Size([1, 8, 15, 15, 15]), c_cur: torch.Size([1, 8, 15, 15, 15])
Layer 0 - Output h: torch.Size([1, 8, 15, 15, 15]), c: torch.Size([1, 8, 15, 15, 15])
Output at time step 0: torch.Size([1, 8, 15, 15, 15])
Hidden state at time step 0: torch.Size([1, 8, 15, 15, 15])
Layer 0 - Hidden state h: torch.Size([1, 8, 15, 15, 15]), c: torch.Size([1, 8, 15, 15, 15])
f: torch.Size([1, 8, 15, 15, 15]), i: torch.Size([1, 8, 15, 15, 15]), o: torch.Size([1, 8, 15, 15, 15]), g: torch.Size([1, 8, 15, 15, 15]), c_cur: torch.Size([1, 8, 15, 15, 15])
Layer 0 - Output h: torch.Size([1, 8, 15, 15, 15]), c: torch.Size([1, 8, 15, 15, 15])
Output at time step 1: torch.Size([1, 8, 15, 15, 15])
Hidden state at time step 1: torch.Size([1, 8, 15, 15, 15])
Layer 0 - Hidden state h: torch.Size([1, 8, 15, 15, 15]), c:

In [5]:
import torch
import torch.nn as nn

LSTM_NEUROES = 3000

class ConvLSTMCell(nn.Module):
    def __init__(self, input_len, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()

        self.input_len = input_len
        self.fc = nn.Linear(input_len, input_dim * kernel_size[0] * kernel_size[1] * kernel_size[2])
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding_size = (kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)
        if any(k % 2 == 0 for k in kernel_size):
            raise ValueError("Only support odd kernel size")
        self.bias = bias
        self.conv = nn.Conv3d(self.input_dim + self.hidden_dim, 
                              4 * self.hidden_dim,  # 4* 是因为后面输出时要切4片
                              self.kernel_size, 
                              padding=self.padding_size, 
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        
        # 经过全连接层
        input_tensor = input_tensor.view(-1, self.input_len)
        input_fc = self.fc(input_tensor)
        
        # 变形
        batch_size = input_tensor.size(0)
        depth, height, width = h_cur.size(2), h_cur.size(3), h_cur.size(4)
        input_fc = input_fc.view(batch_size, self.input_dim, depth, height, width)
        
        combined = torch.cat((input_fc, h_cur), dim=1)
        combined_conv = self.conv(combined)
        cc_f, cc_i, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        f = torch.sigmoid(cc_f)
        i = torch.sigmoid(cc_i)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        print("f: {}, i: {}, o: {}, g: {}, c_cur: {}".format(f.shape, i.shape, o.shape, g.shape, c_cur.shape))
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        depth, height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, depth, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, depth, height, width, device=self.conv.weight.device))

class ConvLSTM(nn.Module):
    def __init__(self, input_len, input_dim, hidden_dim, kernel_size, num_layers, image_size, bias=False):
        super(ConvLSTM, self).__init__()

        self.input_len = input_len
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.image_size = image_size  # 添加 image_size
        self.bias = bias

        cell_list = []
        for i in range(0, num_layers):
            cur_input_dim = input_dim if i == 0 else self.hidden_dim[i-1]
            cell_list.append(ConvLSTMCell(input_len=self.input_len,
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size,
                                          bias=self.bias))
        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state):
        cur_layer_input = input_tensor
        new_hidden_state = []

        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            print(f"Layer {layer_idx} - Hidden state h: {h.shape}, c: {c.shape}")
            h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input, cur_state=[h, c])
            print(f"Layer {layer_idx} - Output h: {h.shape}, c: {c.shape}")
            cur_layer_input = h
            new_hidden_state.append([h, c])

        return cur_layer_input, new_hidden_state

    def init_hidden(self, batch_size):
        return [cell.init_hidden(batch_size, self.image_size) for cell in self.cell_list]
        



# %%
class CTNN3DDecoder(nn.Module):
    def __init__(self, input_dim):
        super(CTNN3DDecoder, self).__init__()
        
        self.conv1 = nn.ConvTranspose3d(in_channels=input_dim, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.dropout1 = nn.Dropout3d(p=0.25)
        
        self.conv2 = nn.ConvTranspose3d(in_channels=12, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.dropout2 = nn.Dropout3d(p=0.25)
        
        self.conv3 = nn.ConvTranspose3d(in_channels=8, out_channels=4, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')
        self.dropout3 = nn.Dropout3d(p=0.25)
        
        self.conv4 = nn.ConvTranspose3d(in_channels=4, out_channels=4, kernel_size=3, stride=1, padding=1)
        self.relu4 = nn.ReLU()
        self.upsample4 = nn.Upsample(size=(32, 32, 32), mode='nearest')
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.upsample1(out)
        out = self.dropout1(out)

        out = self.conv2(out)
        out = self.relu2(out)
        out = self.upsample2(out)
        out = self.dropout2(out)
        
        out = self.conv3(out)
        out = self.relu3(out)
        out = self.upsample3(out)
        out = self.dropout3(out)
        
        out = self.conv4(out)
        out = self.relu4(out)
        out = self.upsample4(out)
        return out
    
class LSTMDecoder(nn.Module):
    def __init__(self, input_len, input_dim, hidden_dim, kernel_size, num_layers, image_size, bias=False):
        super(LSTMDecoder, self).__init__()
        self.lstm = ConvLSTM(input_len, input_dim, hidden_dim, kernel_size, num_layers, image_size, bias)
        self.decoder = CTNN3DDecoder(hidden_dim[0])
        
    def forward(self, x, hidden_state):
        out, hidden_state = self.lstm(x, hidden_state)
        out = self.decoder(out)
        return out, hidden_state
    
hidden_dim = [8]
kernel_size = (15, 15, 15)
num_layers = 1
image_size = (15, 15, 15)  # 这里假设卷积核大小为3x3x3
single_time_step_data = torch.randn(1, input_len)

model = LSTMDecoder(input_len=input_len, input_dim=16, hidden_dim=hidden_dim, kernel_size=kernel_size, num_layers=num_layers, image_size=image_size, bias=True)
hidden_state = model.lstm.init_hidden(batch_size=1)
output, hidden_state = model(single_time_step_data, hidden_state)
print(f"Output: {output.shape}")


Layer 0 - Hidden state h: torch.Size([1, 8, 15, 15, 15]), c: torch.Size([1, 8, 15, 15, 15])
f: torch.Size([1, 8, 15, 15, 15]), i: torch.Size([1, 8, 15, 15, 15]), o: torch.Size([1, 8, 15, 15, 15]), g: torch.Size([1, 8, 15, 15, 15]), c_cur: torch.Size([1, 8, 15, 15, 15])
Layer 0 - Output h: torch.Size([1, 8, 15, 15, 15]), c: torch.Size([1, 8, 15, 15, 15])
Output: torch.Size([1, 4, 32, 32, 32])
