In [1]:
import os
import shutil
from collections import Counter
import numpy as np
import torch
import torch.utils.data
import pandas as pd
import argparse
import re
from helpers import makedir
import model
import save
from log import create_logger
import train_and_test as tnt
from helpers import makedir
import find_nearest
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import push

from preprocess import mean, std, preprocess_input_function
from settings import train_dir, test_dir, train_push_dir
from settings import img_size, num_classes, prototype_activation_function, add_on_layers_type

from bounding_box_metrics import bounding_box_overlap
from find_nearest import find_k_nearest_patches_to_prototypes
from settings import coefs



In [8]:
'''
Model initialization for vgg base; do not run this block when creating other models 
'''

base_architecture = 'vgg19'
prototype_shape = (2000, 128, 1, 1)

ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)

ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
ppnet_multi = torch.nn.DataParallel(ppnet).cuda()
class_specific = True

warm_optimizer_lrs = {'add_on_layers': 1e-3,
                      'prototype_vectors': 1e-3}
warm_optimizer_specs = \
    [{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
     {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']},
    ]
warm_optimizer = torch.optim.Adam(warm_optimizer_specs)

joint_optimizer_lrs = {'features': 1e-4,
                       'add_on_layers': 1e-3,
                       'prototype_vectors': 1e-3}
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)

joint_lr_step_size = 5
gamma = 0.5
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.5)

last_layer_optimizer_lr = 1e-4
last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)

train_batch_size, test_batch_size, train_push_batch_size = 128, 128, 128
push_epochs = [20, 40, 60, 99]
push_saved_epochs = [20, 40, 60, 99]

model_dir = './ppnet_results/007_vgg19/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))

In [4]:
'''
Model initialization for resnet base; do not run this block when creating other models 
'''

base_architecture = 'resnet34'
prototype_shape = (2000, 128, 1, 1)

ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)

ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
ppnet_multi = torch.nn.DataParallel(ppnet).cuda()
class_specific = True

warm_optimizer_lrs = {'add_on_layers': 1e-3,
                      'prototype_vectors': 1e-3}
warm_optimizer_specs = \
    [{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
     {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']},
    ]
warm_optimizer = torch.optim.Adam(warm_optimizer_specs)

joint_optimizer_lrs = {'features': 1e-4,
                       'add_on_layers': 1e-3,
                       'prototype_vectors': 1e-3}
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)

joint_lr_step_size = 5
gamma = 0.5
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.5)

last_layer_optimizer_lr = 1e-4
last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)

train_batch_size, test_batch_size, train_push_batch_size = 128, 128, 128
push_epochs = [20, 40, 60, 99]
push_saved_epochs = [20, 40, 60, 99]

model_dir = '/scratch/users/jiaxun1218/saved_models/resnet34/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))


In [10]:
'''
Model initialization for densenet; do not run this block when creating other models 
'''

base_architecture = 'densenet161'
prototype_shape = (2000, 128, 1, 1)

ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)

ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
ppnet_multi = torch.nn.DataParallel(ppnet).cuda()
class_specific = True

warm_optimizer_lrs = {'add_on_layers': 1e-3,
                      'prototype_vectors': 1e-3}
warm_optimizer_specs = \
    [{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
     {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']},
    ]
warm_optimizer = torch.optim.Adam(warm_optimizer_specs)

joint_optimizer_lrs = {'features': 1e-4,
                       'add_on_layers': 1e-3,
                       'prototype_vectors': 1e-3}
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)

joint_lr_step_size = 5
gamma = 0.5
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.5)

last_layer_optimizer_lr = 1e-4
last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)

train_batch_size, test_batch_size, train_push_batch_size = 128, 128, 128
push_epochs = [20, 40, 60, 99]
push_saved_epochs = [20, 40, 60, 99]

model_dir = '/scratch/users/jiaxun1218/saved_models/densenet161/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))


Downloading: "https://download.pytorch.org/models/densenet161-8d451a50.pth" to ./pretrained_models/densenet161-8d451a50.pth
100%|██████████| 110M/110M [00:01<00:00, 116MB/s] 


In [11]:
normalize = transforms.Normalize(mean=mean,
                                 std=std)
train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True,
    num_workers=2, pin_memory=False)

# push set
train_push_dataset = datasets.ImageFolder(
    train_push_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
]))

