In [6]:
import os
import shutil
from collections import Counter
import numpy as np
import torch
import torch.utils.data
import pandas as pd
import argparse
import re
from helpers import makedir
import model
import save
from log import create_logger
import train_and_test as tnt
from helpers import makedir
import find_nearest
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import push

from preprocess import mean, std, preprocess_input_function
from settings import train_dir, test_dir, train_push_dir
from settings import img_size, num_classes, prototype_activation_function, add_on_layers_type

from bounding_box_metrics import bounding_box_overlap
from find_nearest import find_k_nearest_patches_to_prototypes
from settings import coefs



In [2]:
'''
Model initialization for vgg base; do not run this block when creating other models 
'''

base_architecture = 'vgg19'
prototype_shape = (2000, 128, 1, 1)

ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)

ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
ppnet_multi = torch.nn.DataParallel(ppnet).cuda()
class_specific = True

warm_optimizer_lrs = {'add_on_layers': 1e-3,
                      'prototype_vectors': 1e-3}
warm_optimizer_specs = \
    [{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
     {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']},
    ]
warm_optimizer = torch.optim.Adam(warm_optimizer_specs)

joint_optimizer_lrs = {'features': 1e-4,
                       'add_on_layers': 1e-3,
                       'prototype_vectors': 1e-3}
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)

joint_lr_step_size = 5
gamma = 0.5
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.5)

last_layer_optimizer_lr = 1e-4
last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)

train_batch_size, test_batch_size, train_push_batch_size = 128, 128, 128
push_epochs = [20, 40, 60, 99]
push_saved_epochs = [20, 40, 60, 99]

model_dir = './ppnet_results/007_vgg19/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))

In [17]:
'''
Model initialization for resnet base; do not run this block when creating other models 
'''

base_architecture = 'resnet34'
prototype_shape = (2000, 256, 1, 1)

ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)

ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
ppnet_multi = torch.nn.DataParallel(ppnet).cuda()
class_specific = True

warm_optimizer_lrs = {'add_on_layers': 1e-3,
                      'prototype_vectors': 1e-3}
warm_optimizer_specs = \
    [{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
     {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']},
    ]
warm_optimizer = torch.optim.Adam(warm_optimizer_specs)

joint_optimizer_lrs = {'features': 1e-4,
                       'add_on_layers': 1e-3,
                       'prototype_vectors': 1e-3}
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)

joint_lr_step_size = 5
gamma = 0.5
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.5)

last_layer_optimizer_lr = 1e-4
last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)

train_batch_size, test_batch_size, train_push_batch_size = 128, 128, 128
push_epochs = [20, 40, 60, 99]
push_saved_epochs = [20, 40, 60, 99]

model_dir = './ppnet_results/005_resnet34/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))


In [11]:
'''
Model initialization for densenet; do not run this block when creating other models 
'''

base_architecture = 'densenet121'
prototype_shape = (2000, 128, 1, 1)

ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)

ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
ppnet_multi = torch.nn.DataParallel(ppnet).cuda()
class_specific = True

warm_optimizer_lrs = {'add_on_layers': 1e-3,
                      'prototype_vectors': 1e-3}
warm_optimizer_specs = \
    [{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
     {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']},
    ]
warm_optimizer = torch.optim.Adam(warm_optimizer_specs)

joint_optimizer_lrs = {'features': 1e-4,
                       'add_on_layers': 1e-3,
                       'prototype_vectors': 1e-3}
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)

joint_lr_step_size = 5
gamma = 0.5
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.5)

last_layer_optimizer_lr = 1e-4
last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)

train_batch_size, test_batch_size, train_push_batch_size = 128, 128, 128
push_epochs = [20, 40, 60, 99]
push_saved_epochs = [20, 40, 60, 99]

model_dir = './ppnet_results/006_densenet121/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))


In [3]:
normalize = transforms.Normalize(mean=mean,
                                 std=std)
train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True,
    num_workers=2, pin_memory=False)

# push set
train_push_dataset = datasets.ImageFolder(
    train_push_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
]))

train_push_loader = torch.utils.data.DataLoader(
    train_push_dataset, batch_size=train_push_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)

# test set
test_dataset = datasets.ImageFolder(
    test_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)

In [None]:
for epoch in range(20, 100):
    print('epoch: \t{0}'.format(epoch))

    #if epoch in range(0, 10):
    #    tnt.joint_warm(model=ppnet_multi, log=log)
    #else:
    
    if epoch in range(0, 5):
        tnt.warm_only(model=ppnet_multi, log=log)
        _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=warm_optimizer,
                  class_specific=class_specific, coefs=coefs, log=log)
        
    else:
        tnt.joint(model=ppnet_multi, log=log)
        _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=joint_optimizer,
                  class_specific=class_specific, coefs=coefs, log=log)
        joint_lr_scheduler.step()

