<a href="https://colab.research.google.com/github/asalcedo31/CSC2516_project/blob/master/clean_pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# set up


In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
import torch.utils.model_zoo as model_zoo
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.nn.modules import Module
import torchvision.models.vgg as tv_vgg
import time
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from torch.autograd import Variable
import time
import os
import copy
import math
import re
from collections import OrderedDict

In [2]:
transform = transforms.Compose(
    [transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=5,
                                         shuffle=False, num_workers=0)
# _,trainset = torch.utils.data.random_split(trainset,(49200,800))
# _,trainset = torch.utils.data.random_split(trainset,(49995,5))
# print(trainset.__len__())

# train_data, val_data = torch.utils.data.random_split(trainset,(int(0.8*len(trainset)),int(0.2*len(trainset))))
# print(train_data.__len__(),val_data.__len__() )

# trainloader = torch.utils.data.DataLoader(train_data, batch_size=5,
#                                           shuffle=True, num_workers=0)
# valloader = torch.utils.data.DataLoader(val_data, batch_size=5,
#                                           shuffle=True, num_workers=0)


Files already downloaded and verified
Files already downloaded and verified


In [0]:
# image_datasets= {'train': train_data,'val': val_data}
# dataloaders = {'train': trainloader, 'val': valloader}

# dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# class_names = image_datasets['train'].classes


In [0]:
def freeze_layers(model_ft, exclude=[]):
#   children = list(model_ft.named_children())
  for name,param in model_ft.named_parameters():   
    if(name not in  exclude):
      param.requires_grad = False

In [0]:
def countNonZeroWeights(model):
    nonzeros = 0
    weights = 0
    for name,param in model.named_parameters():
        if param is not None:
            nonzeros += torch.sum((param != 0).int()).data[0]
            weights += torch.sum(param).data[0]
    
    return nonzeros, weights

In [0]:
def set_threshold(model,prop=0.05):
  for child in model.named_children():    
    for child in child[1].named_children():
#       print(child)
      if type(child[1]) == MaskedLinear or type(child[1]) == MaskedConv: 
        child[1].set_threshold(prop=prop)
        print("layer {}  new threshold {:.4f}".format(child[0], child[1].threshold))        

In [0]:
def train_model_prune(model, dloaders, dataset_sizes, criterion, optimizer, scheduler,prop=0.05, num_epochs=25, device='cuda',pruning='threshold'):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    print(len(dloaders['train']))
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
                data_idx = 0
            else:
                model.eval()   # Set model to evaluate mode
                data_idx = 1

            running_loss = 0.0
            running_corrects = 0
            i=0
            
#             print(dloaders[phase].__iter__().next())
            # Iterate over data.
            for inputs, labels in dloaders[phase]:               
#                 print("batch {} phase {}".format(i, phase))
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    if pruning == 'L0':
                      loss = criterion(outputs, labels,model)
                    else:
                      loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        model.clamp_parameters()
                        exp_flops, exp_l0 = model.get_exp_flops_l0()
                i+=1
                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                           
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            if epoch % 5 == 0 and phase == 'train': 
              if pruning == 'threshold':
                set_threshold(model,prop=prop)
              elif pruning == 'L0':
                print(exp_flops.item(), exp_l0.item())
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return epoch_loss

In [0]:
class Masked:
  def make_mask(self, threshold,mask=None):
    if mask is None:
      print("new mask",device)
      self.mask = torch.ones(self.weight.size(), requires_grad=False).to(device)
    else:
      self.mask = mask      
    self.zeros = torch.zeros(self.weight.size(), requires_grad=False).to(device)
    self.threshold = threshold
  def set_threshold(self,prop=0.05):
    unique_weights = torch.unique(self.weight*self.mask)
    mask_size = self.mask.reshape(-1).size()[0]
#     mask_size = mask_size[0]*mask_size[1]
    mask_nonzero = torch.sum(self.mask.view([mask_size]))
    mask_total = mask_size
    print('nonzero proportion: {:.4f}'.format(mask_nonzero/mask_total))
    self.threshold = torch.max(torch.topk(torch.abs(unique_weights),int(prop*unique_weights.size()[0]),largest=False)[0])    
  def make_threshold_mask(self):
    self.mask = torch.where(torch.abs(self.weight) >= self.threshold,self.mask,self.zeros).to(device)
#     self.mask.requires_grad_(requires_grad=False)
  def mask_weight(self):
    self.weight = torch.nn.Parameter(self.weight*self.mask).to(device) 

    

# L0 pruning


In [0]:
class MaskedLinear(torch.nn.Linear,Masked):
  def __init__(self, in_features, out_features, bias=True, threshold=0.001,mask=None):
    super(MaskedLinear, self).__init__(in_features,out_features)
    self.make_mask(threshold,mask)
  def forward(self, input):
    self.make_threshold_mask()
    self.mask_weight()
#     print(self.mask[125:135,125:135])
#     print(self.weight[125:135,125:135])
    return F.linear(input, self.weight, self.bias)

class MaskedConv(torch.nn.Conv2d,Masked):
  def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias=True,threshold=0.0001):
    super(MaskedConv,self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
    self.make_mask(threshold)    
  def forward(self, input):
    self.mask_weight()
    return F.conv2d(input, self.weight, self.bias, self.stride,
                    self.padding, self.dilation, self.groups)
  
limit_a, limit_b, epsilon = -.1, 1.1, 1e-6
device='cuda'

class LinearL0(Module):
  """Implementation of L0 regularization for the input units of a fully connected layer"""
  def __init__(self, in_features, out_features, bias=True, weight_decay=1., droprate_init=0.5, temperature=2./3.,
                 lamba=1., local_rep=False, qz_loga=None, **kwargs):
        """
        :param in_features: Input dimensionality
        :param out_features: Output dimensionality
        :param bias: Whether we use a bias
        :param weight_decay: Strength of the L2 penalty
        :param droprate_init: Dropout rate that the L0 gates will be initialize d to
        :param temperature: Temperature of the concrete distribution
        :param lamba: Strength of the L0 penalty
        :param local_rep: Whether we will use a separate gate sample per element in the minibatch
        """
        super(LinearL0, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_prec = weight_decay
        self.weight = torch.nn.Parameter(torch.Tensor(in_features, out_features).to(device))
#         self.qz_loga = torch.Tensor(in_features).to(device)
        self.qz_loga = torch.nn.Parameter(torch.Tensor(in_features).to(device))
        self.temperature = temperature
        self.droprate_init = droprate_init if droprate_init != 0. else 0.5
        self.lamba = lamba
        self.use_bias = False
        self.local_rep = local_rep
        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_features).to(device))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
        self.reset_parameters()
        if qz_loga is not None:
          self.qz_loga = qz_loga

  def reset_parameters(self):
      torch.nn.init.kaiming_normal(self.weight, mode='fan_out')
      self.qz_loga.data.normal_(math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2)

      if self.use_bias:
          self.bias.data.fill_(0)

  def constrain_parameters(self, **kwargs):
      self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))

  def cdf_qz(self, x):
      """Implements the CDF of the 'stretched' concrete distribution"""
      xn = (x - limit_a) / (limit_b - limit_a)
      logits = math.log(xn) - math.log(1 - xn)
      return F.sigmoid(logits * self.temperature - self.qz_loga).clamp(min=epsilon, max=1 - epsilon).to(device)

  def quantile_concrete(self, x):
      """Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
      y = F.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) / self.temperature).to(device)
      return y * (limit_b - limit_a) + limit_a

  def _reg_w(self):
      """Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty"""
      logpw_col = torch.sum(- (.5 * self.prior_prec * self.weight.pow(2)) - self.lamba, 1).to(device)
      logpw = torch.sum((1 - self.cdf_qz(0)) * logpw_col).to(device)
      logpb = 0 if not self.use_bias else - torch.sum(.5 * self.prior_prec * self.bias.pow(2)).to(device)
      return logpw + logpb

  def regularization(self):
      return self._reg_w()

  def count_expected_flops_and_l0(self):
      """Measures the expected floating point operations (FLOPs) and the expected L0 norm"""
      # dim_in multiplications and dim_in - 1 additions for each output neuron for the weights
      # + the bias addition for each neuron
      # total_flops = (2 * in_features - 1) * out_features + out_features
      ppos = torch.sum(1 - self.cdf_qz(0))
      expected_flops = (2 * ppos - 1) * self.out_features
      expected_l0 = ppos * self.out_features
      if self.use_bias:
          expected_flops += self.out_features
          expected_l0 += self.out_features
#       return expected_flops.data[0], expected_l0.data[0]
      return expected_flops, expected_l0

  def get_eps(self, size):
      """Uniform random numbers for the concrete distribution"""
      eps = self.floatTensor(size).uniform_(epsilon, 1-epsilon).to(device)
      eps = Variable(eps)
      return eps

  def sample_z(self, batch_size, sample=True):
      """Sample the hard-concrete gates for training and use a deterministic value for testing"""
      if sample:
          eps = self.get_eps(self.floatTensor(batch_size, self.in_features))
          z = self.quantile_concrete(eps)
          return F.hardtanh(z, min_val=0, max_val=1).to(device)
      else:  # mode
          pi = F.sigmoid(self.qz_loga).view(1, self.in_features).expand(batch_size, self.in_features).to(device)
          return F.hardtanh(pi * (limit_b - limit_a) + limit_a, min_val=0, max_val=1).to(device)

  def sample_weights(self):
      z = self.quantile_concrete(self.get_eps(self.floatTensor(self.in_features)))
      mask = F.hardtanh(z, min_val=0, max_val=1).to(device)
      return mask.view(self.in_features, 1) * self.weight

  def forward(self, input):
      if self.local_rep or not self.training:
          z = self.sample_z(input.size(0), sample=self.training)
          xin = input.mul(z)
          output = xin.mm(self.weight)
      else:
          weights = self.sample_weights()
          output = input.mm(weights)
      if self.use_bias:
          output.add_(self.bias)
      return output

  def __repr__(self):
      s = ('{name}({in_features} -> {out_features}, droprate_init={droprate_init}, '
           'lamba={lamba}, temperature={temperature}, weight_decay={prior_prec}, '
           'local_rep={local_rep}')
      if not self.use_bias:
          s += ', bias=False'
      s += ')'
      return s.format(name=self.__class__.__name__, **self.__dict__)


In [0]:
def mask_network(network,layers_to_mask, n=None, threshold=0.002, linear_masking=None,random_init=False, bias=True,masks=None):
  """"
  replaces linear layers with masked linear layers
  network is the initial sequential container
  layers is a list of layers to mask
  random init is a logical indicating whether to preserve the initial weights or to modify them
  """
  network.masked_layers=[]
  for name,layer in network.named_children():   
    if int(name) in layers_to_mask:
      layer_mask = None
      if masks is not None:
        if name in masks:
          layer_mask = masks.get(name)      
      if type(layer)== torch.nn.Linear and linear_masking is None:
        masked_layer = MaskedLinear(layer.in_features, layer.out_features, bias=bias,threshold=threshold,masks=layer_mask)
      elif type(layer)== torch.nn.Linear and linear_masking =='L0':
        masked_layer = LinearL0(layer.in_features, layer.out_features, bias=bias, lamba=0.5/n,qz_loga=layer_mask)
#         masked_layer = LinearL0(layer.in_features, layer.out_features, bias=bias)
        network.masked_layers.append(masked_layer)
      elif type(layer)== torch.nn.Conv2d:
        masked_layer = MaskedConv(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, layer.dilation,layer.groups, bias=bias, threshold=threshold)
      if random_init != True and linear_masking != 'L0':
        masked_layer.weight = copy.deepcopy(layer.weight)
        masked_layer.bias = copy.deepcopy(layer.bias)
      elif random_init != True and linear_masking == 'L0':
        weight_copy = torch.transpose(copy.deepcopy(layer.weight),0,1)
        masked_layer.weights = torch.nn.Parameter(weight_copy)
        
      network[int(name)] = masked_layer

In [0]:
class L0_Meta_Objective():
  def __init__(self,train_data,val_data):
    self.trainloader = torch.utils.data.DataLoader(train_data, batch_size=5,
                                            shuffle=True, num_workers=0)
    
    self.valloader = torch.utils.data.DataLoader(val_data, batch_size=5,
                                            shuffle=True, num_workers=0)
    
    self.dataloaders = {'train':self.trainloader, 'val': self.valloader}
    self.image_datasets= {'train': train_data,'val': val_data}
    self.dataset_sizes = {x: len(self.image_datasets[x]) for x in ['train', 'val']}
    self.model_ft = vgg16_L0(pretrained=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.model_ft = self.model_ft.to(device)    
    self.criterion = nn.CrossEntropyLoss()
        
    freeze_layers(self.model_ft.features, exclude=[])    
    
  def add_L0_layers(self,layers_to_prune,initial_masks=None,n=None):    
    mask_network(self.model_ft.classifier,layers_to_prune,n=n,linear_masking="L0",masks=initial_masks)
    self.model_ft.masked_layers = self.model_ft.classifier.masked_layers  
  
  def inner_loss_function(self,outputs,targets,model):
      loss = self.criterion(outputs,targets)
      loss += self.model_ft.regularize(200)
      return loss
        
  def inner_train_loop(self, inner_epochs):    
    # Observe that all parameters are being optimized
    # Decay LR by a factor of 0.1 every 7 epochs
    self.optimizer_ft = optim.Adam(self.model_ft.parameters(), lr=0.001)
    exp_lr_scheduler = lr_scheduler.StepLR(self.optimizer_ft, step_size=7, gamma=0.1)
    final_mod = train_model_prune(self.model_ft, self.dataloaders,self.dataset_sizes, self.inner_loss_function, self.optimizer_ft, exp_lr_scheduler,
                                    num_epochs=inner_epochs, pruning="L0")
    return final_mod
    
def get_initial_masks(model,layers_to_mask,droprate_init=0.5):
    initial_masks = {}    
    for name,layer in model.named_children():
      if int(name) in layers_to_mask:
        qz_loga = torch.nn.Parameter(torch.Tensor(layer.in_features).to(device))
        qz_loga.data.normal_(math.log(1 - droprate_init) - math.log(droprate_init), 1e-2)
        initial_masks[name] = qz_loga                             
    return initial_masks
      


In [0]:
class VGG_L0(tv_vgg.VGG):
  def regularization(self):
    regularization = 0.
    for layer in self.layers:
        regularization += - (1. / self.N) * layer.regularization()
    if torch.cuda.is_available():
        regularization = regularization.cuda()
    return regularization
  
  def regularize(self, N):
    regularization = 0.
    for layer in self.masked_layers:
          regularization += - (1. / N) * layer.regularization()          
    if torch.cuda.is_available():
        regularization = regularization.cuda()
    return regularization
    
  def clamp_parameters(self):
    for layer in self.masked_layers:
      layer.constrain_parameters()
  
  def get_exp_flops_l0(self):
    total_flops = 0
    total_l0 = 0
    for layer in self.masked_layers:
      exp_flops, exp_l0 = layer.count_expected_flops_and_l0()
      total_flops += exp_flops
      total_l0 += exp_l0
    return total_flops, total_l0  
          

def vgg16_L0(pretrained=False, **kwargs):
  """VGG 16-layer model (configuration "D")
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  if pretrained:
      kwargs['init_weights'] = False
  model = VGG_L0(tv_vgg.make_layers(tv_vgg.cfg['D']), **kwargs)
  if pretrained:
      model.load_state_dict(model_zoo.load_url(tv_vgg.model_urls['vgg16']))
  return model


