In [1]:
from torchvision import datasets
from torchvision.transforms import transforms
from torch.utils.data import Subset, DataLoader

# from cifar10tools import test, config_model_proper, VGG, make_layers, make_model, calculate_params, get_model_filters_config

from svhn_tools import test, config_model_proper, VGG, make_layers, make_model, calculate_params, get_model_filters_config

from fvcore.nn import FlopCountAnalysis

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import pickle

ACC_THRESH = 2

class Individual():
    def __init__(self, bitstring) -> None:
        self.bitstring = bitstring
        self.metrics = None
    
    def mutate(self):
        mp_list = [0.0002]
        mutation_prob = mp_list[np.random.randint(len(mp_list))]
        self.metrics = None
        self.bitstring ^= np.random.binomial(1, mutation_prob, len(self.bitstring))
    
class GeneticAlgo():
    def __init__(self,
                 num_genes,
                 num_parents,
                 base_model,
                 ) -> None:
        self.num_genes = num_genes
        self.num_parents = num_parents
        self.population = [Individual(np.random.binomial(1, 0.99, num_genes)) for _ in range(50)]
        self.base_model = base_model
        # self.dataset_cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([
        #     transforms.ToTensor(),
        #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        # ]))
        # self.dataset_cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
        #     transforms.ToTensor(),
        #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        # ]))

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        self.dataset_train = datasets.SVHN(root='./data', split='train', transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True)

        self.dataset_test = datasets.SVHN(root='./data', split='test', transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]), download=True)
    
    def assign_score(self, indiv, accuracy):
        metrics = {}
        accuracy = accuracy
        accuracy_drop = self.base_accuracy - accuracy
        num_filters_dropped = np.count_nonzero(indiv.bitstring == 0)
        per_filters_dropped = num_filters_dropped * 100 / self.num_genes

        if accuracy_drop < ACC_THRESH:
            score = per_filters_dropped + 1000
        else:
            score = per_filters_dropped/(accuracy_drop+10) - (accuracy_drop)*100

        metrics['score'] = score
        metrics['accuracy_drop'] = accuracy_drop
        metrics['num_filters_dropped'] = num_filters_dropped
        metrics['per_filters_dropped'] = per_filters_dropped
        metrics['accuracy'] = accuracy
        indiv.metrics = metrics
        print(metrics)
    
    def get_accuracy(self, indiv_batch):
        indiv_models = [config_model_proper(indiv.bitstring, self.base_model) for indiv in indiv_batch]
        return test(self.dataloader, indiv_models)
    
    def evaluate_batches(self, batches):
        for batch in batches:
            batch_accuracies = self.get_accuracy(batch)
            for idx, individual in enumerate(batch):
                accuracy = batch_accuracies[idx]
                self.assign_score(individual, accuracy)
    
    def evaluate_population(self, population):
        eval_batches = []
        eval_batch = []
        item_counter = 0
        chunk_size = 100
        for individual in population:
            ## if the individual is not already assigned a metric
            if not individual.metrics:
                eval_batch.append(individual)
                individual.metrics = {}
                item_counter += 1
                if item_counter % chunk_size == 0:
                    eval_batches.append(eval_batch)
                    eval_batch = []
            else:
                pass
        if len(eval_batch) > 0 and len(eval_batch) < chunk_size:
            eval_batches.append(eval_batch)
        self.evaluate_batches(eval_batches)
    
    def crossover(self, parent1, parent2):
        '''Single Crossover'''
        rand_point = np.random.randint(self.num_genes)
        child1 = Individual(np.concatenate((parent1.bitstring[:rand_point], parent2.bitstring[rand_point:])))
        child2 = Individual(np.concatenate((parent2.bitstring[:rand_point], parent1.bitstring[rand_point:])))
        child1.mutate()
        child2.mutate()
        return child1, child2
    
    def get_next_population(self):
        offspring_list = list()
        sorted_score = sorted(self.population, key=lambda x: x.metrics['score'], reverse=True)[:self.num_parents]
        for idx1 in range(len(sorted_score)-1):
            for idx2 in range(idx1+1, len(sorted_score)):
                offspring_list += self.crossover(sorted_score[idx1], sorted_score[idx2])
        self.evaluate_population(offspring_list)
        self.population += offspring_list

    def run(self, save_acc_drop=[], save_fil_drop=[], save_flop_drop=[], save_params_drop=[], gen=0):
        H, W = 86, 64
        gen = gen
        save_acc_drop = save_acc_drop
        save_fil_drop = save_fil_drop
        save_flop_drop = save_flop_drop
        save_params_drop = save_params_drop

        while True:
            print("==========")
            print(f"gen:{gen}, popsize:{len(self.population)}")
            print("==========")
            subset_indices = np.arange(len(self.dataset_train))
            # np.random.shuffle(subset_indices)
            ## slice subset indices to create smaller subset
            subset = Subset(self.dataset_train, subset_indices)
            self.dataloader = DataLoader(
                subset,
                batch_size=2048, shuffle=True,
                num_workers=4, pin_memory=True)
            self.base_accuracy = test(self.dataloader, [self.base_model])[0]

            self.evaluate_population(self.population)
            sorted_by_score = sorted(self.population, key=lambda x: x.metrics['score'], reverse=True)[:300]
            del self.population
            self.population = sorted_by_score

            print(f"GEN:{gen}::@@@@@@@@@@@@@@@@@@@@")
            print(sorted_by_score[0].metrics)
            
            print(sorted_by_score[1].metrics)
            print(sorted_by_score[2].metrics)
            print(sorted_by_score[3].metrics)
            print(sorted_by_score[4].metrics)
            print(f"{gen}::@@@@@@@@@@@@@@@@@@@@@")
            #-----------
            indiv0 = np.array(sorted_by_score[0].bitstring).reshape(H,W)
            indiv1 = np.array(sorted_by_score[1].bitstring).reshape(H,W)
            indiv2 = np.array(sorted_by_score[2].bitstring).reshape(H,W)
            indiv3 = np.array(sorted_by_score[3].bitstring).reshape(H,W)
            indiv4 = np.array(sorted_by_score[4].bitstring).reshape(H,W)

            for i, indiv in [('score0', indiv0), ('score1', indiv1), ('score2', indiv2), ('score3', indiv3), ('score4', indiv4)]:
                img = np.zeros((H*4,W*4,3), np.uint8)
                for x in range(indiv.shape[0]):
                    for y in range(indiv.shape[1]):
                        if indiv[x,y] == 1:
                            img[x*4:x*4+4,y*4:y*4+4] = [255,0,0]
                        else:
                            img[x*4:x*4+4,y*4:y*4+4] = [0,255,0]

                cv2.imwrite(f'./debug/visualize/this_{i}.jpg',img)
            #-----------
            
            self.get_next_population()
            # for p in self.population:
            #     p.metrics = None
            SAVE_STEP = 2
            if gen % SAVE_STEP == 0:
                ## SAVE THE PLOTS AFTER EVERY 2 iterations
                ##------------------------
                ## SAVE THE ACC DROP AND FILTERS DROPPED
                fig, axs = plt.subplots(4, 1, figsize=(15,15), dpi=400)
                axs[0].plot(np.arange(0,len(save_acc_drop)*SAVE_STEP,SAVE_STEP), save_acc_drop, linestyle='--', marker='o', color='b')
                axs[0].set_xlabel('generations')
                axs[0].set_ylabel('accuracy drop %')
                
                axs[1].plot(np.arange(0,len(save_fil_drop)*SAVE_STEP,SAVE_STEP), save_fil_drop, linestyle='--', marker='o', color='b')
                axs[1].set_xlabel('generations')
                axs[1].set_ylabel('filters dropped %')

                save_acc_drop.append(sorted_by_score[0].metrics['accuracy_drop'])
                save_fil_drop.append(sorted_by_score[0].metrics['per_filters_dropped'])

                ##------------------------
                ## SAVE THE FLOP DROP

                pruned_model = make_model(sorted_by_score[0].bitstring, self.base_model)

                random_input = self.dataset_test[np.random.randint(len(self.dataset_test))][0]
                random_input = random_input.unsqueeze(0)
                random_input = random_input.to('cuda')

                base_flops_analysis = FlopCountAnalysis(self.base_model, random_input)
                pruned_flops_analysis = FlopCountAnalysis(pruned_model, random_input)

                base_flops = base_flops_analysis.total()
                pruned_flops = pruned_flops_analysis.total()

                # print('BASE   FLOPS:',base_flops)
                # print('PRUNED FLOPS:',pruned_flops)

                # print('FLOP DROP:', f'{(base_flops - pruned_flops)*100/base_flops}%')

                save_flop_drop.append((base_flops - pruned_flops)*100/base_flops)

                axs[2].plot(np.arange(0,len(save_flop_drop)*SAVE_STEP,SAVE_STEP), save_flop_drop, linestyle='--', marker='o', color='b')
                axs[2].set_xlabel('generations')
                axs[2].set_ylabel('FLOP dropped %')

                ##------------------------
                ## SAVE THE PARAM DROP

                base_params_count = calculate_params(self.base_model)
                pruned_params_count = calculate_params(pruned_model)
                # print((base_params_count - pruned_params_count)*100 / base_params_count)
                save_params_drop.append((base_params_count - pruned_params_count)*100 / base_params_count)

                axs[3].plot(np.arange(0,len(save_params_drop)*SAVE_STEP,SAVE_STEP), save_params_drop, linestyle='--', marker='o', color='b')
                axs[3].set_xlabel('generations')
                axs[3].set_ylabel('Parameters dropped %')
            
                plt.savefig('./debug/save_svhn.png', bbox_inches='tight')
                plt.close()

                pickle.dump(self.population, open('./checkpoints/pruned_population_ckpt/last_svhn_gen.p', 'wb'))
                pickle.dump((save_acc_drop,save_fil_drop,save_flop_drop,save_params_drop), open('./checkpoints/pruned_plot_ckpt/last_svhn_plot.p', 'wb'))

            gen += 1

