In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import torchnet as tnt
from torch.autograd import Variable
from tqdm import tqdm
import argparse

import numpy as np
import collections
from PIL import Image
import json

In [None]:
class Expert(nn.Module):
    def __init__(self, model_type="resnet18", num_classes=10):
        super(Expert, self).__init__()
        if model_type == "resnet18":
            self.model = models.resnet18(pretrained=True)
            self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        elif model_type == "vgg16":
            self.model = models.vgg16(pretrained=True)
            self.model.classifier[-1] = nn.Linear(self.model.classifier[-1].in_features, num_classes)
        else:
            raise ValueError("Unsupported model type. Choose 'resnet18' or 'vgg16'")
    
    def forward(self, x):
        return F.softmax(self.model(x), dim=-1)

class Gate(nn.Module):
    def __init__(self, n_experts, input_dim=512):
        super(Gate, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, n_experts)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=-1)

class MOEModel(nn.Module):
    def __init__(self, n_experts=3, model_type="resnet18", num_classes=10, lr=0.001):
        super(MOEModel, self).__init__()
        self.experts = nn.ModuleList([Expert(model_type, num_classes) for _ in range(n_experts)])
        self.gate = Gate(n_experts, 512 if model_type == "resnet18" else 4096)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
    
    def forward(self, x):
        gate_outputs = self.gate(x).unsqueeze(-1)  # (batch_size, n_experts, 1)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=-1)  # (batch_size, num_classes, n_experts)
        return torch.matmul(expert_outputs, gate_outputs).squeeze(-1)  # (batch_size, num_classes)
    
    def probabilities(self, x, y):
        expert_outputs = self.forward(x)
        return torch.sum(y * expert_outputs, dim=1)
    
    def calculate_loss(self, x, y):
        probs = self.probabilities(x, y)
        return -torch.log(probs + 1e-9).mean()
    
    def grad(self, x, y):
        self.optimizer.zero_grad()
        loss = self.calculate_loss(x, y)
        loss.backward()
        return loss
    
    def step(self, x, y):
        loss = self.grad(x, y)
        self.optimizer.step()
        return loss


In [None]:
import torch
import torch.nn as nn

class SparsePruner:
    """Performs pruning on the experts in the MOE model."""

    def __init__(self, model, prune_perc, previous_masks, train_bias, train_bn):
        self.model = model
        self.prune_perc = prune_perc
        self.train_bias = train_bias
        self.train_bn = train_bn
        self.current_masks = None
        self.previous_masks = previous_masks
        
        # Identify dataset index from the previous masks
        valid_key = list(previous_masks.keys())[0]
        self.current_dataset_idx = previous_masks[valid_key].max()
    
    def pruning_mask(self, weights, previous_mask, layer_idx):
        """Computes a pruning mask based on weight magnitudes."""
        previous_mask = previous_mask.cuda()
        tensor = weights[previous_mask.eq(self.current_dataset_idx)]
        abs_tensor = tensor.abs()
        cutoff_rank = round(self.prune_perc * tensor.numel())
        cutoff_value = abs_tensor.view(-1).cpu().kthvalue(cutoff_rank)[0].item()

        remove_mask = weights.abs().le(cutoff_value) * previous_mask.eq(self.current_dataset_idx)
        previous_mask[remove_mask.eq(1)] = 0
        mask = previous_mask
        
        print(f'Layer #{layer_idx}, pruned {mask.eq(0).sum()}/{tensor.numel()} ({100 * mask.eq(0).sum() / tensor.numel():.2f}%)')
        return mask
    
    def prune(self):
        """Prunes only the expert networks while keeping the gating network untouched."""
        print(f'Pruning experts for dataset idx: {self.current_dataset_idx}')
        assert self.current_masks is None, 'Pruning twice?'
        self.current_masks = {}
        
        print(f'Pruning each expert layer by removing {100 * self.prune_perc:.2f}% of values')
        
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    mask = self.pruning_mask(module.weight.data, self.previous_masks[(expert_idx, layer_idx)], layer_idx)
                    self.current_masks[(expert_idx, layer_idx)] = mask.cuda()
                    module.weight.data[self.current_masks[(expert_idx, layer_idx)].eq(0)] = 0.0
    
    def make_grads_zero(self):
        """Sets gradients of pruned weights to zero for the expert networks."""
        assert self.current_masks is not None
        
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    layer_mask = self.current_masks[(expert_idx, layer_idx)]
                    
                    if module.weight.grad is not None:
                        module.weight.grad.data[layer_mask.ne(self.current_dataset_idx)] = 0
                        if not self.train_bias and module.bias is not None:
                            module.bias.grad.data.fill_(0)

    def make_pruned_zero(self):
        """Forces pruned weights to remain zero in the expert networks."""
        assert self.current_masks is not None
        
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    module.weight.data[self.current_masks[(expert_idx, layer_idx)].eq(0)] = 0.0

    def apply_mask(self, dataset_idx):
        """Applies the pruning mask for a specific dataset."""
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    weight = module.weight.data
                    mask = self.previous_masks[(expert_idx, layer_idx)].cuda()
                    weight[mask.eq(0)] = 0.0
                    weight[mask.gt(dataset_idx)] = 0.0
    
    def restore_biases(self, biases):
        """Use the given biases to replace existing biases."""
        for module_idx, module in enumerate(self.model.shared.modules()):
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if module.bias is not None:
                    module.bias.data.copy_(biases[module_idx])

    def get_biases(self):
        """Gets a copy of the current biases."""
        biases = {}
        for module_idx, module in enumerate(self.model.shared.modules()):
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if module.bias is not None:
                    biases[module_idx] = module.bias.data.clone()
        return biases


    def make_finetuning_mask(self):
        """Allows previously pruned weights to be trainable for the new dataset."""
        assert self.previous_masks is not None
        self.current_dataset_idx += 1
        
        for expert_idx, expert in enumerate(self.model.experts):
            for layer_idx, module in enumerate(expert.modules()):
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    mask = self.previous_masks[(expert_idx, layer_idx)]
                    mask[mask.eq(0)] = self.current_dataset_idx
        
        self.current_masks = self.previous_masks


