In [39]:
import torch
import torch.nn as nn

In [15]:
import os
import pickle
import numpy as np
import torchvision.transforms as transforms

def read_data(path_dataset):
    # train batch
    train_batch = {}
    for i in range(5):
        filename = os.path.join(path_dataset, 'data_batch_{}'.format(i+1))
        with open(filename, 'rb') as f:
            try:
                batch = pickle.load(f, encoding='bytes')
            except TypeError:
                batch = pickle.load(f) # for python 2
            for key in batch.keys():
                train_batch.setdefault(key, []).extend(batch[key])
    train_batch = {k: np.stack(v, 0) for k, v in train_batch.items()} # stack into one batch

    # test batch
    filename = os.path.join(path_dataset, 'test_batch')
    with open(filename, 'rb') as f:
        try:
            test_batch = pickle.load(f, encoding='bytes')
        except TypeError:
            test_batch = pickle.load(f)

    # Reshape images: (n, 3072) -> (n, 32, 32, 3)
    label_key = 'labels'.encode('utf-8')
    train_images = np.transpose(
        np.reshape(train_batch['data'.encode('utf-8')], [-1, 3, 32, 32]), [0,2,3,1])
    train_labels = np.asarray(train_batch[label_key])
    test_images = np.transpose(
        np.reshape(test_batch['data'.encode('utf-8')], [-1, 3, 32, 32]), [0,2,3,1])
    test_labels = np.asarray(test_batch[label_key])

    # Pre-processing (normalize)
    train_images = np.divide(train_images, 255, dtype=np.float32)
    test_images = np.divide(test_images, 255, dtype=np.float32)
    channel_mean = np.mean(train_images, axis=(0,1,2), dtype=np.float32, keepdims=True)
    channel_std = np.std(train_images, axis=(0,1,2), dtype=np.float32, keepdims=True)
    train_images = (train_images - channel_mean) / channel_std
    test_images = (test_images - channel_mean) / channel_std

    dataset = {
        'train': {'input': train_images, 'label': train_labels},
        'test': {'input': test_images, 'label': test_labels},
    }
    return dataset

