In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np

import ray
import resnet.models as models
import random,time
from time import sleep
import copy 
import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import argparse
import numpy as np
import os
import shutil
from torch.utils.tensorboard import SummaryWriter


if ray.is_initialized():
    ray.shutdown()
ray.init(address="10.21.5.172:11572")


In [None]:
import time
@ray.remote
def f():
    time.sleep(0.01)
    return ray.services.get_node_ip_address()

# Get a list of the IP addresses of the nodes that have joined the cluster.
set(ray.get([f.remote() for _ in range(1000)]))


In [None]:
import os

@ray.remote(num_gpus=1,max_calls=1)
def use_gpu():
    print("ray.get_gpu_ids(): {}".format(ray.get_gpu_ids()))
    print("CUDA_VISIBLE_DEVICES: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
    
[use_gpu.remote() for _ in range(3)]

In [None]:
arch = "resnet"
depth = 56
cuda = torch.cuda.is_available()
print("cuda is ready: ",cuda)
# cuda = False
seed = 1
save = "./logs"
dataset = "cifar10"
batch_size = 64
test_batch_size = 100
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

# ckpt_ssp_resnet = "./ckpt_ssp_resnet/checkpoint.pth.tar"
ckpt_ssp_resnet = "./ckpt_ssp_resnet"
lr = 0.1
momentum=0.9
weight_decay=1e-4 
log_interval=100
start_epoch = 0
epochs=160

In [None]:
if not os.path.exists(save):
    os.makedirs(save)
if not os.path.exists(ckpt_ssp_resnet):
    os.makedirs(ckpt_ssp_resnet)
    

In [None]:
train_loader1 = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data.cifar10', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Pad(4),
                       transforms.RandomCrop(32),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
train_loader2 = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data.cifar10', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Pad(4),
                       transforms.RandomCrop(32),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                   ])),
    batch_size=test_batch_size, shuffle=True, **{})

In [None]:
train_loader = [train_loader1,train_loader2]

In [None]:
# @ray.remote(num_gpus=1)
@ray.remote
class ParameterServer():
    def __init__(self,lr,num_workers,stalness_limit,test_loader,resume_from_ckpt):
        self.lr = lr
        self.model = models.__dict__[arch](dataset=dataset,depth=depth)
        self.stalness_table = [0] * num_workers
        self.stalness_limit = stalness_limit 
        self.global_step = 0
        self.eva_model = models.__dict__[arch](dataset=dataset,depth=depth)
#         self.eva_model.eval()
        self.optimizer = optim.SGD(self.model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay)
        self.test_loader = test_loader
        self.model.cpu()
        self.eva_model.cpu()
#         if cuda:
#             self.model.cuda()
#             self.eva_model.cuda()
        # tensorboard logger
        self.ps_writer = SummaryWriter()
        
        if resume_from_ckpt:
            self.model.load_state_dict(torch.load(resume_from_ckpt))


    def apply_gradients(self, gradients, wk_idx):
        print("applying gradients from the ",wk_idx, " worker")
        for idx, p in enumerate(self.model.parameters()):
            p.data -= self.lr * gradients[idx]
        self.stalness_table[wk_idx] += 1
        self.global_step += 1
        print("finished applying gradients from the ",wk_idx, " worker")
        if self.global_step % 10 == 0:
            print("global_step: ",self.global_step," and prepare evaluate")
            self.evaluate()
        
    def pull_weights(self):
        return self.model.state_dict()
    
    def check_stalness(self,wk_idx):
        min_iter = min(self.stalness_table)
        return self.stalness_table[wk_idx] - min_iter < self.stalness_limit
        
    def get_stalness(self):
        return min(self.stalness_table)
    
    def evaluate(self):
        print("going to evaluate")
        test_loss = 0.
        correct = 0.
        print("pulled weights")
        self.eva_model.load_state_dict(copy.deepcopy(self.model.state_dict()))
        print("loaded weights")
        print("length of the test_loader dataset is : ",len(test_loader.dataset))
        self.eva_model.eval()
        batch = iter(test_loader)
        data,target = next(batch)
        data,target = Variable(data,volatile=True),Variable(target)
        output = self.eva_model(data)
        test_loss = F.cross_entropy(output,target,size_average=True)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        print("evaling, get pred and going to cal correct")