train_push_loader = torch.utils.data.DataLoader(
    train_push_dataset, batch_size=train_push_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)

# test set
test_dataset = datasets.ImageFolder(
    test_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)

In [12]:
for epoch in range(0, 100):
    print('epoch: \t{0}'.format(epoch))

    #if epoch in range(0, 10):
    #    tnt.joint_warm(model=ppnet_multi, log=log)
    #else:
    
    if epoch in range(0, 5):
        tnt.warm_only(model=ppnet_multi, log=log)
        _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=warm_optimizer,
                  class_specific=class_specific, coefs=coefs, log=log)
        
    else:
        tnt.joint(model=ppnet_multi, log=log)
        _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=joint_optimizer,
                  class_specific=class_specific, coefs=coefs, log=log)
        joint_lr_scheduler.step()

#     accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
#                     class_specific=class_specific, log=log)
#     save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'nopush', accu=accu,
#                                 target_accu=0.70, log=log)

    if epoch in push_epochs:
        # Only save the model in in push_saved_epochs
        if epoch in push_saved_epochs:
            bounding_box_tracker = push.push_prototypes(
                train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
                prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors
                class_specific=class_specific,
                preprocess_input_function=preprocess_input_function, # normalize if needed
                prototype_layer_stride=1,
                root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten
                prototype_img_filename_prefix=None,
                prototype_self_act_filename_prefix=None,
                proto_bound_boxes_filename_prefix=None,
                save_prototype_class_identity=True,
                log=log,
                bounding_box_tracker=None)
            accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                            class_specific=class_specific, log=log)
            save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu,
                                        target_accu=0.76, log=log)

           

        if prototype_activation_function != 'linear':
            tnt.last_only(model=ppnet_multi, log=log)
            # Fine tune the last layers
            for i in range(8):
                log('iteration: \t{0}'.format(i))
                _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=last_layer_optimizer,
                              class_specific=class_specific, coefs=coefs, log=log)
                accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, log=log)
                print("Test accuracy: ", accu)
                save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=accu,
                                target_accu=0.76, log=log)
            
            save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=accu,
                                target_accu=0.76, log=log)
            # Save the last with the final layer fine-tuned
        
    accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, log=log)
    print("Test accuracy: ", accu)
    save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch), accu=accu,
                                target_accu=0.76, log=log)

epoch: 	0
	warm
	train


1405it [08:39,  2.71it/s]


	time: 	519.3797466754913
	cross ent: 	4.7905920799092465
	cluster: 	1.7534360656216474
	separation:	1.4305785773567246
	avg separation:	9.857729778086163
	accu: 		7.18385051718385%
	l1: 		201000.0
	p dist pair: 	19.467283248901367
	test


46it [00:17,  2.62it/s]


	time: 	17.71340012550354
	cross ent: 	4.0388519038324775
	cluster: 	0.15035044045551962
	separation:	0.02786378910684067
	avg separation:	9.562138951343039
	accu: 		13.030721435968243%
	l1: 		201000.0
	p dist pair: 	19.467283248901367
Test accuracy:  0.13030721435968243
epoch: 	1
	warm
	train


1405it [08:42,  2.69it/s]


	time: 	522.4186036586761
	cross ent: 	3.127825588348497
	cluster: 	0.09601140156441312
	separation:	0.02537955978746092
	avg separation:	8.990663071628992
	accu: 		25.176843510176845%
	l1: 		201000.0
	p dist pair: 	17.13971519470215
	test


46it [00:17,  2.69it/s]


	time: 	17.358643531799316
	cross ent: 	2.760948289995608
	cluster: 	0.10254242433154065
	separation:	0.02658904412680346
	avg separation:	8.345132060672926
	accu: 		30.89402830514325%
	l1: 		201000.0
	p dist pair: 	17.13971519470215
Test accuracy:  0.3089402830514325
epoch: 	2
	warm
	train


1405it [08:44,  2.68it/s]


	time: 	524.251060962677
	cross ent: 	1.960029955181787
	cluster: 	0.07127862664940518
	separation:	0.030742590134678362
	avg separation:	7.7159495085583885
	accu: 		47.14325436547659%
	l1: 		201000.0
	p dist pair: 	14.90956974029541
	test