In [46]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
transform = transforms.Compose([
    transforms.Resize(224),  # Optional, depending on the model's expected input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
trainset = datasets.CIFAR10(root='/Users/matthewkolodner/Desktop/Stanford/CS330/Project/data/', train=True, download=False, transform=transform)
testset = datasets.CIFAR10(root='/Users/matthewkolodner/Desktop/Stanford/CS330/Project/data/', train=False, download=False, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

In [20]:
from torchvision.models import resnet18, ResNet18_Weights

class Model(object):
    def __init__(self,
                 datasource,
                 arch,
                 num_classes,
                 target_sparsity,
                 optimizer,
                 lr_decay_type,
                 lr,
                 decay_boundaries,
                 decay_values,
                 initializer_w_bp,
                 initializer_b_bp,
                 initializer_w_ap,
                 initializer_b_ap,
                 **kwargs):
        self.datasource = datasource
        self.arch = arch
        self.num_classes = num_classes
        self.target_sparsity = target_sparsity
        self.optimizer = optimizer
        self.lr_decay_type = lr_decay_type
        self.lr = lr
        self.decay_boundaries = decay_boundaries
        self.decay_values = decay_values
        self.initializer_w_bp = initializer_w_bp
        self.initializer_b_bp = initializer_b_bp
        self.initializer_w_ap = initializer_w_ap
        self.initializer_b_ap = initializer_b_ap

    def construct_model(self):
        # Base-learner
        self.net = resnet18(weights=ResNet18_Weights.DEFAULT)
        
        weights = self.net.state_dict()
        mask_init = mask_init = {k: var_no_train(tf.ones(weights[k].shape)) for k in prn_keys}
        mask_init = {k: var_no_train(weights[k].shape) for k in weights}
        mask_prev = {k: var_no_train(weights[k].shape) for k in weights}
        
        # Model

        def get_sparse_mask():
            w_mask = apply_mask(weights, mask_init)
            logits = net.forward_pass(w_mask, self.inputs['input'],
                self.is_train, trainable=False)
            loss = tf.reduce_mean(compute_loss(self.inputs['label'], logits))
            grads = tf.gradients(loss, [mask_init[k] for k in prn_keys])
            gradients = dict(zip(prn_keys, grads))
            cs = normalize_dict({k: tf.abs(v) for k, v in gradients.items()})
            return create_sparse_mask(cs, self.target_sparsity)

        mask = tf.cond(self.compress, lambda: get_sparse_mask(), lambda: mask_prev)
        with tf.control_dependencies([tf.assign(mask_prev[k], v) for k,v in mask.items()]):
            w_final = apply_mask(weights, mask)

        # Forward pass
        logits = net.forward_pass(w_final, self.inputs['input'], self.is_train)

        # Loss
        opt_loss = tf.reduce_mean(compute_loss(self.inputs['label'], logits))
        reg = 0.00025 * tf.reduce_sum([tf.reduce_sum(tf.square(v)) for v in w_final.values()])
        opt_loss = opt_loss + reg

        # Optimization
        optim, lr, global_step = prepare_optimization(opt_loss, self.optimizer, self.lr_decay_type,
            self.lr, self.decay_boundaries, self.decay_values)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # TF version issue
        with tf.control_dependencies(update_ops):
            self.train_op = optim.minimize(opt_loss, global_step=global_step)

        # Outputs
        output_class = tf.argmax(logits, axis=1, output_type=tf.int32)
        output_correct_prediction = tf.equal(self.inputs['label'], output_class)
        output_accuracy_individual = tf.cast(output_correct_prediction, tf.float32)
        output_accuracy = tf.reduce_mean(output_accuracy_individual)
        self.outputs = {
            'logits': logits,
            'los': opt_loss,
            'acc': output_accuracy,
            'acc_individual': output_accuracy_individual,
        }
        self.sparsity = compute_sparsity(w_final, prn_keys)

        # Summaries
        tf.summary.scalar('loss', opt_loss)
        tf.summary.scalar('accuracy', output_accuracy)
        tf.summary.scalar('lr', lr)
        self.summ_op = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES))

def compute_loss(labels, logits):
    assert len(labels.shape)+1 == len(logits.shape)
    num_classes = logits.shape.as_list()[-1]
    labels = tf.one_hot(labels, num_classes, dtype=tf.float32)
    return tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)

def get_optimizer(optimizer, lr):
    if optimizer == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(lr)
    elif optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(lr, 0.9)
    else:
        raise NotImplementedError
    return optimizer

def prepare_optimization(loss, optimizer, lr_decay_type, learning_rate, boundaries, values):
    global_step = tf.Variable(0, trainable=False)
    if lr_decay_type == 'constant':
        learning_rate = tf.constant(learning_rate)
    elif lr_decay_type == 'piecewise':
        assert len(boundaries)+1 == len(values)
        learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
    else:
        raise NotImplementedError
    optim = get_optimizer(optimizer, learning_rate)
    return optim, learning_rate, global_step

def vectorize_dict(x, sortkeys=None):
    assert isinstance(x, dict)
    if sortkeys is None:
        sortkeys = x.keys()
    def restore(v, x_shape, sortkeys):
        # v splits for each key
        split_sizes = []
        for key in sortkeys:
            split_sizes.append(functools.reduce(lambda x, y: x*y, x_shape[key]))
        v_splits = tf.split(v, num_or_size_splits=split_sizes)
        # x restore
        x_restore = {}
        for i, key in enumerate(sortkeys):
            x_restore.update({key: tf.reshape(v_splits[i], x_shape[key])})
        return x_restore
    # vectorized dictionary
    x_vec = tf.concat([tf.reshape(x[k], [-1]) for k in sortkeys], axis=0)
    # restore function
    x_shape = {k: x[k].shape.as_list() for k in sortkeys}
    restore_fn = functools.partial(restore, x_shape=x_shape, sortkeys=sortkeys)
    return x_vec, restore_fn

