In [1]:
from functools import partial

import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from config import config
from datasets import prepare_data_loaders
from models.utils import seed_fn
from models.modules import EMAWeightOptimizer, ContentBCE, SimpleNet
from main import train_cycle

In [2]:
seed_fn()

unsup_weight = 3
learning_rate = 0.001
num_epoch = 30
logdir = 'logs/simplenet_arcface_relu'
model_name = 'simplenet_arcface_relu'

mixed_train_loaders = prepare_data_loaders(config['mixed'], 'train')
svhn_loaders = prepare_data_loaders(config['svhn'], ['valid', 'train'])
mnist_valid_loader = prepare_data_loaders(config['mnist'], 'valid')
loaders = {
    'train': mixed_train_loaders,
    'source_train': svhn_loaders['train'],
    'source_valid': svhn_loaders['valid'],
    'target_valid': mnist_valid_loader,
}

pre_model = partial(SimpleNet, num_classes=10)

student_model = pre_model().cuda()
teacher_model = pre_model().cuda()

student_params = list(student_model.parameters())
teacher_params = list(teacher_model.parameters())

for param in teacher_params:
    param.requires_grad = False

student_optimizer = optim.Adam(student_params, lr=learning_rate, weight_decay=5e-3)
teacher_optimizer = EMAWeightOptimizer(teacher_model, student_model, alpha=0.99)
scheduler = optim.lr_scheduler.OneCycleLR(student_optimizer, max_lr=learning_rate, epochs=30, steps_per_epoch=235, 
                                          pct_start=0.3, div_factor=10)

supervised_criterion = nn.CrossEntropyLoss()
unsupervised_criterion = ContentBCE(0.9)

In [3]:
train_cycle(num_epoch, loaders, model_name, student_model, teacher_model,
            student_optimizer, teacher_optimizer,
            supervised_criterion, unsupervised_criterion,
            scheduler=scheduler, 
            metrics=config['metrics'], 
            summary_writer=SummaryWriter(logdir), 
            unsup_weight=unsup_weight)

Start epoch 0


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 1.842
accuracy: 0.38614166666666666


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 1.777
accuracy: 0.44676959198438376


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 1.711
accuracy: 0.4958896742470805


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 1.827
accuracy: 0.3852
Previous best score:  0 Current score:  0.3852

Start epoch 1


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.842
accuracy: 0.8017833333333333


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.585
accuracy: 0.8529969832234462


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.536
accuracy: 0.875


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.990
accuracy: 0.7106
Previous best score:  0.3852 Current score:  0.7106

Start epoch 2


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.530
accuracy: 0.8742833333333333


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.320
accuracy: 0.906411673969723


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.276
accuracy: 0.9215580823601721


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.618
accuracy: 0.8151
Previous best score:  0.7106 Current score:  0.8151

Start epoch 3


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.456
accuracy: 0.89325


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.265
accuracy: 0.9231199748829464


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.224
accuracy: 0.9355792870313461


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.397
accuracy: 0.8822
Previous best score:  0.8151 Current score:  0.8822

Start epoch 4


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.419
accuracy: 0.9057833333333334


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.239
accuracy: 0.9321976056895587


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.199
accuracy: 0.9438767670559312


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.232
accuracy: 0.9337
Previous best score:  0.8822 Current score:  0.9337

Start epoch 5


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.401
accuracy: 0.9103083333333334


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.223
accuracy: 0.9370572095499407


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.186
accuracy: 0.9478718500307314


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.153
accuracy: 0.9569
Previous best score:  0.9337 Current score:  0.9569

Start epoch 6


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.373
accuracy: 0.9154083333333334


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.208
accuracy: 0.9414799950857938


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.174
accuracy: 0.9524431468961279


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.102
accuracy: 0.9688
Previous best score:  0.9569 Current score:  0.9688

Start epoch 7


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.357
accuracy: 0.91945


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.199
accuracy: 0.9439507487339094


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.166
accuracy: 0.9548248309772588


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.081
accuracy: 0.9762
Previous best score:  0.9688 Current score:  0.9762

Start epoch 8


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.345
accuracy: 0.9217416666666667


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.199
accuracy: 0.9435139304093807


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.167
accuracy: 0.9534035033804549


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.065
accuracy: 0.9811
Previous best score:  0.9762 Current score:  0.9811

Start epoch 9


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.339
accuracy: 0.9232166666666667


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.191
accuracy: 0.947131332159384


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.159
accuracy: 0.9559772587584512


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.054
accuracy: 0.9842
Previous best score:  0.9811 Current score:  0.9842

Start epoch 10


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.319
accuracy: 0.92765


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.192
accuracy: 0.9463259483735342


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.161
accuracy: 0.9562845728334358


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.051
accuracy: 0.9846
Previous best score:  0.9842 Current score:  0.9846