46it [00:17,  2.66it/s]


	time: 	17.52150869369507
	cross ent: 	2.0574652811755305
	cluster: 	0.08947994138883508
	separation:	0.0357327349079044
	avg separation:	7.160227392030799
	accu: 		44.33897134967207%
	l1: 		201000.0
	p dist pair: 	14.90956974029541
Test accuracy:  0.44338971349672074
epoch: 	3
	warm
	train


1405it [08:42,  2.69it/s]


	time: 	522.8884618282318
	cross ent: 	1.3332104642196059
	cluster: 	0.0639174494189724
	separation:	0.037986899254797195
	avg separation:	6.638148860320502
	accu: 		62.72327883438995%
	l1: 		201000.0
	p dist pair: 	13.03901195526123
	test


46it [00:17,  2.67it/s]


	time: 	17.38951873779297
	cross ent: 	1.8480012572329978
	cluster: 	0.08699104708174
	separation:	0.03921900616715784
	avg separation:	6.172654390335083
	accu: 		49.17155678287884%
	l1: 		201000.0
	p dist pair: 	13.03901195526123
Test accuracy:  0.4917155678287884
epoch: 	4
	warm
	train


1405it [08:42,  2.69it/s]


	time: 	522.2639009952545
	cross ent: 	0.987990308656387
	cluster: 	0.06283804499859064
	separation:	0.04386117627241009
	avg separation:	5.78716905668537
	accu: 		72.04982760538317%
	l1: 		201000.0
	p dist pair: 	11.57022476196289
	test


46it [00:17,  2.67it/s]


	time: 	17.48765540122986
	cross ent: 	1.7281557010567707
	cluster: 	0.08760713753492935
	separation:	0.04142333575240944
	avg separation:	5.423294730808424
	accu: 		53.072143596824304%
	l1: 		201000.0
	p dist pair: 	11.57022476196289
Test accuracy:  0.5307214359682431
epoch: 	5
	joint
	train


1405it [15:46,  1.48it/s]


	time: 	946.9248020648956
	cross ent: 	0.43360619147265084
	cluster: 	0.051461943885927945
	separation:	0.05135943494286401
	avg separation:	5.368987154366707
	accu: 		88.00077855633411%
	l1: 		201000.0
	p dist pair: 	11.107673645019531
	test


46it [00:17,  2.65it/s]


	time: 	17.638038396835327
	cross ent: 	1.2611286886360333
	cluster: 	0.071277578158871
	separation:	0.04994214471915494
	avg separation:	5.210546980733457
	accu: 		67.55264066275457%
	l1: 		201000.0
	p dist pair: 	11.107673645019531
Test accuracy:  0.6755264066275457
epoch: 	6
	joint
	train


1405it [15:45,  1.49it/s]


	time: 	945.7870688438416
	cross ent: 	0.12774042166583902
	cluster: 	0.03730304170386647
	separation:	0.06086874375139691
	avg separation:	5.0777796422883705
	accu: 		96.75675675675676%
	l1: 		201000.0
	p dist pair: 	10.5148344039917
	test


46it [00:17,  2.67it/s]


	time: 	17.35269784927368
	cross ent: 	1.3003558594247568
	cluster: 	0.075099627851792
	separation:	0.055130296589239784
	avg separation:	4.929481475249581
	accu: 		69.0714532274767%
	l1: 		201000.0
	p dist pair: 	10.5148344039917
Test accuracy:  0.690714532274767
epoch: 	7
	joint
	train


1405it [15:44,  1.49it/s]


	time: 	944.5445446968079
	cross ent: 	0.07014070653300268
	cluster: 	0.031897193315402465
	separation:	0.06521914045581614
	avg separation:	4.846441993509748
	accu: 		98.24602380157935%
	l1: 		201000.0
	p dist pair: 	10.14432430267334
	test


46it [00:18,  2.48it/s]


	time: 	18.822028875350952
	cross ent: 	1.2573473764502483
	cluster: 	0.0714864682244218
	separation:	0.059819871154816254
	avg separation:	4.743803542593251
	accu: 		70.33137728684846%
	l1: 		201000.0
	p dist pair: 	10.14432430267334
Test accuracy:  0.7033137728684846
epoch: 	8
	joint
	train


