<a href="https://colab.research.google.com/github/YHL04/tryingoutideas/blob/main/quantifyinguncertainty.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        hidden = self.fc1(x)
        x = F.relu(hidden)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output, hidden


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output, _ = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, _ = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


# Training settings

batch_size = 64
test_batch_size = 1000
lr = 1.0
epochs = 1
gamma = 0.7

torch.manual_seed(0)
device = torch.device("cuda")

train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}

cuda_kwargs = {'num_workers': 1,
                'pin_memory': True,
                'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 337862835.68it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 61740924.48it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 140179410.50it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3968033.49it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw







Test set: Average loss: 0.0432, Accuracy: 9859/10000 (99%)



In [None]:
hidden_features = [torch.zeros(128).cuda() for _ in range(10)]
count = [0 for _ in range(10)]

for batch_idx, (data, target) in enumerate(train_loader):
    data = data.cuda()
    target = target.cuda()

    with torch.no_grad():
        pred, hidden = model(data)

        for h, t in zip(hidden, target):
            hidden_features[t] = hidden_features[t] + h
            count[t] += 1


for i in range(len(hidden_features)):
    hidden_features[i] = hidden_features[i] / count[i]


for h in hidden_features:
    print(h.sum())



tensor(-51.0206, device='cuda:0')
tensor(14.8273, device='cuda:0')
tensor(-52.2406, device='cuda:0')
tensor(-30.6050, device='cuda:0')
tensor(-15.3873, device='cuda:0')
tensor(-31.0936, device='cuda:0')
tensor(-21.9191, device='cuda:0')
tensor(-33.0974, device='cuda:0')
tensor(-32.5966, device='cuda:0')
tensor(-20.4490, device='cuda:0')


In [45]:

import numpy as np


class ConfModel(nn.Module):
    def __init__(self):
        super(ConfModel, self).__init__()
        self.d_model = 128
        self.n_cos = 64

        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, self.d_model)

        self.cos_embedding = nn.Linear(self.n_cos, self.d_model)
        self.linear = nn.Linear(self.d_model, self.d_model)
        self.out = nn.Linear(self.d_model, 1)

        self.pis = torch.FloatTensor([np.pi*i for i in range(1, self.n_cos+1)]).view(1, 1, self.n_cos).cuda()

        self.gelu = nn.GELU()

    def calc_cos(self, batch_size, n_tau=8):
        """
        Calculating the co-sin values depending on the number of tau samples
        """
        assert torch.equal(self.pis,
                           torch.FloatTensor([np.pi*i for i in range(1, self.n_cos+1)]).view(1, 1, self.n_cos).cuda())

        # (batch_size, n_tau, 1)
        taus = torch.rand(batch_size, n_tau).unsqueeze(-1).cuda()
        cos = torch.cos(taus * self.pis)

        assert cos.shape == (batch_size, n_tau, self.n_cos)

        cos = cos.view(batch_size * n_tau, self.n_cos)
        cos = self.gelu(self.cos_embedding(cos))
        cos = cos.view(batch_size, n_tau, self.d_model)

        return cos, taus

    def forward(self, x, n_tau):
        """
        :param x:     Tensor[batch_size, 1, d_model]
        :param n_tau: int
        :return:      Tensor[batch_size, n_tau]
                      Tensor[batch_size, n_tau]
        """
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)

        # IQN begins here
        batch_size = x.size(0)
        x = x.unsqueeze(1)
        assert x.shape == (batch_size, 1, self.d_model)

        cos, taus = self.calc_cos(batch_size, n_tau)

        cos = cos.view(batch_size, n_tau, self.d_model)
        taus = taus.view(batch_size, n_tau)

        x = (x * cos).view(batch_size * n_tau, self.d_model)
        x = self.gelu(self.linear(x))
        x = self.out(x)
        x = x.view(batch_size, n_tau)

        return x, taus


def quantile_loss(expected, target, taus):
    batch_size = expected.size(0)
    n_tau = expected.size(1)

    assert expected.shape == (batch_size, n_tau)
    assert target.shape == (batch_size,)
    assert taus.shape == (batch_size, n_tau)
    assert not taus.requires_grad

    batch_size = expected.size(0)

    expected = expected.view(batch_size, n_tau, 1)
    target = target.unsqueeze(1).repeat(1, n_tau).unsqueeze(1)
    taus = taus.view(batch_size, n_tau, 1)

    assert target.shape == (batch_size, 1, n_tau)

    td_error = target - expected
    huber_loss = torch.where(td_error.abs() <= 1, 0.5 * td_error.pow(2), td_error.abs() - 0.5)
    quantile_loss = abs(taus - (td_error.detach() < 0).float()) * huber_loss

    loss = quantile_loss.sum(dim=1).mean(dim=1)
    loss = loss.mean()

    return loss


confidence_model = ConfModel().to(device)
optimizer = optim.Adam(confidence_model.parameters(), lr=lr)


for batch_idx, (data, target) in enumerate(train_loader):
    data = data.cuda()
    target = target.cuda()
    optimizer.zero_grad()

    pred, _ = model(data)

    error = F.nll_loss(pred, target, reduction="none")
    pred_error, taus = confidence_model(data, n_tau=16)

    loss = quantile_loss(pred_error, error, taus)
    loss.backward()
    optimizer.step()

    if batch_idx % 10 == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))





In [50]:


for batch_idx, (data, target) in enumerate(test_loader):
    data = data.cuda()
    target = target.cuda()

    pred, _ = model(data)
    pred = torch.argmax(pred, dim=1)

    pred_error, taus = confidence_model(data, n_tau=16)
    pred_error = pred_error.mean(dim=1)

    print("pred {} real {} estimated error {}"
          .format(pred[0], target[0], pred_error[0]))



pred 4 real 4 estimated error 0.05287313461303711
pred 0 real 0 estimated error 0.05287313461303711
pred 8 real 8 estimated error 0.05287313461303711
pred 0 real 0 estimated error 0.05287313461303711
pred 6 real 6 estimated error 0.05287313461303711
pred 2 real 2 estimated error 0.05287313461303711
pred 7 real 2 estimated error 0.05287313461303711
pred 7 real 7 estimated error 0.05287313461303711
pred 9 real 9 estimated error 0.05287313461303711
pred 7 real 7 estimated error 0.05287313461303711
