<a href="https://colab.research.google.com/github/Rohan-Narayan/sru-counting/blob/main/SRU_Counting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [82]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class SRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, activation=nn.Tanh()):
        super(SRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activation = activation

        self.reset_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.update_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.candidate = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(self, inputs, state):
        combined = torch.cat((inputs, state), dim=1)
        reset = torch.sigmoid(self.reset_gate(combined))
        update = torch.sigmoid(self.update_gate(combined))
        reset_state = reset * state
        candidate_state = self.candidate(combined)
        new_state = update * state + (1 - update) * candidate_state
        output = reset_state + (1 - reset) * self.activation(new_state)
        return output, new_state

class SRUModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(SRUModel, self).__init__()
        self.sru_cell1 = SRUCell(input_size, hidden_size)
        self.sru_cell2 = SRUCell(hidden_size, hidden_size)
        self.dense = nn.Linear(hidden_size, 1)

    def forward(self, inputs):
        seq_len = inputs.size(1)
        state = torch.zeros((inputs.size(0), self.sru_cell1.hidden_size), device=inputs.device)
        for i in range(seq_len):
            x = inputs[:, i].view(-1, 1)
            x, state = self.sru_cell1(x, state)
            output, state = self.sru_cell2(x, state)

        output = self.dense(output)
        return output


In [87]:
def generate_training_data(num_sequences, sequence_length):
    X_train = []
    y_train = []

    for _ in range(num_sequences):
        start_value = np.random.randint(1, 50)
        new_sequence = [start_value + i for i in range(sequence_length)]
        X_train.append(new_sequence[:-1])
        y_train.append(new_sequence[-1])

    X_train = np.array(X_train)
    y_train = np.array(y_train)

    return X_train, y_train

In [88]:
X_train, y_train = generate_training_data(1000, 4)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)

input_size = 1
hidden_size = 32
model = SRUModel(input_size, hidden_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 10001
for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs.squeeze(), y_train)
    loss.backward()
    optimizer.step()

    if epoch % 1000 == 0:
        print(f'Epoch {epoch}/{epochs}, Loss: {loss.item()}')

# Test the model
X_test = torch.tensor([[2, 3, 4]], dtype=torch.float32)
prediction = model(X_test)
print("Predicted next number:", prediction.item())

Epoch 0/10001, Loss: 1043.95166015625
Epoch 1000/10001, Loss: 0.0015386036830022931
Epoch 2000/10001, Loss: 5.955374945187941e-05
Epoch 3000/10001, Loss: 1.4294002539827488e-05
Epoch 4000/10001, Loss: 4.229873411532026e-06
Epoch 5000/10001, Loss: 1.3856449641025392e-06
Epoch 6000/10001, Loss: 7.454631827386038e-07
Epoch 7000/10001, Loss: 6.734469479852123e-07
Epoch 8000/10001, Loss: 1.6297078673233045e-06
Epoch 9000/10001, Loss: 4.649822926694469e-07
Epoch 10000/10001, Loss: 2.5979539941545227e-07
Predicted next number: 5.001199245452881


In [92]:
X_test = torch.tensor([[1,2,3]], dtype=torch.float32)
prediction = model(X_test)
print("Predicted next number:", round(prediction.item()))

Predicted next number: 4


In [93]:
X_test = torch.tensor([[7,8,9]], dtype=torch.float32)
prediction = model(X_test)
print("Predicted next number:", round(prediction.item()))

Predicted next number: 10


In [94]:
X_test = torch.tensor([[20,21,22]], dtype=torch.float32)
prediction = model(X_test)
print("Predicted next number:", round(prediction.item()))

Predicted next number: 23


In [95]:
X_test = torch.tensor([[83,84,85]], dtype=torch.float32)
prediction = model(X_test)
print("Predicted next number:", round(prediction.item()))

Predicted next number: 86


In [97]:
X_test = torch.tensor([[-3,-2,-1]], dtype=torch.float32)
prediction = model(X_test)
print("Predicted next number:", round(prediction.item()))

Predicted next number: 0
