In [1]:
# Program 9

import torch
import torch.nn as nn
import torch.optim as optim


class SimpleRNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNNModel, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out, _ = self.rnn(x)           # RNN output: (batch, seq_len, hidden)
        out = out[:, -1, :]            # Take the output of the last time step
        out = self.fc(out)             # Dense layer
        return out

# Hyperparameters
input_size = 8      # Features per time step
hidden_size = 32    # Number of RNN hidden units
output_size = 1     # Binary classification (use >1 for multi-class)
seq_len = 10
batch_size = 16

# Generate random input and target data
X = torch.randn(batch_size, seq_len, input_size)  # Shape: (batch, seq_len, input_size)
y = torch.randint(0, 2, (batch_size, 1)).float()   # Binary targets
print(X)
print(y)

model = SimpleRNNModel(input_size, hidden_size, output_size)
criterion = nn.BCEWithLogitsLoss()  
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    outputs = model(X)
    loss = criterion(outputs, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

tensor([[[ 0.6361, -1.2551,  0.1355,  ..., -1.0857, -0.2139, -0.9448],
         [ 0.2159, -0.5999,  0.9355,  ...,  0.0058,  0.4996, -0.9498],
         [-0.3503, -1.5568, -0.0612,  ..., -1.8335, -0.5826,  0.8628],
         ...,
         [-1.5708, -1.3555,  0.4085,  ..., -0.1421, -0.8089, -0.7730],
         [-0.8273, -1.2098, -0.2047,  ..., -1.2532,  0.7175, -0.1643],
         [-1.1004, -0.1862,  1.3991,  ...,  0.2250,  0.3076,  0.6879]],

        [[ 0.7800, -0.9208,  1.9385,  ..., -0.1484, -1.7166, -0.7470],
         [-0.4261, -1.5691,  0.8748,  ...,  0.3645,  0.6376, -1.3929],
         [ 0.6706,  0.4302, -1.1697,  ...,  0.2187, -0.0316,  0.1746],
         ...,
         [-0.4948, -1.6718,  0.9679,  ..., -0.0189,  0.5890,  0.0240],
         [-0.7097, -0.1453,  0.7242,  ..., -0.5420, -0.8249,  0.2958],
         [-0.6559, -1.1007,  0.3938,  ..., -0.3635, -1.6678, -0.1567]],

        [[ 0.4035, -2.3576, -1.5998,  ...,  0.5959, -0.7309, -2.7380],
         [ 0.8042, -1.3175,  1.1589,  ..., -0