In [15]:
import os
import shutil

import torch
import torch.utils.data
import pandas as pd
# import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import argparse
import re

from helpers import makedir
import model
import push
import prune
import train_and_test as tnt
import save
from log import create_logger
from preprocess import mean, std, preprocess_input_function

from bounding_box_metrics import bounding_box_overlap
from find_nearest import find_k_nearest_patches_to_prototypes
from settings import coefs, img_size
from settings import num_train_epochs, num_warm_epochs, push_start, push_epochs, push_saved_epochs

In [14]:
from settings import train_dir, test_dir, train_push_dir
from settings import last_layer_optimizer_lr
last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)
train_batch_size, test_batch_size, train_push_batch_size = 120, 120, 120
push_epochs = [20, 50, 70, 100]
push_saved_epochs = [20, 50, 70, 100]
normalize = transforms.Normalize(mean=mean,
                                 std=std)
train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True,
    num_workers=2, pin_memory=False)

# push set
train_push_dataset = datasets.ImageFolder(
    train_push_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
]))

train_push_loader = torch.utils.data.DataLoader(
    train_push_dataset, batch_size=train_push_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)

# test set
test_dataset = datasets.ImageFolder(
    test_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)

In [8]:
ppnet = torch.load(r'/scratch/users/jiaxun1218/saved_models/resnet34/r2_014.pth')
ppnet_multi = torch.nn.DataParallel(ppnet).cuda()
#ppnet = torch.nn.DataParallel(ppnet)
class_specific = True

In [10]:
from settings import joint_optimizer_lrs, joint_lr_step_size
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': 1e-4, 'weight_decay': 1e-4}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': 1e-4, 'weight_decay': 1e-4},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': 1e-4, 'weight_decay': 1e-4},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=3, gamma=0.1)
train_batch_size, test_batch_size, train_push_batch_size = 128, 128, 128
prototype_activation_function = 'log'

In [17]:
'''
Start from scratch
'''
from settings import joint_optimizer_lrs, joint_lr_step_size
from settings import base_architecture, img_size, prototype_shape, num_classes, \
                     prototype_activation_function, add_on_layers_type, experiment_run
base_architecture = 'resnet34'
prototype_shape = (2000, 256, 1, 1)
ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)
#if prototype_activation_function == 'linear':
#    ppnet.set_last_layer_incorrect_connection(incorrect_strength=0)
ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
class_specific = True
from settings import joint_optimizer_lrs, joint_lr_step_size
joint_optimizer_specs = \
[{'params': ppnet_multi.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet_multi.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet_multi.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.5)
prototype_activation_function = 'log'

In [11]:
model_dir = '/scratch/users/jiaxun1218/saved_models/resnet34/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))

push_epochs = [30, 50, 99]
push_saved_epochs = [30, 50, 99]

In [None]:
for epoch in range(0, 60):
    print('epoch: \t{0}'.format(epoch))

    #if epoch in range(0, 5):
    #    tnt.joint_warm(model=ppnet_multi, log=log)
    #else:
    
    #if epoch in range(0, 5):
    #    tnt.warm_only(model=ppnet_multi, log=log)
    #else:
    tnt.joint(model=ppnet_multi, log=log)
    
    _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=joint_optimizer,
                  class_specific=class_specific, coefs=coefs, log=log)
    joint_lr_scheduler.step()

