In [1]:
# pytorch optimizer 让动量参与计算，以及手动修改lr

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np

import ray
import resnet.models as models
import random,time
from time import sleep
import copy 
import datetime
import argparse


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import os
import shutil
from torch.utils.tensorboard import SummaryWriter
from filelock import FileLock

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def generate_train_loader(batch_size,kwargs):
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=True, download=True,
        transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
    batch_size=batch_size, shuffle=True, **kwargs)
    return train_loader

def generate_test_loader(test_batch_size):
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),batch_size=test_batch_size, shuffle=True)
    return test_loader

@ray.remote
class ParameterServer():
    def __init__(self,args,test_loader):
        self.model = models.__dict__["resnet"](dataset="cifar10",depth=args.depth)
        self.stalness_table = [0] * args.num_workers
        self.stalness_limit = args.stalness_limit 
        self.global_step = 0
        self.lr = args.lr
        self.args = args
        self.eva_model = models.__dict__["resnet"](dataset="cifar10",depth=args.depth)
        self.optimizer = optim.SGD(self.model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
        self.test_loader = test_loader
        self.model.cpu()
        self.eva_model.cpu()
        self.ps_writer = SummaryWriter(os.path.join(os.getcwd(),(args.tb_path+'/ps')))
        self.save_path = args.save
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                self.global_step = checkpoint['global_step']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.stalness_table = [self.global_step/args.num_workers] * args.num_workers
                print("=> loaded checkpoint '{}' (global step: {})".format(args.resume, checkpoint['global_step']))                
                if 'epoch' in checkpoint: print("epoch: {}".format(checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

    def apply_gradients(self, iter_diff, wk_idx, epoch):
        if args.debug: print("applying gradients from the ",wk_idx, " worker")    
            
#         if epoch == (int)(self.args.epochs * 0.5):
#             self.optimizer.param_groups[0]['lr'] = 0.01
#             print("lr has been changed to: ",self.optimizer.param_groups[0]['lr'])
    
        for idx, p in enumerate(self.model.parameters()):
            p.data -= iter_diff[idx]
            
        self.stalness_table[wk_idx] += 1
        self.global_step += 1
        if args.debug: print("finished applying gradients from the ",wk_idx, " worker")
        if self.global_step % 1000 == 0:
#             print("global_step: ",self.global_step," and prepare evaluate")
#             self.evaluate()
            self.save_ckpt({
                'epoch':epoch,
                'global_step':self.global_step,
                'state_dict':self.model.state_dict(),
                'optimizer':self.optimizer.state_dict()
            },filepath=os.path.join(os.getcwd(),self.save_path))
            
        
    def pull_weights(self):
        return self.model.state_dict()
    
    def get_optim(self):
        return self.optimizer
    
    def pull_optimizer_state(self):
        return self.optimizer.state_dict()

    def check_stalness(self,wk_idx):
        min_iter = min(self.stalness_table)
        return self.stalness_table[wk_idx] - min_iter < self.stalness_limit
        
    def get_stalness(self):
        return min(self.stalness_table)
    
    def get_stalness_table(self):
        return self.stalness_table
    
    def get_global_step(self):
        return self.global_step
    
    def save_ckpt(self,state,filepath):
        torch.save(state,os.path.join(filepath,'checkpoint.pth.tar'))
        
    def evaluate(self):
        print("going to evaluate")
        test_loss = 0.
        correct = 0.
        print("pulled weights")
        self.eva_model.load_state_dict(copy.deepcopy(self.model.state_dict()))
        print("loaded weights")
        print("length of the test_loader dataset is : ",len(self.test_loader.dataset))
        self.eva_model.eval()
        count = 0
        for data,target in self.test_loader:
            count += 1
            if count % 20 == 0: print("in eval, the batch is: ",count)
            data, target = Variable(data,volatile=True),Variable(target)
            output = self.eva_model(data)
            batch_loss = F.cross_entropy(output, target, size_average=False).data
            test_loss += batch_loss
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        len_testset = len(self.test_loader.dataset)
        test_loss /= len_testset 
        accuracy = correct / len_testset
        # log 
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len_testset,accuracy))

        self.ps_writer.add_scalar('Accuracy/eval', accuracy, self.global_step)
        self.ps_writer.add_scalar('Loss/eval',test_loss , self.global_step)
        
        



In [3]:
@ray.remote(num_gpus=1)
def worker_task(args,ps,worker_index, train_loader):
    # Initialize the model.
#     if args.debug: print(worker_index, " worker is going to sleep ",worker_index*5000)
#     time.sleep(worker_index * 5000)
    
    model = models.__dict__["resnet"](dataset="cifar10",depth=args.depth)
    local_step = 0
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    if args.cuda:
        starttime = datetime.datetime.now()
        model.cuda()
        endtime = datetime.datetime.now()
        time_cost = (endtime - starttime).seconds
        if args.debug: print("move model to gpu takes: ", time_cost, "seconds")
    if args.resume:
        checkpoint = torch.load(args.resume)
        local_step = checkpoint['global_step'] / args.num_workers
        optimizer.load_state_dict(checkpoint['optimizer'])
        if 'epoch' in checkpoint:
            args.start_epoch = checkpoint['epoch']

    wk_writer = SummaryWriter(os.path.join(os.getcwd(),args.tb_path,('wk_'+str(worker_index))))
    
    for epoch in range(args.start_epoch,args.epochs):
        avg_loss = 0.
        train_acc = 0.
        for batch_idx,(data,target) in enumerate(train_loader):
            if args.cuda:
                starttime = datetime.datetime.now()
                data,target = data.cuda(),target.cuda()
                mid = datetime.datetime.now()
                if args.debug: print("move data to gpu takes: ", (mid - starttime).seconds, "seconds")
                model.cuda()
                endtime = datetime.datetime.now()
                time_cost = (endtime - starttime).seconds
                if args.debug: print("move model to gpu takes: ", time_cost, "seconds")
                
            while(local_step - ray.get(ps.get_stalness.remote()) > args.stalness_limit):
                print(worker_index," works too fast")
                sleep(1)
            # Get the current weights from the parameter server.
            if args.debug: print("the ",worker_index," pulls wei from ps.")
            init_wei = ray.get(ps.pull_weights.remote())
            model.load_state_dict(init_wei)
            if args.debug: print("the ",worker_index," loaded the latest wei from ps.")
            # Compute an update and push it to the parameter server.        
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            if args.debug: print(worker_index,' is generating output')
            output = model(data)
            if args.debug: print(worker_index,' generated output done and going to calculate loss')
            loss = F.cross_entropy(output,target)
            avg_loss += loss
            pred = output.data.max(1,keepdim=True)[1]
            batch_acc = pred.eq(target.data.view_as(pred)).cpu().sum()
            train_acc += batch_acc
            if args.debug: print(worker_index,' calculated loss and going to bp')
            loss.backward()
            if args.debug: print(worker_index,' bp done')
            starttime = datetime.datetime.now()
            model.cpu()
            endtime = datetime.datetime.now()
            time_cost = (endtime - starttime).seconds
            if args.debug: print("move model to cpu takes: ", time_cost, "seconds")
            old_tensors = copy.deepcopy([p.data for p in model.parameters()])    
            optimizer.step()
            new_tensors = [p.data for p in model.parameters()]
            local_step += 1
            iter_diff = [old_tensor - new_tensor for (old_tensor, new_tensor) in zip(old_tensors,new_tensors)]
            ps.apply_gradients.remote(iter_diff,worker_index,epoch)
            if batch_idx % args.log_interval == 0:
                print('The {} worker, Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
                worker_index, epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data))
                wk_writer.add_scalar("Loss/worker_train",loss,local_step)
                wk_writer.add_scalar("Accuracy/worker_train",batch_acc,local_step)
        print("The {} worker finished its {} epoch with loss: {} and accuracy: {}".format(
            worker_index,
            epoch,
            avg_loss / len(train_loader),
            train_acc / float(len(train_loader)
        )))