In [2]:
# pruned_model, indiv = pickle.load(open('./models/pruned_models/pruned_vgg16_cifar10_model', 'rb'))

In [3]:
base_model = VGG(make_layers(
    [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512, 'M']
))

base_model.load_state_dict(torch.load('models/vgg19_svhn_base_model'))
base_model = base_model.to('cuda')

myGA = GeneticAlgo(
    num_genes = 5504,
    num_parents = 10,
    base_model = base_model
)

# acc_drop, fil_drop, flop_drop, params_drop = pickle.load(open('./checkpoints/pruned_plot_ckpt/last_plot_bckp.p','rb'))
# myGA.population = pickle.load(open('./checkpoints/pruned_population_ckpt/last_gen_bckp.p','rb'))
# gen = int(open('./checkpoints/last_gen','r').read())
# myGA.run(save_acc_drop=acc_drop, save_fil_drop=fil_drop, save_flop_drop=flop_drop, save_params_drop=params_drop, gen=gen)

myGA.run()

Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat
gen:0, popsize:50
{'score': -221.19247276422587, 'accuracy_drop': 2.2127578251907636, 'num_filters_dropped': 56, 'per_filters_dropped': 1.0174418604651163, 'accuracy': 96.38259824999659}
{'score': 1001.0537790697674, 'accuracy_drop': 1.1534733882086385, 'num_filters_dropped': 58, 'per_filters_dropped': 1.0537790697674418, 'accuracy': 97.44188268697872}
{'score': -403.0288963248117, 'accuracy_drop': 4.031014101041535, 'num_filters_dropped': 56, 'per_filters_dropped': 1.0174418604651163, 'accuracy': 94.56434197414582}
{'score': -273.75489345364826, 'accuracy_drop': 2.738304871889369, 'num_filters_dropped': 53, 'per_filters_dropped': 0.9629360465116279, 'accuracy': 95.85705120329798}
{'score': 1000.9265988372093, 'accuracy_drop': 1.579371254624121, 'num_filters_dropped': 51, 'per_filters_dropped': 0.9265988372093024, 'accuracy': 97.01598482056323}
{'score': -629.77053670931