In [4]:
import sys
sys.path.append('../algorithm')

In [5]:
# Load in CIFAR dataset
from datasets import load_data
from config import Dataset
train_data, val_data, test_data = load_data("../data",
                                            Dataset.CIFAR,
                                            subset=0.8,  # irrelevant for this task
                                            seed=0
                                           )

seed: 0
shuffle: [1670, 13379, 10234, 4719, 7003, 2831, 13014, 11979, 8610, 519, ...]


In [6]:
# Apply known noise to the labels
# Let's first try the same noise as the MNIST_FASHION_05 has
import numpy as np

test_labels = test_data.tensors[1]
test_labels_clean = test_labels.clone().detach()
T = np.array([[0.5, 0.3, 0.2], [0.2, 0.5, 0.3], [0.3, 0.2, 0.5]])
for i in range(len(test_labels)):
    # Given this is the true label, use the corresponding noisy probs to get a noisy label
    test_labels[i] = np.random.choice(np.array([0, 1, 2]), p=T[test_labels[i]])

In [7]:
# Check how many are equal to their original value (should be roughly 0.5)
(test_labels_clean == test_labels).sum() / len(test_labels)

tensor(0.4953)

In [8]:
# Here we want to use the test data as our `training data` for estimating the transition matrix
train_data = test_data

In [35]:
# See how our transition matrix esitmator performs on this data
import argparse
import hashlib
import json
import os.path
import sys
from copy import copy
from datetime import datetime
from pathlib import Path
from typing import Callable, Iterable, List, Optional

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from datasets import load_data
from factories import BackboneFactory, EstimatorFactory, LossFactory
from loggers import JSONLLogger, Logger, StreamLogger, WandbLogger
from sklearn.metrics import accuracy_score
from termcolor import colored
from torch.utils.data import DataLoader
from utils import LabelSmoothingCrossEntropyLoss

from config import Backbone, Dataset, Estimator, LossCorrection

anchor_trans_thresh = 0.9
# Adjust params here
config_dict = {'dataset': Dataset.CIFAR, 'subset': 0.8, 'seed': 0, 'loss_correction': LossCorrection.FORWARD, 'backbone': Backbone.MLP, 'estimator': Estimator.ANCHOR, 'freeze_estimator': True, 'anchor_outlier_threshold': anchor_trans_thresh, 'epochs': 0, 'lr': 0.01, 'backbone_pretrain_epochs': 150, 'label_smoothing': 0.0, 'batch_size': 32, 'device': 'cpu', 'id': '20201118_164657', 'log_step': 32, 'results_dir': '/Users/nick/uni/comp5328/assignment2/comp-5328-assignment-two/code/results', 'wandb': False}
config = argparse.Namespace()
vars(config).update(config_dict)


def generate_trans_cifar(config):
    loggers = [StreamLogger()]
    input_size = tuple(train_data.tensors[0].size()[1:])
    class_count = len(set(train_data.tensors[1].tolist()))
    criterion = torch.nn.CrossEntropyLoss()
    
    # Create backbone
    print(colored("backbone:", attrs=["bold"]))
    backbone_factory = BackboneFactory(input_size, class_count)
    backbone = backbone_factory.create(config)
    print(backbone)
    
    # Perform pretraining on noisy data without the transition matrix if necessary
    if config.backbone_pretrain_epochs > 0:
        print(colored("pretraining backbone:", attrs=["bold"]))
        backbone = backbone.to(config.device)
        pretrain_backbone(
            backbone,
            train_data,
            torch.optim.SGD(backbone.parameters(), lr=1e-3),
            criterion,
            loggers,
            config,
        )

    # Estimator could be None if we don't want to use a transition matrix
    # Create transition matrix
    print(colored("estimator:", attrs=["bold"]))
    estimator_factory = EstimatorFactory(class_count)
    estimator = estimator_factory.create(
        config,
        backbone,
        DataLoader(train_data, batch_size=config.batch_size, shuffle=False, num_workers=0),
    )
    if estimator is not None:
        print(
            f"Transition matrix to be usd by the Label Noise Robust Model:\n"
            f"{estimator.transitions=}"
        )






