# Tutorial 1.b: Learning Uncertainty Representations from Data with Gradient Descent

In the previous tutorial we demonstrated several types of predictions and metrics to measure their quality. Naturally, given a quality measurement we can also optimize the predictions to maximize the quality. This is the goal of this tutorial. 


We will use a standard deep learning pipeline. We define a neural network that inputs the features, and outputs a tensor with the correct shape (for example, if want a quantile prediction, the network should output a tensor of shape ``[batch_size, n_quantiles]``). We use torchuq metrics as a learning objective. One of the key property of torchuq is that most metrics are differentiable. Therefore, they can be used as objective functions and optimized with gradient descent. To see which metrics are differentiable see the reference list in [TBD]. 

### Setting up the Environment 

We first setup the environment of the tutorial. First load the necessary dependencies.

In [2]:
from matplotlib import pyplot as plt
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append('../..')
import torchuq 

device = torch.device('cuda:0')  
# device = torch.device('cpu')  # Use this option if you do not have GPU

**Dataset**: We will use the UCI boston dataset. For your convenience torchuq includes a large collection of fairly small datasets with a single interface ``torchuq.dataset.regression.get_regression_datasets`` or ``torchuq.dataset.classification.get_classification_datasets``. All the data files are included with the repo, so you should be able to use these datasets out-of-the-box. For a list of available datasets see [link](https://github.com/ShengjiaZhao/torchuq/tree/main/torchuq/dataset). To access a dataset call 

`` train_dataset, val_dataset, test_dataset = torchuq.dataset.regression.get_regression_datasets(dataset_name, val_fraction=0.2, test_fraction=0.2) ``

You can split the data into train/val/test by setting non-zero values to the arguments ``val_fraction`` and ``test_fraction``. The return values are pytorch Dataset instances, which can be conviniently used with pytorch dataloaders. 

In [6]:
train_dataset, val_dataset, _ = torchuq.dataset.regression.get_regression_datasets('boston', val_fraction=0.2, test_fraction=0.0, verbose=True)
x_dim = len(train_dataset[0][0])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=1)

Loading dataset boston....
Splitting into train/val/test with 405/101/0 samples
Done loading dataset boston


**Prediction Model**: For simplicity, we use a 3 layer fully connected neural network as the prediction function. 

In [7]:
class NetworkFC(nn.Module):
    def __init__(self, x_dim, out_dim=1, num_feat=30):
        super(NetworkFC, self).__init__()
        self.fc1 = nn.Linear(x_dim, num_feat)
        self.fc2 = nn.Linear(num_feat, num_feat)
        self.fc3 = nn.Linear(num_feat, out_dim)
        
    def forward(self, x):
        x = F.leaky_relu(self.fc2(F.leaky_relu(self.fc1(x))))
        out = self.fc3(x)
        return out

### Learning a Point Prediction

Learning a point prediction follows very standard 

In [None]:
from torchuq.metric.point import compute_l2_loss

net = NetworkFC(x_dim, out_dim=1).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)

for epoch in range(100):
    if epoch % 10 == 0:    # Evaluate the validation performance
        with torch.no_grad():  
            val_x, val_y = val_dataset[:]
            pred_val = net(val_x.to(device)).flatten()   # The network outputs an array of shape [batch_size, 1], must flatten to turn into point prediction
            loss = compute_l2_loss(pred_val, val_y.to(device))
            print("Epoch %d, loss=%.4f" % (epoch, loss))
    
    for i, (bx, by) in enumerate(train_loader):  # Standard pytorch training loop
        optimizer.zero_grad()
        pred = net(bx.to(device)).flatten()
        loss = compute_l2_loss(pred, by.to(device))
        loss.backward()
        optimizer.step()
        
# Record the point predictions on the validation set
# val_x, val_y = val_dataset[:]
# predictions_point = net(val_x.to(device)).flatten().cpu().detach()

### Learning a Probability Prediction

There is more flexibility with 

In [8]:
from torchuq.metric.distribution import compute_crps, compute_nll
from torch.distributions.normal import Normal

net = NetworkFC(x_dim, out_dim=2).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-4)

