In [1]:
import torch
import torch.nn as nn

print(torch.__version__)

2.2.2+cpu


In [2]:
def get_batch(batch_size: int, seq_len: int=10) -> torch.Tensor:
    x = torch.randn((batch_size, seq_len, 1))
    y = torch.cumsum(x, dim=1)
    
    return x, y

In [3]:
class Model(nn.Module):
    def __init__(self, hidden_dim: int):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(1, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
    
    def forward(self, x, target=None):
        h, _ = self.lstm(x)
        y = self.fc(h)

        if target is not None:
            loss = nn.MSELoss()(y, target)
            return y, loss
        else: 
            return y, None
        
model = Model(32)
x, y = get_batch(32, 10)
y_hat, loss = model(x, y)

In [4]:
model = Model(50)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

Number of parameters: 10651


In [5]:
# Training loop
total_loss = 0

for i in range(5000):
    x, y = get_batch(512)
    optimizer.zero_grad(set_to_none=True)
    y_pred, loss = model(x, y)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

    if i % 500 == 499:
        print(f"Epoch {i} - Loss: {total_loss / 100}")
        total_loss = 0

Epoch 499 - Loss: 0.6777813371364028
Epoch 999 - Loss: 0.0074887769510678485
Epoch 1499 - Loss: 0.006045350203276029
Epoch 1999 - Loss: 0.01876315464498475
Epoch 2499 - Loss: 0.0009290905124362325
Epoch 2999 - Loss: 0.0034977569584953017
Epoch 3499 - Loss: 0.0036023992789705516
Epoch 3999 - Loss: 0.006128725718372152
Epoch 4499 - Loss: 0.0031855731414907494
Epoch 4999 - Loss: 0.003935306051789667


In [6]:
with torch.no_grad():
    x, y = get_batch(1, 8)
    y_pred, _ = model(x)

    print(x.view(-1))
    print("GT:", y.view(-1))
    print("Pred:", y_pred.view(-1))

    print("Loss:", nn.MSELoss()(y_pred, y).item())
    print("Max difference:", (y_pred - y).abs().max().item())

tensor([-1.5836, -0.2945,  1.4568,  0.1181, -0.8401, -1.7151, -1.8321, -0.5571])
GT: tensor([-1.5836, -1.8781, -0.4213, -0.3032, -1.1433, -2.8584, -4.6906, -5.2477])
Pred: tensor([-1.5836, -1.8733, -0.4260, -0.3078, -1.1487, -2.8627, -4.7022, -5.2570])
Loss: 4.20053765992634e-05
Max difference: 0.011635780334472656