1405it [15:42,  1.49it/s]


	time: 	942.9451603889465
	cross ent: 	0.07343651868429472
	cluster: 	0.03309611863725126
	separation:	0.07182535831932496
	avg separation:	4.670403616962908
	accu: 		98.06306306306305%
	l1: 		201000.0
	p dist pair: 	9.857073783874512
	test


46it [00:17,  2.64it/s]


	time: 	17.64574360847473
	cross ent: 	1.3240969828937366
	cluster: 	0.07549974864915661
	separation:	0.06565197152288063
	avg separation:	4.609466386877972
	accu: 		70.46945115636866%
	l1: 		201000.0
	p dist pair: 	9.857073783874512
Test accuracy:  0.7046945115636866
epoch: 	9
	joint
	train


1405it [15:39,  1.50it/s]


	time: 	939.2335751056671
	cross ent: 	0.07139614568658677
	cluster: 	0.03326503044239567
	separation:	0.07963073640213318
	avg separation:	4.566268227193704
	accu: 		98.11589367144923%
	l1: 		201000.0
	p dist pair: 	9.674647331237793
	test


46it [00:17,  2.62it/s]


	time: 	17.734500646591187
	cross ent: 	1.296534319286761
	cluster: 	0.07351782245804435
	separation:	0.0725031751005546
	avg separation:	4.5172032066013506
	accu: 		72.00552295478082%
	l1: 		201000.0
	p dist pair: 	9.674647331237793
Test accuracy:  0.7200552295478081
epoch: 	10
	joint
	train


1405it [15:37,  1.50it/s]


	time: 	938.1361243724823
	cross ent: 	0.008019843794025558
	cluster: 	0.01583986459134206
	separation:	0.10293395336836683
	avg separation:	4.523675139688512
	accu: 		99.81870759648538%
	l1: 		201000.0
	p dist pair: 	9.581328392028809
	test


46it [00:17,  2.64it/s]


	time: 	17.623018503189087
	cross ent: 	1.1266653829294702
	cluster: 	0.06530033319216708
	separation:	0.1032672216710837
	avg separation:	4.517140564711197
	accu: 		76.2167759751467%
	l1: 		201000.0
	p dist pair: 	9.581328392028809
Test accuracy:  0.762167759751467
	above 76.00%
epoch: 	11
	joint
	train


1405it [15:35,  1.50it/s]


	time: 	935.3936235904694
	cross ent: 	0.01723439512032045
	cluster: 	0.016217535646678716
	separation:	0.13754049606912924
	avg separation:	4.508335390803652
	accu: 		99.57457457457457%
	l1: 		201000.0
	p dist pair: 	9.491765975952148
	test


46it [00:17,  2.70it/s]


	time: 	17.290948152542114
	cross ent: 	1.3413943350315094
	cluster: 	0.07986576280192188
	separation:	0.10387205883212712
	avg separation:	4.456866824108621
	accu: 		72.04004142216085%
	l1: 		201000.0
	p dist pair: 	9.491765975952148
Test accuracy:  0.7204004142216086
epoch: 	12
	joint
	train


1405it [15:36,  1.50it/s]


	time: 	937.0760126113892
	cross ent: 	0.018855178781853013
	cluster: 	0.0175221060346348
	separation:	0.13615676788248626
	avg separation:	4.445093722360414
	accu: 		99.53564675786897%
	l1: 		201000.0
	p dist pair: 	9.40796184539795
	test


46it [00:17,  2.65it/s]


	time: 	17.53559374809265
	cross ent: 	1.3058302143345708
	cluster: 	0.08102401872367962
	separation:	0.12454107257982959
	avg separation:	4.444021214609561
	accu: 		74.11114946496376%
	l1: 		201000.0
	p dist pair: 	9.40796184539795
Test accuracy:  0.7411114946496375
epoch: 	13
	joint
	train


1405it [15:35,  1.50it/s]


	time: 	935.7076163291931
	cross ent: 	0.02672057153029016
	cluster: 	0.019816217415061285
	separation:	0.14714150037417634
	avg separation:	4.408021684558366
	accu: 		99.31598264931598%
	l1: 		201000.0
	p dist pair: 	9.320932388305664
	test