#     accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
#                     class_specific=class_specific, log=log)
#     save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'nopush', accu=accu,
#                                 target_accu=0.70, log=log)

    if epoch >= push_start and epoch in push_epochs:
        # Only save the model in in push_saved_epochs
        if epoch in push_saved_epochs:
            bounding_box_tracker = push.push_prototypes(
                train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
                prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors
                class_specific=class_specific,
                preprocess_input_function=preprocess_input_function, # normalize if needed
                prototype_layer_stride=1,
                root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten
                prototype_img_filename_prefix=None,
                prototype_self_act_filename_prefix=None,
                proto_bound_boxes_filename_prefix=None,
                save_prototype_class_identity=True,
                log=log,
                bounding_box_tracker=None)
            accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                            class_specific=class_specific, log=log)
            save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu,
                                        target_accu=0.75, log=log)

            # updated bounding_box_tracker is used  
            #overlap = bounding_box_overlap(bounding_box_tracker, prototype_shape[0], num_classes)
            #overlaps["epoch"].append(epoch)
            #overlaps["overlap"].append(overlap)   

            #pd.DataFrame(overlaps).to_csv(model_dir + "overlap_" + str(epoch) + ".csv")

        else:
            bounding_box_tracker = push.push_prototypes(
                train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
                prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors
                class_specific=class_specific,
                preprocess_input_function=preprocess_input_function, # normalize if needed
                prototype_layer_stride=1,
                root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten
                prototype_img_filename_prefix=None,
                prototype_self_act_filename_prefix=None,
                proto_bound_boxes_filename_prefix=None,
                save_prototype_class_identity=True,
                log=log,
                bounding_box_tracker=None)
#             accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
#                             class_specific=class_specific, log=log)
#             save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu,
#                                         target_accu=0.70, log=log)

            #overlap = bounding_box_overlap(bounding_box_tracker, prototype_shape[0], num_classes)
            #overlaps["epoch"].append(epoch)
            #overlaps["overlap"].append(overlap)

        if prototype_activation_function != 'linear':
            tnt.last_only(model=ppnet_multi, log=log)
            # Fine tune the last layers
            for i in range(8):
                log('iteration: \t{0}'.format(i))
                _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=last_layer_optimizer,
                              class_specific=class_specific, coefs=coefs, log=log)
                accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, log=log)
                print("Test accuracy: ", accu)
                save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=accu,
                                target_accu=0.77, log=log)
            # Save the last with the final layer fine-tuned
        
    accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, log=log)
    print("Test accuracy: ", accu)
    save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch), accu=accu,
                                target_accu=0.77, log=log)

epoch: 	0
	joint
	train


1499it [08:18,  3.01it/s]


	time: 	499.2157759666443
	cross ent: 	0.09301090578494557
	cluster: 	0.0328410022287537
	separation:	0.1893728561450872
	avg separation:	4.103647775535507
	accu: 		97.55255255255256%
	l1: 		201000.0
	p dist pair: 	7.582414150238037
	test


49it [00:15,  3.17it/s]


	time: 	15.615715026855469
	cross ent: 	1.5259946499552046
	cluster: 	0.11866801422165364
	separation:	0.15300414848084354
	avg separation:	3.967268778353321
	accu: 		70.22782188470832%
	l1: 		201000.0
	p dist pair: 	7.582414150238037
Test accuracy:  0.7022782188470832
epoch: 	1
	joint
	train


1499it [08:28,  2.95it/s]


	time: 	508.9400553703308
	cross ent: 	0.05910449005917194
	cluster: 	0.025506318680410228
	separation:	0.18606568195964593
	avg separation:	3.9403770701578256
	accu: 		98.42676009342675%
	l1: 		201000.0
	p dist pair: 	7.264782905578613
	test


49it [00:15,  3.20it/s]


	time: 	15.691144466400146
	cross ent: 	1.4428664664832913
	cluster: 	0.11277881105031286
	separation:	0.1569143927523068
	avg separation:	3.808889320918492
	accu: 		72.38522609596134%
	l1: 		201000.0
	p dist pair: 	7.264782905578613
Test accuracy:  0.7238522609596134
epoch: 	2
	joint
	train


1499it [08:18,  3.01it/s]


	time: 	498.98002314567566
	cross ent: 	0.0470308025474615
	cluster: 	0.02162777987110607
	separation:	0.19038645950971722
	avg separation:	3.572617657905105
	accu: 		98.76988099210321%
	l1: 		201000.0
	p dist pair: 	7.464593887329102
	test


