In [1]:
import numpy as  np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from tqdm import trange, tqdm

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Hyper parameters
input_size = 1
hidden_size = 4
num_layers = 1
learning_rate = 0.01

In [4]:
# Generating random data
def generate_data(num_points, seq_length):
    x = np.random.randint(2, size=(num_points, seq_length, 1))
    y = x.sum(axis=1) % 2
    return x, y

In [5]:
# Recurrent neural network (many to one)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 2)

    def forward(self, x):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # Forward LSTM
        out, _ = self.lstm(x, (h0, c0))

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

In [6]:
model = RNN(input_size, hidden_size, num_layers).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
train_batch_nums = 0
for l in range(2, 10, 2):
    print("sequence length: %s" %l)
    train_X, train_Y = generate_data(100, l)
    dev_X, dev_Y = generate_data(1000, l)
    train_x = torch.from_numpy(train_X) 
    train_y = torch.from_numpy(train_Y)
    dev_x = torch.from_numpy(dev_X)
    dev_y = torch.from_numpy(dev_Y)
    train = TensorDataset(train_x, train_y)
    dev = TensorDataset(dev_x, dev_y)
    train_loader = DataLoader(dataset = train, batch_size = 10, shuffle= True)
    dev_loader = DataLoader(dataset = dev, batch_size = 10, shuffle= False)
    while True:
        train_batch_nums += 1
        # Train the model
        for i, (seqs, labels) in enumerate(train_loader):
            seqs = seqs.to(device, dtype=torch.float)
            labels = labels.squeeze_().to(device, dtype=torch.long)

            # Forward
            outputs = model(seqs)
            loss = criterion(outputs, labels)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # Dev the model
        with torch.no_grad():
            correct = 0
            total = 0
            for seqs, labels in dev_loader:
                seqs = seqs.to(device, dtype=torch.float)
                labels = labels.squeeze().to(device)
                outputs = model(seqs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum()
            if (correct == total):
                break

print("Finish Training")
print("train samples:%s" % (100*train_batch_nums))
        
    

Processing: 100%|██████████| 10/10 [00:00<00:00, 249.32it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 571.29it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 604.09it/s]

torch.Size([100, 2, 1]) torch.Size([100, 1])
torch.Size([1000, 2, 1]) torch.Size([1000, 1])



Processing: 100%|██████████| 10/10 [00:00<00:00, 397.74it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 598.58it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 639.91it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 569.34it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 588.23it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 558.73it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 527.69it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 767.06it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 583.05it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 609.88it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 617.43it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 582.41it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 879.25it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 527.02it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 506.35it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 485.52it/s]
Processing: 100%|██████

torch.Size([100, 4, 1]) torch.Size([100, 1])
torch.Size([1000, 4, 1]) torch.Size([1000, 1])



Processing: 100%|██████████| 10/10 [00:00<00:00, 545.54it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 453.30it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 572.90it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 502.91it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 478.55it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 423.14it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 385.07it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 408.29it/s]

torch.Size([100, 6, 1]) torch.Size([100, 1])
torch.Size([1000, 6, 1]) torch.Size([1000, 1])



Processing: 100%|██████████| 10/10 [00:00<00:00, 377.94it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 385.80it/s]
Processing:   0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([100, 8, 1]) torch.Size([100, 1])
torch.Size([1000, 8, 1]) torch.Size([1000, 1])


Processing: 100%|██████████| 10/10 [00:00<00:00, 293.01it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 355.44it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 323.47it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 312.45it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 283.85it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 244.95it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 377.58it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 311.93it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 402.81it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 287.51it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 250.50it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 348.38it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 325.81it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 349.90it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 308.83it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 304.56it/s]
Processing: 100%|███████

Processing: 100%|██████████| 10/10 [00:00<00:00, 260.58it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 356.91it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 371.40it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 361.64it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 351.48it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 381.16it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 363.38it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 386.81it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 360.48it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 357.37it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 371.65it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 324.38it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 340.51it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 402.21it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 377.53it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 399.39it/s]
Processing: 100%|███████

Processing: 100%|██████████| 10/10 [00:00<00:00, 331.38it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 333.85it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 326.09it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 336.91it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 371.11it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 362.75it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 316.80it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 328.45it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 381.60it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 291.49it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 263.27it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 291.63it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 341.09it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 323.97it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 259.00it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 262.64it/s]
Processing: 100%|███████

Processing: 100%|██████████| 10/10 [00:00<00:00, 377.14it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 349.62it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 366.17it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 349.11it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 400.86it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 375.98it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 382.34it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 379.85it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 359.16it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 320.68it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 363.49it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 353.97it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 347.70it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 372.95it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 349.05it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 382.74it/s]
Processing: 100%|███████

Processing: 100%|██████████| 10/10 [00:00<00:00, 378.24it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 370.87it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 347.97it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 360.45it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 370.85it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 383.56it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 327.84it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 405.59it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 367.85it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 369.73it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 361.48it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 359.26it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 280.00it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 378.10it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 338.48it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 362.03it/s]
Processing: 100%|███████

Processing: 100%|██████████| 10/10 [00:00<00:00, 285.73it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 282.53it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 280.04it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 312.43it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 388.17it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 333.20it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 301.57it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 359.23it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 300.73it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 357.60it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 367.52it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 331.78it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 311.55it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 364.47it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 342.54it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 347.90it/s]
Processing: 100%|███████

Processing: 100%|██████████| 10/10 [00:00<00:00, 358.87it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 359.02it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 245.39it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 256.74it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 351.36it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 351.60it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 360.63it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 365.89it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 365.82it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 364.22it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 286.45it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 267.37it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 287.23it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 371.81it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 353.62it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 256.01it/s]
Processing: 100%|███████

Processing: 100%|██████████| 10/10 [00:00<00:00, 286.27it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 289.22it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 308.63it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 343.18it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 292.56it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 272.03it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 259.13it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 313.72it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 336.92it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 284.90it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 329.03it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 264.44it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 286.44it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 378.14it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 348.90it/s]
Processing: 100%|██████████| 10/10 [00:00<00:00, 346.84it/s]
Processing: 100%|███████

In [10]:
# Test the model in sequence length of 10
test_X, test_Y = generate_data(10000, 10)
test_x = torch.from_numpy(test_X) 
test_y = torch.from_numpy(test_Y)
test = TensorDataset(test_x, test_y)
test_loader = DataLoader(dataset = test, batch_size = 10, shuffle= False)
# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for seqs, labels in test_loader:
        seqs = seqs.to(device, dtype=torch.float)
        labels = labels.squeeze().to(device)
        outputs = model(seqs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()  
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 


Test Accuracy of the model on the 10000 test images: 99 %
