In [38]:
import re
import parser
import os
import shutil
import time
import math
import logging

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import torchvision

from mean_teacher import architectures, datasets, data, losses, ramps, cli
from mean_teacher.run_context import RunContext
from mean_teacher.data import NO_LABEL
from mean_teacher.utils import *

import json
from data.cifar import CIFAR10, CIFAR100
from data.mnist import MNIST


import resnet
from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter('runs/supervised_only+self+uniform_noise_40%')

LOG = logging.getLogger('main')

args = None
best_prec1 = 0
global_step = 0


ensemble_preds = torch.Tensor(np.zeros((50000,10))).cuda()
how_many_labels_filtered = torch.Tensor(np.zeros((10))).cuda()

In [39]:
def create_data_loaders(train_transformation,
                        eval_transformation,
                        datadir,
                        args):
    args.noise_rate = 0.4
    # load dataset
    if args.dataset == 'mnist':
        input_channel = 1
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = MNIST(root='./data/',
                              download=True,
                              train=True,
                              transform=train_transformation,
                              noise_type=args.noise_type,
                              noise_rate=args.noise_rate
                              )

        test_dataset = MNIST(root='./data/',
                             download=True,
                             train=False,
                             transform=eval_transformation,
                             noise_type=args.noise_type,
                             noise_rate=args.noise_rate
                             )

    if args.dataset == 'cifar10':
        input_channel = 3
        num_classes = 10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 200
        train_dataset = CIFAR10(root='./data/',
                                download=None,
                                train=True,
                                transform=train_transformation,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate
                                )

        test_dataset = CIFAR10(root='./data/',
                               download=None,
                               train=False,
                               transform=eval_transformation,
                               noise_type=args.noise_type,
                               noise_rate=args.noise_rate
                               )

    if args.dataset == 'cifar100':
        input_channel = 3
        num_classes = 100
        args.top_bn = False
        args.epoch_decay_start = 100
        args.n_epoch = 200
        train_dataset = CIFAR100(root='./data/',
                                 download=True,
                                 train=True,
                                 transform=train_transformation,
                                 noise_type=args.noise_type,
                                 noise_rate=args.noise_rate
                                 )

        test_dataset = CIFAR100(root='./data/',
                                download=True,
                                train=False,
                                transform=eval_transformation,
                                noise_type=args.noise_type,
                                noise_rate=args.noise_rate
                                )

    print('loading dataset...')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               shuffle=True,
                                               pin_memory=True)

    eval_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              num_workers=2*args.num_workers,
                                              drop_last=False,
                                              shuffle=False,
                                              pin_memory=True)

    return train_loader, eval_loader

In [40]:
def parse_dict_args(**kwargs):
    global args

    def to_cmdline_kwarg(key, value):
        if len(key) == 1:
            key = "-{}".format(key)
        else:
            key = "--{}".format(re.sub(r"_", "-", key))
        value = str(value)
        return key, value

    kwargs_pairs = (to_cmdline_kwarg(key, value)
                    for key, value in kwargs.items())
    cmdline_args = list(sum(kwargs_pairs, ()))
    args = parser.parse_args(cmdline_args)
    
def adjust_learning_rate(optimizer, epoch, step_in_epoch, total_steps_in_epoch):
    lr = args.lr
    epoch = epoch + step_in_epoch / total_steps_in_epoch

    # LR warm-up to handle large minibatch sizes from https://arxiv.org/abs/1706.02677
    lr = ramps.linear_rampup(epoch, args.lr_rampup) * (args.lr - args.initial_lr) + args.initial_lr

    # Cosine LR rampdown from https://arxiv.org/abs/1608.03983 (but one cycle only)
    if args.lr_rampdown_epochs:
        assert args.lr_rampdown_epochs >= args.epochs
        lr *= ramps.cosine_rampdown(epoch, args.lr_rampdown_epochs)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)

def accuracy(output, target, topk=(1,)): # topk? the num of labels which have the highest confidence from logit val(list)
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk) # how many labels will be extracted from each logit?
    labeled_minibatch_size = max(target.ne(NO_LABEL).sum(), 1e-8) # num of labeled data
    labeled_minibatch_size = labeled_minibatch_size.float() # dtype from int64 to float32 => correction

    _, pred = output.topk(maxk, 1, True, True) # logits(128,10) => extract the labels which have the top5(=maxk) highest confidence
    pred = pred.t() # shape from (128,5) to (5,128)
    correct = pred.eq(target.view(1, -1).expand_as(pred)) # target.view(1,-1).expand_as(pred) => make target to have the shape of (5,128) like that of pred

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / labeled_minibatch_size))
    return res