#     accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
#                     class_specific=class_specific, log=log)
#     save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'nopush', accu=accu,
#                                 target_accu=0.70, log=log)

    if epoch in push_epochs:
        # Only save the model in in push_saved_epochs
        if epoch in push_saved_epochs:
            bounding_box_tracker = push.push_prototypes(
                train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
                prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors
                class_specific=class_specific,
                preprocess_input_function=preprocess_input_function, # normalize if needed
                prototype_layer_stride=1,
                root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten
                prototype_img_filename_prefix=None,
                prototype_self_act_filename_prefix=None,
                proto_bound_boxes_filename_prefix=None,
                save_prototype_class_identity=True,
                log=log,
                bounding_box_tracker=None)
            accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                            class_specific=class_specific, log=log)
            save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu,
                                        target_accu=0.72, log=log)

           

        if prototype_activation_function != 'linear':
            tnt.last_only(model=ppnet_multi, log=log)
            # Fine tune the last layers
            for i in range(8):
                log('iteration: \t{0}'.format(i))
                _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=last_layer_optimizer,
                              class_specific=class_specific, coefs=coefs, log=log)
                accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, log=log)
                print("Test accuracy: ", accu)
                save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=accu,
                                target_accu=0.72, log=log)
            
            save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=accu,
                                target_accu=0.60, log=log)
            # Save the last with the final layer fine-tuned
        
    accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, log=log)
    print("Test accuracy: ", accu)
    save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch), accu=accu,
                                target_accu=0.72, log=log)

epoch: 	20
	joint
	train


1405it [12:10,  1.92it/s]


	time: 	731.0269372463226
	cross ent: 	0.01065046897030529
	cluster: 	0.007841696317417024
	separation:	0.028849512831808408
	avg separation:	2.6800737301225763
	accu: 		99.77699922144366%
	l1: 		201000.0
	p dist pair: 	5.810579776763916
	push
	Executing push ...
	push time: 	37.34978795051575
	test


46it [00:17,  2.60it/s]


	time: 	17.887080192565918
	cross ent: 	2.5526346927103787
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		61.9951674145668%
	l1: 		201000.0
	p dist pair: 	1.227476716041565
	last layer
iteration: 	0
	train


1405it [08:31,  2.75it/s]


	time: 	511.83247327804565
	cross ent: 	0.014254061758080729
	cluster: 	0.0056659161706464995
	separation:	0.016999168599892766
	avg separation:	0.4899969080801112
	accu: 		99.63574685796908%
	l1: 		190433.546875
	p dist pair: 	1.227476716041565
	test


46it [00:17,  2.61it/s]


	time: 	17.80571150779724
	cross ent: 	1.6779367535010627
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		70.46945115636866%
	l1: 		190433.546875
	p dist pair: 	1.227476716041565
Test accuracy:  0.7046945115636866
iteration: 	1
	train


1405it [08:32,  2.74it/s]


	time: 	512.2095513343811
	cross ent: 	0.01312662246529331
	cluster: 	0.005665772756581523
	separation:	0.01699928095120128
	avg separation:	0.48999571882957244
	accu: 		99.72138805472139%
	l1: 		178743.828125
	p dist pair: 	1.227476716041565
	test


46it [00:17,  2.56it/s]


	time: 	18.25442099571228
	cross ent: 	1.69286937687708
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		69.69278564031757%
	l1: 		178743.828125
	p dist pair: 	1.227476716041565
Test accuracy:  0.6969278564031757
iteration: 	2
	train


1405it [09:10,  2.55it/s]


	time: 	550.7097837924957
	cross ent: 	0.014655202777734775
	cluster: 	0.00566584013097545
	separation:	0.016999275828595686
	avg separation:	0.4899966882429089
	accu: 		99.73417862306752%
	l1: 		167131.71875
	p dist pair: 	1.227476716041565
	test


46it [00:19,  2.42it/s]


	time: 	19.385153770446777
	cross ent: 	1.6867472192515498
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		69.72730410769762%
	l1: 		167131.71875
	p dist pair: 	1.227476716041565
Test accuracy:  0.6972730410769762
iteration: 	3
	train


1405it [09:13,  2.54it/s]


	time: 	553.6571452617645
	cross ent: 	0.017434012177615723
	cluster: 	0.005665822116677977
	separation:	0.016999104398955654
	avg separation:	0.4899963152578293
	accu: 		99.71805138471805%
	l1: 		156200.765625
	p dist pair: 	1.227476716041565
	test