In [None]:
def step_lr(epoch, base_lr, lr_decay_every, lr_decay_factor, optimizer):
    """Handles step decay of learning rate."""
    factor = np.power(lr_decay_factor, np.floor((epoch - 1) / lr_decay_every))
    new_lr = base_lr * factor
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr
    print('Set lr to ', new_lr)
    return optimizer

In [None]:
FLAGS = argparse.ArgumentParser()
FLAGS.add_argument('--arch',
                   choices=['vgg16', 'vgg16bn', 'resnet50', 'densenet121'],
                   help='Architectures')
FLAGS.add_argument('--mode',
                   choices=['finetune', 'prune', 'check', 'eval'],
                   help='Run mode')
FLAGS.add_argument('--finetune_layers',
                   choices=['all', 'fc', 'classifier'], default='all',
                   help='Which layers to finetune, fc only works with vgg')
FLAGS.add_argument('--num_outputs', type=int, default=-1,
                   help='Num outputs for dataset')
# Optimization options.
FLAGS.add_argument('--lr', type=float,
                   help='Learning rate')
FLAGS.add_argument('--lr_decay_every', type=int,
                   help='Step decay every this many epochs')
FLAGS.add_argument('--lr_decay_factor', type=float,
                   help='Multiply lr by this much every step of decay')
FLAGS.add_argument('--finetune_epochs', type=int,
                   help='Number of initial finetuning epochs')
FLAGS.add_argument('--batch_size', type=int, default=32,
                   help='Batch size')
FLAGS.add_argument('--weight_decay', type=float, default=0.0,
                   help='Weight decay')
# Paths.
FLAGS.add_argument('--dataset', type=str, default='',
                   help='Name of dataset')
FLAGS.add_argument('--train_path', type=str, default='',
                   help='Location of train data')
FLAGS.add_argument('--test_path', type=str, default='',
                   help='Location of test data')
FLAGS.add_argument('--save_prefix', type=str, default='../checkpoints/',
                   help='Location to save model')
FLAGS.add_argument('--loadname', type=str, default='',
                   help='Location to save model')
# Pruning options.
FLAGS.add_argument('--prune_method', type=str, default='sparse',
                   choices=['sparse'],
                   help='Pruning method to use')
FLAGS.add_argument('--prune_perc_per_layer', type=float, default=0.5,
                   help='% of neurons to prune per layer')