def run_normal_training_with_L0_pruning(this_trainset):
  print(this_trainset.__len__())  
  _,mytrainset = torch.utils.data.random_split(this_trainset,(49800,200))
  # _,trainset = torch.utils.data.random_split(trainset,(49995,5))
  print(mytrainset.__len__())

  mytrain_data, myval_data = torch.utils.data.random_split(mytrainset,(int(0.8*len(mytrainset)),int(0.2*len(mytrainset))))
  print(mytrain_data.__len__(),myval_data.__len__() )

  mytrainloader = torch.utils.data.DataLoader(mytrain_data, batch_size=5,
                                            shuffle=True, num_workers=0)
  myvalloader = torch.utils.data.DataLoader(myval_data, batch_size=5,
                                            shuffle=True, num_workers=0)
  mydataloaders = {'train': mytrainloader, 'val': myvalloader}
  image_datasets= {'train': mytrain_data,'val': myval_data}
  dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}  

  model_ft = vgg16_L0(pretrained=True)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model_ft = model_ft.to(device)

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model_ft = model_ft.to(device)

  freeze_layers(model_ft.features, exclude=[])
  mask_network(model_ft.classifier,[0,3,6],linear_masking="L0", n=200,random_init=False)
  model_ft.masked_layers = model_ft.classifier.masked_layers
  
  print(model_ft.classifier[0].weights.size())  

  criterion = nn.CrossEntropyLoss()

  def loss_function(outputs,targets, model):
    loss = criterion(outputs,targets)
    loss += model.regularize(200)
    return loss


  # Observe that all parameters are being optimized
  optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