for epoch in range(100):
    if epoch % 10 == 0:    # Evaluate the validation performance
        with torch.no_grad():  
            val_x, val_y = val_dataset[:]
            pred_raw = net(val_x.to(device)) 
            pred_val = Normal(loc=pred_raw[:, 0], scale=pred_raw[:, 1].abs())
            loss = compute_crps(pred_val, val_y.to(device)).mean()
            # loss = compute_nll(pred_val, val_y.to(device)).mean()   
            print("Epoch %d, loss=%.4f" % (epoch, loss))
    
    for i, (bx, by) in enumerate(train_loader):  # Standard pytorch training loop
        optimizer.zero_grad()
        pred_raw = net(bx.to(device)) 
        pred = Normal(loc=pred_raw[:, 0], scale=pred_raw[:, 1].abs())
        loss = compute_crps(pred, by.to(device)).mean()
        loss.backward()
        optimizer.step()
        
# val_x, val_y = val_dataset[:]
# pred_raw = net(val_x.to(device)).cpu().detach()
# predictions_distribution = Normal(loc=pred_raw[:, 0], scale=pred_raw[:, 1].abs())

Epoch 0, loss=0.6168
Epoch 10, loss=0.5125
Epoch 20, loss=0.4510
Epoch 30, loss=0.4199
Epoch 40, loss=0.4008
Epoch 50, loss=0.3834
Epoch 60, loss=0.3589
Epoch 70, loss=0.3338
Epoch 80, loss=0.3075
Epoch 90, loss=0.2873


### Learning a Quantile Prediction

In [None]:
from torchuq.metric.quantile import compute_pinball_loss

# Train quantile loss
net = NetworkFC(x_dim, out_dim=10).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)

for epoch in range(100):
    if epoch % 10 == 0:    # Evaluate the validation performance
        with torch.no_grad():  
            val_x, val_y = val_dataset[:]
            pred_val = net(val_x.to(device))
            loss = compute_pinball_loss(pred_val, val_y.to(device))
            print("Epoch %d, loss=%.4f" % (epoch, loss))
    
    for i, (bx, by) in enumerate(train_loader):  # Standard pytorch training loop
        optimizer.zero_grad()
        pred = net(bx.to(device))
        loss = compute_pinball_loss(pred, by.to(device))
        loss.backward()
        optimizer.step()
        
# Record the quantile predictions on the validation set
val_x, val_y = val_dataset[:]
# predictions_quantile = torch.cummax(net(val_x.to(device)), dim=1)[0].cpu().detach()
predictions_quantile = net(val_x.to(device)).cpu().detach()

In [None]:
from torchuq.metric import quantile
quantile.plot_quantiles(torch.sort(predictions_quantile, dim=1)[0], val_y)

In [None]:
from torchuq.metric.distribution import compute_nll
from torch.distributions.normal import Normal
from torchuq.transform.naive import quantile_to_distribution

net = NetworkFC(x_dim, out_dim=20).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-4)

for epoch in range(100):
    if epoch % 10 == 0:    # Evaluate the validation performance
        with torch.no_grad():  
            val_x, val_y = val_dataset[:]
            pred_raw = net(val_x.to(device)) 
            pred_val = quantile_to_distribution(pred_raw)

            loss = compute_nll(pred_val, val_y.to(device)).mean()   
            print("Epoch %d, loss=%.4f" % (epoch, loss))
    
    for i, (bx, by) in enumerate(train_loader):  # Standard pytorch training loop
        optimizer.zero_grad()
        pred_raw = net(bx.to(device)) 
        pred_val = quantile_to_distribution(pred_raw)
        loss = compute_nll(pred_val, by.to(device)).mean() 
        loss.backward()
        optimizer.step()
        
val_x, val_y = val_dataset[:]
pred_raw = net(val_x.to(device)).cpu().detach()
predictions_distribution = quantile_to_distribution(pred_raw)

In [None]:
from torchuq.metric.distribution import plot_density, plot_reliability_diagram
# Record the quantile predictions on the validation set
plot_density(predictions_distribution, val_y)
plt.show()
plot_reliability_diagram(predictions_distribution, val_y)
plt.show()


In [None]:
# Save all the predictions we have learned in this tutorial. These are the predictions used in the first tutorial
torch.save({'labels': val_y.flatten().cpu(),
            'predictions_point': predictions_point,
            'predictions_quantile': predictions_quantile,
            'predictions_distribution': predictions_distribution,
            'predictions_interval': torch.cat([predictions_quantile[:, :1], predictions_quantile[:, -1:]], dim=1)
            }, 'pretrained/boston_pretrained.tar')