49it [00:15,  3.19it/s]


	time: 	15.639282941818237
	cross ent: 	1.434429293992568
	cluster: 	0.10671873763203621
	separation:	0.15109529483075046
	avg separation:	3.00745989838425
	accu: 		71.93648602002071%
	l1: 		201000.0
	p dist pair: 	7.464593887329102
Test accuracy:  0.7193648602002071
epoch: 	3
	joint
	train


1499it [08:22,  2.98it/s]


	time: 	502.66923451423645
	cross ent: 	0.005579702288379001
	cluster: 	0.012881534829551853
	separation:	0.21432932439925592
	avg separation:	3.1121532419190716
	accu: 		99.87987987987988%
	l1: 		201000.0
	p dist pair: 	7.515518665313721
	test


49it [00:15,  3.13it/s]


	time: 	15.856647729873657
	cross ent: 	1.2371226123401098
	cluster: 	0.10272008439107817
	separation:	0.18674197367259435
	avg separation:	3.1773139214029116
	accu: 		76.13047980669658%
	l1: 		201000.0
	p dist pair: 	7.515518665313721
Test accuracy:  0.7613047980669658
epoch: 	4
	joint
	train


1499it [08:27,  2.95it/s]


	time: 	507.51051330566406
	cross ent: 	0.0007719358820117544
	cluster: 	0.009478704108259553
	separation:	0.2669762259845021
	avg separation:	3.38395884643005
	accu: 		99.994994994995%
	l1: 		201000.0
	p dist pair: 	7.5808210372924805
	test


49it [00:15,  3.16it/s]


	time: 	15.738267660140991
	cross ent: 	1.2639934378010886
	cluster: 	0.12754193092791402
	separation:	0.239099633632874
	avg separation:	3.606801378483675
	accu: 		76.45840524680703%
	l1: 		201000.0
	p dist pair: 	7.5808210372924805
Test accuracy:  0.7645840524680704
epoch: 	5
	joint
	train


1499it [08:30,  2.94it/s]


	time: 	510.78336215019226
	cross ent: 	0.0005103944300539275
	cluster: 	0.007808296216789725
	separation:	0.3311720712730454
	avg separation:	3.8892710870548117
	accu: 		99.99721944166389%
	l1: 		201000.0
	p dist pair: 	7.619817733764648
	test


49it [00:15,  3.17it/s]


	time: 	15.75843358039856
	cross ent: 	1.3494512402281469
	cluster: 	0.18452609756163188
	separation:	0.32843781171404585
	avg separation:	4.153429469283746
	accu: 		74.88781498101484%
	l1: 		201000.0
	p dist pair: 	7.619817733764648
Test accuracy:  0.7488781498101484
epoch: 	6
	joint
	train


1499it [08:44,  2.86it/s]


	time: 	524.2636785507202
	cross ent: 	0.0002955712004160239
	cluster: 	0.006843575441388487
	separation:	0.3738470971783135
	avg separation:	4.184563363210769
	accu: 		99.99833166499833%
	l1: 		201000.0
	p dist pair: 	7.620518207550049
	test


49it [00:15,  3.09it/s]


	time: 	16.080308198928833
	cross ent: 	1.3717252539128673
	cluster: 	0.20051757778440202
	separation:	0.3497306993421243
	avg separation:	4.244123624295605
	accu: 		74.85329651363479%
	l1: 		201000.0
	p dist pair: 	7.620518207550049
Test accuracy:  0.7485329651363479
epoch: 	7
	joint
	train


878it [04:52,  2.70it/s]

In [9]:
torch.save(ppnet_multi.module, r'./scratch/users/jiaxun1218/saved_models/densenet121/r3_0.75.pth')

In [10]:
acc = tnt.test(model=ppnet_multi, dataloader=test_loader, class_specific=True, log=print)

	test


46it [00:17,  2.67it/s]


	time: 	17.44526433944702
	cross ent: 	1.3075031968562498
	cluster: 	0.055004217135517494
	separation:	0.09643075874318248
	avg separation:	3.4655522833699766
	accu: 		74.05937176389368%
	l1: 		201000.0
	p dist pair: 	5.022886276245117