#   [print(p) for p in model_ft.parameters()]
#   return

  # Decay LR by a factor of 0.1 every 7 epochs
  exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

  MO = L0_Meta_Objective(mytrain_data,myval_data)
  MO.add_L0_layers([0,3,6],n=200)

#   model_ft = train_model_prune(model_ft, mydataloaders,dataset_sizes, loss_function, optimizer_ft, exp_lr_scheduler,
#                        num_epochs=20, pruning="L0")

  model_ft = MO.inner_train_loop(20)
  
# run_normal_training_with_L0_pruning(trainset)

In [13]:
def gd_step(cost, params, lrate):
    """Perform one gradient descent step on the given cost function with learning
    rate lrate. Returns a new set of parameters, and (IMPORTANT) does not modify
    the input parameters."""

    cost_grad_fun = torch.autograd.grad(cost)
    grads = cost_grad_fun(params)
    opt_params = {}
    for p in params.keys():
        opt_params[p] = params[p] - grads[p]*lrate
    return opt_params

def set_params(model,inparams):
  for name,layer in model.named_children():
    for name_p, p in layer.named_parameters():
      key = str(name)+'.'+name_p
      if key in inparams.keys():
#         print(layer._parameters)
        layer._parameters[name_p] = inparams[key]
#         print(key)
#         print(inparams[key])
  return



