In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np

import ray
import models
import random,time
from time import sleep
import copy 
import datetime
import argparse
import sys


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import os
import shutil
from torch.utils.tensorboard import SummaryWriter
from filelock import FileLock

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def generate_train_loader(batch_size,kwargs):
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=True, download=True,
        transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
    batch_size=batch_size, shuffle=True, **kwargs)
    return train_loader

def generate_test_loader(test_batch_size):
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),batch_size=test_batch_size, shuffle=True)
    return test_loader

def _get_params(model):
    bns = {}
    non_bns = {}
    param_count = 0.
    bn_param_count = 0.
    for name,param in model.named_parameters():
        param_count += len(param)
        if 'bn' in name:
            bns[name] = param
            bn_param_count += len(param)
        else:
            non_bns[name] = param
    print("bn params occupies: ",bn_param_count/param_count)
    return bns,non_bns

In [3]:
@ray.remote
class ParameterServer():
    def __init__(self,args,test_loader):
        self.model = models.__dict__["resnet"](dataset="cifar10",depth=args.depth)
        self.stalness_table = [0] * args.num_workers
        self.stalness_limit = args.stalness_limit 
        self.global_step = 0
        self.lr = args.lr
        self.args = args
        self.eva_model = models.__dict__["resnet"](dataset="cifar10",depth=args.depth)
        self.optimizer = optim.SGD(self.model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
        self.test_loader = test_loader
        self.model.cpu()
        self.eva_model.cpu()
        self.ps_writer = SummaryWriter(os.path.join(os.getcwd(),(args.tb_path+'/ps')))
        self.save_path = args.save
        self.num_workers = (int)(args.num_workers)
        
        # get point to all non_bns parameters
        self.non_bns = [param.data for name,param in self.model.named_parameters() if 'bn' not in name]
        self.bns = [param.data for name,param in self.model.named_parameters() if 'bn' in name]
        self.bns_sync = [None] * args.num_workers
        self.cfg = None
        self.finished = [0] * args.num_workers
        
        if args.refine:
            if os.path.isfile(args.refine):
                print('found pruned ckpt')
                checkpoint = torch.load(args.refine)
                self.cfg = checkpoint['cfg']
                self.model = models.__dict__['resnet'](dataset='cifar10', depth=args.depth, cfg=checkpoint['cfg'])
                
                if args.resume:
                    if os.path.isfile(args.resume):
                        print("=> loading checkpoint '{}'".format(args.resume))
                        checkpoint = torch.load(args.resume)
                        self.global_step = checkpoint['global_step']
                        if 'optimizer' in checkpoint:
                            self.optimizer.load_state_dict(checkpoint['optimizer'])
                        self.stalness_table = [self.global_step/args.num_workers] * args.num_workers
                    else:
                        print("=> no checkpoint found at '{}'".format(args.resume))
                        
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer = optim.SGD(self.model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

        elif args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                self.global_step = checkpoint['global_step']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.stalness_table = [self.global_step/args.num_workers] * args.num_workers
                print("=> loaded checkpoint '{}' (global step: {})".format(args.resume, checkpoint['global_step']))                
                if 'epoch' in checkpoint: print("epoch: {}".format(checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))                    
                
#     def apply_gradients(self, iter_diff, wk_idx, epoch):
#         if args.debug: print("applying gradients from the ",wk_idx, " worker")
        
#         # updata all params
#         for idx, p in enumerate(self.model.parameters()):
#             p.data -= iter_diff[idx]

# #         if wk_idx == 0:
# #             # updata all params
# #             for idx, p in enumerate(self.model.parameters()):
# #                 p.data -= iter_diff[idx]
# #         else:
# #             # only update non_bns params
# #             for idx, tensor in enumerate(self.non_bns):
# #                 tensor -= iter_diff[idx]
        
# #         # only update non_bns params
# #         for idx, tensor in enumerate(self.non_bns):
# #             tensor -= iter_diff[idx]
        
#         self.stalness_table[wk_idx] += 1
#         self.global_step += 1
#         if args.debug: print("finished applying gradients from the ",wk_idx, " worker")
#         if self.global_step % 1000 == 0:
# #             print("global_step: ",self.global_step," and prepare evaluate")
# #             self.evaluate()
#             self.save_ckpt({
#                 'epoch':epoch,
#                 'global_step':self.global_step,
#                 'state_dict':self.model.state_dict(),
#                 'optimizer':self.optimizer.state_dict()
#             },filepath=os.path.join(os.getcwd(),self.save_path))
    
    def apply_gradients_with_running_bn(self,iter_diff,wk_idx,epoch):
        state = self.model.state_dict()
        for idx, p in enumerate(state):
            state[p] -= iter_diff[idx]
            
        self.stalness_table[wk_idx] += 1
        self.global_step += 1
        if args.debug: print("finished applying gradients from the ",wk_idx, " worker")
        if self.global_step % 782 == 0:
            self.save_ckpt({
                'epoch':epoch,
                'global_step':self.global_step,
                'state_dict':self.model.state_dict(),
                'optimizer':self.optimizer.state_dict(),
                'cfg':self.cfg
            },filepath=os.path.join(os.getcwd(),self.save_path))
            
        if self.global_step % (782 * self.num_workers) == 0:
            print("global_step: ",self.global_step," and prepare evaluate")
            self.evaluate()

        # when all worker finished
        if all(self.finished):
            self.save_ckpt({
                'epoch':epoch,
                'global_step':self.global_step,
                'state_dict':self.model.state_dict(),
                'optimizer':self.optimizer.state_dict(),
                'cfg':self.cfg
            },filepath=os.path.join(os.getcwd(),self.save_path))
            
            print("All worker finished its job, and have saved the ckpt")
#     def apply_gradients_non_bns(self,iter_diff,wk_idx,epoch):
#         for i in range(len(self.non_bns)):
#             self.non_bns[i] -= iter_diff[i]
# #         print(wk_idx,"finished updating non bns on ps")

        
#     def apply_gradients_bns(self,iter_diff,wk_idx,epoch):
#         for idx, tensor in enumerate(self.bns):
# #             if idx == len(iter_diff) / 2:
# #                 print(wk_idx,"is in the middle of updating bns on ps")
#             tensor -= iter_diff[idx]
# #         print(wk_idx,"finished updating non bns on ps")
# #         self.stalness_table[wk_idx] += 1
# #         self.global_step += 1

#     def apply_gradients_partical_bns(self, iter_diff, wk_idx, epoch):
#         if wk_idx == 0 :
#             for i in range(0,(int)(len(self.bns)/self.num_workers)):
#                 self.bns[i] -= iter_diff[i]
#         elif wk_idx == self.num_workers:
#             for i in range((int)(len(self.bns) * wk_idx / self.num_workers + 1) , len(self.bns)):
#                 self.bns[i] -= iter_diff[i]
#         else:
#             for i in range((int)(len(self.bns) * wk_idx / self.num_workers + 1) , (int)(len(self.bns) * (wk_idx + 1) / self.num_workers)):
#                 self.bns[i] -= iter_diff[i]
#         self.stalness_table[wk_idx] += 1
#         self.global_step += 1
#         if self.global_step % 1000 == 0:
#             self.save_ckpt({
#                 'epoch':epoch,
#                 'global_step':self.global_step,
#                 'state_dict':self.model.state_dict(),
#                 'optimizer':self.optimizer.state_dict()
#             },filepath=os.path.join(os.getcwd(),self.save_path))
    def pull_cfg(self):
        return self.cfg
        
    def pull_weights(self):
        return self.model.state_dict()
    
    def pull_non_bn_weights(self):
        self.non_bns = [param.data for name,param in self.model.named_parameters() if 'bn' not in name]
        return copy.deepcopy(self.non_bns)
    
    def pull_bn_weights(self):
        self.bns = [param.data for name,param in self.model.named_parameters() if 'bn' in name]
        return copy.deepcopy(self.bns)
    
    def get_optim(self):
        return self.optimizer
    
    def pull_optimizer_state(self):
        return self.optimizer.state_dict()

    def check_stalness(self,wk_idx):
        min_iter = min(self.stalness_table)
        return self.stalness_table[wk_idx] - min_iter < self.stalness_limit
        
    def get_stalness(self):
        return min(self.stalness_table)
    
    def get_stalness_table(self):
        return self.stalness_table
    
    def get_global_step(self):
        return self.global_step
    
    def get_model(self):
        return self.model
    
    def save_ckpt(self,state,filepath):
        torch.save(state,os.path.join(filepath,'checkpoint.pth.tar'))
    
    def get_bns_ready(self):
        return any(self.bns_sync) == False

    def set_finished(self,worker_index):
        print('worker:',worker_index,'has finished its job')
        self.finished[worker_index] = 1
        
    def aggregate_bns(self,wk_bns,worker_index):
        self.bns_sync[worker_index] = wk_bns
        if all(self.bns_sync):
            for i in range(len(self.bns)):
                tmp = copy.deepcopy(self.bns_sync[0][i])
                for j in range(1, self.num_workers):
                    tmp += self.bns_sync[j][i]
                self.bns[i] = tmp / self.num_workers    
            self.bns_sync = [None] * self.num_workers
        return self.bns_sync
    
    def evaluate(self):
        print("going to evaluate")
        test_loss = 0.
        correct = 0.
        print("pulled weights")
        self.eva_model= copy.deepcopy(self.model)
        print("loaded weights")
        print("length of the test_loader dataset is : ",len(self.test_loader.dataset))
        self.eva_model.eval()
        count = 0
        for data,target in self.test_loader:
            count += 1
            if count % 20 == 0: print("in eval, the batch is: ",count)
            data, target = Variable(data,volatile=True),Variable(target)
            output = self.eva_model(data)
            batch_loss = F.cross_entropy(output, target, size_average=False).data
            test_loss += batch_loss
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        len_testset = len(self.test_loader.dataset)
        test_loss /= len_testset 
        accuracy = correct.float() / len_testset
        # log 
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f})\n'.format(
        test_loss, correct, len_testset,accuracy))

        self.ps_writer.add_scalar('Accuracy/eval', accuracy, self.global_step)
        self.ps_writer.add_scalar('Loss/eval',test_loss , self.global_step)
        


