# Custom LSTM Experiments Demo

This notebook demonstrates two projects built with a custom LSTM cell implementation:
1. **Binary Counting** — learning to count ones in binary sequences.
2. **Sine Wave Prediction** — learning to predict the next values in a sine wave sequence.

We'll explore training, evaluation, and visualization for both projects.

---

## Setup and Imports


In [None]:
import torch
import matplotlib.pyplot as plt
import time
from torch.utils.data import DataLoader, TensorDataset
import copy

from lstm_cell import LSTMCell
from binary_counting import LSTMModel, train_model
from sine_wave_prediction import LSTMModelTrig, input_data


---

## 1. Binary Counting Task

In this project, our custom LSTM model learns to count the number of ones in a binary sequence.

### Dataset
Random binary sequences of length 20; targets are the sum of ones in each sequence.

### Model
Custom LSTM cell + fully connected output layer.


In [None]:
# Parameters
EPOCHS = 10
TRAINING_SAMPLES = 2000
BATCH_SIZE = 32
TEST_SAMPLES = 500
SEQUENCE_LENGTH = 20
HIDDEN_UNITS = 20

# Data
X_train = torch.randint(0, 2, (TRAINING_SAMPLES, SEQUENCE_LENGTH, 1)).float()
X_test = torch.randint(0, 2, (TEST_SAMPLES, SEQUENCE_LENGTH, 1)).float()
y_train = torch.sum(X_train, dim=1)
y_test = torch.sum(X_test, dim=1)

train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE)

# Model and training setup
model = LSTMModel(input_size=1, hidden_size=HIDDEN_UNITS, output_dim=1)
model_untrained = copy.deepcopy(model)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Train
train_loss, train_acc, valid_loss, valid_acc = train_model(
    model, loss_fn, optimizer, train_dl, test_dl, tolerance=1
)

# Visualization: Before and After Training
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
y_pred_before = model_untrained(X_test).detach().numpy().round()
plt.scatter(y_test.numpy(), y_pred_before, alpha=0.3)
plt.title("Before Training")
plt.xlabel("True Sum")
plt.ylabel("Predicted Sum")

plt.subplot(1, 2, 2)
y_pred_after = model(X_test).detach().numpy().round()
plt.scatter(y_test.numpy(), y_pred_after, alpha=0.3)
plt.title("After Training")
plt.xlabel("True Sum")
plt.ylabel("Predicted Sum")

plt.suptitle("Binary Counting Task")
plt.show()


---

## 2. Sine Wave Prediction Task

Here, we train a custom LSTM model to predict the next values of a sine wave sequence using a sliding window approach.

### Dataset
Sine wave values sampled over 800 points; last 40 points held out for testing.

### Model
Custom LSTM cell + fully connected output layer.


In [None]:
torch.manual_seed(20)
window_size = 40
num_epochs = 5
learning_rate = 0.01

# Prepare data
x = torch.linspace(0, 799, 800)
y = torch.sin(x * torch.pi * 2 / 40)
y_train, y_test = y[:-40], y[-40:]
train_data = input_data(y_train, window_size)

# Model setup
model = LSTMModelTrig(input_size=1, hidden_size=50)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    train_loss = 0.0
    start_time = time.time()
    model.train()
    for x_batch, y_batch in train_data:
        x_batch = x_batch.view(1, window_size, 1)
        y_pred = model(x_batch)[0]
        loss = loss_fn(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    train_loss /= len(train_data)

    # Predict last window_size points
    sample_input = y_train[-window_size:].tolist()
    model.eval()
    with torch.no_grad():
        for _ in range(window_size):
            input_tensor = torch.tensor(sample_input[-window_size:]).view(1, window_size, 1)
            pred = model(input_tensor).item()
            sample_input.append(pred)

    test_loss = loss_fn(torch.tensor(sample_input[-window_size:]), y_train[-window_size:])
    print(f"Epoch {epoch+1}/{num_epochs} | {time.time() - start_time:.2f} sec | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}")

# Visualization of predictions
plt.figure(figsize=(12, 4))
plt.plot(y.numpy(), color="#8000ff", label="True Signal")
plt.plot(range(760, 800), sample_input[-window_size:], color="#ff8000", label="Predicted")
plt.legend(loc="upper left")
plt.title("Sine Wave Prediction Task")
plt.xlabel("Time Step")
plt.ylabel("sin(x)")
plt.show()


---

## Conclusion

- The **Binary Counting model** effectively learns to count the number of ones in binary sequences, as shown by the improved alignment of predicted vs. true sums after training.
- The **Sine Wave Prediction model** successfully predicts the future values of the sine wave based on past inputs, capturing the periodicity well.
- Both models use the same custom LSTM cell implementation, demonstrating the flexibility and power of custom recurrent units in PyTorch.
- This notebook can serve as a foundation for experimenting with custom RNN/LSTM cells on various sequence tasks.

Feel free to explore and modify the models and training parameters to deepen your understanding!

---