In [41]:
def counter_labels(clean_idxs):
    global how_many_labels_filtered

    clean_idxs = clean_idxs.tolist()
    for i in range(10):
        how_many_labels_filtered[i] += clean_idxs.count(i)

    return how_many_labels_filtered

def filtering(indexes, labels):
    global ensemble_preds

    clean_idxs = np.argwhere(np.argmax(ensemble_preds[indexes].cpu().detach(),axis=1) == labels.cpu())
    clean_idxs = clean_idxs[0] # idxs of clean data in minibatch 128
    writer.add_scalar('len_clean_idxs', len(clean_idxs), global_step=global_step)
    return clean_idxs

def update_ensemble_preds(logit1, indexes):
    global ensemble_preds
    global ema_ensemble_preds

    alpha = 0.99
    ensemble_preds[indexes] = alpha * ensemble_preds[indexes] + (1-alpha)*logit1

def loss_by_filtered_samples(logit1, labels, indexes):
    global how_many_labels_filtered

    update_ensemble_preds(logit1, indexes)
    clean_idx = filtering(indexes, labels)

    logit1_update = logit1[clean_idx]
    labels_update = labels[clean_idx]
    counter_labels(labels_update)

    size = len(clean_idx) # 0~128
    # class_loss
    class_criterion = nn.CrossEntropyLoss(size_average=False).cuda()
    class_loss = class_criterion(logit1_update, labels_update.cuda()) / size

#     writer.add_histogram('how_many_labels_filtered', how_many_labels_filtered, global_step=global_step)
    return class_loss, logit1_update, labels_update

In [42]:
import re
import argparse
import logging

from mean_teacher.cli import architectures, datasets

LOG = logging.getLogger('main')

__all__ = ['parse_cmd_args', 'parse_dict_args']

