# Self KL-divergence for detecting out of distribution data and unsupervised text classification
> Running two models alongside each other for trivial out of distribution detection in production models and side bonus is getting unsupervised text classification

- toc: true 
- badges: true
- comments: true
- categories: [ml, nlp, kldivergence]
- image: images/chart-preview.png

> TL;DR. By training two models at the same time (same architecture, same loss, but different initialization)
> I was able to obtain a consistent out-of-distribution detector by measuring the kl-divergence between model outputs.
> This out-of-distribution measure used on text could lead to unsupervised text classification.


## What's the problem ?

ML models usually are not really capable of predicting how well the data you     
feed them is close to what was in the dataset. It really matters in production 
models as they might make really stupid mistakes just because they are off       
the training set.                             


                     

Let's train a simple mnist model

In [None]:
#hide
!pip install torch torchvision

Collecting torch
  Downloading torch-1.4.0-cp37-cp37m-manylinux1_x86_64.whl (753.4 MB)
[K     |█                               | 23.9 MB 1.1 MB/s eta 0:10:55

In [None]:
#collapse
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import os


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def mnist():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(os.path.expanduser('~/data'), train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
        
# Notebook specific hack
import sys; sys.argv=['']; del sys
mnist()



Then generate an random image for which the model is highly confident yet it's completely absurd. This new image is out of distribution yet the model does not know it. We want to avoid doing such mistakes in production.

In [6]:
#collapse
from torch.distributions import Categorical
from torch.nn.parameter import Parameter
from torchvision import transforms

mnist_model = Net()
mnist_model.load_state_dict(torch.load('mnist_cnn.pt'))

dummy_input = Parameter(torch.rand(1, 1, 28, 28, requires_grad=True))

lr = 1
optimizer = optim.Adadelta([dummy_input], lr=lr)
for i in range(100):
    output = mnist_model(dummy_input)
    entropy = Categorical(logits = output).entropy()
    print(f'Entropy {entropy.item():.2f}')
    optimizer.zero_grad()
    entropy.backward()
    optimizer.step()

MAX = output[0].exp().max(dim=-1)
print(f"MNIST Model says : This is a {MAX.indices.item()} with probability {MAX.values.item() * 100:.2f}%")
pil_img = transforms.Resize((240, 240))(transforms.ToPILImage()(dummy_input[0]))
display(pil_img)


ModuleNotFoundError: No module named 'torch'

- [x] Find a good example of failure                                           
- [x] Find classical solutions to this problem.                                
- [ ] Good thing that makes this measure better than other methods is that it's measured in bits/nats so you know.
- [ ] Less tl;dr as to why this should work (lottery ticket hypothesis, model structure forces some relevant manifold of the output
                                                                               

## Other approaches


- https://ai.googleblog.com/2019/12/improving-out-of-distribution-detection.html
- https://arxiv.org/pdf/1910.04241.pdf (Manifold approximation via Embedding (VAE or GAN) and out-of-distribution sampling via Manifold perturbation.)
- https://paperswithcode.com/task/out-of-distribution-detection                
- https://openreview.net/pdf?id=Hkxzx0NtDB (Energy based model, hard to train but effective, learn `p(x, y)` at the same time as `p(y|x)`)
- https://arxiv.org/pdf/1802.04865v1.pdf (Adding extra loss making optimization problem joint between confidence and accuracy)

## Our approach

Tl;dr : Make two similar models, with two different random initialization, then train them at the same time.
Check their converged average kl-distance on the train set, that will give you a baseline of what similar is.
                                                                               
Check what kind of values do you get on test/validation set. You should get something similar or higher.
                                                                               
Then you have by measuring this self kl-divergence on new sample a measure of newness. Then it's a matter of choosing your
own threshold about what's acceptable or not.                                  
                                                                               
linked to the training data leading to good properties in terms of out of distribution values).
                                                                               

## Experiments

- Test two identical networks. With same training we should have kl-divergence = 0 everywhere. So no possibility of detecting out of distribution.
  Test on widely different architecture and check that we don't get correct results
- On same architecture, different initialization show that it can be used for out of distribution detection for english, french and the train set.
- Test with various initialization patterns, with various architectures. Show that it's linked to 
  descent method, and probably structure of network (does not seem fully generalizable)
- Test with random inputs to check that it works.                              
- Test with adversarial sampling to see if we can generate samples from it. 

## Unsupervised text classification

- Show that small network trained on a single english book enables to detect different languages
  or different patterns of writing (old english, irish, french, or event dictionnaries)
- The detection is super fined grained capable of detecting english within a French book.

## Limits

## Future Work