# Finding Mislabelled Samples through ResNet MNIST Training Process

This notebook trains a ResNet model using MNIST dataset and employed TrainIng Data analYzer (TIDY) method based on Forgetting Events algorithm, specifically `ForgettingEventsInterpreter`, to investigate the training process by recording the predictions in the process. Some samples are manually mislabelled and we are able to find them by looking into the predictions along the training. 

In [1]:
import paddle
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import interpretdl as it

Define a ResNet architecture for MNIST, the code is borrowed from [PaddlePaddle Official Documentation](https://www.paddlepaddle.org.cn/tutorials/projectdetail/1516124).

In [2]:
import paddle.nn as nn
import paddle.nn.functional as F

class ConvBNLayer(paddle.nn.Layer):
    def __init__(self,
                 num_channels,
                 num_filters,
                 filter_size,
                 stride=1,
                 groups=1,
                 act=None):
        super(ConvBNLayer, self).__init__()

        self._conv = nn.Conv2D(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
            groups=groups,
            bias_attr=False)

        self._batch_norm = paddle.nn.BatchNorm2D(num_filters)
        
        self.act = act

    def forward(self, inputs):
        y = self._conv(inputs)
        y = self._batch_norm(y)
        if self.act == 'leaky':
            y = F.leaky_relu(x=out, negative_slope=0.1)
        elif self.act == 'relu':
            y = F.relu(x=y)
        return y

class BottleneckBlock(paddle.nn.Layer):
    def __init__(self,
                 num_channels,
                 num_filters,
                 stride,
                 shortcut=True):
        super(BottleneckBlock, self).__init__()
        self.conv0 = ConvBNLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=1,
            act='relu')
        self.conv1 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
            act='relu')
        self.conv2 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters * 4,
            filter_size=1,
            act=None)
        if not shortcut:
            self.short = ConvBNLayer(
                num_channels=num_channels,
                num_filters=num_filters * 4,
                filter_size=1,
                stride=stride)

        self.shortcut = shortcut

        self._num_channels_out = num_filters * 4

    def forward(self, inputs):
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)

        if self.shortcut:
            short = inputs
        else:
            short = self.short(inputs)

        y = paddle.add(x=short, y=conv2)
        y = F.relu(y)
        return y

class ResNet(paddle.nn.Layer):
    def __init__(self, layers=50, class_dim=1):
        super(ResNet, self).__init__()
        self.layers = layers
        supported_layers = [50, 101, 152]
        assert layers in supported_layers, \
            "supported layers are {} but input layer is {}".format(supported_layers, layers)

        if layers == 50:
            depth = [3, 4, 6, 3]
        elif layers == 101:
            depth = [3, 4, 23, 3]
        elif layers == 152:
            depth = [3, 8, 36, 3]
        
        num_filters = [64, 128, 256, 512]

        self.conv = ConvBNLayer(
            num_channels=1,
            num_filters=64,
            filter_size=7,
            stride=2,
            act='relu')
        self.pool2d_max = nn.MaxPool2D(
            kernel_size=3,
            stride=2,
            padding=1)

        self.bottleneck_block_list = []
        num_channels = 64
        for block in range(len(depth)):
            shortcut = False
            for i in range(depth[block]):
                bottleneck_block = self.add_sublayer(
                    'bb_%d_%d' % (block, i),
                    BottleneckBlock(
                        num_channels=num_channels,
                        num_filters=num_filters[block],
                        stride=2 if i == 0 and block != 0 else 1, # c3、c4、c5将会在第一个残差块使用stride=2；其余所有残差块stride=1
                        shortcut=shortcut))
                num_channels = bottleneck_block._num_channels_out
                self.bottleneck_block_list.append(bottleneck_block)
                shortcut = True

        self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D(output_size=1)

        import math
        stdv = 1.0 / math.sqrt(2048 * 1.0)
        
        self.out = nn.Linear(in_features=2048, out_features=class_dim,
                      weight_attr=paddle.ParamAttr(
                          initializer=paddle.nn.initializer.Uniform(-stdv, stdv)))

    def forward(self, inputs):
        y = self.conv(inputs)
        y = self.pool2d_max(y)
        for bottleneck_block in self.bottleneck_block_list:
            y = bottleneck_block(y)
        y = self.pool2d_avg(y)
        y = paddle.reshape(y, [y.shape[0], -1])
        y = self.out(y)
        return y

Use the MNIST dataset generator from **paddle.vision** to get the labels and manually mislabel 1% samples.

In [3]:
from paddle.vision.transforms import ToTensor, Resize, Compose
from paddle.vision.datasets import MNIST

train_dataset = MNIST(mode='train', transform=Compose([Resize(size=32), ToTensor()]))

In [4]:
# Prepare manually mislabelled samples
labels = []
for i in range(0, 60000, 100):
    labels.append(np.random.choice(np.delete(np.arange(10), train_dataset[i][-1])))

Initialize the model.

In [5]:
model = ResNet(class_dim=10)

Define a new data generator based on MNIST data generator. It replaces 1% true labels by the wrong ones. 

**Important:** the data generator shoud generate the index of each sample as the first element so that each sample's behavior can be recorded according to its index.

In [6]:
def reader_prepare(dataset, new_labels):
    def reader():
        idx = 0
        for data, label in dataset:
            if idx % 100 == 0:
                label = new_labels[idx // 100]
            yield idx, data, int(label)
            idx += 1
    return reader

Set up a data loader with batch size of 128, and an Momentum optimizer for training.

In [7]:
BATCH_SIZE = 128
train_reader = paddle.batch(
    reader_prepare(train_dataset, labels), batch_size=BATCH_SIZE)
optimizer = paddle.optimizer.Momentum(learning_rate=0.001,
                     momentum=0.9,
                     parameters=model.parameters())

First initialize the `ForgettingEventsInterpreter` and then start `interpret`ing the training process by training 100 epochs. 

*stats* is a dictionary that maps image index to predictions in the training process and if they are correct; *noisy_samples* is a list of mislabelled image ids. *stats* is saved at "assets/stats.pkl".

In [8]:
fe = it.ForgettingEventsInterpreter(model, device='gpu:0')

epochs = 100
print('Training %d epochs. This may take some time.' % epochs)
stats, noisy_samples = fe.interpret(
    train_reader,
    optimizer,
    batch_size=BATCH_SIZE,
    epochs=epochs,
    noisy_labels=True,
    save_path='assets')

Training 100 epochs. This may take some time.
| Epoch [  1/100] Iter[  2]		Loss: 2.5311 Acc@1: 10.938%

  "When training, we now always track global mean and variance.")


| Epoch [100/100] Iter[469]		Loss: 0.0000 Acc@1: 100.000%

Calculate the recall, precision and F1 for our found noisy samples. 

88.7% of mislabelled samples have been found and among those samples found, 80.1% are indeed mislabelled.

In [10]:
recall = np.sum([id_ % 100 == 0 for id_ in noisy_samples]) / (60000 / 100)
precision = np.sum([id_ % 100 == 0 for id_ in noisy_samples]) / len(noisy_samples)
print('Recall: ', recall)
print('Precision: ', precision)
print('F1 Score: ', 2 * (recall * precision) / (recall + precision))

Recall:  0.8866666666666667
Precision:  0.8012048192771084
F1 Score:  0.8417721518987342
