In [20]:
import os
import shutil

import torch
import torch.utils.data
import pandas as pd
# import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import argparse
import re

from helpers import makedir
import model
import push
import prune
import train_and_test as tnt
import save
from log import create_logger
from preprocess import mean, std, preprocess_input_function

from bounding_box_metrics import bounding_box_overlap
from find_nearest import find_k_nearest_patches_to_prototypes
from settings import coefs
from settings import num_train_epochs, num_warm_epochs, push_start, push_epochs, push_saved_epochs

In [25]:
from settings import train_dir, test_dir, train_push_dir
from settings import last_layer_optimizer_lr
last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)
#train_batch_size, test_batch_size, train_push_batch_size = 80, 80, 80
push_epochs = [50, 70, 100]
push_saved_epochs = [50, 70, 100]
normalize = transforms.Normalize(mean=mean,
                                 std=std)
train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True,
    num_workers=1, pin_memory=False)

# push set
train_push_dataset = datasets.ImageFolder(
    train_push_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
]))

train_push_loader = torch.utils.data.DataLoader(
    train_push_dataset, batch_size=train_push_batch_size, shuffle=False,
    num_workers=1, pin_memory=False)

# test set
test_dataset = datasets.ImageFolder(
    test_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=1, pin_memory=False)

In [23]:
ppnet = torch.load(r'./A3C_results/joint_004_pruned_10_2_0.7262')
ppnet_multi = torch.nn.DataParallel(ppnet).cuda()
#ppnet = torch.nn.DataParallel(ppnet)
class_specific = True

In [24]:
from settings import joint_optimizer_lrs, joint_lr_step_size
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': 1e-3, 'weight_decay': 1e-4}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': 3e-3, 'weight_decay': 1e-4},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': 3e-3, 'weight_decay': 1e-4},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=5, gamma=0.5)
train_batch_size, test_batch_size, train_push_batch_size = 80, 80, 80

In [17]:
'''
Start from scratch
'''
from settings import joint_optimizer_lrs, joint_lr_step_size
from settings import base_architecture, img_size, prototype_shape, num_classes, \
                     prototype_activation_function, add_on_layers_type, experiment_run
base_architecture = 'resnet34'
prototype_shape = (2000, 256, 1, 1)
ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)
#if prototype_activation_function == 'linear':
#    ppnet.set_last_layer_incorrect_connection(incorrect_strength=0)
ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
class_specific = True
from settings import joint_optimizer_lrs, joint_lr_step_size
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.5)

In [26]:
model_dir = './joint_results/004_pruned/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))

push_epochs = [20, 50, 99]
push_saved_epochs = [20, 50, 99]

In [27]:
for epoch in range(0, 60):
    print('epoch: \t{0}'.format(epoch))

    #if epoch in range(0, 10):
    #    tnt.joint_warm(model=ppnet_multi, log=log)
    #else:
    
    if epoch in range(0, 5):
        tnt.warm_only(model=ppnet_multi, log=log)
    else:
        tnt.joint(model=ppnet_multi, log=log)
    
    _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=joint_optimizer,
                  class_specific=class_specific, coefs=coefs, log=log)
    joint_lr_scheduler.step()

#     accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
#                     class_specific=class_specific, log=log)
#     save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'nopush', accu=accu,
#                                 target_accu=0.70, log=log)

    if epoch >= push_start and epoch in push_epochs:
        # Only save the model in in push_saved_epochs
        if epoch in push_saved_epochs:
            bounding_box_tracker = push.push_prototypes(
                train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
                prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors
                class_specific=class_specific,
                preprocess_input_function=preprocess_input_function, # normalize if needed
                prototype_layer_stride=1,
                root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten
                prototype_img_filename_prefix=None,
                prototype_self_act_filename_prefix=None,
                proto_bound_boxes_filename_prefix=None,
                save_prototype_class_identity=True,
                log=log,
                bounding_box_tracker=None)
            accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                            class_specific=class_specific, log=log)
            save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu,
                                        target_accu=0.74, log=log)

            # updated bounding_box_tracker is used  
            #overlap = bounding_box_overlap(bounding_box_tracker, prototype_shape[0], num_classes)
            #overlaps["epoch"].append(epoch)
            #overlaps["overlap"].append(overlap)   

            #pd.DataFrame(overlaps).to_csv(model_dir + "overlap_" + str(epoch) + ".csv")

        else:
            bounding_box_tracker = push.push_prototypes(
                train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
                prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors
                class_specific=class_specific,
                preprocess_input_function=preprocess_input_function, # normalize if needed
                prototype_layer_stride=1,
                root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten
                prototype_img_filename_prefix=None,
                prototype_self_act_filename_prefix=None,
                proto_bound_boxes_filename_prefix=None,
                save_prototype_class_identity=True,
                log=log,
                bounding_box_tracker=None)