#MAML attempt
def train_meta_prune_L0(trainset, valset, layers_to_prune,outer_steps=5, inner_steps=5, num_samples=800, inner_lr=0.001,outer_lr=0.001, device='cuda'):  
  shuffled_train = torch.utils.data.RandomSampler(trainset)
  train_sample_list = list(torch.utils.data.BatchSampler(shuffled_train,num_samples,False))
  shuffled_train = [x for x in shuffled_train]
  #sample model for dimensions
  model_ft = vgg16_L0(pretrained=True)
  initial_masks = get_initial_masks(model_ft.classifier, [0])
#   outer_optimizer = optim.RMSprop(initial_masks.values(), lr=0.001, momentum=0.9)
#   print(initial_masks)
  
  for i in range(outer_steps):
#     for phase in ['train','val']
    train_sample = [trainset[j] for j in train_sample_list[i]] 
    train_data, val_data = torch.utils.data.random_split(train_sample,(int(0.8*num_samples),int(0.2*num_samples)))
    
#     outer_optimizer.zero_grad()
    #meta obejctive object contains model and loaders
    MO = L0_Meta_Objective(train_data,val_data)
        
    #make separate inital masks for each layer    
    MO.add_L0_layers(layers_to_prune,initial_masks)
    print("new outer")
    losses_ta
    
    for i in range(5):
      print("i ",i)
      inputs,labels = MO.trainloader.__iter__().next()
      inputs = inputs.to(device)
      labels = labels.to(device)
      preds = MO.model_ft(inputs)
      
      loss =  F.cross_entropy(preds,labels)+MO.model_ft.regularize(640)
      vars = dict(MO.model_ft.classifier.named_parameters())
      grad = torch.autograd.grad(loss,vars.values())