46it [00:17,  2.63it/s]


	time: 	17.735252857208252
	cross ent: 	1.722397911807765
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		68.89886089057646%
	l1: 		156200.765625
	p dist pair: 	1.227476716041565
Test accuracy:  0.6889886089057646
iteration: 	4
	train


1405it [09:25,  2.49it/s]


	time: 	565.5263206958771
	cross ent: 	0.020450494058837854
	cluster: 	0.005665848427037858
	separation:	0.016999189953362814
	avg separation:	0.4899963175911072
	accu: 		99.71082193304416%
	l1: 		146103.5
	p dist pair: 	1.227476716041565
	test


46it [00:18,  2.54it/s]


	time: 	18.374956130981445
	cross ent: 	1.761491931003073
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		68.74352778736625%
	l1: 		146103.5
	p dist pair: 	1.227476716041565
Test accuracy:  0.6874352778736624
iteration: 	5
	train


1405it [08:41,  2.69it/s]


	time: 	521.982168674469
	cross ent: 	0.02229824583485285
	cluster: 	0.005665862403173676
	separation:	0.016999166604675008
	avg separation:	0.4899963261818122
	accu: 		99.6924702480258%
	l1: 		136789.71875
	p dist pair: 	1.227476716041565
	test


46it [00:17,  2.64it/s]


	time: 	17.723021030426025
	cross ent: 	1.761248124682385
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		68.58819468415602%
	l1: 		136789.71875
	p dist pair: 	1.227476716041565
Test accuracy:  0.6858819468415602
iteration: 	6
	train


1405it [09:00,  2.60it/s]


	time: 	540.605010509491
	cross ent: 	0.024016307264539908
	cluster: 	0.005665794321836314
	separation:	0.016999026333575147
	avg separation:	0.48999664724085257
	accu: 		99.67856745634523%
	l1: 		128021.25
	p dist pair: 	1.227476716041565
	test


46it [00:17,  2.65it/s]


	time: 	17.61374855041504
	cross ent: 	1.7647510067276333
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		67.91508457024507%
	l1: 		128021.25
	p dist pair: 	1.227476716041565
Test accuracy:  0.6791508457024508
iteration: 	7
	train


1405it [08:40,  2.70it/s]


	time: 	520.4415755271912
	cross ent: 	0.024629724375718004
	cluster: 	0.005665851899445905
	separation:	0.016999194721999542
	avg separation:	0.48999574977732213
	accu: 		99.70136803470136%
	l1: 		119617.9453125
	p dist pair: 	1.227476716041565
	test


46it [00:17,  2.60it/s]


	time: 	17.907306432724
	cross ent: 	1.7587055836034857
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		67.69071453227477%
	l1: 		119617.9453125
	p dist pair: 	1.227476716041565
Test accuracy:  0.6769071453227476
	above 60.00%
	test


46it [00:18,  2.49it/s]


	time: 	18.63805603981018
	cross ent: 	1.7587055836034857
	cluster: 	0.022421653615310788
	separation:	0.015334570193258316
	avg separation:	0.4866242739169494
	accu: 		67.69071453227477%
	l1: 		119617.9453125
	p dist pair: 	1.227476716041565
Test accuracy:  0.6769071453227476
epoch: 	21
	joint
	train


1405it [12:40,  1.85it/s]


	time: 	760.7598741054535
	cross ent: 	0.020514712565714407
	cluster: 	0.007363760205535914
	separation:	0.02196989439799056
	avg separation:	0.5143205379039791
	accu: 		99.6385274163052%
	l1: 		119617.9453125
	p dist pair: 	1.2727779150009155
	test


46it [00:16,  2.71it/s]


	time: 	17.252188444137573
	cross ent: 	1.5993035992850428
	cluster: 	0.02432983724967293
	separation:	0.02003438151239053
	avg separation:	0.5151119575552319
	accu: 		70.38315498791854%
	l1: 		119617.9453125
	p dist pair: 	1.2727779150009155
Test accuracy:  0.7038315498791854
epoch: 	22
	joint
	train


1405it [12:49,  1.83it/s]


	time: 	769.386013507843
	cross ent: 	0.012061042207839245
	cluster: 	0.006724150413216433
	separation:	0.022818939330792088
	avg separation:	0.514905491672801
	accu: 		99.79757535313091%
	l1: 		119617.9453125
	p dist pair: 	1.2713173627853394
	test


46it [00:17,  2.66it/s]


	time: 	17.50869607925415
	cross ent: 	1.541394910086756
	cluster: 	0.024243998644954485
	separation:	0.021119763629268044
	avg separation:	0.509473873221356
	accu: 		70.19330341732827%
	l1: 		119617.9453125
	p dist pair: 	1.2713173627853394
Test accuracy:  0.7019330341732827
epoch: 	23
	joint
	train


