In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [30]:
class LSMT(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSMT, self).__init__()
        self.LSMT = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.LSMT(x)
        out = self.fc(out[:, -1, :])
        return out

In [None]:
num_samples = 5000
sequence_length = 100
input_size = 100
hidden_size = 32
output_size = 1

# Generate synthetic data
data = np.random.rand(num_samples, sequence_length, input_size)
labels = np.random.rand(num_samples, output_size)

# Convert data to PyTorch tensors
data_tensor = torch.tensor(data, dtype=torch.float32)
labels_tensor = torch.tensor(labels, dtype=torch.float32)

# Create DataLoader
dataset = TensorDataset(data_tensor, labels_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Create RNN model
model = LSMT(input_size, hidden_size, output_size)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(dataloader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/100], Loss: 0.0963


In [None]:
# Evaluation
with torch.no_grad():
    tot_error = 0
    num_correct = 0
    model.eval()
    for inputs, labels in dataloader:
        outputs = model(inputs)
        tot_error += criterion(outputs, labels).item()
        num_correct += (outputs.round() == labels).sum().item()
    print(f"Total error: {tot_error:.4f}")
    print(f"Number of correct predictions: {num_correct}")
    print(f"Accuracy: {num_correct / len(dataset):.4f}")

Total error: 4.9826
Number of correct predictions: 0
Accuracy: 0.0000