#       grad = torch.autograd.grad(loss,[MO.model_ft.classifier[0].weights,MO.model_ft.classifier[3].weight, MO.model_ft.classifier[6].weight, MO.model_ft.classifier[0].qz_loga] )
      params = {}
      
      for p in zip(grad,vars.values(),vars.keys()):
#         print(p[1]-p[0])
        params[p[2]] = p[1] - inner_lr* p[0]
      
#       print(params['3.weight'])  
#       params = dict(p[0]= p[1] - inner_lr* p[0], zip(grad,vars.values(),vars.keys()))
      set_params( MO.model_ft.classifier,params)
#       loss += self.model_ft.regularize(640)
      print('final inner loss {:.4f}'.format(loss))
    
    inputs,labels = MO.dataloaders['train'].__iter__().next()

    inputs = inputs.to(device)
    labels = labels.to(device)
    final_preds = MO.model_ft(inputs)
    final_loss =  F.cross_entropy(final_preds,labels)+MO.model_ft.regularize(640)
    print('final outer loss {:.4f}'.format(final_loss))
    outer_grad = torch.autograd.grad(final_loss,initial_masks.values())
    for p in zip(outer_grad,initial_masks.values(),initial_masks.keys()):
#           print(p[1]-p[0])
        initial_masks[p[2]] = torch.nn.Parameter(p[1] - outer_lr* p[0])