def pretrain_backbone(
    backbone: torch.nn.Module,
    data: torch.utils.data.Dataset,
    optimiser: torch.optim.Optimizer,
    criterion: Optional[Callable[..., torch.Tensor]],
    loggers: Iterable[Logger],
    config: argparse.Namespace,
):
    """Pretrain a backbone model."""
    dataloader = DataLoader(data, batch_size=config.batch_size, shuffle=True, num_workers=0)
    class_count = len(set(dataloader.dataset.tensors[1].tolist()))
    backbone.train()
    for epoch in range(config.backbone_pretrain_epochs):
        for batch, (feats, labels) in enumerate(dataloader):
            # Move data to GPU
            feats = feats.to(config.device)
            # Convert labels to one-hots if using BCEWithLogitsLoss
            if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
                labels = F.one_hot(labels, num_classes=class_count).type(torch.float32)
            labels = labels.to(config.device)
            optimiser.zero_grad()
            noisy_posteriors, noisy_activations = backbone(feats)
            loss = criterion(noisy_activations, labels)
            loss.backward()

            optimiser.step()

            if batch % config.log_step == config.log_step - 1 or batch == len(dataloader) - 1:
                metrics = {
                    "pretrain/epoch": epoch
                    + batch * dataloader.batch_size / len(dataloader.dataset),
                    "pretrain/loss": loss.item(),
                }
                for logger in loggers:
                    logger(metrics)
                    
                    
generate_trans_cifar(config)

