In [1]:
import argparse
import collections
import sklearn.metrics as metrics
import tensorboardX as tb
import torch as th
import torch.nn.functional as F
import torch.nn.modules.loss as loss
import torch.optim as optim
import torch.utils as utils
import data
import my
import lenet
import resnet

In [2]:
args = argparse.Namespace()
args.batch_size = 50
args.gpu = 1
args.log_every = 1000
args.n_iterations = 50000

keys = sorted(vars(args).keys())
excluded = ('gpu', 'log_every', 'n_iterations')
run_id = 'cifar10-9-1-ce' + '-'.join('%s-%s' % (key, str(getattr(args, key))) for key in keys if key not in excluded)
writer = tb.SummaryWriter('runs/' + run_id)

In [3]:
if args.gpu < 0:
    cuda = False
else:
    cuda = True
    th.cuda.set_device(args.gpu)

labelling = {(0, 9) : 0, (9, 10) : 1}
# labelling = {(0, 5) : 0, (5, 6) : 1, (6, 7) : 2, (7, 8) : 3, (8, 9) : 4, (9, 10) : 5}
train_x, train_y, test_x, test_y = data.load_cifar10(labelling, rbg=True, torch=True)

train_set = utils.data.TensorDataset(train_x, train_y)
train_loader = utils.data.DataLoader(train_set, 4096, drop_last=False)
test_set = utils.data.TensorDataset(test_x, test_y)
test_loader = utils.data.DataLoader(test_set, 4096, drop_last=False)

loader = data.BalancedDataLoader(train_x, train_y, args.batch_size, cuda)

n_classes = int(train_y.max() - train_y.min() + 1)

In [4]:
def global_scores(c, loader):
    keys = ('accuracy', 'precision', 'recall', 'f1')
    scores = (
        metrics.accuracy_score,
        lambda y, y_bar: metrics.precision_score(y, y_bar, average='macro'),
        lambda y, y_bar: metrics.recall_score(y, y_bar, average='macro'),
        lambda y, y_bar: metrics.f1_score(y, y_bar, average='macro'),
    )
    values = [value.item() for value in my.global_scores(c, loader, scores)]
    return collections.OrderedDict(zip(keys, values))

In [5]:
# c = lenet.LeNet(3, n_classes)
c = resnet.ResNet(18, n_classes)

if cuda:
    c.cuda()
    
# optimizer = optim.SGD(c.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4)
optimizer = optim.Adam(c.parameters(), amsgrad=True)

for key, value in global_scores(c, test_loader).items():
    print(key, value)

accuracy 0.1
precision 0.5
recall 0.05
f1 0.09090909090909091


  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


In [6]:
n_iterations = 0
for i in range(args.n_iterations):
    x, y = next(loader)
    ce = loss.CrossEntropyLoss()(c(x), y)
    optimizer.zero_grad()
    ce.backward()
    optimizer.step()

    if (i + 1) % args.log_every == 0:
        train_scores = global_scores(c, train_loader)
        test_scores = global_scores(c, test_loader)

        prefix = '0' * (len(str(args.n_iterations)) - len(str(i + 1)))
        print('[iteration %s%d]' % (prefix, i + 1) + \
              ' | '.join('%s %0.3f/%0.3f' % (key, value, test_scores[key]) for key, value in train_scores.items()))

        for key, value in train_scores.items():
            writer.add_scalar('train-' + key, value, i)

        for key, value in test_scores.items():
            writer.add_scalar('test-' + key, value, i)

[iteration 01000]accuracy 0.948/0.948 | precision 0.795/0.795 | recall 0.895/0.893 | f1 0.836/0.835
[iteration 02000]accuracy 0.965/0.961 | precision 0.897/0.882 | recall 0.907/0.898 | f1 0.902/0.889
[iteration 03000]accuracy 0.973/0.965 | precision 0.890/0.863 | recall 0.952/0.934 | f1 0.918/0.894
[iteration 04000]accuracy 0.978/0.970 | precision 0.931/0.907 | recall 0.943/0.924 | f1 0.937/0.915
[iteration 05000]accuracy 0.981/0.972 | precision 0.939/0.906 | recall 0.956/0.936 | f1 0.948/0.921
[iteration 06000]accuracy 0.984/0.974 | precision 0.952/0.916 | recall 0.959/0.935 | f1 0.955/0.925
[iteration 07000]accuracy 0.988/0.976 | precision 0.960/0.920 | recall 0.974/0.946 | f1 0.967/0.933
[iteration 08000]accuracy 0.989/0.974 | precision 0.955/0.897 | recall 0.983/0.955 | f1 0.969/0.924
[iteration 09000]accuracy 0.991/0.976 | precision 0.967/0.909 | recall 0.985/0.952 | f1 0.976/0.929
[iteration 10000]accuracy 0.994/0.977 | precision 0.978/0.911 | recall 0.989/0.955 | f1 0.983/0.932


In [7]:
print(my.global_scores(c, train_loader, metrics.classification_report))
print(my.global_scores(c, test_loader, metrics.classification_report))

             precision    recall  f1-score   support

          0       1.00      1.00      1.00     45000
          1       1.00      1.00      1.00      5000

avg / total       1.00      1.00      1.00     50000

             precision    recall  f1-score   support

          0       0.99      0.98      0.99      9084
          1       0.85      0.93      0.89       916

avg / total       0.98      0.98      0.98     10000

