In [1]:
import os
import shutil

import torch
import torch.utils.data
import pandas as pd
# import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import argparse
import re

from helpers import makedir
import model
import push
import prune
import train_and_test as tnt
import save
from log import create_logger
from preprocess import mean, std, preprocess_input_function

from bounding_box_metrics import bounding_box_overlap
from find_nearest import find_k_nearest_patches_to_prototypes
from settings import coefs
from settings import num_train_epochs, num_warm_epochs, push_start, push_epochs, push_saved_epochs

In [14]:
from settings import train_dir, test_dir, train_push_dir, \
                     train_batch_size, test_batch_size, train_push_batch_size
from settings import base_architecture, img_size, prototype_shape, num_classes, \
                     prototype_activation_function, add_on_layers_type, experiment_run

normalize = transforms.Normalize(mean=mean,
                                 std=std)
train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True,
    num_workers=2, pin_memory=False)

# push set
train_push_dataset = datasets.ImageFolder(
    train_push_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
]))

train_push_loader = torch.utils.data.DataLoader(
    train_push_dataset, batch_size=train_push_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)

# test set
test_dataset = datasets.ImageFolder(
    test_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
]))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)

In [20]:
ppnet = torch.load(r'./A3C_results/005_ppnet.pth')
ppnet_multi = ppnet.cuda()
#ppnet = torch.nn.DataParallel(ppnet)
class_specific = True

In [21]:
from settings import joint_optimizer_lrs, joint_lr_step_size
joint_optimizer_specs = \
[{'params': ppnet.module.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet.module.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet.module.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.1)

In [22]:
model_dir = './joint_results/001/'
makedir(model_dir)
log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))

In [27]:
for epoch in range(10, 20):
    print('epoch: \t{0}'.format(epoch))

    
    tnt.joint(model=ppnet_multi, log=log)
    joint_lr_scheduler.step()
    _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=joint_optimizer,
                  class_specific=class_specific, coefs=coefs, log=log)

#     accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
#                     class_specific=class_specific, log=log)
#     save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'nopush', accu=accu,
#                                 target_accu=0.70, log=log)

    if epoch >= push_start and epoch in push_epochs:
        # Only save the model in in push_saved_epochs
        if epoch in push_saved_epochs:
            bounding_box_tracker = push.push_prototypes(
                train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
                prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors
                class_specific=class_specific,
                preprocess_input_function=preprocess_input_function, # normalize if needed
                prototype_layer_stride=1,
                root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten
                prototype_img_filename_prefix=prototype_img_filename_prefix,
                prototype_self_act_filename_prefix=prototype_self_act_filename_prefix,
                proto_bound_boxes_filename_prefix=proto_bound_boxes_filename_prefix,
                save_prototype_class_identity=True,
                log=log,
                bounding_box_tracker=bounding_box_tracker)
            accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                            class_specific=class_specific, log=log)
            save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu,
                                        target_accu=0.70, log=log)

            # updated bounding_box_tracker is used  
            #overlap = bounding_box_overlap(bounding_box_tracker, prototype_shape[0], num_classes)
            #overlaps["epoch"].append(epoch)
            #overlaps["overlap"].append(overlap)   

            #pd.DataFrame(overlaps).to_csv(model_dir + "overlap_" + str(epoch) + ".csv")

        else:
            bounding_box_tracker = push.push_prototypes(
                train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
                prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors
                class_specific=class_specific,
                preprocess_input_function=preprocess_input_function, # normalize if needed
                prototype_layer_stride=1,
                root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                epoch_number=None, # if not provided, prototypes saved previously will be overwritten
                prototype_img_filename_prefix="",
                prototype_self_act_filename_prefix="",
                proto_bound_boxes_filename_prefix=proto_bound_boxes_filename_prefix="",
                save_prototype_class_identity=True,
                log=log,
                bounding_box_tracker=bounding_box_tracker)
#             accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
#                             class_specific=class_specific, log=log)
#             save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu,
#                                         target_accu=0.70, log=log)

            #overlap = bounding_box_overlap(bounding_box_tracker, prototype_shape[0], num_classes)
            #overlaps["epoch"].append(epoch)
            #overlaps["overlap"].append(overlap)

        if prototype_activation_function != 'linear':
            tnt.last_only(model=ppnet_multi, log=log)
            # Fine tune the last layers
            for i in range(8):
                log('iteration: \t{0}'.format(i))
                _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=last_layer_optimizer,
                              class_specific=class_specific, coefs=coefs, log=log)

            # Save the last with the final layer fine-tuned
            if epoch in push_saved_epochs:
                accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
                                class_specific=class_specific, log=log)
                print("Test accuracy: ", accu)
                save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=accu,
                                            target_accu=0.70, log=log)

epoch: 	10
	joint
	train


2248it [13:02,  2.87it/s]


	time: 	782.7445023059845
	cross ent: 	0.00116299423359295
	cluster: 	0.007484521454587924
	separation:	0.24203524141973448
	avg separation:	1.4643946545085873
	accu: 		99.99721944166389%
	l1: 		1273.353515625
	p dist pair: 	2.5812838077545166


NameError: name 'prototype_img_filename_prefix' is not defined

In [25]:
torch.save(ppnet_multi.module, r'./joint_results/001_ppnet_72.07.pth')

In [28]:
acc = tnt.test(model=ppnet_multi, dataloader=test_loader, class_specific=True, log=print)

	test


58it [00:16,  3.49it/s]


	time: 	16.795644998550415
	cross ent: 	1.3844326472487942
	cluster: 	0.14498716457907496
	separation:	0.20993644397320418
	avg separation:	1.4998664485997166
	accu: 		71.50500517777012%
	l1: 		1273.353515625
	p dist pair: 	2.5812838077545166