46it [00:17,  2.63it/s]


	time: 	17.699159145355225
	cross ent: 	1.321306864204614
	cluster: 	0.07777145699314449
	separation:	0.11828590297828549
	avg separation:	4.37506215468697
	accu: 		73.88677942699344%
	l1: 		201000.0
	p dist pair: 	9.320932388305664
Test accuracy:  0.7388677942699344
epoch: 	14
	joint
	train


1405it [15:37,  1.50it/s]


	time: 	937.5758759975433
	cross ent: 	0.009042269284867501
	cluster: 	0.013258638818254462
	separation:	0.18137242352835226
	avg separation:	4.417545116329532
	accu: 		99.79423868312757%
	l1: 		201000.0
	p dist pair: 	9.322993278503418
	test


46it [00:17,  2.63it/s]


	time: 	17.625316381454468
	cross ent: 	1.238804420699244
	cluster: 	0.0914674985991872
	separation:	0.17009928595760596
	avg separation:	4.4405820784361465
	accu: 		75.81981360027615%
	l1: 		201000.0
	p dist pair: 	9.322993278503418
Test accuracy:  0.7581981360027614
epoch: 	15
	joint
	train


1405it [15:40,  1.49it/s]


	time: 	940.4532306194305
	cross ent: 	0.002120751785625278
	cluster: 	0.008956051254611847
	separation:	0.27168805168616816
	avg separation:	4.466067482483345
	accu: 		99.95884773662551%
	l1: 		201000.0
	p dist pair: 	9.360575675964355
	test


46it [00:17,  2.62it/s]


	time: 	17.798328399658203
	cross ent: 	1.117713992362437
	cluster: 	0.1185679331259883
	separation:	0.2770588184180467
	avg separation:	4.50194033332493
	accu: 		78.32240248532966%
	l1: 		201000.0
	p dist pair: 	9.360575675964355
Test accuracy:  0.7832240248532966
	above 76.00%
epoch: 	16
	joint
	train


1405it [15:46,  1.49it/s]


	time: 	946.4413831233978
	cross ent: 	0.0014824204533277522
	cluster: 	0.008721500928132559
	separation:	0.45631919810780425
	avg separation:	4.588345516109806
	accu: 		99.97497497497497%
	l1: 		201000.0
	p dist pair: 	9.59024429321289
	test


46it [00:17,  2.64it/s]


	time: 	17.61433434486389
	cross ent: 	1.2306682784920153
	cluster: 	0.1902751277970231
	separation:	0.4668898452883181
	avg separation:	4.689864251924598
	accu: 		77.01070072488781%
	l1: 		201000.0
	p dist pair: 	9.59024429321289
Test accuracy:  0.7701070072488782
	above 76.00%
epoch: 	17
	joint
	train


1405it [15:42,  1.49it/s]


	time: 	942.484277009964
	cross ent: 	0.0034060751280219756
	cluster: 	0.014269860580669603
	separation:	0.8390265022732609
	avg separation:	5.0393443209420745
	accu: 		99.9460571682794%
	l1: 		201000.0
	p dist pair: 	10.28971004486084
	test


46it [00:17,  2.66it/s]


	time: 	17.47099804878235
	cross ent: 	1.2641144537407418
	cluster: 	0.31306956449280615
	separation:	0.9431581898875858
	avg separation:	5.59945681820745
	accu: 		77.26958923023818%
	l1: 		201000.0
	p dist pair: 	10.28971004486084
Test accuracy:  0.7726958923023818
	above 76.00%
epoch: 	18
	joint
	train


1405it [15:40,  1.49it/s]


	time: 	940.9686245918274
	cross ent: 	0.008584913739096043
	cluster: 	0.029048151681961964
	separation:	2.064534478000899
	avg separation:	7.178990892498518
	accu: 		99.90156823490157%
	l1: 		201000.0
	p dist pair: 	11.839831352233887
	test


46it [00:17,  2.61it/s]


	time: 	17.856024742126465
	cross ent: 	1.2716398297444633
	cluster: 	0.549814271538154
	separation:	2.2995940317278323
	avg separation:	8.47033201093259
	accu: 		77.70107007248878%
	l1: 		201000.0
	p dist pair: 	11.839831352233887
Test accuracy:  0.7770107007248878
	above 76.00%
epoch: 	19
	joint
	train