In [4]:
parser = argparse.ArgumentParser(description='Distributed SSP CIFAR-10 Restnet train with network slimming')
parser.add_argument('--ray-master',type=str,default='127.0.0.1')
parser.add_argument('--redis-port',type=str,default='6379')
parser.add_argument('--batch-size',type=int,default=64)
parser.add_argument('--test-batch-size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=160)
parser.add_argument('--start-epoch', default=0, type=int)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float)
parser.add_argument('--resume', default=None, type=str) 
parser.add_argument('--no-cuda', action='store_true', default=False)
parser.add_argument('--save', default='./logs', type=str)
parser.add_argument('--depth', default=164, type=int)
parser.add_argument('--tb-path', default='./logs', type=str)
parser.add_argument('--log-interval', type=int, default=100)
parser.add_argument('--num-workers',type=int,default=1)
parser.add_argument('--stalness-limit',type=int,default=5)
parser.add_argument('--debug',action='store_true',default=False)

args = parser.parse_args(args=['--num-workers=3','--resume=/userhome/34/gyu/logs/checkpoint.pth.tar'])
# '--resume=/userhome/34/gyu/logs/checkpoint.pth.tar',
args.cuda = not args.no_cuda and torch.cuda.is_available()

In [5]:
if ray.is_initialized():
    ray.shutdown()