#         correct = pred.eq(target.data.view_as(pred)).cpu().sum()
        correct = pred.eq(target.data.view_as(pred)).sum()
        print("evaling, got correct")
        #log tensorboard
        self.ps_writer.add_scalar('Accuracy/eval', (100.0 * correct) / len(data), self.global_step)
        self.ps_writer.add_scalar('Loss/eval',test_loss , self.global_step)
        
        
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
            test_loss, 
            correct, 
            len(data),
            100. * correct / len(data)))



In [None]:
@ray.remote(num_gpus=1)
# @ray.remote
def worker_task(ps,worker_index,stale_limit, train_loader,lr,momentum,weight_decay,batch_size=50):
    # Initialize the model.
    model = models.__dict__[arch](dataset=dataset,depth=depth)
    local_step = 0
    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay)
    if cuda:
        starttime = datetime.datetime.now()
        model.cuda()
        endtime = datetime.datetime.now()
        time_cost = (endtime - starttime).seconds
#         print("move model to gpu takes: ", time_cost, "seconds")
        
    wk_writer = SummaryWriter("ssp_resnet_runs/wk_"+str(worker_index))
    
    for batch_idx,(data,target) in enumerate(train_loader):
        if cuda:
            starttime = datetime.datetime.now()

            data,target = data.cuda(),target.cuda()
            mid = datetime.datetime.now()
#             print("move data to gpu takes: ", (mid - starttime).seconds, "seconds")
            model.cuda()
            endtime = datetime.datetime.now()
            time_cost = (endtime - starttime).seconds
#             print("move model to gpu takes: ", time_cost, "seconds")
            
        while(local_step - ray.get(ps.get_stalness.remote()) > stale_limit):
            print(worker_index," works too fast")
            sleep(1)
        # Get the current weights from the parameter server.
#         print("the ",worker_index," pulls wei from ps.")
        init_wei = ray.get(ps.pull_weights.remote())
        model.load_state_dict(init_wei)
#         print("the ",worker_index," loaded the latest wei from ps.")
        # Compute an update and push it to the parameter server.        
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
#         print(worker_index,' is generating output')
        output = model(data)
#         print(worker_index,' generated output done and going to calculate loss')
        loss = F.cross_entropy(output,target)
#         print(worker_index,' calculated loss and going to bp')
        loss.backward()
#         print(worker_index,' bp done')
        
        starttime = datetime.datetime.now()
        model.cpu()
        endtime = datetime.datetime.now()
        time_cost = (endtime - starttime).seconds
#         print("move model to cpu takes: ", time_cost, "seconds")


        grad = [p.grad for p in model.parameters()]
#         print(worker_index,' got the grad list')
        local_step += 1
        ps.apply_gradients.remote(grad,worker_index)
#         print(worker_index,' sended the grad to ps and going to move next step')
        optimizer.step()
        wk_writer.add_scalar("Loss/worker_train",loss,local_step)
#         print("the ",worker_index," has finished its ",local_step," update")


In [None]:
num_worker = 2
stalness_table = [0] * num_worker
stalness_limit = 4

ps = ParameterServer.remote(lr,num_worker,stalness_limit,test_loader,None)


In [None]:
sleep(30)

In [None]:
worker_tasks = [worker_task.remote(ps,i,stalness_limit,train_loader[i],lr,momentum,weight_decay) 
                for i in range(num_worker)]

In [None]:
def save_checkpoint(state, is_best, filepath):
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(os.path.join(filepath, 'checkpoint.pth.tar'), os.path.join(filepath, 'model_best.pth.tar'))

import datetime

while True:
    wei = ray.get(ps.pull_weights.remote())
    save_checkpoint(wei,False,ckpt_ssp_resnet)
    print("saved ckpt at: ", )
    print(datetime.datetime.now())
    time.sleep(600)