1405it [15:36,  1.50it/s]


	time: 	937.1266820430756
	cross ent: 	0.013208897933082197
	cluster: 	0.03835959378033346
	separation:	5.062068303878621
	avg separation:	11.731820972619108
	accu: 		99.96774552330108%
	l1: 		201000.0
	p dist pair: 	15.043126106262207
	test


46it [00:17,  2.65it/s]


	time: 	17.509530782699585
	cross ent: 	1.4441641711670419
	cluster: 	1.0807057554307191
	separation:	5.175519200770752
	avg separation:	13.475024866021197
	accu: 		78.32240248532966%
	l1: 		201000.0
	p dist pair: 	15.043126106262207
Test accuracy:  0.7832240248532966
	above 76.00%
epoch: 	20
	joint
	train


1405it [15:40,  1.49it/s]


	time: 	940.4572038650513
	cross ent: 	0.015672081507397715
	cluster: 	0.03119476279939833
	separation:	8.950711749881188
	avg separation:	16.91864326585654
	accu: 		99.99110221332444%
	l1: 		201000.0
	p dist pair: 	17.380887985229492
	push
	Executing push ...
	push time: 	29.181363582611084
	test


46it [00:16,  2.71it/s]


	time: 	17.20949077606201
	cross ent: 	11.06087589263916
	cluster: 	1.1135991261057232
	separation:	0.021658909993003243
	avg separation:	16.433472052864406
	accu: 		38.298239558163615%
	l1: 		201000.0
	p dist pair: 	20.344970703125
	last layer
iteration: 	0
	train


1405it [08:40,  2.70it/s]


	time: 	520.8059225082397
	cross ent: 	0.07533810576762166
	cluster: 	0.020323812925783766
	separation:	0.02246109310975066
	avg separation:	17.63270058275542
	accu: 		98.57079301523746%
	l1: 		154039.453125
	p dist pair: 	20.344970703125
	test


46it [00:17,  2.68it/s]


	time: 	17.35026979446411
	cross ent: 	2.5032392753207167
	cluster: 	1.0986662500578424
	separation:	0.02142780225562013
	avg separation:	16.379355865976084
	accu: 		71.21159820503969%
	l1: 		154039.453125
	p dist pair: 	20.344970703125
Test accuracy:  0.712115982050397
iteration: 	1
	train


1405it [08:40,  2.70it/s]


	time: 	520.7200820446014
	cross ent: 	0.01168931233005812
	cluster: 	0.020427018357606547
	separation:	0.02245367330323335
	avg separation:	17.633374923488848
	accu: 		99.92881770659548%
	l1: 		104364.703125
	p dist pair: 	20.344970703125
	test


46it [00:17,  2.66it/s]


	time: 	17.503476858139038
	cross ent: 	2.357899529778439
	cluster: 	1.1054347274096117
	separation:	0.021602307810731556
	avg separation:	16.409164926280145
	accu: 		69.91715567828788%
	l1: 		104364.703125
	p dist pair: 	20.344970703125
Test accuracy:  0.6991715567828788
iteration: 	2
	train


1405it [08:40,  2.70it/s]


	time: 	520.507354259491
	cross ent: 	0.017736146645703322
	cluster: 	0.02036053176929518
	separation:	0.022470266708699834
	avg separation:	17.632692551528006
	accu: 		99.8387276165054%
	l1: 		53887.49609375
	p dist pair: 	20.344970703125
	test


46it [00:17,  2.70it/s]


	time: 	17.28618049621582
	cross ent: 	2.4691255597964576
	cluster: 	1.108552991696026
	separation:	0.021838139094736263
	avg separation:	16.42232888677846
	accu: 		64.03175698998965%
	l1: 		53887.49609375
	p dist pair: 	20.344970703125
Test accuracy:  0.6403175698998964
iteration: 	3
	train


1405it [08:41,  2.69it/s]


	time: 	521.8863065242767
	cross ent: 	0.01789360020592543
	cluster: 	0.020412542800886353
	separation:	0.022450176078918988
	avg separation:	17.63320653243421
	accu: 		99.83761539317095%
	l1: 		19727.0
	p dist pair: 	20.344970703125
	test


