### Asynchronous SGD updates using Distributed Cache Redis

1. Create spot instances
2. Mount EFS (Not required for our project)
3. Start redis server in all the instances and create a cluster
4. Pull the code from github repo (https://github.com/SrujithPoondla/vanilla-hogwild.git)
5. If need to divide the dataset between the nodes run the specific cell
6. Run the scripts to start training
7. After training ends close the instances

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from aws_setup import *
from argparse import ArgumentParser
import boto3
import os 
sys.argv = ['foo']

#### Define parameters

In [3]:
parser = ArgumentParser(description='Asynchronous SGD updates using Redis')
parser.add_argument('--n-nodes', type=int, default=1, metavar='N',
                    help='how many aws instances to start')
parser.add_argument('--is-redis', action='store_true', default = True,
                    help="Choose whether the model to be trained using redis or not."
                    "If not using Redis, model will be trained on single process")


_StoreTrueAction(option_strings=['--is-redis'], dest='is_redis', nargs=0, const=True, default=True, type=None, choices=None, help='Choose whether the model to be trained using redis or not.If not using Redis, model will be trained on single process', metavar=None)

In [4]:
args = parser.parse_args()
vpc_name='vpc-1b056b60'
args.n_nodes = 4
n_instances = args.n_nodes
instance_type = 'm4.10xlarge'
ami_sr = 'ami-0b991974'
a_zone = 'us-east-1a'

if args.n_nodes is 2 and args.is_redis:
    print('Cant create a cluster with 2 redis nodes. Chose either a 3 node cluster or single instance')

#### Get Existing VPC by tag name

In [5]:
vpc = get_vpc(vpc_name); vpc

ec2.Vpc(id='vpc-1b056b60')

#### Create EFS (if you haven't already)

In [6]:
# efs_tag = f'{vpc_name}-efs'

In [7]:
# efs = create_efs(efs_tag, vpc, performance_mode='maxIO')

#### Request Spot instance

In [8]:
instance_name = f'{vpc_name}-instance'
# Recommend a high compute instance as we need to do multi-threaded resizing later on

In [9]:
spot_price = get_spot_prices()[instance_type]
bid_price = "%.4f" % (float(spot_price)*3)
print(f'Spot price: {spot_price}, Bid price: {bid_price}')

Spot price: 0.601800, Bid price: 1.8054


In [10]:
launch_specs = LaunchSpecs(vpc, instance_type=instance_type, ami= ami_sr,availability_zone=a_zone).build()

In [11]:
# launch_specs['BlockDeviceMappings'][0]['Ebs']['VolumeSize'] = 1000

In [12]:
launch_specs

{'BlockDeviceMappings': [{'DeviceName': '/dev/sda1',
   'Ebs': {'DeleteOnTermination': True,
    'VolumeSize': 20,
    'VolumeType': 'gp2'}}],
 'ImageId': 'ami-0b991974',
 'InstanceType': 'm4.10xlarge',
 'KeyName': 'aws-key-spot-instance',
 'NetworkInterfaces': [{'AssociatePublicIpAddress': True,
   'DeviceIndex': 0,
   'Groups': ['sg-2624da6f'],
   'SubnetId': 'subnet-10d9d04d'}]}

In [13]:
ec2 = boto3.resource('ec2')
filters = [
    {
        'Name': 'instance-state-name',
        'Values': ['running']
    }
]
# filter the instances based on filters() above
ec2_instances = list(ec2.instances.filter(Filters=filters))
instances = []
for instance in ec2_instances:
    instances.append(instance)
print(instances)
instances_to_request = n_instances-len(instances)
print(instances_to_request)
instances = create_multiple_spot_instance(instance_name, launch_specs,instance_count=args.n_nodes, spot_price=bid_price)
print(instances)

[]
4
Created keypair
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-0ac17f00f29b3a07a
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-0cd12e7fccad34e74
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-0e9fe1c86bf4c30ea
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-014728ebfc5b5aca1
Rebooting...
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@34.228.215.131
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@18.232.132.227
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@35.168.12.100
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@54.172.189.24
[ec2.Instance(id='i-014728ebfc5b5aca1'), ec2.Instance(id='i-0e9fe1c86bf4c30ea'), ec2.Instance(id='i-0ac17f00f29b3a07a'), ec2.Instance(id='i-0cd12e7fccad34e74')]


In [14]:
# instance = get_instance(instance_name); instance
ssh_commands = []
public_ip_list = []
private_ip_list = []
for instance in instances:
    # for each instance, append to lists
    private_ip_list.append(instance.private_ip_address)
    public_ip_list.append(instance.public_ip_address)
    ssh_commands.append(get_ssh_command(instance))
print(ssh_commands, public_ip_list, private_ip_list)

['ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@34.228.215.131', 'ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@18.232.132.227', 'ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@35.168.12.100', 'ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@54.172.189.24'] ['34.228.215.131', '18.232.132.227', '35.168.12.100', '54.172.189.24'] ['10.0.0.7', '10.0.0.8', '10.0.0.9', '10.0.0.6']


#### Terminating instances

In [15]:
# for instance in instances:
#     print(instance.terminate())

#### Activate Conda Environment in all the instances and check whether we need to create a cluster or not. Then start redis using conf files.

In [None]:
# for client,ip in zip(clients,private_ip_list):
#     if 'redis' not in run_command(client,'tmux ls'):
#         redis = TmuxSession(client, 'redis-sess')
#     run_command(client, 'cd ~/vanilla-hogwild && git stash && git pull && git checkout stable')
#     if (args.n_nodes >= 3):
#         redis.run_command('rm dump.rdb && rm appendonly.aof && rm nodes-6379.conf')
#         redis.run_command('echo bind '+ip+'>> ~/redis-conf/redis_cluster.conf')
#         print(redis.run_command('nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis_cluster.conf &'))
#         redis.run_command('-d')
#     else:
#         redis.run_command('echo bind '+ip+'>> ~/redis-conf/redis_cluster.conf')
#         redis.run_command('rm dump.rdb && rm appendonly.aof && rm nodes-6379.conf')
#         print(redis.run_command('~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis.conf'))
#         redis.run_command('-d')

# for client,ip in zip(clients,private_ip_list):
#     run_command(client, 'cd ~/vanilla-hogwild && git stash && git pull && git checkout stable')
#     if (args.n_nodes >= 3):
#         run_command(client,'cd /home/ubuntu/ && rm dump.rdb && rm appendonly.aof && rm nodes-6379.conf')
#         run_command(client,'echo bind '+ip+'>> ~/redis-conf/redis_cluster.conf')
#         print(run_command(client, 'nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis_cluster.conf &'))
#     else:
#         run_command(client,'cd /home/ubuntu/ && rm dump.rdb && rm appendonly.aof && rm nodes-6379.conf')
#         print(run_command(client, 'nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis.conf &'))

commands = ['cd ~/vanilla-hogwild', 'git stash', 'git pull' ,'git checkout stable',
            'rm -rf ~/redis-data', 'cd ~','ls' ,'mkdir redis-data']

for instance,pr_ip in zip(instances,private_ip_list):
    client = connect_to_instance(instance)
    if (args.n_nodes >= 3):
        indep_commands = ['echo bind '+pr_ip+'>> ~/redis-conf/redis_cluster.conf',
                          'echo requirepass lsmldeeplearning >> ~/redis-conf/redis_cluster.conf',
            'nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis_cluster.conf >\
                          /home/ubuntu/redis-data/redis-log.out &']
    else:
        indep_commands = ['echo requirepass lsmldeeplearning >> ~/redis-conf/redis.conf','nohup ~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis.conf > \
                          /home/ubuntu/redis-data/redis-log.out &']
    commands = commands + indep_commands
    inp, out,err = client.exec_command("\n".join(commands))
#     sleep(5)
    inp.write('\n')
    inp.flush()
    output = out.read()
    # Close down
    out.close()
    inp.close()
    client.close()
    print(output)
    
        


Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Exception: timed out Retrying...
Exception: timed out Retrying...
Connected!
b"No local changes to save\nAlready up-to-date.\nYour branch is up-to-date with 'origin/stable'.\ncifar10_data\nlargescale.yml\nminiconda3\nmnist_data\nredis-4.0.9\nredis-conf\nredis-ml\nvanilla-hogwild\n"
Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Connected!
b"No local changes to save\nAlready up-to-date.\nYour branch is up-to-date with 'origin/stable'.\ncifar10_data\nlargescale.yml\nminiconda3\nmnist_data\nredis-4.0.9\nredis-conf\nredis-ml\nvanilla-hogwild\n"
Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Exception: [Errno None] Unable to connect to port 22 on 35.168.12.100 Retrying...


#### Create redis cluster

In [None]:
# if args.n_nodes >= 3:
#     ip_str = ''
#     for ip in private_ip_list:
#         ip_str = ip_str+ ip +":6379 "
#     print(ip_str)
#     redis = TmuxSession(clients[0],'redis-serv-sess')
#     redis.run_command('cd /home/ubuntu/redis-4.0.9/src')
#     redis.run_command('./redis-trib.rb create '+ ip_str)
#     redis.run_command('yes')
#     redis.run_command('-d')

if args.n_nodes >= 3:
    ip_str = ''
    for ip in private_ip_list:
        ip_str = ip_str+ ip +":6379 "
    print(ip_str)
    instance = instances[0]
    print(instance)
    client = connect_to_instance(instance)
    print(client)
    commands = ['cd /home/ubuntu/redis-4.0.9/src', './redis-trib.rb create --password lsmldeeplearning '+ip_str]
    ssh_input = ['yes']
    inp, out,err = client.exec_command("\n".join(commands))
    inp.write('yes\n')
    inp.flush()
    output = out.read()
    # Close down
    out.close()
    inp.close()
    client.close()
    print(output)

#### Creating Arguments String

In [None]:
#Model parameters
batch_size = 128
epochs = 1
lr = 0.01
momentum = 0.5
log_interval = 50
num_processes = 1
nnet_arch = 'LeNet'
dataset = 'MNIST'
args.is_redis = True


In [None]:
hosts = ''
for ip in private_ip_list:
    hosts = hosts+ip+','
hosts = hosts.strip(',')

In [None]:
#dataset can be 'MNIST' or 'cifar10'
#architecture can be 'LeNet' or 'ResNet'(still working on this)
#num_processes should be either 1 or 2
#batch size 128,256,512,1024,2048

arg_str = '--is-redis='+str(args.is_redis)+' --dataset='+dataset+' --nnet-arch='+nnet_arch+' --num-processes='+ str(num_processes) + ' --batch-size='+str(batch_size) +' --lr='+str(lr) + ' --hosts='+hosts +' --epochs='+str(epochs)

print(arg_str)

#### Chose the log file name

In [None]:
if args.is_redis:
    log_file = dataset+'-'+nnet_arch+'-'+str(batch_size)+'-'+str(num_processes)+'-'+ 'redis'
else:
    log_file = dataset+'-'+nnet_arch+'-'+str(batch_size)+'-'+str(num_processes)
print('Log file name: '+log_file)

In [None]:
count = 0
running = True
# for sess,ip in zip(tsess,private_ip_list):
#     sess.run_command('source activate largescale')
#     sess.run_command('~/miniconda3/envs/largescale/bin/redis-cli -h '+str(ip) +' flushall')
#     sess.run_command('python3 -u ~/vanilla-hogwild/main.py '+arg_str+ ' 2>&1 | tee '+log_file)
#     sess.run_command('-d')
import select
path = "/Users/srujithpoondla/lsml_results/"+str(args.n_nodes)+"/"
if not os.path.exists(path):
    os.mkdir(path)
commands = ['source ~/miniconda3/bin/activate largescale']
for i,(instance,ip) in enumerate(zip(instances,private_ip_list)):
    redis_cli_flush = ['~/miniconda3/envs/largescale/bin/redis-cli -h -a lsmldeeplearning'+str(ip) +' flushall']
#     train = ['nohup python3 -u ~/vanilla-hogwild/main.py '+arg_str+ ' 2>&1 | tee '+log_file+' &']
    train = ['python3 -u ~/vanilla-hogwild/main.py '+arg_str+ '>>'+log_file+' &']

    client = connect_to_instance(instance)
#     transport = client.get_transport()
#     channel = transport.open_session()
    print(commands+redis_cli_flush+train)
    client.exec_command("\n".join(commands+redis_cli_flush+train))
#     while True:
#         if channel.exit_status_ready():
#             break
# #         rl, wl, xl = select.select([channel], [], [], 0.0)
# #         if len(rl) > 0:
# #             print(channel.recv(1024))
    
#     sftp = client.open_sftp()
#     print(path+log_file+str(i))
#     sftp.get(log_file, path+log_file+str(i))
#     print('copied README back here')
    

#### Terminate all instances

In [None]:
# # sleep(100)
# while(running):
#     client = connect_to_instance(instances[0])
#     inp,out,err=client.exec_command('pidof python3')
#     out = out.read()[0]
#     inp.close()
#     out.close()
#     client.close()
#     if len(out.strip(' ').split(' ')[0])<1:
#         for instance in instances:
#             try:
#                 client = connect_to_instance(instance)
#                 out=client.exec_command('pidof python3')
#                 print(out.read().strip(' ').split(' '))
#                 if len(out.read().strip(' ').split(' '))>1:
#                     continue
#                 else:
#                     count = count+1
#             except Exception as e:
#                 print(e)
#             if count==len(instances):
#                 running = False
#                 break
#     else:
#         sleep(30)
        
# path = "/Users/srujithpoondla/lsml_results/"+str(args.n_nodes)+"/"
# if not os.path.exists(path):
#     os.mkdir(path)
# for i,instance in enumerate(instances):
#     client = connect_to_instance(instance)
#     sftp = client.open_sftp()
#     print(path+log_file+str(i))
#     sftp.get(log_file, path+log_file+str(i))
#     print('copied README back here')
#     client.close()

In [None]:
# for instance in instances:
#     print(instance.terminate())
# ec2.KeyPair('aws-key-spot-instance').delete()