In [4]:
@ray.remote(num_gpus=1)
def worker_task(args,ps,worker_index, train_loader):
    # Initialize the model.
#     if args.debug: print(worker_index, " worker is going to sleep ",worker_index*5000)
#     time.sleep(worker_index * 5000)
    
    model = models.__dict__["resnet"](dataset="cifar10",depth=args.depth)
    local_step = 0
    
    wk_non_bns = [param.data for name,param in model.named_parameters() if 'bn' not in name]
    wk_bns = [param.data for name,param in model.named_parameters() if 'bn' in name]
    
    if args.cuda:
        starttime = datetime.datetime.now()
        model.cuda()
        endtime = datetime.datetime.now()
        time_cost = (endtime - starttime).seconds
        if args.debug: print("move model to gpu takes: ", time_cost, "seconds")
            
    # all workers owns the same init values
#     init_wei = ray.get(ps.pull_weights.remote())
#     model.load_state_dict(init_wei)
    
    
    optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)
    
            
    if args.refine:
        if os.path.isfile(args.refine):
            print('found pruned ckpt')
            checkpoint = torch.load(args.refine)
            model = models.__dict__['resnet'](dataset='cifar10', depth=args.depth, cfg=checkpoint['cfg'])
            model.load_state_dict(checkpoint['state_dict'])
            optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)
            if args.resume:
                print("args.resume filled! ")
                checkpoint = torch.load(args.resume)
                model.load_state_dict(checkpoint['state_dict'])
                local_step = int(checkpoint['global_step'] / args.num_workers)
                print("local_step=",local_step)
                args.start_epoch = checkpoint['epoch']
                if 'optimizer' in checkpoint:
                    optimizer.load_state_dict(checkpoint['optimizer'])
                print("worker #",worker_index,"resumes from local_step: ",local_step)
                if 'epoch' in checkpoint:
                    args.start_epoch = checkpoint['epoch'] 
    elif args.resume:
        checkpoint = torch.load(args.resume)
        local_step = int(checkpoint['global_step'] / args.num_workers)
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("worker #",worker_index,"resumes from local_step: ",local_step)
        if 'epoch' in checkpoint:
            args.start_epoch = checkpoint['epoch']
            
    
    wk_writer = SummaryWriter(os.path.join(os.getcwd(),args.tb_path,('wk_'+str(worker_index))))
    print("worker #",worker_index," is online")
    