FLAGS.add_argument('--post_prune_epochs', type=int, default=0,
                   help='Number of epochs to finetune for after pruning')
FLAGS.add_argument('--disable_pruning_mask', action='store_true', default=False,
                   help='use masking or not')
FLAGS.add_argument('--train_biases', action='store_true', default=False,
                   help='use separate biases or not')
FLAGS.add_argument('--train_bn', action='store_true', default=False,
                   help='train batch norm or not')
# Other.
FLAGS.add_argument('--cuda', action='store_true', default=True,
                   help='use CUDA')
FLAGS.add_argument('--init_dump', action='store_true', default=False,
                   help='Initial model dump.')

In [None]:
def test_loader_caffe(path, batch_size, num_workers=4, pin_memory=False):
    """Legacy loader for caffe. Used with models loaded from caffe."""
    # Returns images in 256 x 256 to subtract given mean of same size.
    return data.DataLoader(
        datasets.ImageFolder(path,
                             transforms.Compose([
                                 Scale((256, 256)),
                                 # transforms.CenterCrop(224),
                                 transforms.ToTensor(),
                             ])),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory)


def train_loader_cropped(path, batch_size, num_workers=4, pin_memory=False):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return data.DataLoader(
        datasets.ImageFolder(path,
                             transforms.Compose([
                                 Scale((224, 224)),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 normalize,
                             ])),
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory)


def test_loader_cropped(path, batch_size, num_workers=4, pin_memory=False):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return data.DataLoader(
        datasets.ImageFolder(path,
                             transforms.Compose([
                                 Scale((224, 224)),
                                 transforms.ToTensor(),
                                 normalize,
                             ])),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory)


# Note: This might not be needed anymore given that this functionality exists in
# the newer PyTorch versions.
class Scale(object):
    """Rescale the input PIL.Image to the given size.
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(
            size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be scaled.
        Returns:
            PIL.Image: Rescaled image.
        """
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
        else:
            return img.resize(self.size, self.interpolation)

