In [None]:
import torch
import torch.nn as nn

# 定义一个简单的RNN模型
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)



In [None]:
# 定义一个简单的GRU模型
class SimpleGRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleGRU, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, 3 * hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        reset, update, gate = torch.split(self.i2h(combined), self.hidden_size, dim=1)
        reset = torch.sigmoid(reset)
        update = torch.sigmoid(update)
        gate = torch.tanh(gate)
        hidden = reset * gate + update * hidden
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

# 定义一个简单的LSTM模型
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, 4 * hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        forget, input, output, gate = torch.split(self.i2h(combined), self.hidden_size, dim=1)
        forget = torch.sigmoid(forget)
        input = torch.sigmoid(input)
        output = torch.tanh(output)
        gate = torch.tanh(gate)
        hidden = forget * hidden + input * gate
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return (torch.zeros(1, self.hidden_size), torch.zeros(1, self.hidden_size))

# 定义一个简单的Bi-LSTM模型
class SimpleBiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleBiLSTM, self).__init__()
        self.forward_lstm = SimpleLSTM(input_size, hidden_size, output_size)
        self.backward_lstm = SimpleLSTM(input_size, hidden_size, output_size)

    def forward(self, input, hidden):
        forward_output, forward_hidden = self.forward_lstm(input, hidden)
        backward_output, backward_hidden = self.backward_lstm(input, hidden)
        output = forward_output + backward_output
        hidden = forward_hidden + backward_hidden
        return output, hidden

    def initHidden(self):
        return (torch.zeros(1, self.hidden_size), torch.zeros(1, self.hidden_size))