# train_meta_prune_L0(trainset,testset,[0], outer_steps=10,inner_lr=0.1, outer_lr=0.01)


def set_params_from_t(model,inparams):
  for name,layer in model.named_children():
    for name_p, p in layer.named_parameters():
      key = str(name)+'.'+name_p
      if key in inparams.keys():        
        p_new = inparams[key].data
#         p_new.requires_grad = True        
#           model[int(name)].weight = p
#         print(model[int(name)]._parameters[name_p])
        model[int(name)]._parameters[name_p].data.copy_(p_new)
#   print(model[6]._parameters[name_p])
  return 

def meta_prune_reptile(trainset, valset, layers_to_prune,outer_steps=10, inner_steps=5, num_samples=200, inner_lr=0.001,outer_lr=0.001, device='cuda'):  
  shuffled_train = torch.utils.data.RandomSampler(trainset)
  train_sample_list = list(torch.utils.data.BatchSampler(shuffled_train,num_samples,False))
  
  train_sample = train_sample_list[0]
  train_data, val_data = torch.utils.data.random_split(train_sample,(int(0.8*num_samples),int(0.2*num_samples)))
  
  MO = L0_Meta_Objective(train_data,val_data)
  
  #make separate inital masks for each layer    
  MO.add_L0_layers(layers_to_prune,n=200)
  initial_p = {}
  for name, param in OrderedDict(MO.model_ft.classifier.named_parameters()).items():
#     print(param.data)
    i_p = torch.ones(param.data.size(),requires_grad=False).to(device)
    i_p.copy_(param.data)
#     i_p
#     i_p.requires_grad = False
    initial_p[name] = i_p
#   print(initial_p)
  for i in range(1,outer_steps):
    print("new outer ", i)
    train_sample = [trainset[j] for j in train_sample_list[i]] 
    train_data, val_data = torch.utils.data.random_split(train_sample,(int(0.8*num_samples),int(0.2*num_samples)))
    MO = L0_Meta_Objective(train_data,val_data)
    MO.add_L0_layers(layers_to_prune,n=200)
    print('initial',initial_p['6.qz_loga'][125:130])
#     print(dataloaders['train'].__iter__().next())
#     set_params_from_t(MO.model_ft.classifier, initial_p)
  
#     print('weights',MO.model_ft.classifier[6].weight[125:130,125:130])
    new_mod = MO.inner_train_loop(5)
#     print('after', MO.model_ft.classifier[6].weight[125:130,125:130])
#     print('initial after',initial_p['6.weight'][125:130,125:130])
    new_p = OrderedDict(MO.model_ft.classifier.named_parameters())
    
    for name, p in initial_p.items():
      update_p = (p - new_p[name].data)/outer_steps*outer_lr
#       if name == '6.weight':
#         print('p',p[125:130,125:130])
#         print('update',update_p[125:130,125:130])
      initial_p[name] = p+update_p.data
  
  


#   for i in range(outer_steps):
meta_prune_reptile(trainset,testset,[0,3,6], outer_steps=20,inner_lr=0.1, outer_lr=0.1)    



new outer  1
initial tensor([-0.0002, -0.0122,  0.0169, -0.0040, -0.0074], device='cuda:0')
32
Epoch 0/4
----------




train Loss: 1304.4893 Acc: 0.1750
204578928.0 102298656.0
val Loss: 1292.5214 Acc: 0.2500

Epoch 1/4
----------
train Loss: 1286.8639 Acc: 0.3438
val Loss: 1281.4763 Acc: 0.2750

Epoch 2/4
----------
train Loss: 1276.7361 Acc: 0.5813
val Loss: 1273.8656 Acc: 0.3500

Epoch 3/4
----------
train Loss: 1267.9896 Acc: 0.7563
val Loss: 1265.4070 Acc: 0.2750

