# Robust Trainers for Noisy Labels

## Setup

In [None]:
# Supress warnings caused by future changes in packages
import torchvision
import warnings
warnings.filterwarnings('ignore')
torchvision.disable_beta_transforms_warning()

# Import necessary objects and functions
from algorithm.datasets import fashion_mnist_05, fashion_mnist_06, cifar, ImageDataset, get_test_loader
from algorithm.classifiers import ResNet34
from algorithm.eval import eval_metrics
from algorithm.experiments import set_seed, Experiment
from algorithm.trainers import CoTeaching, ForwardLossCorrection

## Individual Experiments

### Known Flip Rate

In [None]:
# Set seed for reproducibility
seed = 32
set_seed(seed)

# Choose a dataset
dataset = fashion_mnist_05()
# Load the training, validation, and test sets
(Xtr, Str), (Xval, Sval), (Xts, Yts) = dataset.load_data(random_state=seed)
# Get the mean and standard deviation of the training set
mean, std = dataset.mean, dataset.std
# Get the transition matrix of noisy labels
T = dataset.T
# Get the input dimension / channel of the dataset
input_dim = Xtr.shape[-1]
# Encapsulate the dataset into a PyTorch Dataset object
dataset_tr = ImageDataset(Xtr, Str, mean, std, is_augment=True)
dataset_val = ImageDataset(Xval, Sval, mean, std)
dataset_ts = ImageDataset(Xts, Yts, mean, std)
# Get the test data loader for evaluation
tsLoader = get_test_loader(dataset_ts)

In [None]:
# Instantiate the model according to the input dimension
model = ResNet34(input_dim)
# Instantiate the trainer (Forward Loss Correction algorithm)
trainer = ForwardLossCorrection()
# Perform training
trainer.train(model, dataset_tr, dataset_val, epochs=200, T=T)
# Evaluate the model
acc, precision, recall, f1 = eval_metrics(model, tsLoader)
print(acc, precision, recall, f1)

In [None]:
# Instantiate the models according to the input dimension
model1 = ResNet34(input_dim)
model2 = ResNet34(input_dim)
# Instantiate the trainer (Co-Teaching algorithm)
trainer = CoTeaching()
# Perform training
trainer.train(model1, model2, dataset_tr, dataset_val, epochs=200, T=T)
# Evaluate the models
acc_1, precision_1, recall_1, f1_1 = eval_metrics(model1, tsLoader)
acc_2, precision_2, recall_2, f1_2 = eval_metrics(model2, tsLoader)
print(acc_1, precision_1, recall_1, f1_1)
print(acc_2, precision_2, recall_2, f1_2)

### Unknown Flip Rate

In [None]:
# Set seed for reproducibility
seed = 32
set_seed(seed)

# Choose a dataset
dataset = fashion_mnist_06()
# Load the training, validation, and test sets
(Xtr, Str), (Xval, Sval), (Xts, Yts) = dataset.load_data(random_state=seed)
# Get the mean and standard deviation of the training set
mean, std = dataset.mean, dataset.std
# Get the transition matrix of noisy labels
T = dataset.T
# Get the input dimension / channel of the dataset
input_dim = Xtr.shape[-1]
# Encapsulate the dataset into a PyTorch Dataset object
dataset_tr = ImageDataset(Xtr, Str, mean, std, is_augment=True)
dataset_val = ImageDataset(Xval, Sval, mean, std)
dataset_ts = ImageDataset(Xts, Yts, mean, std)
tsLoader = get_test_loader(dataset_ts)

# Instantiate the model according to the input dimension
model = ResNet34(input_dim)
# Instantiate the trainer (Forward Loss Correction algorithm)
trainer = ForwardLossCorrection()
# Perform training
trainer.train(model, dataset_tr, dataset_val, epochs=500, T=None)
print(trainer.T)
# Evaluate the model
acc, precision, recall, f1 = eval_metrics(model, tsLoader)
print(acc, precision, recall, f1)

## Intensive Experiments

In [None]:
# Set a list of seeds for reproducibility
seeds = [2 ** (i+1) for i in range(10)]

In [None]:
# Choose a dataset
dataset1 = fashion_mnist_05()
df1 = Experiment(seeds, dataset1, classifier='ResNet34', 
                 robust_method='co_teaching', # Options: 'loss_correction', 'jocor', 'o2u_net'
                 epochs=300, save_best_model=False)
df1

In [None]:
dataset3 = cifar()
df2, transition_matrix = Experiment(seeds, dataset3, classifier='ResNet34', 
                                    robust_method='loss_correction', # Only 'loss_correction' is available for transition matrix estimation
                                    epochs=300, save_best_model=False)
df2

In [None]:
transition_matrix