Start epoch 11


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.317
accuracy: 0.9264916666666667


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.185
accuracy: 0.9483598836971211


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.154
accuracy: 0.9573985863552551


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.044
accuracy: 0.9867
Previous best score:  0.9846 Current score:  0.9867

Start epoch 12


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.312
accuracy: 0.92775


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.184
accuracy: 0.9486465457225931


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.156
accuracy: 0.9574370006146281


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.042
accuracy: 0.9871
Previous best score:  0.9867 Current score:  0.9871

Start epoch 13


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.308
accuracy: 0.9296583333333334


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.179
accuracy: 0.9497931938244809


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.154
accuracy: 0.9587046711739398


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.037
accuracy: 0.9883
Previous best score:  0.9871 Current score:  0.9883

Start epoch 14


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.300
accuracy: 0.930575


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.177
accuracy: 0.9506395293282554


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.151
accuracy: 0.9602028272894899


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.035
accuracy: 0.9883
Previous best score:  0.9883 Current score:  0.9883

Start epoch 15


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.303
accuracy: 0.9299166666666666


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.172
accuracy: 0.952277598045238


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.148
accuracy: 0.9597418561770129


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.033
accuracy: 0.9893
Previous best score:  0.9883 Current score:  0.9893

Start epoch 16


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.290
accuracy: 0.9337916666666667


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.174
accuracy: 0.9518953820112753


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.149
accuracy: 0.9601259987707437


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.032
accuracy: 0.9889
Previous best score:  0.9893 Current score:  0.9889

Start epoch 17


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.297
accuracy: 0.9316


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.173
accuracy: 0.9528236209508989


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.150
accuracy: 0.9595881991395205


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.031
accuracy: 0.9897
Previous best score:  0.9893 Current score:  0.9897

Start epoch 18


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.287
accuracy: 0.9350833333333334


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.165
accuracy: 0.9548302551292027


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.140
accuracy: 0.9625460971112477


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.030
accuracy: 0.9904
Previous best score:  0.9897 Current score:  0.9904

Start epoch 19


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.285
accuracy: 0.9349583333333333


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.173
accuracy: 0.9516223705584449


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.148
accuracy: 0.9590119852489244


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.029
accuracy: 0.9904
Previous best score:  0.9904 Current score:  0.9904

Start epoch 20


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.287
accuracy: 0.9331916666666666


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.162
accuracy: 0.9557038917782601


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.140
accuracy: 0.96269975414874


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.029
accuracy: 0.9906
Previous best score:  0.9904 Current score:  0.9906

Start epoch 21


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.284
accuracy: 0.9346666666666666


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.163
accuracy: 0.9554991331886372


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.145
accuracy: 0.9611631837738168


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.028
accuracy: 0.991
Previous best score:  0.9906 Current score:  0.991

Start epoch 22


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.272
accuracy: 0.9389166666666666


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.166
accuracy: 0.9541340759244851


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.144
accuracy: 0.9619314689612785


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.028
accuracy: 0.9902
Previous best score:  0.991 Current score:  0.9902

Start epoch 23


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.286
accuracy: 0.9342416666666666


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.172
accuracy: 0.9521820440367473


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.147
accuracy: 0.959319299323909


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.027
accuracy: 0.9916
Previous best score:  0.991 Current score:  0.9916

Start epoch 24


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.279
accuracy: 0.9371166666666667


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.174
accuracy: 0.9517998280027847


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.149
accuracy: 0.9592808850645359


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.026
accuracy: 0.992
Previous best score:  0.9916 Current score:  0.992

Start epoch 25


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.275
accuracy: 0.938125


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.163
accuracy: 0.9559223009405244


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.143
accuracy: 0.9611631837738168


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.026
accuracy: 0.9919
Previous best score:  0.992 Current score:  0.9919

Start epoch 26


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.272
accuracy: 0.93965


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.160
accuracy: 0.955362627462222


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.139
accuracy: 0.9628918254456054


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.025
accuracy: 0.992
Previous best score:  0.992 Current score:  0.992

Start epoch 27


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.267
accuracy: 0.9408583333333334


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.160
accuracy: 0.9557994457867508


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.141
accuracy: 0.9617778119237861


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.025
accuracy: 0.9927
Previous best score:  0.992 Current score:  0.9927

Start epoch 28


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.270
accuracy: 0.9397166666666666


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.179
accuracy: 0.9499433501235377


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.152
accuracy: 0.9584357713583282


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.025
accuracy: 0.9925
Previous best score:  0.9927 Current score:  0.9925

Start epoch 29


HBox(children=(IntProgress(value=0, max=235), HTML(value='')))


train loss: 0.266
accuracy: 0.9412166666666667


HBox(children=(IntProgress(value=0, max=153), HTML(value='')))


source_train loss: 0.170
accuracy: 0.9534242461471258


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


source_valid loss: 0.149
accuracy: 0.9586278426551936


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))


target_valid loss: 0.024
accuracy: 0.9929
Previous best score:  0.9927 Current score:  0.9929