Epoch 4/4
----------
train Loss: 1259.9328 Acc: 0.7438
val Loss: 1257.2574 Acc: 0.3000

Training complete in 0m 24s
Best val Acc: 0.350000
new outer  2
initial tensor([ 0.0003, -0.0121,  0.0174, -0.0036, -0.0070], device='cuda:0')
32
Epoch 0/4
----------
train Loss: 1305.4203 Acc: 0.1125
204580896.0 102299640.0
val Loss: 1292.2144 Acc: 0.2000

Epoch 1/4
----------
train Loss: 1286.9439 Acc: 0.3500
val Loss: 1281.8355 Acc: 0.2500

Epoch 2/4
----------
train Loss: 1277.0829 Acc: 0.5813
val Loss: 1273.6596 Acc: 0.3250

Epoch 3/4
----------
train Loss: 1268.2806 Acc: 0.7812
val Loss: 1265.3886 Acc: 0.3250

Epoch 4/4
----------
train Loss: 12

# run

In [0]:
def run_normal_training_with_pruning(this_trainset):
  _,mytrainset = torch.utils.data.random_split(this_trainset,(49200,800))

  mytrain_data, myval_data = torch.utils.data.random_split(mytrainset,(int(0.8*len(mytrainset)),int(0.2*len(mytrainset))))
  print(mytrain_data.__len__(),myval_data.__len__() )

  mytrainloader = torch.utils.data.DataLoader(mytrain_data, batch_size=5,
                                            shuffle=True, num_workers=0)
  myvalloader = torch.utils.data.DataLoader(myval_data, batch_size=5,
                                            shuffle=True, num_workers=0)
  mydataloaders = {'train': mytrainloader, 'val': myvalloader}
  image_datasets= {'train': mytrain_data,'val': myval_data}
  dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

  model_ft = models.vgg16(pretrained=True)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model_ft = model_ft.to(device)

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model_ft = model_ft.to(device)

  freeze_layers(model_ft.features, exclude=[])
  mask_network(model_ft.classifier,[0],threshold=0.0001)
  set_threshold(model_ft)
  
  criterion = nn.CrossEntropyLoss()  
     
  # Observe that all parameters are being optimized
  optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

  # Decay LR by a factor of 0.1 every 7 epochs
  exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  
  
  model_ft = train_model_prune(model_ft, mydataloaders,dataset_sizes, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=2)
  
# run_normal_training_with_pruning(trainset)

In [0]:
def train_meta_prune(model,trainset, outer_steps, num_samples=800, device='cuda'):
  mask_dict = {'0':torch.ones(model.classifier[0].weight.size()).to(device)}
  shuffled_train = torch.utils.data.RandomSampler(trainset)
  train_sample_list = list(torch.utils.data.BatchSampler(shuffled_train,num_samples,False))
  shuffled_train = [x for x in shuffled_train]
  for i in range(outer_steps):
#     train_sample = [trainset[j] for j in train_sample_list[i]] 
    
#     print(len(train_sample))
    _,train_sample = torch.utils.data.random_split(trainset,(49200,800))
    train_data, val_data = torch.utils.data.random_split(train_sample,(int(0.8*num_samples),int(0.2*num_samples)))

    trainloader = torch.utils.data.DataLoader(train_data, batch_size=5,
                                            shuffle=True, num_workers=0)
    valloader = torch.utils.data.DataLoader(val_data, batch_size=5,
                                            shuffle=True, num_workers=0)
    
    subdataloaders = {'train': trainloader, 'val': valloader}
    image_datasets= {'train': train_data,'val': val_data}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    
    model_ft = models.vgg16(pretrained=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_ft = model_ft.to(device)

    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

    # freeze_layers(model_ft.features, exclude=['28.weight'])
    freeze_layers(model_ft.features)   
    mask_network(model_ft.classifier,[0],threshold=0.0001,masks=mask_dict)
    model_ft = train_model_prune(model_ft, subdataloaders, dataset_sizes, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=10, prop=0.1)
    mask_dict = {'0':model_ft.classifier[0].mask}
#     set_threshold(model_ft)

#     cost = meta_objective({'train':trainloader, 'val':valoader}, model, optimizer, inner_epochs)


# model_ft = models.vgg16(pretrained=True)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_ft = model_ft.to(device)

# train_meta_prune(model_ft,trainset,15)