In [1]:
import argparse
import collections
import sklearn.metrics as metrics
import tensorboardX as tb
import torch as th
import torch.nn.functional as F
import torch.nn.modules.loss as loss
import torch.optim as optim
import torch.utils as utils
import data
import my
import lenet
import resnet

In [2]:
args = argparse.Namespace()
args.batch_size = 500
args.gpu = 1
args.log_every = 1000
args.n_iterations = 10000

keys = sorted(vars(args).keys())
excluded = ('gpu', 'log_every', 'n_iterations')
run_id = 'cifar100-relabelled-ce' + '-'.join('%s-%s' % (key, str(getattr(args, key))) for key in keys if key not in excluded)
writer = tb.SummaryWriter('runs/' + run_id)

In [3]:
if args.gpu < 0:
    cuda = False
else:
    cuda = True
    th.cuda.set_device(args.gpu)

labelling = {(0, 99) : 0, (99, 100) : 1}
train_x, train_y, test_x, test_y = data.load_cifar100(labelling, rbg=True, torch=True)

train_set = utils.data.TensorDataset(train_x, train_y)
train_loader = utils.data.DataLoader(train_set, 4096, drop_last=False)
test_set = utils.data.TensorDataset(test_x, test_y)
test_loader = utils.data.DataLoader(test_set, 4096, drop_last=False)

loader = data.BalancedDataLoader(train_x, train_y, args.batch_size, cuda)

n_classes = int(train_y.max() - train_y.min() + 1)

In [4]:
def global_scores(c, loader):
    keys = ('accuracy', 'precision', 'recall', 'f1')
    scores = (
        metrics.accuracy_score,
        lambda y, y_bar: metrics.precision_score(y, y_bar, average='macro'),
        lambda y, y_bar: metrics.recall_score(y, y_bar, average='macro'),
        lambda y, y_bar: metrics.f1_score(y, y_bar, average='macro'),
    )
    values = [value.item() for value in my.global_scores(c, loader, scores)]
    return collections.OrderedDict(zip(keys, values))

In [5]:
# c = lenet.LeNet(3, n_classes)
c = resnet.ResNet(18, n_classes)

if cuda:
    c.cuda()
    
# optimizer = optim.SGD(c.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4)
optimizer = optim.Adam(c.parameters(), amsgrad=True)

for key, value in global_scores(c, test_loader).items():
    print(key, value)

accuracy 0.99
precision 0.5
recall 0.495
f1 0.49748743718592964


  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


In [6]:
n_iterations = 0
for i in range(args.n_iterations):
    x, y = next(loader)
    ce = loss.CrossEntropyLoss()(c(x), y)
    optimizer.zero_grad()
    ce.backward()
    optimizer.step()

    if (i + 1) % args.log_every == 0:
        train_scores = global_scores(c, train_loader)
        test_scores = global_scores(c, test_loader)

        prefix = '0' * (len(str(args.n_iterations)) - len(str(i + 1)))
        print('[iteration %s%d]' % (prefix, i + 1) + \
              ' | '.join('%s %0.3f/%0.3f' % (key, value, test_scores[key]) for key, value in train_scores.items()))

        for key, value in train_scores.items():
            writer.add_scalar('train-' + key, value, i)

        for key, value in test_scores.items():
            writer.add_scalar('test-' + key, value, i)

[iteration 01000]accuracy 0.992/0.992 | precision 0.639/0.624 | recall 0.878/0.853 | f1 0.702/0.683
[iteration 02000]accuracy 0.999/0.992 | precision 0.943/0.704 | recall 0.988/0.844 | f1 0.965/0.756
[iteration 03000]accuracy 1.000/0.993 | precision 0.998/0.724 | recall 0.998/0.879 | f1 0.998/0.781
[iteration 04000]accuracy 1.000/0.993 | precision 1.000/0.739 | recall 0.998/0.878 | f1 0.999/0.793
[iteration 05000]accuracy 1.000/0.993 | precision 0.999/0.719 | recall 0.999/0.876 | f1 0.999/0.777
[iteration 06000]accuracy 1.000/0.993 | precision 0.999/0.734 | recall 1.000/0.876 | f1 0.999/0.788
[iteration 07000]accuracy 1.000/0.993 | precision 0.999/0.729 | recall 0.999/0.874 | f1 0.999/0.784
[iteration 08000]accuracy 1.000/0.993 | precision 0.999/0.724 | recall 0.999/0.879 | f1 0.999/0.781
[iteration 09000]accuracy 1.000/0.993 | precision 1.000/0.719 | recall 0.998/0.883 | f1 0.999/0.779
[iteration 10000]accuracy 1.000/0.993 | precision 1.000/0.719 | recall 0.998/0.883 | f1 0.999/0.779


In [7]:
print(my.global_scores(c, train_loader, metrics.classification_report))
print(my.global_scores(c, test_loader, metrics.classification_report))

             precision    recall  f1-score   support

          0       1.00      1.00      1.00     49498
          1       1.00      1.00      1.00       502

avg / total       1.00      1.00      1.00     50000

             precision    recall  f1-score   support

          0       1.00      0.99      1.00      9943
          1       0.44      0.77      0.56        57

avg / total       1.00      0.99      0.99     10000