#             accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
#                             class_specific=class_specific, log=log)
#             save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu,
#                                         target_accu=0.70, log=log)

            #overlap = bounding_box_overlap(bounding_box_tracker, prototype_shape[0], num_classes)
            #overlaps["epoch"].append(epoch)
            #overlaps["overlap"].append(overlap)

        if prototype_activation_function != 'linear':
            tnt.last_only(model=ppnet_multi, log=log)
            # Fine tune the last layers
            for i in range(8):
                log('iteration: \t{0}'.format(i))
                _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=last_layer_optimizer,
                              class_specific=class_specific, coefs=coefs, log=log)
                accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, log=log)
                print("Test accuracy: ", accu)
                save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=accu,
                                target_accu=0.74, log=log)
            # Save the last with the final layer fine-tuned
        
    accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, log=log)
    print("Test accuracy: ", accu)
    save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch), accu=accu,
                                target_accu=0.74, log=log)

epoch: 	0
	warm
	train


1405it [46:50,  2.00s/it]


	time: 	2810.54070019722
	cross ent: 	0.009459355266021525
	cluster: 	0.0085663544610673
	separation:	0.10174093587967001
	avg separation:	1.211847614987465
	accu: 		99.87209431653876%
	l1: 		1154.874755859375
	p dist pair: 	2.073502779006958
	test


46it [00:40,  1.12it/s]


	time: 	41.017781496047974
	cross ent: 	1.278269174306289
	cluster: 	0.09025181101068207
	separation:	0.09849840807526009
	avg separation:	1.2947259985882302
	accu: 		72.10907835692095%
	l1: 		1154.874755859375
	p dist pair: 	2.073502779006958
Test accuracy:  0.7210907835692095
epoch: 	1
	warm
	train


1405it [21:32,  1.09it/s]


	time: 	1292.314367055893
	cross ent: 	0.010019275824724965
	cluster: 	0.009936128923291626
	separation:	0.11783326220469967
	avg separation:	1.3835304765938865
	accu: 		99.93882771660549%
	l1: 		1154.874755859375
	p dist pair: 	2.2729530334472656
	test


46it [00:41,  1.12it/s]


	time: 	41.22726225852966
	cross ent: 	1.3041704711706743
	cluster: 	0.09255143279290717
	separation:	0.10125618282219638
	avg separation:	1.5050663948059082
	accu: 		71.29789437348981%
	l1: 		1154.874755859375
	p dist pair: 	2.2729530334472656
Test accuracy:  0.7129789437348981
epoch: 	2
	warm
	train


1405it [21:30,  1.09it/s]


	time: 	1290.8692982196808
	cross ent: 	0.012403770076985938
	cluster: 	0.009917787496941794
	separation:	0.12100571819259603
	avg separation:	1.6611461922791504
	accu: 		99.93771549327106%
	l1: 		1154.874755859375
	p dist pair: 	2.7793469429016113
	test


46it [00:41,  1.11it/s]


	time: 	41.419177770614624
	cross ent: 	1.3147545858569767
	cluster: 	0.09385850264326386
	separation:	0.10259575558745343
	avg separation:	1.8790126131928486
	accu: 		71.12530203658957%
	l1: 		1154.874755859375
	p dist pair: 	2.7793469429016113
Test accuracy:  0.7112530203658958
epoch: 	3
	warm
	train


1405it [21:33,  1.09it/s]


	time: 	1293.34783244133
	cross ent: 	0.013945816727388266
	cluster: 	0.00965549912555574
	separation:	0.12219147120082081
	avg separation:	2.098136682612192
	accu: 		99.93382271160048%
	l1: 		1154.874755859375
	p dist pair: 	3.5718705654144287
	test


46it [00:41,  1.11it/s]


	time: 	41.494428396224976
	cross ent: 	1.333665167507918
	cluster: 	0.09484866998441842
	separation:	0.10328083760712457
	avg separation:	2.4607459410377173
	accu: 		70.69382119433897%
	l1: 		1154.874755859375
	p dist pair: 	3.5718705654144287
Test accuracy:  0.7069382119433897
epoch: 	4
	warm
	train


1405it [21:34,  1.09it/s]


	time: 	1294.4294912815094
	cross ent: 	0.014486726229924219
	cluster: 	0.00924782876150676
	separation:	0.12227709598812768
	avg separation:	2.6826651289793944
	accu: 		99.93715938160382%
	l1: 		1154.874755859375
	p dist pair: 	4.46205997467041
	test


46it [00:42,  1.09it/s]


	time: 	42.38503456115723
	cross ent: 	1.338297135156134
	cluster: 	0.09534132869347282
	separation:	0.10340326668127724
	avg separation:	3.1150668807651685
	accu: 		70.71108042802899%
	l1: 		1154.874755859375
	p dist pair: 	4.46205997467041
Test accuracy:  0.7071108042802899
epoch: 	5
	joint
	train


0it [00:02, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 784.00 MiB (GPU 0; 11.90 GiB total capacity; 5.14 GiB already allocated; 175.69 MiB free; 5.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [37]:
torch.save(ppnet_multi.module, r'./joint_results/005/first_50.pth')

In [41]:
acc = tnt.test(model=ppnet_multi, dataloader=test_loader, class_specific=True, log=print)

	test


46it [00:18,  2.47it/s]


	time: 	18.917237520217896
	cross ent: 	2.556407788525457
	cluster: 	6.45002179560454
	separation:	0.5268320557863816
	avg separation:	20.61585994388746
	accu: 		60.303762512944424%
	l1: 		1331.3865966796875
	p dist pair: 	20.12684440612793