def normalize_dict(x):
    x_v, restore_fn = vectorize_dict(x)
    x_v_norm = tf.divide(x_v, tf.reduce_sum(x_v))
    x_norm = restore_fn(x_v_norm)
    return x_norm

def compute_sparsity(weights, target_keys):
    assert isinstance(weights, dict)
    w = {k: weights[k] for k in target_keys}
    w_v, _ = vectorize_dict(w)
    sparsity = tf.nn.zero_fraction(w_v)
    return sparsity

def create_sparse_mask(mask, target_sparsity):
    def threshold_vec(vec, target_sparsity):
        num_params = vec.shape.as_list()[0]
        kappa = int(round(num_params * (1. - target_sparsity)))
        topk, ind = tf.nn.top_k(vec, k=kappa, sorted=True)
        mask_sparse_v = tf.sparse_to_dense(ind, tf.shape(vec),
            tf.ones_like(ind, dtype=tf.float32), validate_indices=False)
        return mask_sparse_v
    if isinstance(mask, dict):
        mask_v, restore_fn = vectorize_dict(mask)
        mask_sparse_v = threshold_vec(mask_v, target_sparsity)
        return restore_fn(mask_sparse_v)
    else:
        return threshold_vec(mask, target_sparsity)

def apply_mask(weights, mask):
    all_keys = weights.keys()
    target_keys = mask.keys()
    remain_keys = list(set(all_keys) - set(target_keys))
    w_sparse = {k: mask[k] * weights[k] for k in target_keys}
    w_sparse.update({k: weights[k] for k in remain_keys})
    return w_sparse

In [45]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
target_sparsity = 0.95
def var_no_train(shape):
    return torch.ones(shape, dtype=torch.float32, requires_grad=False)
criterion = nn.CrossEntropyLoss()
weights = model.state_dict()
mask_init = {k: var_no_train(weights[k].shape) for k in weights}
mask_prev = {k: var_no_train(weights[k].shape) for k in weights}
w_mask = apply_mask(weights, mask_init)

for inputs, labels in trainloader:
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    model.zero_grad()
    loss.backward()
    absolute_gradients = {name: param.grad.abs() for name, param in model.named_parameters()}
    break

In [101]:
def restore(v, x_shape, sortkeys):
    split_sizes = []
    for key in sortkeys:
        split_sizes.append(functools.reduce(lambda x, y: x*y, x_shape[key]))
    v_splits = torch.split(v, split_size_or_sections=split_sizes)
    x_restore = {}
    for i, key in enumerate(sortkeys):
        x_restore.update({key: torch.reshape(v_splits[i], x_shape[key])})
    return x_restore

In [102]:
mask_v = torch.cat([absolute_gradients[k].view(-1) for k in absolute_gradients], dim=0)
x_shape = {k: absolute_gradients[k].shape for k in absolute_gradients}
restore_fn = functools.partial(restore, x_shape=x_shape, sortkeys=absolute_gradients)

In [103]:
num_params = mask_v.shape[0]
kappa = int(round(num_params * (1. - target_sparsity)))
topk, ind = torch.topk(mask_v, k=kappa, largest=True, sorted=True)
mask_sparse_v = torch.zeros_like(mask_v, dtype=torch.float32)
mask_sparse_v = mask_sparse_v.scatter_(0, ind, 1)

In [115]:
mask_final = restore_fn(mask_sparse_v)

In [116]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # or any other optimizer
num_epochs=10

In [117]:
for epoch in range(num_epochs):
    for inputs, labels in trainloader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        with torch.no_grad():
            for param, mask in zip(model.parameters(), mask_final):
                param.grad *= mask
        optimizer.step()

KeyboardInterrupt: 