In [1]:
import argparse
import collections
import sklearn.metrics as metrics
import tensorboardX as tb
import torch as th
import torch.nn.functional as F
import torch.nn.modules.loss as loss
import torch.optim as optim
import torch.utils as utils
import data
import my
import lenet
import resnet

In [2]:
args = argparse.Namespace()
args.batch_size = 50
args.gpu = 1
args.log_every = 1000
args.n_iterations = 50000

keys = sorted(vars(args).keys())
excluded = ('gpu', 'log_every', 'n_iterations')
run_id = 'cifar10-9-1-ce' + '-'.join('%s-%s' % (key, str(getattr(args, key))) for key in keys if key not in excluded)
writer = tb.SummaryWriter('runs/' + run_id)

In [3]:
if args.gpu < 0:
    cuda = False
else:
    cuda = True
    th.cuda.set_device(args.gpu)

labelling = {(0, 9) : 0, (9, 10) : 1}
# labelling = {(0, 5) : 0, (5, 6) : 1, (6, 7) : 2, (7, 8) : 3, (8, 9) : 4, (9, 10) : 5}
train_x, train_y, test_x, test_y = data.load_cifar10(labelling, rbg=True, torch=True)

train_set = utils.data.TensorDataset(train_x, train_y)
train_loader = utils.data.DataLoader(train_set, 4096, drop_last=False)
test_set = utils.data.TensorDataset(test_x, test_y)
test_loader = utils.data.DataLoader(test_set, 4096, drop_last=False)

loader = data.BalancedDataLoader(train_x, train_y, args.batch_size, cuda)

n_classes = int(train_y.max() - train_y.min() + 1)

In [4]:
def global_scores(c, loader, average='binary'):
    keys = ('accuracy', 'precision', 'recall', 'f1')
    scores = (
        metrics.accuracy_score,
        lambda y, y_bar: metrics.precision_score(y, y_bar, average=average),
        lambda y, y_bar: metrics.recall_score(y, y_bar, average=average),
        lambda y, y_bar: metrics.f1_score(y, y_bar, average=average),
    )
    values = [value.item() for value in my.global_scores(c, loader, scores)]
    return collections.OrderedDict(zip(keys, values))

In [5]:
# c = lenet.LeNet(3, n_classes)
c = resnet.ResNet(18, n_classes)

if cuda:
    c.cuda()
    
# optimizer = optim.SGD(c.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4)
optimizer = optim.Adam(c.parameters(), amsgrad=True)

for key, value in global_scores(c, test_loader).items():
    print(key, value)

accuracy 0.9001
precision 0.001
recall 1.0
f1 0.0019980019980019984


In [6]:
n_iterations = 0
for i in range(args.n_iterations):
    x, y = next(loader)
    ce = loss.CrossEntropyLoss()(c(x), y)
    optimizer.zero_grad()
    ce.backward()
    optimizer.step()

    if (i + 1) % args.log_every == 0:
        train_scores = global_scores(c, train_loader)
        test_scores = global_scores(c, test_loader)

        prefix = '0' * (len(str(args.n_iterations)) - len(str(i + 1)))
        print('[iteration %s%d]' % (prefix, i + 1) + \
              ' | '.join('%s %0.3f/%0.3f' % (key, value, test_scores[key]) for key, value in train_scores.items()))

        for key, value in train_scores.items():
            writer.add_scalar('train-' + key, value, i)

        for key, value in test_scores.items():
            writer.add_scalar('test-' + key, value, i)

[iteration 01000]accuracy 0.952/0.948 | precision 0.636/0.615 | recall 0.849/0.820 | f1 0.727/0.703
[iteration 02000]accuracy 0.966/0.960 | precision 0.785/0.772 | recall 0.858/0.820 | f1 0.820/0.795
[iteration 03000]accuracy 0.968/0.962 | precision 0.777/0.735 | recall 0.890/0.862 | f1 0.830/0.793
[iteration 04000]accuracy 0.975/0.965 | precision 0.845/0.785 | recall 0.895/0.849 | f1 0.869/0.816
[iteration 05000]accuracy 0.981/0.971 | precision 0.877/0.808 | recall 0.929/0.891 | f1 0.902/0.847
[iteration 06000]accuracy 0.985/0.972 | precision 0.894/0.820 | recall 0.949/0.891 | f1 0.920/0.854
[iteration 07000]accuracy 0.985/0.971 | precision 0.905/0.815 | recall 0.944/0.889 | f1 0.924/0.850
[iteration 08000]accuracy 0.990/0.972 | precision 0.953/0.853 | recall 0.951/0.867 | f1 0.952/0.860
[iteration 09000]accuracy 0.991/0.974 | precision 0.954/0.856 | recall 0.956/0.878 | f1 0.955/0.867
[iteration 10000]accuracy 0.991/0.972 | precision 0.947/0.833 | recall 0.962/0.881 | f1 0.954/0.857


In [7]:
print(my.global_scores(c, train_loader, metrics.classification_report))
print(my.global_scores(c, test_loader, metrics.classification_report))

             precision    recall  f1-score   support

          0       1.00      1.00      1.00     45000
          1       1.00      1.00      1.00      5000

avg / total       1.00      1.00      1.00     50000

             precision    recall  f1-score   support

          0       0.99      0.98      0.99      9061
          1       0.86      0.92      0.89       939

avg / total       0.98      0.98      0.98     10000