46it [00:17,  2.63it/s]


	time: 	17.58963131904602
	cross ent: 	2.653589769549992
	cluster: 	1.1095303996749546
	separation:	0.02155235600050377
	avg separation:	16.411412550055463
	accu: 		61.59820503969624%
	l1: 		19727.0
	p dist pair: 	20.344970703125
Test accuracy:  0.6159820503969624
iteration: 	4
	train


1405it [08:40,  2.70it/s]


	time: 	521.07732462883
	cross ent: 	0.010810019342237884
	cluster: 	0.020412149020467365
	separation:	0.022455505549483452
	avg separation:	17.632981928217877
	accu: 		99.89322655989322%
	l1: 		14981.8759765625
	p dist pair: 	20.344970703125
	test


46it [00:17,  2.68it/s]


	time: 	17.391732215881348
	cross ent: 	2.672717485738837
	cluster: 	1.0980235409477483
	separation:	0.021191646749882595
	avg separation:	16.35634708404541
	accu: 		62.01242664825681%
	l1: 		14981.8759765625
	p dist pair: 	20.344970703125
Test accuracy:  0.6201242664825681
iteration: 	5
	train


1405it [08:45,  2.67it/s]


	time: 	525.8438637256622
	cross ent: 	0.01083962180017292
	cluster: 	0.020457200712633728
	separation:	0.022495614407166468
	avg separation:	17.6328410355646
	accu: 		99.89155822489157%
	l1: 		12991.3212890625
	p dist pair: 	20.344970703125
	test


46it [00:17,  2.65it/s]


	time: 	17.538774728775024
	cross ent: 	2.6001186526339986
	cluster: 	1.108976864296457
	separation:	0.021417711335031883
	avg separation:	16.418416624483854
	accu: 		64.46323783224025%
	l1: 		12991.3212890625
	p dist pair: 	20.344970703125
Test accuracy:  0.6446323783224025
iteration: 	6
	train


1405it [08:41,  2.70it/s]


	time: 	521.4331614971161
	cross ent: 	0.012458923517819992
	cluster: 	0.02040315902397514
	separation:	0.02247177708456525
	avg separation:	17.633507437383578
	accu: 		99.88710933155377%
	l1: 		11489.009765625
	p dist pair: 	20.344970703125
	test


46it [00:17,  2.69it/s]


	time: 	17.263211011886597
	cross ent: 	2.597791884256446
	cluster: 	1.1124284069823183
	separation:	0.0219725603642671
	avg separation:	16.43863114066746
	accu: 		67.19019675526407%
	l1: 		11489.009765625
	p dist pair: 	20.344970703125
Test accuracy:  0.6719019675526406
iteration: 	7
	train


1405it [08:40,  2.70it/s]


	time: 	520.4599702358246
	cross ent: 	0.014061652977794062
	cluster: 	0.020400022146593633
	separation:	0.02246989172309955
	avg separation:	17.632274339886322
	accu: 		99.88377266155044%
	l1: 		10302.90234375
	p dist pair: 	20.344970703125
	test


46it [00:17,  2.68it/s]


	time: 	17.434003353118896
	cross ent: 	2.59794549708781
	cluster: 	1.1018755361437798
	separation:	0.02129922612853672
	avg separation:	16.391666826994523
	accu: 		68.32930617880567%
	l1: 		10302.90234375
	p dist pair: 	20.344970703125
Test accuracy:  0.6832930617880566
	test


46it [00:17,  2.68it/s]


	time: 	17.41022825241089
	cross ent: 	2.59794549708781
	cluster: 	1.1018755361437798
	separation:	0.02129922612853672
	avg separation:	16.391666826994523
	accu: 		68.32930617880567%
	l1: 		10302.90234375
	p dist pair: 	20.344970703125
Test accuracy:  0.6832930617880566
epoch: 	21
	joint
	train


1405it [15:35,  1.50it/s]


	time: 	935.5695149898529
	cross ent: 	0.008283310331928752
	cluster: 	0.04282246598728611
	separation:	10.15593119473644
	avg separation:	22.177039144387024
	accu: 		99.97942386831275%
	l1: 		10302.90234375
	p dist pair: 	22.479202270507812
	test


46it [00:17,  2.63it/s]


	time: 	17.56853723526001
	cross ent: 	1.7125517058631647
	cluster: 	2.3748333065406135
	separation:	10.203246982201286
	avg separation:	23.16922079998514
	accu: 		78.40869865377978%
	l1: 		10302.90234375
	p dist pair: 	22.479202270507812