[1mbackbone:[0m
MLPBackbone(
  (fc_0): Linear(in_features=3072, out_features=1537, bias=True)
  (act_0): ReLU()
  (dropout_0): Dropout(p=0.5, inplace=False)
  (fc_1): Linear(in_features=1537, out_features=1537, bias=True)
  (act_1): ReLU()
  (fc_2): Linear(in_features=1537, out_features=3, bias=True)
  (sm): Softmax(dim=-1)
)
[1mpretraining backbone:[0m
pretrain/epoch: [36m0.3307[0m pretrain/loss: [36m1.0928[0m
pretrain/epoch: [36m0.6720[0m pretrain/loss: [36m1.1056[0m
pretrain/epoch: [36m0.9920[0m pretrain/loss: [36m1.0942[0m
pretrain/epoch: [36m1.3307[0m pretrain/loss: [36m1.1005[0m
pretrain/epoch: [36m1.6720[0m pretrain/loss: [36m1.1003[0m
pretrain/epoch: [36m1.9920[0m pretrain/loss: [36m1.1045[0m
pretrain/epoch: [36m2.3307[0m pretrain/loss: [36m1.0971[0m
pretrain/epoch: [36m2.6720[0m pretrain/loss: [36m1.1055[0m
pretrain/epoch: [36m2.9920[0m pretrain/loss: [36m1.1051[0m
pretrain/epoch: [36m3.3307[0m pretrain/loss: [36m1.1137[0m
pretrain/e

pretrain/epoch: [36m41.3307[0m pretrain/loss: [36m1.0568[0m
pretrain/epoch: [36m41.6720[0m pretrain/loss: [36m1.0863[0m
pretrain/epoch: [36m41.9920[0m pretrain/loss: [36m1.1105[0m
pretrain/epoch: [36m42.3307[0m pretrain/loss: [36m1.0621[0m
pretrain/epoch: [36m42.6720[0m pretrain/loss: [36m1.0806[0m
pretrain/epoch: [36m42.9920[0m pretrain/loss: [36m1.0669[0m
pretrain/epoch: [36m43.3307[0m pretrain/loss: [36m1.0524[0m
pretrain/epoch: [36m43.6720[0m pretrain/loss: [36m1.0838[0m
pretrain/epoch: [36m43.9920[0m pretrain/loss: [36m1.0640[0m
pretrain/epoch: [36m44.3307[0m pretrain/loss: [36m1.1182[0m
pretrain/epoch: [36m44.6720[0m pretrain/loss: [36m1.1366[0m
pretrain/epoch: [36m44.9920[0m pretrain/loss: [36m1.0670[0m
pretrain/epoch: [36m45.3307[0m pretrain/loss: [36m1.0731[0m
pretrain/epoch: [36m45.6720[0m pretrain/loss: [36m1.0688[0m
pretrain/epoch: [36m45.9920[0m pretrain/loss: [36m1.0580[0m
pretrain/epoch: [36m46.3307[0m pretrai

pretrain/epoch: [36m84.3307[0m pretrain/loss: [36m1.0549[0m
pretrain/epoch: [36m84.6720[0m pretrain/loss: [36m1.0823[0m
pretrain/epoch: [36m84.9920[0m pretrain/loss: [36m1.0713[0m
pretrain/epoch: [36m85.3307[0m pretrain/loss: [36m1.1434[0m
pretrain/epoch: [36m85.6720[0m pretrain/loss: [36m1.0894[0m
pretrain/epoch: [36m85.9920[0m pretrain/loss: [36m1.0734[0m
pretrain/epoch: [36m86.3307[0m pretrain/loss: [36m1.0807[0m
pretrain/epoch: [36m86.6720[0m pretrain/loss: [36m1.1141[0m
pretrain/epoch: [36m86.9920[0m pretrain/loss: [36m1.1122[0m
pretrain/epoch: [36m87.3307[0m pretrain/loss: [36m1.0506[0m
pretrain/epoch: [36m87.6720[0m pretrain/loss: [36m1.0418[0m
pretrain/epoch: [36m87.9920[0m pretrain/loss: [36m1.0404[0m
pretrain/epoch: [36m88.3307[0m pretrain/loss: [36m1.1091[0m
pretrain/epoch: [36m88.6720[0m pretrain/loss: [36m1.0505[0m
pretrain/epoch: [36m88.9920[0m pretrain/loss: [36m1.0789[0m
pretrain/epoch: [36m89.3307[0m pretrai

pretrain/epoch: [36m126.6720[0m pretrain/loss: [36m1.0024[0m
pretrain/epoch: [36m126.9920[0m pretrain/loss: [36m1.0539[0m
pretrain/epoch: [36m127.3307[0m pretrain/loss: [36m1.0610[0m
pretrain/epoch: [36m127.6720[0m pretrain/loss: [36m1.1175[0m
pretrain/epoch: [36m127.9920[0m pretrain/loss: [36m1.0507[0m
pretrain/epoch: [36m128.3307[0m pretrain/loss: [36m1.0141[0m
pretrain/epoch: [36m128.6720[0m pretrain/loss: [36m1.0853[0m
pretrain/epoch: [36m128.9920[0m pretrain/loss: [36m1.0504[0m
pretrain/epoch: [36m129.3307[0m pretrain/loss: [36m1.0373[0m
pretrain/epoch: [36m129.6720[0m pretrain/loss: [36m1.0152[0m
pretrain/epoch: [36m129.9920[0m pretrain/loss: [36m1.0733[0m
pretrain/epoch: [36m130.3307[0m pretrain/loss: [36m1.0399[0m
pretrain/epoch: [36m130.6720[0m pretrain/loss: [36m1.0370[0m
pretrain/epoch: [36m130.9920[0m pretrain/loss: [36m1.0121[0m
pretrain/epoch: [36m131.3307[0m pretrain/loss: [36m0.9928[0m
pretrain/epoch: [36m131.

In [36]:
anchor_trans_thresh = 0.97
# Adjust params here
config_dict = {'dataset': Dataset.CIFAR, 'subset': 0.8, 'seed': 0, 'loss_correction': LossCorrection.FORWARD, 'backbone': Backbone.MLP, 'estimator': Estimator.ANCHOR, 'freeze_estimator': True, 'anchor_outlier_threshold': anchor_trans_thresh, 'epochs': 0, 'lr': 0.01, 'backbone_pretrain_epochs': 150, 'label_smoothing': 0.0, 'batch_size': 32, 'device': 'cpu', 'id': '20201118_164657', 'log_step': 32, 'results_dir': '/Users/nick/uni/comp5328/assignment2/comp-5328-assignment-two/code/results', 'wandb': False}
config = argparse.Namespace()
vars(config).update(config_dict)

generate_trans_cifar(config)



[1mbackbone:[0m
MLPBackbone(
  (fc_0): Linear(in_features=3072, out_features=1537, bias=True)
  (act_0): ReLU()
  (dropout_0): Dropout(p=0.5, inplace=False)
  (fc_1): Linear(in_features=1537, out_features=1537, bias=True)
  (act_1): ReLU()
  (fc_2): Linear(in_features=1537, out_features=3, bias=True)
  (sm): Softmax(dim=-1)
)
[1mpretraining backbone:[0m
pretrain/epoch: [36m0.3307[0m pretrain/loss: [36m1.1194[0m
pretrain/epoch: [36m0.6720[0m pretrain/loss: [36m1.0971[0m
pretrain/epoch: [36m0.9920[0m pretrain/loss: [36m1.0907[0m
pretrain/epoch: [36m1.3307[0m pretrain/loss: [36m1.0950[0m
pretrain/epoch: [36m1.6720[0m pretrain/loss: [36m1.1030[0m
pretrain/epoch: [36m1.9920[0m pretrain/loss: [36m1.0971[0m
pretrain/epoch: [36m2.3307[0m pretrain/loss: [36m1.0922[0m
pretrain/epoch: [36m2.6720[0m pretrain/loss: [36m1.1094[0m
pretrain/epoch: [36m2.9920[0m pretrain/loss: [36m1.0973[0m
pretrain/epoch: [36m3.3307[0m pretrain/loss: [36m1.0894[0m
pretrain/e

pretrain/epoch: [36m41.3307[0m pretrain/loss: [36m1.0459[0m
pretrain/epoch: [36m41.6720[0m pretrain/loss: [36m1.0910[0m
pretrain/epoch: [36m41.9920[0m pretrain/loss: [36m1.1178[0m
pretrain/epoch: [36m42.3307[0m pretrain/loss: [36m1.0832[0m
pretrain/epoch: [36m42.6720[0m pretrain/loss: [36m1.0640[0m
pretrain/epoch: [36m42.9920[0m pretrain/loss: [36m1.0747[0m
pretrain/epoch: [36m43.3307[0m pretrain/loss: [36m1.0912[0m
pretrain/epoch: [36m43.6720[0m pretrain/loss: [36m1.0436[0m
pretrain/epoch: [36m43.9920[0m pretrain/loss: [36m1.0736[0m
pretrain/epoch: [36m44.3307[0m pretrain/loss: [36m1.0543[0m
pretrain/epoch: [36m44.6720[0m pretrain/loss: [36m1.0741[0m
pretrain/epoch: [36m44.9920[0m pretrain/loss: [36m1.0458[0m
pretrain/epoch: [36m45.3307[0m pretrain/loss: [36m1.1114[0m
pretrain/epoch: [36m45.6720[0m pretrain/loss: [36m1.0777[0m
pretrain/epoch: [36m45.9920[0m pretrain/loss: [36m1.0568[0m
pretrain/epoch: [36m46.3307[0m pretrai

pretrain/epoch: [36m84.3307[0m pretrain/loss: [36m1.0872[0m
pretrain/epoch: [36m84.6720[0m pretrain/loss: [36m1.0730[0m
pretrain/epoch: [36m84.9920[0m pretrain/loss: [36m1.0073[0m
pretrain/epoch: [36m85.3307[0m pretrain/loss: [36m1.0496[0m
pretrain/epoch: [36m85.6720[0m pretrain/loss: [36m1.1048[0m
pretrain/epoch: [36m85.9920[0m pretrain/loss: [36m1.0774[0m
pretrain/epoch: [36m86.3307[0m pretrain/loss: [36m1.0716[0m
pretrain/epoch: [36m86.6720[0m pretrain/loss: [36m1.0396[0m
pretrain/epoch: [36m86.9920[0m pretrain/loss: [36m1.0247[0m
pretrain/epoch: [36m87.3307[0m pretrain/loss: [36m1.0458[0m
pretrain/epoch: [36m87.6720[0m pretrain/loss: [36m1.0380[0m
pretrain/epoch: [36m87.9920[0m pretrain/loss: [36m1.0795[0m
pretrain/epoch: [36m88.3307[0m pretrain/loss: [36m1.1142[0m
pretrain/epoch: [36m88.6720[0m pretrain/loss: [36m1.1126[0m
pretrain/epoch: [36m88.9920[0m pretrain/loss: [36m1.0674[0m
pretrain/epoch: [36m89.3307[0m pretrai

pretrain/epoch: [36m126.6720[0m pretrain/loss: [36m1.0850[0m
pretrain/epoch: [36m126.9920[0m pretrain/loss: [36m1.0894[0m
pretrain/epoch: [36m127.3307[0m pretrain/loss: [36m1.0747[0m
pretrain/epoch: [36m127.6720[0m pretrain/loss: [36m1.0234[0m
pretrain/epoch: [36m127.9920[0m pretrain/loss: [36m1.0111[0m
pretrain/epoch: [36m128.3307[0m pretrain/loss: [36m0.9773[0m
pretrain/epoch: [36m128.6720[0m pretrain/loss: [36m1.0554[0m
pretrain/epoch: [36m128.9920[0m pretrain/loss: [36m1.1408[0m
pretrain/epoch: [36m129.3307[0m pretrain/loss: [36m1.0608[0m
pretrain/epoch: [36m129.6720[0m pretrain/loss: [36m1.1488[0m
pretrain/epoch: [36m129.9920[0m pretrain/loss: [36m0.9669[0m
pretrain/epoch: [36m130.3307[0m pretrain/loss: [36m1.0607[0m
pretrain/epoch: [36m130.6720[0m pretrain/loss: [36m1.1439[0m
pretrain/epoch: [36m130.9920[0m pretrain/loss: [36m1.0076[0m
pretrain/epoch: [36m131.3307[0m pretrain/loss: [36m1.0546[0m
pretrain/epoch: [36m131.

In [37]:
anchor_trans_thresh = 1
# Adjust params here
config_dict = {'dataset': Dataset.CIFAR, 'subset': 0.8, 'seed': 0, 'loss_correction': LossCorrection.FORWARD, 'backbone': Backbone.MLP, 'estimator': Estimator.ANCHOR, 'freeze_estimator': True, 'anchor_outlier_threshold': anchor_trans_thresh, 'epochs': 0, 'lr': 0.01, 'backbone_pretrain_epochs': 150, 'label_smoothing': 0.0, 'batch_size': 32, 'device': 'cpu', 'id': '20201118_164657', 'log_step': 32, 'results_dir': '/Users/nick/uni/comp5328/assignment2/comp-5328-assignment-two/code/results', 'wandb': False}
config = argparse.Namespace()
vars(config).update(config_dict)

generate_trans_cifar(config)



[1mbackbone:[0m
MLPBackbone(
  (fc_0): Linear(in_features=3072, out_features=1537, bias=True)
  (act_0): ReLU()
  (dropout_0): Dropout(p=0.5, inplace=False)
  (fc_1): Linear(in_features=1537, out_features=1537, bias=True)
  (act_1): ReLU()
  (fc_2): Linear(in_features=1537, out_features=3, bias=True)
  (sm): Softmax(dim=-1)
)
[1mpretraining backbone:[0m
pretrain/epoch: [36m0.3307[0m pretrain/loss: [36m1.1017[0m
pretrain/epoch: [36m0.6720[0m pretrain/loss: [36m1.0939[0m
pretrain/epoch: [36m0.9920[0m pretrain/loss: [36m1.1065[0m
pretrain/epoch: [36m1.3307[0m pretrain/loss: [36m1.1111[0m
pretrain/epoch: [36m1.6720[0m pretrain/loss: [36m1.0958[0m
pretrain/epoch: [36m1.9920[0m pretrain/loss: [36m1.0994[0m
pretrain/epoch: [36m2.3307[0m pretrain/loss: [36m1.0983[0m
pretrain/epoch: [36m2.6720[0m pretrain/loss: [36m1.1027[0m
pretrain/epoch: [36m2.9920[0m pretrain/loss: [36m1.1050[0m
pretrain/epoch: [36m3.3307[0m pretrain/loss: [36m1.0993[0m
pretrain/e

pretrain/epoch: [36m41.3307[0m pretrain/loss: [36m1.0618[0m
pretrain/epoch: [36m41.6720[0m pretrain/loss: [36m1.0369[0m
pretrain/epoch: [36m41.9920[0m pretrain/loss: [36m1.0393[0m
pretrain/epoch: [36m42.3307[0m pretrain/loss: [36m1.1019[0m
pretrain/epoch: [36m42.6720[0m pretrain/loss: [36m1.0979[0m
pretrain/epoch: [36m42.9920[0m pretrain/loss: [36m1.0874[0m
pretrain/epoch: [36m43.3307[0m pretrain/loss: [36m1.0720[0m
pretrain/epoch: [36m43.6720[0m pretrain/loss: [36m1.0917[0m
pretrain/epoch: [36m43.9920[0m pretrain/loss: [36m1.0725[0m
pretrain/epoch: [36m44.3307[0m pretrain/loss: [36m1.0913[0m
pretrain/epoch: [36m44.6720[0m pretrain/loss: [36m1.0786[0m
pretrain/epoch: [36m44.9920[0m pretrain/loss: [36m1.0913[0m
pretrain/epoch: [36m45.3307[0m pretrain/loss: [36m1.0779[0m
pretrain/epoch: [36m45.6720[0m pretrain/loss: [36m1.0799[0m
pretrain/epoch: [36m45.9920[0m pretrain/loss: [36m1.0672[0m
pretrain/epoch: [36m46.3307[0m pretrai

KeyboardInterrupt: 