def create_parser():
    parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training')

    parser.add_argument('--train-subdir', type=str, default='train+val',################################################################################
                        help='the subdirectory inside the data directory that contains the training data')
    parser.add_argument('--eval-subdir', type=str, default='test',################################################################################
                        help='the subdirectory inside the data directory that contains the evaluation data')
    parser.add_argument('--labels', default='data-local/labels/cifar10/1000_balanced_labels/00.txt', type=str, metavar='FILE', ################################################################################
                        help='list of image labels (default: based on directory structure)')
    parser.add_argument('--arch', '-a', metavar='ARCH', default='cifar_shakeshake26', ################################################################################
                        choices=architectures.__all__,
                        help='model architecture: ' +
                            ' | '.join(architectures.__all__))
    parser.add_argument('-j', '--num_workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=600, type=int, metavar='N', ################################################################################
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=128, type=int, ################################################################################
                        metavar='N', help='mini-batch size (default: 256)')
    parser.add_argument('--lr', '--learning-rate', default=0.05, type=float,
                        metavar='LR', help='max learning rate')
    parser.add_argument('--initial-lr', default=0.0, type=float,
                        metavar='LR', help='initial learning rate when using linear rampup')
    parser.add_argument('--lr-rampup', default=0, type=int, metavar='EPOCHS',
                        help='length of learning rate rampup in the beginning')
    parser.add_argument('--lr-rampdown-epochs', default=700, type=int, metavar='EPOCHS',
                        help='length of learning rate cosine rampdown (>= length of training)')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--nesterov', default=True, type=str2bool,
                        help='use nesterov momentum', metavar='BOOL')
    parser.add_argument('--weight-decay', '--wd', default=2e-4, type=float, ####################### 1e-4 => 2e-4 according to the paper
                        metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('--ema-decay', default=0.97, type=float, metavar='ALPHA', ########################## 0.999 => 0.97 according to the paper
                        help='ema variable decay rate (default: 0.999)')
    parser.add_argument('--consistency', default=100.0, type=float, metavar='WEIGHT',
                        help='use consistency loss with given weight (default: None)')
    parser.add_argument('--consistency-type', default="mse", type=str, metavar='TYPE',
                        choices=['mse', 'kl'],
                        help='consistency loss type to use')
    parser.add_argument('--consistency-rampup', default=5, type=int, metavar='EPOCHS',
                        help='length of the consistency loss ramp-up')
    parser.add_argument('--logit-distance-cost', default=-1, type=float, metavar='WEIGHT',
                        help='let the student model have two outputs and use an MSE loss between the logits with the given weight (default: only have one output)')
    parser.add_argument('--checkpoint-epochs', default=20, type=int,
                        metavar='EPOCHS', help='checkpoint frequency in epochs, 0 to turn checkpointing off (default: 1)')
    parser.add_argument('--evaluation-epochs', default=1, type=int,
                        metavar='EPOCHS', help='evaluation frequency in epochs, 0 to turn evaluation off (default: 1)')
    parser.add_argument('--print-freq', '-p', default=10, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', type=str2bool,
                        help='evaluate model on evaluation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')
    parser.add_argument('--dataset', type=str, help='mnist, cifar10, or cifar100', default='cifar10')
    parser.add_argument('--noise_rate', type=float, help='corruption rate, should be less than 1', default=0.0)
    parser.add_argument('--forget_rate', type=float, help='forget rate', default=None)
    parser.add_argument('--noise_type', type=str, help='[pairflip, symmetric]', default='symmetric')
    parser.add_argument('--num_gradual', type=int, default=10,
                        help='how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.')
    parser.add_argument('--exponent', type=float, default=1,
                        help='exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.')
    parser.add_argument('--top_bn', action='store_true')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--num_iter_per_epoch', type=int, default=400)
    parser.add_argument('--epoch_decay_start', type=int, default=80)

    return parser

def parse_commandline_args():
    return create_parser().parse_args()


def parse_dict_args(**kwargs):
    def to_cmdline_kwarg(key, value):
        if len(key) == 1:
            key = "-{}".format(key)
        else:
            key = "--{}".format(re.sub(r"_", "-", key))
        value = str(value)
        return key, value

    kwargs_pairs = (to_cmdline_kwarg(key, value)
                    for key, value in kwargs.items())
    cmdline_args = list(sum(kwargs_pairs, ()))

    LOG.info("Using these command line args: %s", " ".join(cmdline_args))

    return create_parser().parse_args(cmdline_args)

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str2epochs(v):
    try:
        if len(v) == 0:
            epochs = []
        else:
            epochs = [int(string) for string in v.split(",")]
    except:
        raise argparse.ArgumentTypeError(
            'Expected comma-separated list of integers, got "{}"'.format(v))
    if not all(0 < epoch1 < epoch2 for epoch1, epoch2 in zip(epochs[:-1], epochs[1:])):
        raise argparse.ArgumentTypeError(
            'Expected the epochs to be listed in increasing order')
    return epochs


In [43]:
if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    args, argv = create_parser().parse_known_args(None, None)
#     args = cli.parse_commandline_args()
#     main(RunContext(__file__, 0))

In [9]:
indexes

NameError: name 'indexes' is not defined

In [28]:
import numpy as np
a=[0,0,0,1,1,0,1,0,1,1,1,0]
b=[0,1,2,3,4,5,6,7,8,9,10,11]
c=list()
for i in range(len(b)):
    if a[i] == 1:
        c.append(b[i])
c

[3, 4, 6, 8, 9, 10]

In [44]:
def train(train_loader, model, optimizer, epoch, log):
    global global_step

    meters = AverageMeterSet()
    model.train()

    if epoch == 0:
        for i, ((input, _), labels, indexes) in enumerate(train_loader):

            class_criterion = nn.CrossEntropyLoss(size_average=False).cuda()

            adjust_learning_rate(optimizer, epoch, i, len(train_loader))
            meters.update('lr', optimizer.param_groups[0]['lr'])

            input_var = torch.autograd.Variable(input)
            target_var = torch.autograd.Variable(labels.cuda(non_blocking=True))

            minibatch_size = len(target_var)  # 128
            model_out = model(input_var)  # tuple(len:2), [128,10], [128,10]

            class_loss = class_criterion(model_out, target_var) / minibatch_size
            meters.update('class_loss', class_loss.item())

            loss = class_loss
            assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format(loss.item())
            meters.update('loss', loss.item())

            prec1, prec5 = accuracy(model_out.data, target_var.data, topk=(1, 5))
            meters.update('top1', prec1[0], minibatch_size)
            meters.update('error1', 100. - prec1[0], minibatch_size)
            meters.update('top5', prec5[0], minibatch_size)
            meters.update('error5', 100. - prec5[0], minibatch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1

            if i % args.print_freq == 0:
                LOG.info(
                        'Epoch: [{0}][{1}/{2}]\t'
                        'Class {meters[class_loss]:.4f}\t'
                        'Prec@1 {meters[top1]:.3f}\t'
                        'Prec@5 {meters[top5]:.3f}'.format(
                            epoch, i, len(train_loader), meters=meters))
                log.record(epoch + i / len(train_loader), {
                        'step': global_step,
                        **meters.values(),
                        **meters.averages(),
                        **meters.sums()
                })
    else:
        for i, ((input, _), labels, indexes) in enumerate(train_loader):

            adjust_learning_rate(optimizer, epoch, i, len(train_loader))
            meters.update('lr', optimizer.param_groups[0]['lr'])
            
#             filtered = list()
#             for idx in indexes:
#                 if filtered_labels[idx] == 1:
#                     filtered.append(indexes[idx])
#             indexes = filtered
            input_var = torch.autograd.Variable(input)
            target_var = torch.autograd.Variable(labels.cuda(non_blocking=True))

            a = list()
            for idx in indexes:
                a.append(filtered_labels[idx])
            num = 0
            for i in range(len(a)):
                i -= num
                if a[i] == 0:
                    del input_var[i]
                    del target_var[i]
                    num += 1

            model_out = model(input_var)  # tuple(len:2), [128,10], [128,10]

            class_criterion = nn.CrossEntropyLoss(size_average=False).cuda()
            class_loss = class_criterion(model_out, target_var.cuda()) / size

            minibatch_size = len(labels_update)

            # class_loss = class_criterion(class_logit, target_var) / minibatch_size
            meters.update('class_loss', class_loss.item())
            meters.update('loss', loss.item())

            prec1, prec5 = accuracy(model_out_update.data, labels_update.data, topk=(1, 5))
            meters.update('top1', prec1[0], minibatch_size)
            meters.update('error1', 100. - prec1[0], minibatch_size)
            meters.update('top5', prec5[0], minibatch_size)
            meters.update('error5', 100. - prec5[0], minibatch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1

            if i % args.print_freq == 0:
                LOG.info(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Class {meters[class_loss]:.4f}\t'
                    'Prec@1 {meters[top1]:.3f}\t'
                    'Prec@5 {meters[top5]:.3f}'.format(
                        epoch, i, len(train_loader), meters=meters))
                log.record(epoch + i / len(train_loader), {
                    'step': global_step,
                    **meters.values(),
                    **meters.averages(),
                    **meters.sums()
                })
    filtered_labels = torch.Tensor(np.zeros((50000,10))).cuda()
    for i, ((input, _), labels, indexes) in enumerate(train_loader):
            
            input_var = torch.autograd.Variable(input)
            target_var = torch.autograd.Variable(labels.cuda(non_blocking=True))
            model_out = model(input_var)
            labels_update = loss_by_filtered_samples(model_out, target_var, indexes)         
        
#     writer.add_scalar('loss', meters['loss'].avg, global_step=global_step)

def validate(eval_loader, model, log, global_step, epoch):
    with torch.no_grad():
        class_criterion = nn.CrossEntropyLoss(size_average=False).cuda()
        meters = AverageMeterSet()

        model.eval()

        for i, (input, target, _) in enumerate(eval_loader): # len(eval_loader) = 79 ( 128*(79-1) + 16*1 = 10000(test_imgs) ). num of each labels is 1000. all labeled

            input_var = torch.autograd.Variable(input, volatile=True) # [128,3,32,32]
            target_var = torch.autograd.Variable(target.cuda(non_blocking=True), volatile=True) # [128]

            minibatch_size = len(target_var)

            output1 = model(input_var)
            class_loss = class_criterion(output1, target_var) / minibatch_size # why does cross_entropy between output and target_var not between softmax and target_var? => In class_criterion, it will be done 'F.log_softmax' in the end

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output1.data, target_var.data, topk=(1, 5))
            meters.update('class_loss', class_loss.item(), minibatch_size)
            meters.update('top1', prec1[0], minibatch_size)
            meters.update('error1', 100.0 - prec1[0], minibatch_size)
            meters.update('top5', prec5[0], minibatch_size)
            meters.update('error5', 100.0 - prec5[0], minibatch_size)

            if i % args.print_freq == 0:
                LOG.info(
                    'Test: [{0}/{1}]\t'
                    'Class {meters[class_loss]:.4f}\t'
                    'Prec@1 {meters[top1]:.3f}\t'
                    'Prec@5 {meters[top5]:.3f}'.format(
                        i, len(eval_loader), meters=meters))

        LOG.info(' * Prec@1 {top1.avg:.3f}\tPrec@5 {top5.avg:.3f}'
              .format(top1=meters['top1'], top5=meters['top5']))
        log.record(epoch, {
            'step': global_step,
            **meters.values(),
            **meters.averages(),
            **meters.sums()
        })
#         writer.add_scalar('test_acc', meters['top1'].avg, global_step=global_step)
    return meters['top1'].avg



In [45]:
from datetime import datetime
from collections import defaultdict
import threading
import time
import logging
import os
import pyarrow

from pandas import DataFrame
from collections import defaultdict


class TrainLog:
    """Saves training logs in Pandas msgpacks"""

    INCREMENTAL_UPDATE_TIME = 300

    def __init__(self, directory, name):
        self.log_file_path = "{}/{}.msgpack".format(directory, name)
        self._log = defaultdict(dict)
        self._log_lock = threading.RLock()
        self._last_update_time = time.time() - self.INCREMENTAL_UPDATE_TIME

    def record_single(self, step, column, value):
        self._record(step, {column: value})

    def record(self, step, col_val_dict):
        self._record(step, col_val_dict)

    def save(self):
        df = self._as_dataframe()
        # df.to_msgpack(self.log_file_path, compress='zlib')
        context = pyarrow.default_serialization_context()
        df = context.serialize(df).to_buffer().to_pybytes()

    def _record(self, step, col_val_dict):
        with self._log_lock:
            self._log[step].update(col_val_dict)
            if time.time() - self._last_update_time >= self.INCREMENTAL_UPDATE_TIME:
                self._last_update_time = time.time()
                self.save()

    def _as_dataframe(self):
        with self._log_lock:
            return DataFrame.from_dict(self._log, orient='index')

class RunContext:
    """Creates directories and files for the run"""

    def __init__(self, runner_file, run_idx):
        logging.basicConfig(level=logging.INFO, format='%(message)s')
        runner_name = os.path.basename(runner_file).split(".")[0] # runner_name = 'main_supervised+self'
        self.result_dir = "{root}/{runner_name}/{date:%Y-%m-%d_%H:%M:%S}/{run_idx}".format(
            root='results',
            runner_name=runner_name,
            date=datetime.now(),
            run_idx=run_idx
        ) # 'results/main_supervised+self/date:2020-07-16_14:00:57/0'
        self.transient_dir = self.result_dir + "/transient"
        os.makedirs(self.result_dir)
        os.makedirs(self.transient_dir)

    def create_train_log(self, name):
        return TrainLog(self.result_dir, name)

## main

In [46]:
global global_step
global best_prec1
global ensemble_preds

context = RunContext('main_supervised+self', 0)

training_log = context.create_train_log("training")
validation_log = context.create_train_log("validation")

dataset_config = datasets.__dict__[args.dataset]()
num_classes = dataset_config.pop('num_classes')
train_loader, eval_loader = create_data_loaders(**dataset_config, args=args)

9 10
50000
Actual noise 0.40
[[0.6        0.04444444 0.04444444 0.04444444 0.04444444 0.04444444
  0.04444444 0.04444444 0.04444444 0.04444444]
 [0.04444444 0.6        0.04444444 0.04444444 0.04444444 0.04444444
  0.04444444 0.04444444 0.04444444 0.04444444]
 [0.04444444 0.04444444 0.6        0.04444444 0.04444444 0.04444444
  0.04444444 0.04444444 0.04444444 0.04444444]
 [0.04444444 0.04444444 0.04444444 0.6        0.04444444 0.04444444
  0.04444444 0.04444444 0.04444444 0.04444444]
 [0.04444444 0.04444444 0.04444444 0.04444444 0.6        0.04444444
  0.04444444 0.04444444 0.04444444 0.04444444]
 [0.04444444 0.04444444 0.04444444 0.04444444 0.04444444 0.6
  0.04444444 0.04444444 0.04444444 0.04444444]
 [0.04444444 0.04444444 0.04444444 0.04444444 0.04444444 0.04444444
  0.6        0.04444444 0.04444444 0.04444444]
 [0.04444444 0.04444444 0.04444444 0.04444444 0.04444444 0.04444444
  0.04444444 0.6        0.04444444 0.04444444]
 [0.04444444 0.04444444 0.04444444 0.04444444 0.04444444 0

In [47]:
model = resnet.ResNet34()
model = nn.DataParallel(model).cuda()

LOG.info(parameters_string(model))

optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
cudnn.benchmark = True

INFO:main:
List of model parameters:
module.conv1.weight                            64 * 3 * 3 * 3 =       1,728
module.bn1.weight                                          64 =          64
module.bn1.bias                                            64 =          64
module.layer1.0.conv1.weight                  64 * 64 * 3 * 3 =      36,864
module.layer1.0.bn1.weight                                 64 =          64
module.layer1.0.bn1.bias                                   64 =          64
module.layer1.0.conv2.weight                  64 * 64 * 3 * 3 =      36,864
module.layer1.0.bn2.weight                                 64 =          64
module.layer1.0.bn2.bias                                   64 =          64
module.layer1.1.conv1.weight                  64 * 64 * 3 * 3 =      36,864
module.layer1.1.bn1.weight                                 64 =          64
module.layer1.1.bn1.bias                                   64 =          64
module.layer1.1.conv2.weight                  64 * 

In [48]:
train(train_loader, model, optimizer, epoch, training_log)

INFO:main:Epoch: [0][0/390]	Class 2.4569 (2.4569)	Prec@1 9.375 (9.375)	Prec@5 52.344 (52.344)
INFO:main:Epoch: [0][10/390]	Class 3.5175 (4.6813)	Prec@1 12.500 (9.659)	Prec@5 49.219 (49.148)
INFO:main:Epoch: [0][20/390]	Class 2.3872 (3.6623)	Prec@1 8.594 (10.417)	Prec@5 55.469 (49.888)
INFO:main:Epoch: [0][30/390]	Class 2.3182 (3.2466)	Prec@1 10.938 (10.207)	Prec@5 55.469 (50.630)
INFO:main:Epoch: [0][40/390]	Class 2.3306 (3.0268)	Prec@1 10.938 (10.575)	Prec@5 50.000 (50.934)
INFO:main:Epoch: [0][50/390]	Class 2.3411 (2.8878)	Prec@1 13.281 (10.738)	Prec@5 53.906 (51.670)
INFO:main:Epoch: [0][60/390]	Class 2.3157 (2.7921)	Prec@1 12.500 (11.053)	Prec@5 56.250 (52.049)
INFO:main:Epoch: [0][70/390]	Class 2.3817 (2.7249)	Prec@1 11.719 (11.466)	Prec@5 49.219 (52.740)
INFO:main:Epoch: [0][80/390]	Class 2.2927 (2.6696)	Prec@1 15.625 (11.834)	Prec@5 55.469 (53.877)
INFO:main:Epoch: [0][90/390]	Class 2.2649 (2.6266)	Prec@1 21.875 (12.354)	Prec@5 62.500 (54.636)
INFO:main:Epoch: [0][100/390]	Class

In [49]:
epoch = 0
if args.evaluation_epochs and (epoch + 1) % args.evaluation_epochs == 0:
    LOG.info("Evaluating the model:")
    prec1 = validate(eval_loader, model, validation_log, global_step, epoch+1)

INFO:main:Evaluating the model:
INFO:main:Test: [0/79]	Class 2.0580 (2.0580)	Prec@1 27.344 (27.344)	Prec@5 78.906 (78.906)
INFO:main:Test: [10/79]	Class 2.0436 (2.0536)	Prec@1 25.781 (26.776)	Prec@5 80.469 (79.332)
INFO:main:Test: [20/79]	Class 2.0688 (2.0510)	Prec@1 23.438 (26.376)	Prec@5 78.125 (79.167)
INFO:main:Test: [30/79]	Class 2.0170 (2.0453)	Prec@1 27.344 (27.092)	Prec@5 78.125 (79.183)
INFO:main:Test: [40/79]	Class 2.0287 (2.0409)	Prec@1 28.125 (27.134)	Prec@5 75.781 (79.402)
INFO:main:Test: [50/79]	Class 2.0927 (2.0432)	Prec@1 19.531 (27.206)	Prec@5 72.656 (78.952)
INFO:main:Test: [60/79]	Class 1.9894 (2.0406)	Prec@1 31.250 (27.228)	Prec@5 85.938 (79.278)
INFO:main:Test: [70/79]	Class 2.0722 (2.0379)	Prec@1 24.219 (27.212)	Prec@5 75.000 (79.335)
INFO:main: * Prec@1 27.200	Prec@5 79.540


In [98]:
if args.evaluate:
    LOG.info("Evaluating the primary model:")
    validate(eval_loader, model, validation_log, global_step, args.start_epoch)
    return

for epoch in range(args.start_epoch, args.epochs):
    train(train_loader, model, optimizer,epoch, training_log)

    if args.evaluation_epochs and (epoch + 1) % args.evaluation_epochs == 0:
        LOG.info("Evaluating the model:")
        prec1 = validate(eval_loader, model, validation_log, global_step, epoch+1)

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       

In [127]:
ensemble_preds = torch.Tensor(np.zeros((50000,10))).cuda()
how_many_labels_filtered = torch.Tensor(np.zeros((10))).cuda()

In [124]:
def counter_labels(clean_idxs):
    global how_many_labels_filtered

    clean_idxs = clean_idxs.tolist()
    for i in range(10):
        how_many_labels_filtered[i] += clean_idxs.count(i)

#     return how_many_labels_filtered

def filtering(indexes, labels):
    global ensemble_preds
    global filtered_labels

    clean_idxs = np.argwhere(np.argmax(ensemble_preds[indexes].cpu().detach(),axis=1) == labels.cpu())
    clean_idxs = clean_idxs[0] # idxs of clean data in minibatch 128
    return clean_idxs

def update_ensemble_preds(logit1, indexes):
    global ensemble_preds
    global ema_ensemble_preds

    alpha = 0.99
    ensemble_preds[indexes] = alpha * ensemble_preds[indexes] + (1-alpha)*logit1

def loss_by_filtered_samples(logit1, labels, indexes):
    global how_many_labels_filtered

    update_ensemble_preds(logit1, indexes)
    clean_idx = filtering(indexes, labels)

    logit1_update = logit1[clean_idx]
    labels_update = labels[clean_idx]
    counter_labels(labels_update)

#     size = len(clean_idx) # 0~128
    # class_loss
#     class_criterion = nn.CrossEntropyLoss(size_average=False).cuda()
#     class_loss = class_criterion(logit1_update, labels_update.cuda()) / size

#     writer.add_histogram('how_many_labels_filtered', how_many_labels_filtered, global_step=global_step)
    return class_loss, logit1_update, labels_update

In [129]:
epoch = 1

train(train_loader, model, optimizer, epoch, training_log)

if args.evaluation_epochs and (epoch + 1) % args.evaluation_epochs == 0:
    LOG.info("Evaluating the model:")
    prec1 = validate(eval_loader, model, validation_log, global_step, epoch+1)

INFO:main:Epoch: [1][0/390]	Class 1.4875 (1.4875)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][10/390]	Class 0.0943 (0.4985)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][20/390]	Class 0.1103 (0.2829)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][30/390]	Class 0.0236 (0.1997)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][40/390]	Class 0.0052 (0.1561)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][50/390]	Class 0.0127 (0.1283)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][60/390]	Class 0.0077 (0.1113)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][70/390]	Class 0.0100 (0.0966)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][80/390]	Class 0.0228 (0.0867)	Prec@1 100.000 (100.000)	Prec@5 100.000 (100.000)
INFO:main:Epoch: [1][90/390]	Class 0.0003 (0.0779)	Prec@1 100.000 (100.000)	Prec@5 100.000 (

In [126]:
ensemble_preds[indexes]

tensor([[ 0.0043,  0.0101, -0.0080,  ..., -0.0066,  0.0083,  0.0114],
        [ 0.0014,  0.0094, -0.0057,  ..., -0.0054,  0.0050,  0.0081],
        [-0.0060, -0.0022,  0.0014,  ..., -0.0020, -0.0051, -0.0010],
        ...,
        [-0.0077, -0.0072,  0.0037,  ...,  0.0017, -0.0087, -0.0048],
        [-0.0060, -0.0004, -0.0002,  ..., -0.0019, -0.0045,  0.0003],
        [-0.0072, -0.0057,  0.0023,  ..., -0.0013, -0.0072, -0.0046]],
       device='cuda:0', grad_fn=<IndexBackward>)

In [130]:
indexes

tensor([32743,  8250, 31858, 19383, 41196,  7054, 48020,  8557, 34022, 29341,
        38399, 22226, 44515, 34395,  3632, 41024, 39514,  1840, 32524, 24091,
        47342, 39475, 39291, 18804, 26030, 44980, 20492, 39384, 43416,  2575,
         1347, 10859, 11452,  3277, 34050, 47797, 28936, 32067, 32960, 12385,
        16871, 33949, 33086, 47683, 26134, 28527,  9977, 25147, 32906, 16865,
        42593, 11241, 27625,  7066, 48149,  1760, 21059, 33762, 47180, 31280,
        30291, 13692, 26787, 41059, 30957, 46348, 48303, 12305, 35750,  4301,
        44994, 26290, 15899,  9853, 24717, 25914, 30066, 47994,  5769, 49577,
         1761, 26136,  1102, 37889, 30093,  3354, 16862,  4116,  1435, 44375,
        47254, 16819,  7368, 38798, 19646, 39492, 41857, 47842, 49149, 29390,
         3890, 19391, 24556,  2670, 45707,   646, 44451, 30504, 22877, 38470,
        46557, 39190,  9380,  7325,   656, 25804, 36983, 35469, 14846, 30568,
        22274, 17851, 42821, 33971, 37306,  9060, 31706, 10835])

In [126]:
ensemble_preds[indexes]

tensor([[ 0.0043,  0.0101, -0.0080,  ..., -0.0066,  0.0083,  0.0114],
        [ 0.0014,  0.0094, -0.0057,  ..., -0.0054,  0.0050,  0.0081],
        [-0.0060, -0.0022,  0.0014,  ..., -0.0020, -0.0051, -0.0010],
        ...,
        [-0.0077, -0.0072,  0.0037,  ...,  0.0017, -0.0087, -0.0048],
        [-0.0060, -0.0004, -0.0002,  ..., -0.0019, -0.0045,  0.0003],
        [-0.0072, -0.0057,  0.0023,  ..., -0.0013, -0.0072, -0.0046]],
       device='cuda:0', grad_fn=<IndexBackward>)

In [None]:
def counter_labels(clean_idxs):
    global how_many_labels_filtered

    clean_idxs = clean_idxs.tolist()
    for i in range(10):
        how_many_labels_filtered[i] += clean_idxs.count(i)

def filtering(indexes, labels):
    global ensemble_preds
    global filtered_labels

    clean_idxs = np.argwhere(np.argmax(ensemble_preds[indexes].cpu().detach(),axis=1) == labels.cpu())
    clean_idxs = clean_idxs[0] # idxs of clean data in minibatch 128
    for idx in clean_idxs:
        filtered_labels[idx] = 1
    return clean_idxs

def filtering_(indexes, labels):
    global filtered_labels
        filtered = list()
        for i,component in enumerate(indexes):
            if filtered_labels[component] == 1:
                filtered.append(indexes[component])
        indexes = filtered

def update_ensemble_preds(logit1, indexes):
    global ensemble_preds
    global ema_ensemble_preds

    alpha = 0.99
    ensemble_preds[indexes] = alpha * ensemble_preds[indexes] + (1-alpha)*logit1

def loss_by_filtered_samples(logit1, labels, indexes):
    global how_many_labels_filtered

    update_ensemble_preds(logit1, indexes)
    clean_idx = filtering(indexes, labels)

#     logit1_update = logit1[clean_idx]
    labels_update = labels[clean_idx]
    counter_labels(labels_update)

#     size = len(clean_idx) # 0~128
    # class_loss
#     class_criterion = nn.CrossEntropyLoss(size_average=False).cuda()
#     class_loss = class_criterion(logit1_update, labels_update.cuda()) / size

    return labels_update

In [45]:
a = [0,1,1,0,1,0,0]
num = 0
for i in range(len(a)):
    i -= num
    print(i,a[i])
    if a[i] == 0:
        del a[i]
        num += 1
a

0 0
0 1
1 1
2 0
2 1
3 0
3 0


[1, 1, 1]

In [41]:
0 and 4

0