In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchsummary import summary
import wandb

from urnn import URNN

In [None]:
hyp_params = {
    'epochs': 20,
    'batch_size': 1,
    'learning_rate': 1e-4,
}

batch_size = 1

# wandb.init(
#     project='regression-neural-network',
#     name='test run 3',
#     entity='konradszafer'
# )
# wandb.config = hyperparameters

In [None]:
data = np.linspace(0, 0.95, 30)
data += np.random.normal(0, .02, data.shape)
plt.plot(data)
plt.show()

In [None]:
x, y = [], []
window = 3
for i in range(0, len(data)-window):
    x.append(data[i:i+window])
    y.append(data[i+window])
x, y = np.array(x), np.array(y)
x, y = Tensor(x), Tensor(y)

In [None]:
model = URNN(
    input_size=window,
    min_value=0.0,
    max_value=1+1e-2,
    latent_resolution=50
)

model.print_bins()

In [None]:
y = model.digitize(y)
y

In [None]:
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataset[0]

In [None]:
epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


for epoch in range(1, epochs+1):

    total_true = 0
    for seq, target in dataloader:

        output = model(seq)
        # print(seq, target)
        # print(output)

        # loss = model.loss(output, target)
        loss = F.cross_entropy(output, target, reduction='sum')
        loss = loss.sum() / batch_size
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        _, output = torch.max(output, 1)
        total_true += output.eq(target).sum().item()


    # printing and logging
    accuracy = total_true / (batch_size * len(dataloader))

    print( (f'Epoch {epoch}/{epochs} '
            f'Loss: {loss:.3f} '
            f'Acc: {accuracy:.3f} ')
    )
    # wandb.log({'loss': loss})


In [None]:
fig = plt.figure(figsize=(20, 10))
x = Tensor([0, 0.04, 0.07])
output, label, bin = model.predict_sample(x)
print(bin)