<a href="https://colab.research.google.com/github/CalculatedContent/ww-phys_theory/blob/master/random_labels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:

"""
cifar-10 dataset, with support for random labels
"""
import numpy as np

import torch
import torchvision.datasets as datasets


class CIFAR10RandomLabels(datasets.CIFAR10):
  """CIFAR10 dataset, with support for randomly corrupt labels.

  Params
  ------
  corrupt_prob: float
    Default 0.0. The probability of a label being replaced with
    random label.
  num_classes: int
    Default 10. The number of classes in the dataset.
  """
  def __init__(self, corrupt_prob=0.0, num_classes=10, **kwargs):
    super(CIFAR10RandomLabels, self).__init__(**kwargs)
    self.n_classes = num_classes
    if corrupt_prob > 0:
      self.corrupt_labels(corrupt_prob)

  def corrupt_labels(self, corrupt_prob):
    labels = np.array(self.train_labels if self.train else self.test_labels)
    np.random.seed(12345)
    mask = np.random.rand(len(labels)) <= corrupt_prob
    rnd_labels = np.random.choice(self.n_classes, mask.sum())
    labels[mask] = rnd_labels
    # we need to explicitly cast the labels from npy.int64 to
    # builtin int type, otherwise pytorch will fail...
    labels = [int(x) for x in labels]

    if self.train:
      self.train_labels = labels
    else:
      self.test_labels = labels

In [0]:
# Wide Resnet model adapted from https://github.com/xternalz/WideResNet-pytorch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
  def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
    super(BasicBlock, self).__init__()
    self.bn1 = nn.BatchNorm2d(in_planes)
    self.relu1 = nn.ReLU(inplace=True)
    self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                            padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(out_planes)
    self.relu2 = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                            padding=1, bias=False)
    self.droprate = dropRate
    self.equalInOut = (in_planes == out_planes)
    self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                                                            padding=0, bias=False) or None

  def forward(self, x):
    if not self.equalInOut:
      x = self.relu1(self.bn1(x))
      out = self.conv1(x)
    else:
      out = self.conv1(self.relu1(self.bn1(x)))

    if self.droprate > 0:
      out = F.dropout(out, p=self.droprate, training=self.training)
    out = self.conv2(self.relu2(self.bn2(out)))
    if not self.equalInOut:
      return torch.add(self.convShortcut(x), out)
    else:
      return torch.add(x, out)



In [0]:
class NetworkBlock(nn.Module):
  def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
    super(NetworkBlock, self).__init__()
    self.layer = self._make_layer(
        block, in_planes, out_planes, nb_layers, stride, dropRate)

  def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
    layers = []
    for i in range(nb_layers):
        layers.append(block(i == 0 and in_planes or out_planes,
                            out_planes, i == 0 and stride or 1, dropRate))
    return nn.Sequential(*layers)

  def forward(self, x):
    return self.layer(x)


In [0]:
class WideResNet(nn.Module):
  def __init__(self, depth, num_classes, widen_factor=1, drop_rate=0.0, init_scale=1.0):
    super(WideResNet, self).__init__()

    nChannels = [16, 16 * widen_factor,
                  32 * widen_factor, 64 * widen_factor]
    assert((depth - 4) % 6 == 0)
    n = (depth - 4) // 6
    block = BasicBlock
    # 1st conv before any network block
    self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                            padding=1, bias=False)
    # 1st block
    self.block1 = NetworkBlock(
        n, nChannels[0], nChannels[1], block, 1, drop_rate)
    # 2nd block
    self.block2 = NetworkBlock(
        n, nChannels[1], nChannels[2], block, 2, drop_rate)
    # 3rd block
    self.block3 = NetworkBlock(
        n, nChannels[2], nChannels[3], block, 2, drop_rate)
    # global average pooling and classifier
    self.bn1 = nn.BatchNorm2d(nChannels[3])
    self.relu = nn.ReLU(inplace=True)
    self.fc = nn.Linear(nChannels[3], num_classes)
    self.nChannels = nChannels[3]

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, init_scale * math.sqrt(2. / n))
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m, nn.Linear):
        m.bias.data.zero_()

        size = m.weight.size()
        fan_out = size[0] # number of rows
        fan_in = size[1] # number of columns
        variance = math.sqrt(2.0/(fan_in + fan_out))
        m.weight.data.normal_(0.0, init_scale * variance)


  def forward(self, x):
    out = self.forward_repr(x)
    return self.fc(out)

  def forward_repr(self, x):
    out = self.conv1(x)
    out = self.block1(out)
    out = self.block2(out)
    out = self.block3(out)
    out = self.relu(self.bn1(out))
    out = F.avg_pool2d(out, 8)
    out = out.view(-1, self.nChannels)
    return out