#     if local_step in [5500,7000]:
#     for param_group in optimizer.param_groups:
#         print("previous lr: ",param_group['lr'])
#         param_group['lr'] = 0.001
#         print("new lr: ",param_group['lr'])

    
    for epoch in range(args.start_epoch,args.epochs):
        avg_loss = 0.
        train_correct = 0.
        for batch_idx,(data,target) in enumerate(train_loader):
            if args.cuda:
                starttime = datetime.datetime.now()
                data,target = data.cuda(),target.cuda()
                mid = datetime.datetime.now()
                if args.debug: print("move data to gpu takes: ", (mid - starttime).seconds, "seconds")
                model.cuda()
                endtime = datetime.datetime.now()
                time_cost = (endtime - starttime).seconds
                if args.debug: print("move model to gpu takes: ", time_cost, "seconds")
                
            while(local_step - ray.get(ps.get_stalness.remote()) > args.stalness_limit):
                sleep(1)
            

            # Get all weights from the parameter server.
            if args.debug: print("the ",worker_index," pulls wei from ps.")
            init_wei = ray.get(ps.pull_weights.remote())
            model.load_state_dict(init_wei)
            model.cpu()
#           # This doesn't contain BN running mean and var
#             old_tensors = copy.deepcopy([p.data for p in model.parameters()]) 

            old_tensors = copy.deepcopy([param.data for name, param in model.state_dict().items()])

            model.cuda()
#             if args.debug: print("the ",worker_index," loaded the latest wei from ps.")
#             model.cpu()    
#             wk_non_bns = [param.data for name,param in model.named_parameters() if 'bn' not in name]
#             wk_bns = [param.data for name,param in model.named_parameters() if 'bn' in name]
#             model.cuda()
#             old_tensors_non_bns = copy.deepcopy(wk_non_bns)
#             old_tensors_bns = copy.deepcopy(wk_bns)
                
#             # Get only non-bn weights from the parameter server.
#             ps_non_bns = ray.get(ps.pull_non_bn_weights.remote())
#             print(worker_index,"pulled non bns from ps")
#             assert len(ps_non_bns) == len(wk_non_bns)
#             for i in range(len(ps_non_bns)):
#                 if i == len(ps_non_bns) / 2:
#                     print(worker_index,"is in the middle of updating non bns")
#                 wk_non_bns[i] = ps_non_bns[i]
#             print(worker_index,"updated non bns from ps and is going to pull bns ")
            
#             # Get only bn weights from the parameter server.
#             ps_bns = ray.get(ps.pull_bn_weights.remote())
#             print(worker_index,"pulled bns from ps")
#             assert len(ps_bns) == len(wk_bns)
#             for i in range(len(ps_bns)):
#                 if i == len(ps_bns) / 2 :
#                     print(worker_index, "is in the middle of updating bns")
#                 wk_bns[i] = ps_bns[i]
#             print(worker_index,"updated bns from ps")

                
            # Compute an update and push it to the parameter server.        
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            if args.debug: print(worker_index,' is generating output')
            output = model(data)
            if args.debug: print(worker_index,' generated output done and going to calculate loss')
            loss = F.cross_entropy(output,target)
            avg_loss += loss
            pred = output.data.max(1,keepdim=True)[1]
            batch_correct = pred.eq(target.data.view_as(pred)).cpu().sum()
            train_correct += batch_correct
            if args.debug: print(worker_index,' calculated loss and going to bp')
            loss.backward()
            if args.debug: print(worker_index,' bp done')
            
            if(args.sr):
                # additional subgradient descent on the sparsity-induced penalty term
                for m in model.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.weight.grad.data.add_(args.s*torch.sign(m.weight.data))  # L1

            
            # calculate difference for this iteration
            optimizer.step()
            model.cpu()