Test accuracy:  0.7840869865377977
	above 76.00%
epoch: 	22
	joint
	train


1405it [15:25,  1.52it/s]


	time: 	925.4764578342438
	cross ent: 	0.016585131929491338
	cluster: 	0.03435154660104433
	separation:	15.75455356679353
	avg separation:	27.557077209635562
	accu: 		99.99332665999333%
	l1: 		10302.90234375
	p dist pair: 	26.365982055664062
	test


46it [00:17,  2.65it/s]


	time: 	17.565930366516113
	cross ent: 	1.753012239933014
	cluster: 	2.9043262626813804
	separation:	12.835137673046278
	avg separation:	27.461534126945164
	accu: 		78.7193648602002%
	l1: 		10302.90234375
	p dist pair: 	26.365982055664062
Test accuracy:  0.787193648602002
	above 76.00%
epoch: 	23
	joint
	train


1405it [15:24,  1.52it/s]


	time: 	924.4130697250366
	cross ent: 	0.019660185280113457
	cluster: 	0.030041946213546597
	separation:	19.033275571667
	avg separation:	31.553987108006595
	accu: 		99.98998998999%
	l1: 		10302.90234375
	p dist pair: 	30.1768798828125
	test


46it [00:17,  2.69it/s]


	time: 	17.304030418395996
	cross ent: 	1.8384615843710692
	cluster: 	3.532467546670333
	separation:	15.261504473893538
	avg separation:	31.130150380341902
	accu: 		78.51225405591991%
	l1: 		10302.90234375
	p dist pair: 	30.1768798828125
Test accuracy:  0.7851225405591992
	above 76.00%
epoch: 	24
	joint
	train


1405it [15:27,  1.52it/s]


	time: 	927.578439950943
	cross ent: 	0.020750985849396827
	cluster: 	0.027220674960797792
	separation:	21.984906144328814
	avg separation:	34.70466137543268
	accu: 		99.99332665999333%
	l1: 		10302.90234375
	p dist pair: 	33.84162139892578
	test


46it [00:17,  2.61it/s]


	time: 	17.820271015167236
	cross ent: 	1.8695558478002963
	cluster: 	4.0299888771513235
	separation:	17.526421712792438
	avg separation:	34.06955038982889
	accu: 		78.52951328960994%
	l1: 		10302.90234375
	p dist pair: 	33.84162139892578
Test accuracy:  0.7852951328960994
	above 76.00%
epoch: 	25
	joint
	train


1405it [15:27,  1.51it/s]


	time: 	928.0487377643585
	cross ent: 	0.01799282976475899
	cluster: 	0.020991196190388178
	separation:	24.240046093217842
	avg separation:	36.67074537989932
	accu: 		99.99666332999666%
	l1: 		10302.90234375
	p dist pair: 	35.664306640625
	test


46it [00:17,  2.69it/s]


	time: 	17.409372806549072
	cross ent: 	1.9463025810925856
	cluster: 	4.487239733986232
	separation:	18.759257907452792
	avg separation:	35.53036407802416
	accu: 		78.58129099068%
	l1: 		10302.90234375
	p dist pair: 	35.664306640625
Test accuracy:  0.7858129099068001
	above 76.00%
epoch: 	26
	joint
	train


1405it [15:29,  1.51it/s]


	time: 	929.8788900375366
	cross ent: 	0.01822818226422916
	cluster: 	0.01943254183805498
	separation:	25.666421083877943
	avg separation:	37.813054847038515
	accu: 		99.99721944166389%
	l1: 		10302.90234375
	p dist pair: 	37.45183181762695
	test


46it [00:17,  2.68it/s]


	time: 	17.385366678237915
	cross ent: 	1.9813822948414346
	cluster: 	4.820735039918319
	separation:	19.920073529948358
	avg separation:	36.80878394582997
	accu: 		78.82292026234036%
	l1: 		10302.90234375
	p dist pair: 	37.45183181762695
Test accuracy:  0.7882292026234036
	above 76.00%
epoch: 	27
	joint
	train


1353it [14:56,  1.51it/s]


KeyboardInterrupt: 