In [6]:
import torch
import torch.nn as nn

In [2]:
# Define the RNN model
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)

        # Forward propagate the RNN
        out, _ = self.rnn(x, h0)

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

In [4]:
# Set the hyperparameters
input_size = 10
hidden_size = 32
output_size = 2
num_epochs = 10
learning_rate = 0.001

# Generate some random data for training
num_samples = 100
seq_length = 20
x_train = torch.randn(num_samples, seq_length, input_size)
y_train = torch.randint(0, output_size, (num_samples,))


# Create the RNN model
model = RNN(input_size, hidden_size, output_size)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [5]:
# Train the model
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(x_train)
    loss = criterion(outputs, y_train)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 1 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Test the model
x_test = torch.randn(1, seq_length, input_size)
with torch.no_grad():
    predictions = model(x_test)
    _, predicted_labels = torch.max(predictions, 1)
    print('Predicted Labels:', predicted_labels.item())

Epoch [1/10], Loss: 0.7047
Epoch [2/10], Loss: 0.6992
Epoch [3/10], Loss: 0.6938
Epoch [4/10], Loss: 0.6888
Epoch [5/10], Loss: 0.6839
Epoch [6/10], Loss: 0.6793
Epoch [7/10], Loss: 0.6748
Epoch [8/10], Loss: 0.6706
Epoch [9/10], Loss: 0.6665
Epoch [10/10], Loss: 0.6625
Predicted Labels: 1