#           # This doesn't contain BN running mean and var            
#             new_tensors = [p.data for p in model.parameters()]
            new_tensors = [param.data for name, param in model.state_dict().items()]

            iter_diff = [(old_tensor - new_tensor)/args.num_workers for (old_tensor, new_tensor) in zip(old_tensors,new_tensors)]
            model.cuda()
            # print("passing sizes: ",sys.getsizeof(iter_diff))
#             ps.apply_gradients.remote(iter_diff,worker_index,epoch)
            ps.apply_gradients_with_running_bn.remote(iter_diff,worker_index,epoch)
#             else:
#                 old_tensors_non_bns = copy.deepcopy([param.data for name,param in model.named_parameters() if 'bn' not in name])
#                 optimizer.step()
#                 wk_non_bns = [param.data for name,param in model.named_parameters() if 'bn' not in name]
#                 iter_diff_non_bns = [(old_tensor_non_bns - new_tensor_non_bns)/args.num_workers for (old_tensor_non_bns, new_tensor_non_bns) in zip(old_tensors_non_bns,wk_non_bns)]
#                 ps.apply_gradients_non_bns.remote(iter_diff_non_bns,worker_index,epoch)
        

                
            
#             if worker_index == 0:
# #                 print(worker_index, "passing all params")
#                 # calculate difference for this iteration
#                 old_tensors = copy.deepcopy([p.data for p in model.parameters()])    
#                 optimizer.step()
#                 new_tensors = [p.data for p in model.parameters()]
#                 iter_diff = [(old_tensor - new_tensor)/args.num_workers for (old_tensor, new_tensor) in zip(old_tensors,new_tensors)]
#                 # print("passing sizes: ",sys.getsizeof(iter_diff))
#                 ps.apply_gradients.remote(iter_diff,worker_index,epoch)
#             else:
# #                 print(worker_index, "passing non bn params")
#                 #calculate only non-bn parameters difference 
#                 old_tensors = copy.deepcopy([param.data for name,param in model.named_parameters() if 'bn' not in name])
#                 optimizer.step()
#                 new_tensors = [param.data for name,param in model.named_parameters() if 'bn' not in name]
#                 iter_diff = [(old_tensor - new_tensor)/args.num_workers for (old_tensor, new_tensor) in zip(old_tensors,new_tensors)]
#                 # print("passing sizes: ",sys.getsizeof(iter_diff))
#                 ps.apply_gradients.remote(iter_diff,worker_index,epoch)

#             # calculate bns and non_bns parameters difference
#             optimizer.step()
#             model.cpu()
#             new_tensors_non_bns = [param.data for name,param in model.named_parameters() if 'bn' not in name]
#             new_tensors_bns = [param.data for name,param in model.named_parameters() if 'bn' in name]
#             model.cuda()
#             iter_diff_non_bns = [(old_tensor_non_bns - new_tensor_non_bns)/args.num_workers for (old_tensor_non_bns, new_tensor_non_bns) in zip(old_tensors_non_bns,new_tensors_non_bns)]
#             iter_diff_bns = [(old_tensor_bns - new_tensor_bns)/args.num_workers for (old_tensor_bns, new_tensor_bns) in zip(old_tensors_bns,new_tensors_bns)]
#             # print("passing sizes: ",sys.getsizeof(iter_diff))
#             ps.apply_gradients_non_bns.remote(iter_diff_non_bns,worker_index,epoch)
#             print(worker_index, "pushed non bns to ps and is going to push bns to ps")
#             ps.apply_gradients_partical_bns.remote(iter_diff_bns,worker_index,epoch)
#             print(worker_index, "pushed bns to ps")
            
            
#             # aggregate and sync bn parameters    
#             if sync_bns_flag:
#                 print("SYNC BNS: worker #",worker_index," is checking to pull bns from ps ")
#                 if ray.get(ps.get_bns_ready.remote()):
#                     print("SYNC BNS: worker #",worker_index," is pulling bns from ps ")
#                     ps_wei = ray.get(ps.pull_weights.remote())
#                     model.load_state_dict(ps_wei)
#                     sync_bns_flag = False
#             if local_step % args.sync_bns == 0:
#                 sync_bns_flag = True
#                 model.cpu()
#                 if args.debug: print("SYNC BNS: goint to aggregate bns")
#                 wk_bns = [param.data for name,param in model.named_parameters() if 'bn' in name]
#                 bns_sync = ray.get(ps.aggregate_bns.remote(wk_bns,worker_index))
#                 print("SYNC BNS: bns of worker #",worker_index," have been pushed")
                
#                 if any(bns_sync) == False: # when all workers have pushed their own bns parameters to ps
#                     print("SYNC BNS: worker #",worker_index," is last worker pushed its bns")
#                     ps_wei = ray.get(ps.pull_weights.remote())
#                     model.load_state_dict(ps_wei)
#                     sync_bns_flag = False
#                 if args.cuda: model.cuda()
            
                
                
                
            