In [6]:
ray.init(address=args.ray_master+':'+args.redis_port)

    

{'node_ip_address': '10.21.5.171',
 'redis_address': '10.21.5.171:6379',
 'object_store_address': '/tmp/ray/session_2019-11-25_15-37-23_318826_32312/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2019-11-25_15-37-23_318826_32312/sockets/raylet',
 'webui_url': 'http://10.21.5.171:8080/?token=4d52497a8914e8a2b0d5eb7019ae2ed3b1d99039d5c83c00',
 'session_dir': '/tmp/ray/session_2019-11-25_15-37-23_318826_32312'}

In [7]:
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

test_loader = generate_test_loader(args.test_batch_size)
train_loaders = [generate_train_loader(args.batch_size,kwargs) for _ in range(args.num_workers)]

resume_from_ckpt = args.resume if (args.resume and os.path.isfile(args.resume)) else None

ps = ParameterServer.remote(args,test_loader)



Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
[2m[36m(pid=30727, ip=10.21.5.172)[0m => loading checkpoint '/userhome/34/gyu/logs/checkpoint.pth.tar'
[2m[36m(pid=30727, ip=10.21.5.172)[0m => loaded checkpoint '/userhome/34/gyu/logs/checkpoint.pth.tar' (global step: 46000)
[2m[36m(pid=30727, ip=10.21.5.172)[0m epoch: 16


In [8]:
worker_tasks = [worker_task.remote(args,ps,idx,train_loaders[idx]) for idx in range(args.num_workers)]

[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast


In [9]:
print(1)

1
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=30729, ip=10.21.5.172)[0m 1  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=30729, ip=10.21.5.172)[0m 1  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=30729, ip=10.21.5.172)[0m 1  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=30729, ip=10.21.5.172)[0m 1  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=30729, ip=10.21.5.172)[0m 1  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=30729, ip=10.21.5.172)[0m 1  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too fast
[2m[36m(pid=30729, ip=10.21.5.172)[0m 1  works too fast
[2m[36m(pid=17473, ip=10.21.5.173)[0m 0  works too 

In [None]:
# ray.get(ps.get_optim.remote())

In [13]:
local_test_loader = generate_test_loader(256)

# test batch size = 256 , model.train(),  acc = 85%
# test batch size = 256 , model.eval(),  acc = 11%
# test batch size = 64 , model.train(),  acc = 84%
# test batch size = 64 , model.eval(),  acc = 11%

local_test_model = models.__dict__["resnet"](dataset="cifar10",depth=args.depth)
checkpoint = torch.load('/userhome/34/gyu/logs/checkpoint.pth.tar')
local_test_model.load_state_dict(checkpoint['state_dict'])
local_test_model.cuda()
local_test_model.eval()
test_loss = 0
correct = 0
batch_count = 0
for data, target in local_test_loader:
    data,target = data.cuda(),target.cuda()
    batch_count += 1
    print("this is in ",batch_count, " batch")
    data, target = Variable(data, volatile=True), Variable(target)
    output = local_test_model(data)
    test_loss += F.cross_entropy(output, target, size_average=False).data # sum up batch loss
    pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct += pred.eq(target.data.view_as(pred)).sum()

test_loss /= len(local_test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
    test_loss, correct, len(local_test_loader.dataset),
    100. * correct / len(local_test_loader.dataset)))
print(correct / float(len(local_test_loader.dataset)))


this is in  1  batch




this is in  2  batch
this is in  3  batch
this is in  4  batch
this is in  5  batch
this is in  6  batch
this is in  7  batch
this is in  8  batch
this is in  9  batch
this is in  10  batch
this is in  11  batch
this is in  12  batch
this is in  13  batch
this is in  14  batch
this is in  15  batch
this is in  16  batch
this is in  17  batch
this is in  18  batch
this is in  19  batch
this is in  20  batch
this is in  21  batch
this is in  22  batch
this is in  23  batch
this is in  24  batch
this is in  25  batch
this is in  26  batch
this is in  27  batch
this is in  28  batch
this is in  29  batch
this is in  30  batch
this is in  31  batch
this is in  32  batch
this is in  33  batch
this is in  34  batch
this is in  35  batch
this is in  36  batch
this is in  37  batch
this is in  38  batch
this is in  39  batch
this is in  40  batch

Test set: Average loss: 22.0370, Accuracy: 1155/10000 (11.0%)

tensor(0, device='cuda:0')