In [0]:
from __future__ import print_function

import os
import logging
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.optim



def get_data_loaders(args, shuffle_train=True):
  if args.data == 'cifar10':
    normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.data_augmentation:
      transform_train = transforms.Compose([
          transforms.RandomCrop(32, padding=4),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor(),
          normalize,
          ])
    else:
      transform_train = transforms.Compose([
          transforms.ToTensor(),
          normalize,
          ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
        ])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(
        CIFAR10RandomLabels(root='./data', train=True, download=True,
                            transform=transform_train, num_classes=args.num_classes,
                            corrupt_prob=args.label_corrupt_prob),
        batch_size=args.batch_size, shuffle=shuffle_train, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        CIFAR10RandomLabels(root='./data', train=False,
                            transform=transform_test, num_classes=args.num_classes,
                            corrupt_prob=args.label_corrupt_prob),
        batch_size=args.batch_size, shuffle=False, **kwargs)

    return train_loader, val_loader
  else:
    raise Exception('Unsupported dataset: {0}'.format(args.data))


def get_model(args):
  # create model
  if args.arch == 'wide-resnet':
    model = WideResNet(args.wrn_depth, args.num_classes,
                                        args.wrn_widen_factor,
                                        drop_rate=args.wrn_droprate)
  elif args.arch == 'mlp':
    n_units = [int(x) for x in args.mlp_spec.split('x')] # hidden dims
    n_units.append(args.num_classes)  # output dim
    n_units.insert(0, 32*32*3)        # input dim
    model = MLP(n_units)

  # for training on multiple GPUs.
  # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

  return model

  


In [0]:
def train_model(args, model, train_loader, val_loader,
                start_epoch=None, epochs=None):
  cudnn.benchmark = True

  # define loss function (criterion) and pptimizer
  criterion = nn.CrossEntropyLoss()#.cuda()
  optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

  start_epoch = start_epoch or 0
  epochs = epochs or args.epochs

  for epoch in range(start_epoch, epochs):
    adjust_learning_rate(optimizer, epoch, args)

    # train for one epoch
    tr_loss, tr_prec1 = train_epoch(train_loader, model, criterion, optimizer, epoch, args)

    # evaluate on validation set
    val_loss, val_prec1 = validate_epoch(val_loader, model, criterion, epoch, args)

    if args.eval_full_trainset:
      tr_loss, tr_prec1 = validate_epoch(train_loader, model, criterion, epoch, args)

    logging.info('%03d: Acc-tr: %6.2f, Acc-val: %6.2f, L-tr: %6.4f, L-val: %6.4f',
                 epoch, tr_prec1, val_prec1, tr_loss, val_loss)



In [0]:
def train_epoch(train_loader, model, criterion, optimizer, epoch, args):
  """Train for one epoch on the training set"""
  batch_time = AverageMeter()
  losses = AverageMeter()
  top1 = AverageMeter()

  # switch to train mode
  model.train()

  for i, (input, target) in enumerate(train_loader):
    #target = target.cuda(async=True)
    #input = input.cuda()
    input_var = torch.autograd.Variable(input)
    target_var = torch.autograd.Variable(target)

    # compute output
    output = model(input_var)
    loss = criterion(output, target_var)

    # measure accuracy and record loss
    prec1 = accuracy(output.data, target, topk=(1,))[0] 
    losses.update(loss.item(), input.size(0))
    top1.update(prec1.item(), input.size(0))

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  return losses.avg, top1.avg


In [0]:
def validate_epoch(val_loader, model, criterion, epoch, args):
  """Perform validation on the validation set"""
  batch_time = AverageMeter()
  losses = AverageMeter()
  top1 = AverageMeter()

  # switch to evaluate mode
  model.eval()

  for i, (input, target) in enumerate(val_loader):
    #target = target.cuda(async=True)
    #input = input.cuda()
    input_var = torch.autograd.Variable(input, volatile=True)
    target_var = torch.autograd.Variable(target, volatile=True)

    # compute output
    output = model(input_var)
    loss = criterion(output, target_var)

    # measure accuracy and record loss
    prec1 = accuracy(output.data, target, topk=(1,))[0]
    losses.update(loss.item(), input.size(0))
    top1.update(prec1.item(), input.size(0))

  return losses.avg, top1.avg


In [0]:
class AverageMeter(object):
  """Computes and stores the average and current value"""
  def __init__(self):
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

