In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchsummary import summary
import wandb

from prnn import PRNN
from prnn.loss_functions import dist_loss, focal_loss

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print('Using {0} device'.format(device))
torch.backends.cudnn.benchmark = True

In [None]:
HP = {
    'epochs': 50,
    'batch_size': 1,
    'learning_rate': 1e-3,
    'latent_resolution': 15
}

# wandb.init(
#     entity='konradszafer',
#     project='probabilistic-regression-neural-network',
#     name='run 0',
#     notes='''
#         test run
#     ''',
#     config=HP
# )

In [None]:
data = np.linspace(0, 0.95, 30)
# data += np.random.normal(0, .01, data.shape)
plt.plot(data)
plt.show()

x, y = [], []
window = 3
for i in range(0, len(data)-window):
    x.append(data[i:i+window])
    y.append(data[i+window])
x, y = torch.FloatTensor(np.array(x)), torch.FloatTensor(np.array(y))

In [None]:
model = PRNN(
    input_size=window,
    min_value=0.0,
    max_value=1+1e-2,
    latent_resolution=HP['latent_resolution']
)
model.to(device)
model.print_intervals()

In [None]:
y = model.digitize(y)
y

In [None]:
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=HP['batch_size'], shuffle=True)
dataset[0]

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=HP['learning_rate'])

for epoch in range(1, HP['epochs']+1):

    total_true = 0
    for seq, target in dataloader:
        seq, target = seq.to(device), target.to(device)
        output = model(seq)

        # loss = F.cross_entropy(output, target, reduction='sum')
        # require output normalization
        # loss = dist_loss(output, target)
        # print(output, target.unsqueeze(0))
        loss = focal_loss(output, target)
        loss = loss.sum() / HP['batch_size']
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        _, output = torch.max(output, 1)
        total_true += output.eq(target).sum().item()

    # printing and logging
    accuracy = total_true / (HP['batch_size'] * len(dataloader))
    print( (f'Epoch {epoch}/{HP["epochs"]} '
            f'Loss: {loss:.3f} '
            f'Acc: {accuracy:.3f} ')
    )
    # wandb.log({'train loss': loss})
    # wandb.log({'train acc': accuracy})

In [None]:
x = Tensor([0, 0.04, 0.07]).to(device)
output, label, interval = model.predict_sample(x)
print(f'Predicted interval: <{interval[0]}, {interval[1]})')

In [None]:
normalizations = [
    'softmax',
    'logarithmic',
    'exp_softmax',
    'sigmoid',
    'linear',
    'relu',
    'leaky_relu',
]

for normalization in normalizations[:]:
    fig = plt.figure(figsize=(10, 5))
    model.plot_latent_distribution(
        output,
        normalization,
        f'Output normalization: {normalization}',
        # f'latient_distribution_{normalization}_0.jpg',
    )

In [None]:
value, probability = model.estimate(output)
print(f'Predicted value: {value.item():.4f}')
print(f'Probability: {probability.item():.4f}')