#             if local_step in [5500,7000]:
#                 for param_group in optimizer.param_groups:
#                     print("previous lr: ",param_group['lr'])
#                     param_group['lr'] *= 0.1
#                     print("new lr: ",param_group['lr'])
                
            local_step += 1

            if batch_idx % args.log_interval == 0:
                print('The {} worker, Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
                worker_index, epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data))
                
                for name,param in model.named_parameters():
                    wk_writer.add_histogram(name, param, local_step)

                wk_writer.add_scalar("Loss/worker_train",loss,local_step)
                wk_writer.add_scalar("Accuracy/worker_train",batch_correct.float()/len(data),local_step)
                
        print("The {} worker finished its {} epoch with loss: {} and accuracy: {}".format(
            worker_index,
            epoch,
            avg_loss / float(len(train_loader.dataset)),
            train_correct.float() / float(len(train_loader.dataset))
        ))
    print("worker #",worker_index," has finished its job, going offline")
    ps.set_finished.remote(worker_index)

In [5]:
parser = argparse.ArgumentParser(description='Distributed SSP CIFAR-10 Restnet train with network slimming')
parser.add_argument('--ray-master',type=str,default='127.0.0.1')
parser.add_argument('--redis-port',type=str,default='6379')
parser.add_argument('--batch-size',type=int,default=64)
parser.add_argument('--test-batch-size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--start-epoch', default=0, type=int)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float)
parser.add_argument('--resume', default=None, type=str) 
parser.add_argument('--refine', default=None, type=str) 
parser.add_argument('--no-cuda', action='store_true', default=False)
parser.add_argument('--save', default='./logs', type=str)
parser.add_argument('--depth', default=164, type=int)
parser.add_argument('--tb-path', default='./logs', type=str)
parser.add_argument('--log-interval', type=int, default=100)
parser.add_argument('--num-workers',type=int,default=1)
parser.add_argument('--stalness-limit',type=int,default=5)
parser.add_argument('--debug',action='store_true',default=False)
parser.add_argument('--sync-bns',type=int, default=194)
parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',
                    help='train with channel sparsity regularization')
parser.add_argument('--s', type=float, default=0.0001,
                    help='scale sparse rate (default: 0.0001)')


args = parser.parse_args(args=['--num-workers=3','--tb-path=/userhome/34/gyu/logs_sr/3wk_p1/train34/',
                               '--save=/userhome/34/gyu/logs_sr/3wk_p1/train34/',
                               '--epochs=10',
                               '--refine=/userhome/34/gyu/logs_sr/3wk_p1/train24/checkpoint.pth.tar',
                               '--lr=0.001',
                               '-sr','--s=0.00001'])

# '--resume=/userhome/34/gyu/logs_sr/checkpoint.pth.tar'
# '--tb-path=logs_no_bns','--save=logs_no_bns'
# '-sr','--s=0.00001'

# '--refine=/userhome/34/gyu/logs_sr/3wk_p1/prune_1st/pruned.pth.tar',
# '--resume=/userhome/34/gyu/logs_sr/3wk_p1/prune_1st/refine1/checkpoint.pth.tar',
args.cuda = not args.no_cuda and torch.cuda.is_available()

In [6]:
if ray.is_initialized():
    ray.shutdown()

In [7]:
ray.init(address=args.ray_master+':'+args.redis_port)

    

{'node_ip_address': '10.21.5.171',
 'redis_address': '10.21.5.171:6379',
 'object_store_address': '/tmp/ray/session_2019-12-30_15-42-27_520221_8838/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2019-12-30_15-42-27_520221_8838/sockets/raylet',
 'webui_url': 'http://10.21.5.171:8080/?token=f97b60ba77ae3f54cbf67e4af32807b34f1cd5a2730770f6',
 'session_dir': '/tmp/ray/session_2019-12-30_15-42-27_520221_8838'}

In [8]:
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

test_loader = generate_test_loader(args.test_batch_size)
train_loaders = [generate_train_loader(args.batch_size,kwargs) for _ in range(args.num_workers)]

resume_from_ckpt = args.resume if (args.resume and os.path.isfile(args.resume)) else None