In [0]:
def adjust_learning_rate(optimizer, epoch, args):
  """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs"""
  lr = args.learning_rate * (0.1 ** (epoch // 150)) * (0.1 ** (epoch // 225))
  for param_group in optimizer.param_groups:
      param_group['lr'] = lr

In [0]:
def accuracy(output, target, topk=(1,)):
  """Computes the precision@k for the specified values of k"""
  maxk = max(topk)
  batch_size = target.size(0)

  _, pred = output.topk(maxk, 1, True, True)
  pred = pred.t()
  correct = pred.eq(target.view(1, -1).expand_as(pred))

  res = []
  for k in topk:
      correct_k = correct[:k].view(-1).float().sum(0)
      res.append(correct_k.mul_(100.0 / batch_size))
  return res

In [0]:
def setup_logging(args):
  import datetime
  exp_dir = os.path.join('runs', args.exp_name)
  if not os.path.isdir(exp_dir):
    os.makedirs(exp_dir)
  log_fn = os.path.join(exp_dir, "LOG.{0}.txt".format(datetime.date.today().strftime("%y%m%d")))
  logging.basicConfig(filename=log_fn, filemode='w', level=logging.DEBUG)
  # also log into console
  console = logging.StreamHandler()
  console.setLevel(logging.INFO)
  logging.getLogger('').addHandler(console)
  print('Logging into %s...' % exp_dir)

In [0]:
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--eval-full-trainset', type=bool, default=True,
                    help='Whether to re-evaluate the full train set on a fixed model, or simply ' +
                    'report the running average of training statistics')

parser.add_argument('--command', default='train', choices=['train'])
parser.add_argument('--data', default='cifar10', choices=['cifar10'])
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--data-augmentation', type=bool, default=False)
parser.add_argument('--label-corrupt-prob', type=float, default=0.0)

parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--learning-rate', type=float, default=0.01)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight-decay', type=float, default=1e-4)

parser.add_argument('--arch', default='wide-resnet', choices=['wide-resnet', 'mlp'])

parser.add_argument('--wrn-depth', type=int, default=28)
parser.add_argument('--wrn-widen-factor', type=int, default=1)
parser.add_argument('--wrn-droprate', type=float, default=0.0)

parser.add_argument('--mlp-spec', default='512',
                    help='mlp spec: e.g. 512x128x512 indicates 3 hidden layers')

parser.add_argument('--name', default='', help='Experiment name')


def format_experiment_name(args):
  name = args.name
  if name != '':
    name += '_'

  name += args.data + '_'
  if args.label_corrupt_prob > 0:
    name += 'corrupt%g_' % args.label_corrupt_prob

  name += args.arch
  if args.arch == 'wide-resnet':
    dropmark = '' if args.wrn_droprate == 0 else ('-dr%g' % args.wrn_droprate)
    name += '{0}-{1}{2}'.format(args.wrn_depth, args.wrn_widen_factor, dropmark)
  elif args.arch == 'mlp':
    name += args.mlp_spec

  name += '_lr{0}_mmt{1}'.format(args.learning_rate, args.momentum)
  if args.weight_decay > 0:
    name += '_Wd{0}'.format(args.weight_decay)
  else:
    name += '_NoWd'
  if not args.data_augmentation:
    name += '_NoAug'

  return name


In [14]:
args = parser.parse_args("")
args

Namespace(arch='wide-resnet', batch_size=16, command='train', data='cifar10', data_augmentation=False, epochs=300, eval_full_trainset=True, label_corrupt_prob=0.0, learning_rate=0.01, mlp_spec='512', momentum=0.9, name='', num_classes=10, weight_decay=0.0001, wrn_depth=28, wrn_droprate=0.0, wrn_widen_factor=1)

In [0]:
args.exp_name = format_experiment_name(args)


In [16]:
 setup_logging(args)


Logging into runs/cifar10_wide-resnet28-1_lr0.01_mmt0.9_Wd0.0001_NoAug...


In [0]:
model = get_model(args)

In [18]:
train_loader, val_loader = get_data_loaders(args, shuffle_train=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data


In [0]:
train_model(args, model, train_loader, val_loader)

  del sys.path[0]
  
000: Acc-tr:  64.58, Acc-val:  63.59, L-tr: 0.9940, L-val: 1.0248
001: Acc-tr:  73.01, Acc-val:  70.62, L-tr: 0.7791, L-val: 0.8576
002: Acc-tr:  79.62, Acc-val:  76.31, L-tr: 0.5867, L-val: 0.6949
003: Acc-tr:  83.24, Acc-val:  79.28, L-tr: 0.4840, L-val: 0.6050
004: Acc-tr:  84.97, Acc-val:  80.20, L-tr: 0.4381, L-val: 0.5886
005: Acc-tr:  86.47, Acc-val:  80.76, L-tr: 0.3887, L-val: 0.5665
006: Acc-tr:  88.66, Acc-val:  82.35, L-tr: 0.3275, L-val: 0.5388


In [0]:
  model.save_state_dict('model.pt')