1405it [12:53,  1.82it/s]


	time: 	773.5918669700623
	cross ent: 	0.010428066860330938
	cluster: 	0.006468707886507927
	separation:	0.023036045847573315
	avg separation:	0.5153071215568488
	accu: 		99.81147814481149%
	l1: 		119617.9453125
	p dist pair: 	1.269964337348938
	test


46it [00:16,  2.77it/s]


	time: 	16.86357617378235
	cross ent: 	1.592056541339211
	cluster: 	0.023552046704065542
	separation:	0.020023153406446396
	avg separation:	0.5171437367149021
	accu: 		69.91715567828788%
	l1: 		119617.9453125
	p dist pair: 	1.269964337348938
Test accuracy:  0.6991715567828788
epoch: 	24
	joint
	train


1405it [12:51,  1.82it/s]


	time: 	771.9182658195496
	cross ent: 	0.0030058866722825783
	cluster: 	0.005213481776111385
	separation:	0.023288502286973798
	avg separation:	0.5145283875092068
	accu: 		99.96274051829607%
	l1: 		119617.9453125
	p dist pair: 	1.266046166419983
	test


46it [00:18,  2.44it/s]


	time: 	19.081137895584106
	cross ent: 	1.4726302831069282
	cluster: 	0.021235314248453662
	separation:	0.019841372035443783
	avg separation:	0.5106389490158662
	accu: 		71.74663444943045%
	l1: 		119617.9453125
	p dist pair: 	1.266046166419983
Test accuracy:  0.7174663444943045
epoch: 	25
	joint
	train


1405it [12:57,  1.81it/s]


	time: 	777.4048755168915
	cross ent: 	0.0024745123485073006
	cluster: 	0.004558156453258627
	separation:	0.02394903534194753
	avg separation:	0.5131698849786643
	accu: 		99.96996996996998%
	l1: 		119617.9453125
	p dist pair: 	1.2623423337936401
	test


46it [00:19,  2.40it/s]


	time: 	19.309679746627808
	cross ent: 	1.5028188785781031
	cluster: 	0.021781712082093178
	separation:	0.02038401984812125
	avg separation:	0.5092129189035167
	accu: 		71.33241284086986%
	l1: 		119617.9453125
	p dist pair: 	1.2623423337936401
Test accuracy:  0.7133241284086986
epoch: 	26
	joint
	train


1405it [12:56,  1.81it/s]


	time: 	776.9197595119476
	cross ent: 	0.004710202839692021
	cluster: 	0.00502723846054268
	separation:	0.023896078111809344
	avg separation:	0.511698853418072
	accu: 		99.93382271160048%
	l1: 		119617.9453125
	p dist pair: 	1.25922691822052
	test


46it [00:18,  2.46it/s]


	time: 	18.941793203353882
	cross ent: 	1.5337617682374043
	cluster: 	0.02245438495731872
	separation:	0.02059364691376686
	avg separation:	0.5097284848275392
	accu: 		70.34863652053849%
	l1: 		119617.9453125
	p dist pair: 	1.25922691822052
Test accuracy:  0.7034863652053849
epoch: 	27
	joint
	train


1405it [12:57,  1.81it/s]


	time: 	777.8798303604126
	cross ent: 	0.005228199890741807
	cluster: 	0.0048925696776059594
	separation:	0.023805537521573997
	avg separation:	0.5104322350534256
	accu: 		99.91825158491825%
	l1: 		119617.9453125
	p dist pair: 	1.258913516998291
	test


46it [00:17,  2.68it/s]


	time: 	17.29952645301819
	cross ent: 	1.5018197272134863
	cluster: 	0.021338048269567284
	separation:	0.02035822885353928
	avg separation:	0.5096861674733784
	accu: 		70.7801173627891%
	l1: 		119617.9453125
	p dist pair: 	1.258913516998291
Test accuracy:  0.707801173627891
epoch: 	28
	joint
	train


1405it [12:56,  1.81it/s]


	time: 	777.1735219955444
	cross ent: 	0.005728962366532177
	cluster: 	0.005215982739147448
	separation:	0.023774554490300685
	avg separation:	0.5098204904179556
	accu: 		99.90824157490825%
	l1: 		119617.9453125
	p dist pair: 	1.254766583442688
	test


46it [00:17,  2.70it/s]


	time: 	17.336747646331787
	cross ent: 	1.4870114274646924
	cluster: 	0.020855757660920852
	separation:	0.020043581643182297
	avg separation:	0.5054347158774085
	accu: 		71.26337590610977%
	l1: 		119617.9453125
	p dist pair: 	1.254766583442688
Test accuracy:  0.7126337590610977
epoch: 	29
	joint
	train


814it [07:27,  1.88it/s]