ps = ParameterServer.remote(args,test_loader)



Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
[2m[36m(pid=8873)[0m found pruned ckpt


In [9]:
worker_tasks = [worker_task.remote(args,ps,idx,train_loaders[idx]) for idx in range(args.num_workers)]

[2m[36m(pid=8868)[0m found pruned ckpt
[2m[36m(pid=8868)[0m worker # 0  is online
[2m[36m(pid=5929, ip=10.21.5.174)[0m found pruned ckpt
[2m[36m(pid=5929, ip=10.21.5.174)[0m worker # 1  is online
[2m[36m(pid=11586, ip=10.21.5.173)[0m found pruned ckpt
[2m[36m(pid=11586, ip=10.21.5.173)[0m worker # 2  is online
[2m[36m(pid=5929, ip=10.21.5.174)[0m The 1 worker finished its 0 epoch with loss: 0.0016215710202232003 and accuracy: 0.9642999768257141
[2m[36m(pid=8868)[0m The 0 worker finished its 0 epoch with loss: 0.0015970575623214245 and accuracy: 0.9648600220680237
[2m[36m(pid=8873)[0m global_step:  2346  and prepare evaluate
[2m[36m(pid=8873)[0m going to evaluate
[2m[36m(pid=8873)[0m pulled weights
[2m[36m(pid=8873)[0m loaded weights
[2m[36m(pid=8873)[0m length of the test_loader dataset is :  10000
[2m[36m(pid=8873)[0m in eval, the batch is:  20
[2m[36m(pid=8873)[0m in eval, the batch is:  40
[2m[36m(pid=8873)[0m in eval, the batch is:  6

[2m[36m(pid=8868)[0m The 0 worker finished its 2 epoch with loss: 0.0014956232625991106 and accuracy: 0.9674400091171265
[2m[36m(pid=5929, ip=10.21.5.174)[0m The 1 worker finished its 2 epoch with loss: 0.0015332846669480205 and accuracy: 0.9666600227355957
[2m[36m(pid=8873)[0m global_step:  7038  and prepare evaluate
[2m[36m(pid=8873)[0m going to evaluate
[2m[36m(pid=8873)[0m pulled weights
[2m[36m(pid=8873)[0m loaded weights
[2m[36m(pid=8873)[0m length of the test_loader dataset is :  10000
[2m[36m(pid=8873)[0m in eval, the batch is:  20
[2m[36m(pid=8873)[0m in eval, the batch is:  40
[2m[36m(pid=8873)[0m in eval, the batch is:  60
[2m[36m(pid=8873)[0m in eval, the batch is:  80
[2m[36m(pid=8873)[0m in eval, the batch is:  100
[2m[36m(pid=8873)[0m in eval, the batch is:  120
[2m[36m(pid=8873)[0m in eval, the batch is:  140
[2m[36m(pid=8873)[0m 
[2m[36m(pid=8873)[0m Test set: Average loss: 0.3682, Accuracy: 8982/10000 (0.8982)
[2m[36m(

[2m[36m(pid=5929, ip=10.21.5.174)[0m The 1 worker finished its 4 epoch with loss: 0.0014382307417690754 and accuracy: 0.968999981880188
[2m[36m(pid=8868)[0m The 0 worker finished its 4 epoch with loss: 0.0014279981842264533 and accuracy: 0.9696199893951416
[2m[36m(pid=8873)[0m global_step:  11730  and prepare evaluate
[2m[36m(pid=8873)[0m going to evaluate
[2m[36m(pid=8873)[0m pulled weights
[2m[36m(pid=8873)[0m loaded weights
[2m[36m(pid=8873)[0m length of the test_loader dataset is :  10000
[2m[36m(pid=8873)[0m in eval, the batch is:  20
[2m[36m(pid=8873)[0m in eval, the batch is:  40
[2m[36m(pid=8873)[0m in eval, the batch is:  60
[2m[36m(pid=8873)[0m in eval, the batch is:  80
[2m[36m(pid=8873)[0m in eval, the batch is:  100
[2m[36m(pid=8873)[0m in eval, the batch is:  120
[2m[36m(pid=8873)[0m in eval, the batch is:  140
[2m[36m(pid=8873)[0m 
[2m[36m(pid=8873)[0m Test set: Average loss: 0.3735, Accuracy: 8965/10000 (0.8965)
[2m[36m(

[2m[36m(pid=5929, ip=10.21.5.174)[0m The 1 worker finished its 6 epoch with loss: 0.0014155225362628698 and accuracy: 0.9690399765968323
[2m[36m(pid=8873)[0m global_step:  16422  and prepare evaluate
[2m[36m(pid=8873)[0m going to evaluate
[2m[36m(pid=8873)[0m pulled weights
[2m[36m(pid=8873)[0m loaded weights
[2m[36m(pid=8873)[0m length of the test_loader dataset is :  10000
[2m[36m(pid=8873)[0m in eval, the batch is:  20
[2m[36m(pid=8873)[0m in eval, the batch is:  40
[2m[36m(pid=8873)[0m in eval, the batch is:  60
[2m[36m(pid=8873)[0m in eval, the batch is:  80
[2m[36m(pid=8873)[0m in eval, the batch is:  100
[2m[36m(pid=8873)[0m in eval, the batch is:  120
[2m[36m(pid=8873)[0m in eval, the batch is:  140
[2m[36m(pid=8873)[0m 
[2m[36m(pid=8873)[0m Test set: Average loss: 0.3747, Accuracy: 8989/10000 (0.8989)
[2m[36m(pid=8873)[0m 
[2m[36m(pid=11586, ip=10.21.5.173)[0m The 2 worker finished its 6 epoch with loss: 0.0013929714914411306 




[2m[33m(pid=raylet, ip=10.21.5.173)[0m     @           0x4de776  RedisAsioClient::handle_read()
[2m[33m(pid=raylet, ip=10.21.5.173)[0m     @           0x4dd9a8  boost::asio::detail::reactive_null_buffers_op<>::do_complete()
[2m[33m(pid=raylet, ip=10.21.5.173)[0m     @           0x425bcd  boost::asio::detail::scheduler::run()
[2m[33m(pid=raylet, ip=10.21.5.173)[0m     @           0x40fb1d  main
[2m[33m(pid=raylet, ip=10.21.5.173)[0m     @     0x7f2ba8438b97  __libc_start_main
[2m[33m(pid=raylet, ip=10.21.5.173)[0m     @           0x4207e1  (unknown)
[2m[33m(pid=raylet, ip=10.21.5.174)[0m F0102 20:54:09.338240  5908 node_manager.cc:481]  Check failed: client_id != gcs_client_->client_table().GetLocalClientId() Exiting because this node manager has mistakenly been marked dead by the monitor.
[2m[33m(pid=raylet, ip=10.21.5.174)[0m *** Check failure stack trace: ***
[2m[33m(pid=raylet, ip=10.21.5.174)[0m     @           0x6f8d1a  google::LogMessage::Fail()
[2m[



[2m[36m(pid=11586, ip=10.21.5.173)[0m 
[2m[36m(pid=11586, ip=10.21.5.173)[0m Traceback (most recent call last):
[2m[36m(pid=11586, ip=10.21.5.173)[0m   File "/userhome/34/gyu/anaconda3/envs/pytorch_env/lib/python3.7/site-packages/ray/workers/default_worker.py", line 105, in <module>
[2m[36m(pid=11586, ip=10.21.5.173)[0m     job_id=None)
[2m[36m(pid=11586, ip=10.21.5.173)[0m   File "/userhome/34/gyu/anaconda3/envs/pytorch_env/lib/python3.7/site-packages/ray/utils.py", line 67, in push_error_to_driver
[2m[36m(pid=11586, ip=10.21.5.173)[0m     worker.raylet_client.push_error(job_id, error_type, message, time.time())
[2m[36m(pid=11586, ip=10.21.5.173)[0m   File "python/ray/_raylet.pyx", line 327, in ray._raylet.RayletClient.push_error
[2m[36m(pid=11586, ip=10.21.5.173)[0m   File "python/ray/_raylet.pyx", line 98, in ray._raylet.check_status
[2m[36m(pid=11586, ip=10.21.5.173)[0m ray.exceptions.RayletError: The Raylet died with this message: [RayletClient] Connecti

[2m[36m(pid=8868)[0m Traceback (most recent call last):
[2m[36m(pid=8868)[0m   File "/userhome/34/gyu/anaconda3/envs/pytorch_env/lib/python3.7/site-packages/ray/worker.py", line 936, in _process_task
[2m[36m(pid=8868)[0m     outputs = function_executor(*arguments)
[2m[36m(pid=8868)[0m   File "<ipython-input-4-869321c4bd9a>", line 86, in worker_task
[2m[36m(pid=8868)[0m   File "/userhome/34/gyu/anaconda3/envs/pytorch_env/lib/python3.7/site-packages/ray/actor.py", line 148, in remote
[2m[36m(pid=8868)[0m     return self._remote(args, kwargs)
[2m[36m(pid=8868)[0m   File "/userhome/34/gyu/anaconda3/envs/pytorch_env/lib/python3.7/site-packages/ray/actor.py", line 169, in _remote
[2m[36m(pid=8868)[0m     return invocation(args, kwargs)
[2m[36m(pid=8868)[0m   File "/userhome/34/gyu/anaconda3/envs/pytorch_env/lib/python3.7/site-packages/ray/actor.py", line 163, in invocation
[2m[36m(pid=8868)[0m     num_return_vals=num_return_vals)
[2m[36m(pid=8868)[0m   File "

In [None]:
print(1)

In [None]:
ray.get(ps.get_stalness_table.remote())

In [None]:
# prune model 
%cd /userhome/34/gyu/git-repo/rethinking-network-pruning/cifar/network-slimming/
!python resprune_modified.py --dataset cifar10 --depth 164 --percent 0.1 --model /userhome/34/gyu/logs_sr/3wk_p1/train4/checkpoint.pth.tar --save /userhome/34/gyu/logs_sr/3wk_p1/prune4/
%cd /userhome/34/gyu

In [None]:
# ps.evaluate.remote()

In [None]:
local_test_loader = generate_test_loader(64)
local_train_loader = generate_train_loader(64,{'num_workers': 1, 'pin_memory': True})


In [None]:

# test_writer = SummaryWriter()


checkpoint = torch.load('/userhome/34/gyu/logs_sr/3wk_p1/train24/checkpoint.pth.tar')
local_test_model = models.__dict__["resnet"](dataset="cifar10",depth=164,cfg=checkpoint['cfg'])
local_test_model.load_state_dict(checkpoint['state_dict'])
local_test_model.cuda()


In [None]:
# load pruned model
local_test_loader = generate_test_loader(64)
local_train_loader = generate_train_loader(64,{'num_workers': 1, 'pin_memory': True})

test_writer = SummaryWriter()


checkpoint = torch.load(args.refine)
local_test_model = models.__dict__['resnet'](dataset='cifar10', depth=args.depth, cfg=checkpoint['cfg'])

checkpoint = torch.load('/userhome/34/gyu/tmp/prune_40/ssp_refine_1wk_runningbn/checkpoint.pth.tar')
local_test_model.load_state_dict(checkpoint['state_dict'])
local_test_model.cuda()


In [None]:
from torchsummary import summary
summary(local_test_model, input_size=(3, 32, 32))


In [None]:
data=[]
state = local_test_model.state_dict()
for idx, p in enumerate(state):
    data.append(state[p])

In [None]:
data[0][0][0][0][0] = 100

In [None]:
data2=[]
state2 = local_test_model.state_dict()
for idx, p in enumerate(state2):
    data2.append(state2[p])

In [None]:
data2[0][0][0][0][0]

In [None]:
state1 = local_test_model.state_dict()

In [None]:
state1['layer1.0.bn1.running_mean']

In [None]:
state1['layer1.0.bn1.running_var']

In [None]:
state1['layer1.0.bn1.weight']


In [None]:
state1['layer1.0.bn1.bias']

In [None]:
local_test_model = ray.get(ps.get_model.remote())


In [None]:
# optim= ray.get(ps.get_optim.remote())
# local_test_model = ray.get(ps.get_model.remote())
# cfg = ray.get(ps.pull_cfg.remote())

# torch.save({'epoch':3,
#                 'global_step':8601,
#                 'state_dict':local_test_model.state_dict(),
#                 'optimizer':optim.state_dict(),
#                 'cfg':cfg
#             },'/userhome/34/gyu/logs_sr/3wk_p1/prune_2nd/checkpoint.pth.tar')

In [None]:
def test_model(model, test_dataloader):
    local_test_model = model
    local_test_loader = test_dataloader
    # local_test_model.train()
    local_test_model.eval()
    # test dataset loader 
    test_loss = 0.
    correct = 0.
    batch_count = 0.
    for data, target in local_test_loader:
        data,target = data.cuda(),target.cuda()
        batch_count += 1
        data, target = Variable(data, volatile=True), Variable(target)
        output = local_test_model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).data # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        batch_correct = pred.eq(target.data.view_as(pred)).sum()
        correct += batch_correct
#         if batch_count % 100  == 0:
#             print("        with model.eval(), batch num: ",batch_count, " with correct: ",int(batch_correct.data), " / ",len(data))

    test_loss /= len(local_test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.6f})\n'.format(
        test_loss, correct, len(local_test_loader.dataset),
        correct / float(len(local_test_loader.dataset))))

In [None]:
local_test_model.eval()
local_test_model.cuda()
test_model(local_test_model,local_test_loader)


In [None]:
# 经过一个epoch 的 test set
state2 = local_test_model.state_dict()


In [None]:
state2['layer1.0.bn1.running_mean']

In [None]:
state2['layer1.0.bn1.running_var']

In [None]:
state2['layer1.0.bn1.weight']

In [None]:
state2['layer1.0.bn1.bias']

In [None]:
# bn层的4个参数都没有变

In [None]:
# train dataset loader, but set model.eval(), acc=11% for one epoch 
# local_test_model.eval()
local_test_model.train()
# train dataset loader

test_loss = 0.
correct = 0.
train_batch_count = 0.
num_batch = 0
for data, target in local_train_loader:
#     if train_batch_count % 30 == 0:
#         print(train_batch_count)
#         local_test_model.eval()
#         test_model(local_test_model,local_test_loader)
        
    if train_batch_count==150:
        break
        
    num_batch += 1   
    local_test_model.train()
    data,target = data.cuda(),target.cuda()
    train_batch_count += 1
    data, target = Variable(data, volatile=True), Variable(target)
    output = local_test_model(data)
    test_loss += F.cross_entropy(output, target, size_average=False).data # sum up batch loss
    pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
    batch_correct = pred.eq(target.data.view_as(pred)).sum()
    correct += batch_correct
    
    for name,param in local_test_model.named_parameters():
        test_writer.add_histogram(name, param, num_batch)

    
#     if train_batch_count % 100  == 0:
#         print("With model.train(), batch num: ",train_batch_count, " , with correct: ",int(batch_correct.data), " / ", len(data))

test_loss /= len(local_train_loader.dataset)
print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.6f})\n'.format(
    test_loss, correct, len(local_train_loader.dataset),
    correct / float(len(local_train_loader.dataset))))

In [None]:
state3 = local_test_model.state_dict()
print(state3['layer1.0.bn1.running_mean'])
print(state3['layer1.0.bn1.running_var'])
print(state3['layer1.0.bn1.weight'])
print(state3['layer1.0.bn1.bias'])

In [None]:
bn_name = []
bn_li = []
for m in local_test_model.modules():
    if isinstance(m, nn.BatchNorm2d):
        bn_name.append(m.name)
        bn_li.append(m.weight.data)


In [None]:
# tmp_model = ray.get(ps.get_model.remote())
# state = {
#     'epoch':3,
#     'global_step':6795,
#     'state_dict':tmp_model.state_dict()
# }
# torch.save(state,'/userhome/34/gyu/logs_sr/3wk_p1/prune_1st/refine1/checkpoint_tmp.pth.tar')

In [None]:
local_test_model2
test_model(local_test_model2,local_test_loader)


In [None]:
def _get_params(model):
    bns = {}
    non_bns = {}
    param_count = 0.
    bn_param_count = 0.
    for name,param in model.named_parameters():
        param_count += len(param)
        if 'bn' in name:
            bns[name] = param
            bn_param_count += len(param)
        else:
            non_bns[name] = param
    print("bn params occupies: ",bn_param_count/param_count)
    return bns,non_bns

In [None]:
_get_params(local_test_model)


In [None]:
local_bns2

In [None]:
state_dict = ray.get(ps.pull_weights.remote())

In [None]:
non_bn={}
for ele in state_dict:
    if 'bn' not in ele:
        non_bn[ele]=state_dict[ele]


In [None]:
init_wei = ray.get(ps.pull_weights.remote())

In [None]:
init_wei[]

In [None]:
import _to_remove_resnet.models as tmpmodels

In [None]:
res56 = tmpmodels.__dict__["resnet"](dataset="cifar10",depth=56)
res56.cuda()

In [None]:
from torchsummary import summary
summary(res56, input_size=(3, 32, 32))


In [None]:
import torch
wide_50 = torch.hub.load('pytorch/vision:v0.4.2', 'wide_resnet50_2', pretrained=True)
wide_50.cuda()

In [None]:
summary(wide_50, input_size=(3, 32, 32))


In [None]:
resnext50 = torch.hub.load('pytorch/vision:v0.4.2', 'resnext50_32x4d', pretrained=True)
resnext50.cuda()

In [None]:
summary(resnext50, input_size=(3, 64, 64))