### Asynchronous SGD updates using Distributed Cache Redis

1. Create spot instances
2. Mount EFS (Not required for our project)
3. Start redis server in all the instances and create a cluster
4. Pull the code from github repo (https://github.com/SrujithPoondla/vanilla-hogwild.git)
5. If need to divide the dataset between the nodes run the specific cell
6. Run the scripts to start training
7. After training ends close the instances

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from aws_setup import *
from argparse import ArgumentParser
import boto3
import os 
sys.argv = ['foo']

#### Define parameters

In [3]:
parser = ArgumentParser(description='Asynchronous SGD updates using Redis')
parser.add_argument('--n-nodes', type=int, default=1, metavar='N',
                    help='how many aws instances to start')
parser.add_argument('--is-redis', action='store_true', default = True,
                    help="Choose whether the model to be trained using redis or not."
                    "If not using Redis, model will be trained on single process")


_StoreTrueAction(option_strings=['--is-redis'], dest='is_redis', nargs=0, const=True, default=True, type=None, choices=None, help='Choose whether the model to be trained using redis or not.If not using Redis, model will be trained on single process', metavar=None)

In [4]:
args = parser.parse_args()
vpc_name='vpc-1b056b60'
args.n_nodes = 4
n_instances = args.n_nodes
instance_type = 'm4.10xlarge'
ami_sr = 'ami-d70f8ea8'
a_zone = 'us-east-1a'

if args.n_nodes is 2 and args.is_redis:
    print('Cant create a cluster with 2 redis nodes. Chose either a 3 node cluster or single instance')

#### Get Existing VPC by tag name

In [5]:
vpc = get_vpc(vpc_name); vpc

ec2.Vpc(id='vpc-1b056b60')

#### Create EFS (if you haven't already)

In [6]:
# efs_tag = f'{vpc_name}-efs'

In [7]:
# efs = create_efs(efs_tag, vpc, performance_mode='maxIO')

#### Request Spot instance

In [8]:
instance_name = f'{vpc_name}-instance'
# Recommend a high compute instance as we need to do multi-threaded resizing later on

In [9]:
spot_price = get_spot_prices()[instance_type]
bid_price = "%.4f" % (float(spot_price)*3)
print(f'Spot price: {spot_price}, Bid price: {bid_price}')

Spot price: 0.601800, Bid price: 1.8054


In [10]:
launch_specs = LaunchSpecs(vpc, instance_type=instance_type, ami= ami_sr,availability_zone=a_zone).build()

In [11]:
# launch_specs['BlockDeviceMappings'][0]['Ebs']['VolumeSize'] = 1000

In [12]:
launch_specs

{'BlockDeviceMappings': [{'DeviceName': '/dev/sda1',
   'Ebs': {'DeleteOnTermination': True,
    'VolumeSize': 20,
    'VolumeType': 'gp2'}}],
 'ImageId': 'ami-d70f8ea8',
 'InstanceType': 'm4.10xlarge',
 'KeyName': 'aws-key-spot-instance',
 'NetworkInterfaces': [{'AssociatePublicIpAddress': True,
   'DeviceIndex': 0,
   'Groups': ['sg-2624da6f'],
   'SubnetId': 'subnet-10d9d04d'}]}

In [13]:
ec2 = boto3.resource('ec2')
filters = [
    {
        'Name': 'instance-state-name',
        'Values': ['running']
    }
]
# filter the instances based on filters() above
ec2_instances = list(ec2.instances.filter(Filters=filters))
instances = []
for instance in ec2_instances:
    instances.append(instance)
instances_to_request = n_instances-len(instances)
instances = create_multiple_spot_instance(instance_name, launch_specs,instance_count=args.n_nodes, spot_price=bid_price)
print(instances)

Created keypair
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-021239995f82d19bf
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-02c139173d5bbb5c5
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-02eb3d0346eab2fd6
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-0fc085b5b608a075e
Rebooting...
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@54.89.57.254
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@34.207.140.5
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@18.232.158.114
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@18.232.98.229
[ec2.Instance(id='i-021239995f82d19bf'), ec2.Instance(id='i-02eb3d0346eab2fd6'), ec2.Instance(id='i-0fc085b5b608a075e'), ec2.Instance(id='i-02c139173d5bbb5c5')]


In [14]:
# instance = get_instance(instance_name); instance
ssh_commands = []
public_ip_list = []
private_ip_list = []
for instance in instances:
    # for each instance, append to lists
    private_ip_list.append(instance.private_ip_address)
    public_ip_list.append(instance.public_ip_address)
    ssh_commands.append(get_ssh_command(instance))
print(ssh_commands, public_ip_list, private_ip_list)

['ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@54.89.57.254', 'ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@34.207.140.5', 'ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@18.232.158.114', 'ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@18.232.98.229'] ['54.89.57.254', '34.207.140.5', '18.232.158.114', '18.232.98.229'] ['10.0.0.12', '10.0.0.6', '10.0.0.9', '10.0.0.4']


#### Terminating instances

In [15]:
# for instance in instances:
#     print(instance.terminate())

#### Activate Conda Environment in all the instances and check whether we need to create a cluster or not. Then start redis using conf files.

In [16]:
# for client,ip in zip(clients,private_ip_list):
#     if 'redis' not in run_command(client,'tmux ls'):
#         redis = TmuxSession(client, 'redis-sess')
#     run_command(client, 'cd ~/vanilla-hogwild && git stash && git pull && git checkout stable')
#     if (args.n_nodes >= 3):
#         redis.run_command('rm dump.rdb && rm appendonly.aof && rm nodes-6379.conf')
#         redis.run_command('echo bind '+ip+'>> ~/redis-conf/redis_cluster.conf')
#         print(redis.run_command('nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis_cluster.conf &'))
#         redis.run_command('-d')
#     else:
#         redis.run_command('echo bind '+ip+'>> ~/redis-conf/redis_cluster.conf')
#         redis.run_command('rm dump.rdb && rm appendonly.aof && rm nodes-6379.conf')
#         print(redis.run_command('~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis.conf'))
#         redis.run_command('-d')

# for client,ip in zip(clients,private_ip_list):
#     run_command(client, 'cd ~/vanilla-hogwild && git stash && git pull && git checkout stable')
#     if (args.n_nodes >= 3):
#         run_command(client,'cd /home/ubuntu/ && rm dump.rdb && rm appendonly.aof && rm nodes-6379.conf')
#         run_command(client,'echo bind '+ip+'>> ~/redis-conf/redis_cluster.conf')
#         print(run_command(client, 'nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis_cluster.conf &'))
#     else:
#         run_command(client,'cd /home/ubuntu/ && rm dump.rdb && rm appendonly.aof && rm nodes-6379.conf')
#         print(run_command(client, 'nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis.conf &'))

commands = ['cd ~/vanilla-hogwild', 'git stash', 'git pull' ,'git checkout stable',
            'rm -rf ~/redis-data', 'cd ~','ls' ,'mkdir redis-data']

for instance,pr_ip in zip(instances,private_ip_list):
    client = connect_to_instance(instance)
    if (args.n_nodes >= 3):
        indep_commands = ['echo bind '+pr_ip+'>> ~/redis-conf/redis_cluster.conf',
                          'requirepass lsmldeeplearning',
            'nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis_cluster.conf >\
                          /home/ubuntu/redis-data/redis-log.out &']
    else:
        indep_commands = ['requirepass lsmldeeplearning','nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis.conf > \
                          /home/ubuntu/redis-data/redis-log.out &']
    commands = commands + indep_commands
    inp, out,err = client.exec_command("\n".join(commands))
#     sleep(5)
    inp.write('\n')
    inp.flush()
    output = out.read()
    # Close down
    out.close()
    inp.close()
    client.close()
    print(output)
    
        


Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Exception: timed out Retrying...
Exception: timed out Retrying...
Connected!
b"No local changes to save\nUpdating 2c6955b..0f8465d\nFast-forward\n aws_setup.py         |   3 +-\n lsml_testbench.ipynb | 503 +++++++++++++++++++++++----------------------------\n train.py             |   4 +-\n 3 files changed, 230 insertions(+), 280 deletions(-)\nYour branch is up-to-date with 'origin/stable'.\ncifar10_data\nlargescale.yml\nminiconda3\nmnist_data\nredis-4.0.9\nredis-conf\nredis-ml\nvanilla-hogwild\n"
Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Connected!
b"No local changes to save\nUpdating 2c6955b..0f8465d\nFast-forward\n aws_setup.py         |   3 +-\n lsml_testbench.ipynb | 503 +++++++++++++++++++++++----------------------------\n train.py             |   4 +-\n 3 files changed, 230 insertions(+), 280 deletions(-)\nYour branch is up-to-date with 'origin/stable'.

#### Create redis cluster

In [17]:
# if args.n_nodes >= 3:
#     ip_str = ''
#     for ip in private_ip_list:
#         ip_str = ip_str+ ip +":6379 "
#     print(ip_str)
#     redis = TmuxSession(clients[0],'redis-serv-sess')
#     redis.run_command('cd /home/ubuntu/redis-4.0.9/src')
#     redis.run_command('./redis-trib.rb create '+ ip_str)
#     redis.run_command('yes')
#     redis.run_command('-d')

if args.n_nodes >= 3:
    ip_str = ''
    for ip in private_ip_list:
        ip_str = ip_str+ ip +":6379 "
    print(ip_str)
    instance = instances[0]
    print(instance)
    client = connect_to_instance(instance)
    print(client)
    commands = ['cd /home/ubuntu/redis-4.0.9/src', './redis-trib.rb create '+ip_str]
    ssh_input = ['yes']
    inp, out,err = client.exec_command("\n".join(commands))
    inp.write('yes\n')
    inp.flush()
    output = out.read()
    # Close down
    out.close()
    inp.close()
    client.close()
    print(output)

10.0.0.12:6379 10.0.0.6:6379 10.0.0.9:6379 10.0.0.4:6379 
ec2.Instance(id='i-021239995f82d19bf')
Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Connected!
<paramiko.client.SSHClient object at 0x10e538828>
b">>> Creating cluster\n>>> Performing hash slots allocation on 4 nodes...\nUsing 4 masters:\n10.0.0.12:6379\n10.0.0.6:6379\n10.0.0.9:6379\n10.0.0.4:6379\nM: 462df4695b4d2d7084a7abe80334376a517bb6af 10.0.0.12:6379\n   slots:0-4095 (4096 slots) master\nM: de01f37ad5259f1491633940ff7469937ed9078d 10.0.0.6:6379\n   slots:4096-8191 (4096 slots) master\nM: aeaf8666b9daa44243558b513ccb71ae08bd5d4c 10.0.0.9:6379\n   slots:8192-12287 (4096 slots) master\nM: d9ce7d22cd1f2f1076fc2703eb5e3f7bdef76410 10.0.0.4:6379\n   slots:12288-16383 (4096 slots) master\nCan I set the above configuration? (type 'yes' to accept): >>> Nodes configuration updated\n>>> Assign a different config epoch to each node\n>>> Sending CLUSTER MEET messages to join the cluster\nWaiting 

#### Creating Arguments String

In [18]:
#Model parameters
batch_size = 128
epochs = 1
lr = 0.01
momentum = 0.5
log_interval = 50
num_processes = 1
nnet_arch = 'LeNet'
dataset = 'MNIST'
args.is_redis = True


In [19]:
hosts = ''
for ip in private_ip_list:
    hosts = hosts+ip+','
hosts = hosts.strip(',')

In [20]:
#dataset can be 'MNIST' or 'cifar10'
#architecture can be 'LeNet' or 'ResNet'(still working on this)
#num_processes should be either 1 or 2
#batch size 128,256,512,1024,2048

arg_str = '--is-redis='+str(args.is_redis)+' --dataset='+dataset+' --nnet-arch='+nnet_arch+' --num-processes='+ str(num_processes) + ' --batch-size='+str(batch_size) +' --lr='+str(lr) + ' --hosts='+hosts +' --epochs='+str(epochs)

print(arg_str)

--is-redis=True --dataset=MNIST --nnet-arch=LeNet --num-processes=1 --batch-size=128 --lr=0.01 --hosts=10.0.0.12,10.0.0.6,10.0.0.9,10.0.0.4 --epochs=1


#### Chose the log file name

In [21]:
if args.is_redis:
    log_file = dataset+'-'+nnet_arch+'-'+str(batch_size)+'-'+str(num_processes)+'-'+ 'redis'
else:
    log_file = dataset+'-'+nnet_arch+'-'+str(batch_size)+'-'+str(num_processes)
print('Log file name: '+log_file)

Log file name: MNIST-LeNet-128-1-redis


In [22]:
count = 0
running = True
# for sess,ip in zip(tsess,private_ip_list):
#     sess.run_command('source activate largescale')
#     sess.run_command('~/miniconda3/envs/largescale/bin/redis-cli -h '+str(ip) +' flushall')
#     sess.run_command('python3 -u ~/vanilla-hogwild/main.py '+arg_str+ ' 2>&1 | tee '+log_file)
#     sess.run_command('-d')
import select
path = "/Users/srujithpoondla/lsml_results/"+str(args.n_nodes)+"/"
if not os.path.exists(path):
    os.mkdir(path)
commands = ['source ~/miniconda3/bin/activate largescale']
for i,(instance,ip) in enumerate(zip(instances,private_ip_list)):
    redis_cli_flush = ['~/miniconda3/envs/largescale/bin/redis-cli -h '+str(ip) +' flushall']
#     train = ['nohup python3 -u ~/vanilla-hogwild/main.py '+arg_str+ ' 2>&1 | tee '+log_file+' &']
    train = ['python3 -u ~/vanilla-hogwild/main.py '+arg_str+ '>>'+log_file+' &']

    client = connect_to_instance(instance)
#     transport = client.get_transport()
#     channel = transport.open_session()
    print(commands+redis_cli_flush+train)
    client.exec_command("\n".join(commands+redis_cli_flush+train))
#     while True:
#         if channel.exit_status_ready():
#             break
# #         rl, wl, xl = select.select([channel], [], [], 0.0)
# #         if len(rl) > 0:
# #             print(channel.recv(1024))
    
#     sftp = client.open_sftp()
#     print(path+log_file+str(i))
#     sftp.get(log_file, path+log_file+str(i))
#     print('copied README back here')
    

Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Connected!
['source ~/miniconda3/bin/activate largescale', '~/miniconda3/envs/largescale/bin/redis-cli -h 10.0.0.12 flushall', 'python3 -u ~/vanilla-hogwild/main.py --is-redis=True --dataset=MNIST --nnet-arch=LeNet --num-processes=1 --batch-size=128 --lr=0.01 --hosts=10.0.0.12,10.0.0.6,10.0.0.9,10.0.0.4 --epochs=1>>MNIST-LeNet-128-1-redis &']
Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Connected!
['source ~/miniconda3/bin/activate largescale', '~/miniconda3/envs/largescale/bin/redis-cli -h 10.0.0.6 flushall', 'python3 -u ~/vanilla-hogwild/main.py --is-redis=True --dataset=MNIST --nnet-arch=LeNet --num-processes=1 --batch-size=128 --lr=0.01 --hosts=10.0.0.12,10.0.0.6,10.0.0.9,10.0.0.4 --epochs=1>>MNIST-LeNet-128-1-redis &']
Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Connected!
['source ~/miniconda3/bin/activate largescale

#### Terminate all instances

{'TerminatingInstances': [{'CurrentState': {'Code': 32, 'Name': 'shutting-down'}, 'InstanceId': 'i-021239995f82d19bf', 'PreviousState': {'Code': 16, 'Name': 'running'}}], 'ResponseMetadata': {'RequestId': '6528483e-4ee6-4165-b0d9-684331ac9127', 'HTTPStatusCode': 200, 'HTTPHeaders': {'content-type': 'text/xml;charset=UTF-8', 'transfer-encoding': 'chunked', 'vary': 'Accept-Encoding', 'date': 'Mon, 07 May 2018 02:50:20 GMT', 'server': 'AmazonEC2'}, 'RetryAttempts': 0}}
{'TerminatingInstances': [{'CurrentState': {'Code': 32, 'Name': 'shutting-down'}, 'InstanceId': 'i-02eb3d0346eab2fd6', 'PreviousState': {'Code': 16, 'Name': 'running'}}], 'ResponseMetadata': {'RequestId': '44ef67bd-990d-4727-ad10-fa5a4209e0c6', 'HTTPStatusCode': 200, 'HTTPHeaders': {'content-type': 'text/xml;charset=UTF-8', 'transfer-encoding': 'chunked', 'vary': 'Accept-Encoding', 'date': 'Mon, 07 May 2018 02:50:20 GMT', 'server': 'AmazonEC2'}, 'RetryAttempts': 0}}
{'TerminatingInstances': [{'CurrentState': {'Code': 32, 'N

{'ResponseMetadata': {'HTTPHeaders': {'content-length': '227',
   'content-type': 'text/xml;charset=UTF-8',
   'date': 'Mon, 07 May 2018 02:50:21 GMT',
   'server': 'AmazonEC2'},
  'HTTPStatusCode': 200,
  'RequestId': 'f03446fc-b327-4fa7-bb17-c082d684f126',
  'RetryAttempts': 0}}

In [None]:
sleep(100)
while(running):
    client = connect_to_instance(instances[0])
    inp,out,err=client.exec_command('pidof python3')
    out = out.read()[0]
    inp.close()
    out.close()
    client.close()
    if len(out.strip(' ').split(' ')[0])<1:
        for instance in instances:
            try:
                client = connect_to_instance(instance)
                out=client.exec_command('pidof python3')
                if len(out.read().strip(' ').split(' ')[0])>1:
                    continue
                else:
                    count = count+1
            except Exception as e:
                print(e)
            if count==len(instances):
                running = False
                break
    else:
        sleep(30)
        
path = "/Users/srujithpoondla/lsml_results/"+str(args.n_nodes)+"/"
if not os.path.exists(path):
    os.mkdir(path)
for i,instance in enumerate(instances):
    client = connect_to_instance(instance)
    sftp = client.open_sftp()
    print(path+log_file+str(i))
    sftp.get(log_file, path+log_file+str(i))
    print('copied README back here')
    client.close()