In [None]:
class Manager(object):
    """Handles training and pruning."""

    def __init__(self, args: Args, model, previous_masks, dataset2idx, dataset2biases):
        self.args = args
        self.cuda = args.cuda
        self.model = model
        self.dataset2idx = dataset2idx
        self.dataset2biases = dataset2biases

        if args.mode != 'check':
            # Set up data loader, criterion, and pruner.
            
            train_loader = train_loader_cropped
            test_loader = test_loader_cropped

            self.train_data_loader = train_loader(
                args.train_path, args.batch_size, pin_memory=args.cuda)
            self.test_data_loader = test_loader(
                args.test_path, args.batch_size, pin_memory=args.cuda)
            self.criterion = nn.CrossEntropyLoss()

            self.pruner = SparsePruner(
                self.model, self.args.prune_perc_per_layer, previous_masks,
                self.args.train_biases, self.args.train_bn)


    def eval(self, dataset_idx, biases=None):
        """Performs evaluation."""
        if not self.args.disable_pruning_mask:
            self.pruner.apply_mask(dataset_idx)
        if biases is not None:
            self.pruner.restore_biases(biases)

        self.model.eval()
        error_meter = None

        print('Performing eval...')
        for batch, label in tqdm(self.test_data_loader, desc='Eval'):
            if self.cuda:
                batch = batch.cuda()
            batch = Variable(batch, volatile=True)

            output = self.model(batch)

            # Init error meter.
            if error_meter is None:
                topk = [1]
                if output.size(1) > 5:
                    topk.append(5)
                error_meter = tnt.meter.ClassErrorMeter(topk=topk)
            error_meter.add(output.data, label)

        errors = error_meter.value()
        print('Error: ' + ', '.join('@%s=%.2f' %
                                    t for t in zip(topk, errors)))
        if self.args.train_bn:
            self.model.train()
        else:
            self.model.train_nobn()
        return errors

    def do_batch(self, optimizer, batch, label):
        """Runs model for one batch."""
        if self.cuda:
            batch = batch.cuda()
            label = label.cuda()
        batch = Variable(batch)
        label = Variable(label)

        # Set grads to 0.
        self.model.zero_grad()

        # Do forward-backward.
        output = self.model(batch)
        self.criterion(output, label).backward()

        # Set fixed param grads to 0.
        if not self.args.disable_pruning_mask:
            self.pruner.make_grads_zero()

        # Update params.
        optimizer.step()

        # Set pruned weights to 0.
        if not self.args.disable_pruning_mask:
            self.pruner.make_pruned_zero()

    def do_epoch(self, epoch_idx, optimizer):
        """Trains model for one epoch."""
        for batch, label in tqdm(self.train_data_loader, desc='Epoch: %d ' % (epoch_idx)):
            self.do_batch(optimizer, batch, label)

    def save_model(self, epoch, best_accuracy, errors, savename):
        """Saves model to file."""
        base_model = self.model

        # Prepare the ckpt.
        self.dataset2idx[self.args.dataset] = self.pruner.current_dataset_idx
        self.dataset2biases[self.args.dataset] = self.pruner.get_biases()
        ckpt = {
            'args': self.args,
            'epoch': epoch,
            'accuracy': best_accuracy,
            'errors': errors,
            'dataset2idx': self.dataset2idx,
            'previous_masks': self.pruner.current_masks,
            'model': base_model,
        }
        if self.args.train_biases:
            ckpt['dataset2biases'] = self.dataset2biases

        # Save to file.
        torch.save(ckpt, savename + '.pt')

    def train(self, epochs, optimizer, save=True, savename='', best_accuracy=0):
        """Performs training."""
        best_accuracy = best_accuracy
        error_history = []

        if self.args.cuda:
            self.model = self.model.cuda()

        for idx in range(epochs):
            epoch_idx = idx + 1
            print('Epoch: %d' % (epoch_idx))

            optimizer = step_lr(epoch_idx, self.args.lr, self.args.lr_decay_every,
                                      self.args.lr_decay_factor, optimizer)
            if self.args.train_bn:
                self.model.train()
            else:
                self.model.train_nobn()
            self.do_epoch(epoch_idx, optimizer)
            errors = self.eval(self.pruner.current_dataset_idx)
            error_history.append(errors)
            accuracy = 100 - errors[0]  # Top-1 accuracy.

            # Save performance history and stats.
            with open(savename + '.json', 'w') as fout:
                json.dump({
                    'error_history': error_history,
                    'args': vars(self.args),
                }, fout)

            # Save best model, if required.
            if save and accuracy > best_accuracy:
                print('Best model so far, Accuracy: %0.2f%% -> %0.2f%%' %
                      (best_accuracy, accuracy))
                best_accuracy = accuracy
                self.save_model(epoch_idx, best_accuracy, errors, savename)

        print('Finished finetuning...')
        print('Best error/accuracy: %0.2f%%, %0.2f%%' %
              (100 - best_accuracy, best_accuracy))
        print('-' * 16)

    def prune(self):
        """Perform pruning."""
        print('Pre-prune eval:')
        self.eval(self.pruner.current_dataset_idx)

        self.pruner.prune()
        self.check(True)

        print('\nPost-prune eval:')
        errors = self.eval(self.pruner.current_dataset_idx)
        accuracy = 100 - errors[0]  # Top-1 accuracy.
        self.save_model(-1, accuracy, errors,
                        self.args.save_prefix + '_postprune')

        # Do final finetuning to improve results on pruned network.
        if self.args.post_prune_epochs:
            print('Doing some extra finetuning...')
            optimizer = optim.SGD(self.model.parameters(),
                                  lr=self.args.lr, momentum=0.9,
                                  weight_decay=self.args.weight_decay)
            self.train(self.args.post_prune_epochs, optimizer, save=True,
                       savename=self.args.save_prefix + '_final', best_accuracy=accuracy)

        print('-' * 16)
        print('Pruning summary:')
        self.check(True)
        print('-' * 16)

    def check(self, verbose=False):
        """Makes sure that the layers are pruned."""
        print('Checking...')
        for layer_idx, module in enumerate(self.model.shared.modules()):
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                weight = module.weight.data
                num_params = weight.numel()
                num_zero = weight.view(-1).eq(0).sum()
                if verbose:
                    print('Layer #%d: Pruned %d/%d (%.2f%%)' %
                          (layer_idx, num_zero, num_params, 100 * num_zero / num_params))

In [None]:
args = Args()