### Asynchronous SGD updates using Distributed Cache Redis

1. Create spot instances
2. Mount EFS (Not required for our project)
3. Start redis server in all the instances and create a cluster
4. Pull the code from github repo (https://github.com/SrujithPoondla/vanilla-hogwild.git)
5. If need to divide the dataset between the nodes run the specific cell
6. Run the scripts to start training
7. After training ends close the instances

In [31]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [32]:
from aws_setup import *
from argparse import ArgumentParser
import boto3
sys.argv = ['foo']

#### Define parameters

In [33]:
parser = ArgumentParser(description='Asynchronous SGD updates using Redis')
parser.add_argument('--n-nodes', type=int, default=1, metavar='N',
                    help='how many aws instances to start')
parser.add_argument('--is-redis', action='store_true', default = True,
                    help="Choose whether the model to be trained using redis or not."
                    "If not using Redis, model will be trained on single process")


_StoreTrueAction(option_strings=['--is-redis'], dest='is_redis', nargs=0, const=True, default=True, type=None, choices=None, help='Choose whether the model to be trained using redis or not.If not using Redis, model will be trained on single process', metavar=None)

In [34]:
args = parser.parse_args()
vpc_name='vpc-1b056b60'
n_instances = args.n_nodes
instance_type = 'm5.2xlarge'
ami_sr = 'ami-2663e759'
a_zone = 'us-east-1a'

if args.n_nodes is 2 and args.is_redis:
    print('Cant create a cluster with 2 redis nodes. Chose either a 3 node cluster or single instance')

#### Get Existing VPC by tag name

In [35]:
vpc = get_vpc(vpc_name); vpc

ec2.Vpc(id='vpc-1b056b60')

#### Create EFS (if you haven't already)

In [36]:
# efs_tag = f'{vpc_name}-efs'

In [37]:
# efs = create_efs(efs_tag, vpc, performance_mode='maxIO')

#### Request Spot instance

In [38]:
instance_name = f'{vpc_name}-instance'
# Recommend a high compute instance as we need to do multi-threaded resizing later on

In [39]:
spot_price = get_spot_prices()[instance_type]
bid_price = "%.4f" % (float(spot_price)*3)
print(f'Spot price: {spot_price}, Bid price: {bid_price}')

Spot price: 0.137100, Bid price: 0.4113


In [40]:
launch_specs = LaunchSpecs(vpc, instance_type=instance_type, ami= ami_sr,availability_zone=a_zone).build()

In [41]:
# launch_specs['BlockDeviceMappings'][0]['Ebs']['VolumeSize'] = 1000

In [42]:
launch_specs

{'BlockDeviceMappings': [{'DeviceName': '/dev/sda1',
   'Ebs': {'DeleteOnTermination': True,
    'VolumeSize': 20,
    'VolumeType': 'gp2'}}],
 'ImageId': 'ami-2663e759',
 'InstanceType': 'm5.2xlarge',
 'KeyName': 'aws-key-spot-instance',
 'NetworkInterfaces': [{'AssociatePublicIpAddress': True,
   'DeviceIndex': 0,
   'Groups': ['sg-2624da6f'],
   'SubnetId': 'subnet-10d9d04d'}]}

In [43]:
ec2 = boto3.resource('ec2')
filters = [
    {
        'Name': 'instance-state-name',
        'Values': ['running']
    }
]
# filter the instances based on filters() above
ec2_instances = list(ec2.instances.filter(Filters=filters))
instances = []
if not len(list(ec2_instances)):
    for i in range(n_instances):
        instance = create_spot_instance(instance_name, launch_specs, spot_price=bid_price); 
        instances.append(instance)
else:
    instances = ec2_instances
print(instances)

Keypair exists
Waiting on spot fullfillment...
Fulfillment completed. InstanceId: i-0f47026c195c01e98
i-0f47026c195c01e98
Rebooting...
Completed. SSH:  ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@35.171.189.81
[ec2.Instance(id='i-0f47026c195c01e98')]


In [44]:
# instance = get_instance(instance_name); instance
ssh_commands = []
public_ip_list = []
private_ip_list = []
for instance in instances:
    # for each instance, append to lists
    private_ip_list.append(instance.private_ip_address)
    public_ip_list.append(instance.public_ip_address)
    ssh_commands.append(get_ssh_command(instance))
print(ssh_commands, public_ip_list, private_ip_list)

['ssh -i ~/.ssh/aws-key-spot-instance.pem ubuntu@35.171.189.81'] ['35.171.189.81'] ['10.0.0.8']


#### Terminating instances

In [15]:
# for instance in ec2_instances:
#     print(instance.terminate())

### SSH

In [45]:
clients = []
for instance in instances:
    clients.append(connect_to_instance(instance))
print(clients)

Connecting to SSH...
Got client
/Users/srujithpoondla/.ssh/aws-key-spot-instance.pem
Exception: [Errno None] Unable to connect to port 22 on 35.171.189.81 Retrying...
Connected!
[<paramiko.client.SSHClient object at 0x115cf5c50>]


#### Mount EFS

In [17]:
# efs_addr = get_efs_address('fast-ai-efs'); efs_addr

In [18]:
# _ = run_command(client, 'mkdir ~/efs_mount')

In [19]:
# efs_mount_cmd = f'sudo mount -t nfs -o nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2 {efs_addr}:/ ~/efs_mount'
# _ = run_command(client, efs_mount_cmd)

In [20]:
# _ = run_command(client, 'ls efs_mount') # no reformatting

## Tmux

In [46]:
tsess = []
for client in clients:
    if 'sess' not in run_command(client,'tmux ls'):
        tsess.append(TmuxSession(client, 'sess'))         
print(tsess)

[<aws_setup.TmuxSession object at 0x11540d4e0>]


#### Activate Conda Environment in all the instances and check whether we need to create a cluster or not. Then start redis using conf files.

In [47]:
redis_sess=[]
for client,sess in zip(clients,tsess):
    if 'redis' not in run_command(client,'tmux ls'):
        redis = TmuxSession(client, 'redis-sess')
        redis_sess.append(redis)
    print(run_command(client, '. ~/miniconda3/bin/activate largescale'))
    print(run_command(client, 'cd ~/vanilla-hogwild && git stash && git pull && git checkout stable'))
    if (args.n_nodes > 3):
        print(redis.run_command('~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis_cluster.conf'))
    else:
        print(redis.run_command('~/miniconda3/envs/largescale/bin/redis-server ~/redis-conf/redis.conf'))

('', '')
("No local changes to save\r\nremote: Counting objects: 21, done.\x1b[K\r\nremote: Compressing objects:   7% (1/14)   \x1b[K\rremote: Compressing objects:  14% (2/14)   \x1b[K\rremote: Compressing objects:  21% (3/14)   \x1b[K\rremote: Compressing objects:  28% (4/14)   \x1b[K\rremote: Compressing objects:  35% (5/14)   \x1b[K\rremote: Compressing objects:  42% (6/14)   \x1b[K\rremote: Compressing objects:  50% (7/14)   \x1b[K\rremote: Compressing objects:  57% (8/14)   \x1b[K\rremote: Compressing objects:  64% (9/14)   \x1b[K\rremote: Compressing objects:  71% (10/14)   \x1b[K\rremote: Compressing objects:  78% (11/14)   \x1b[K\rremote: Compressing objects:  85% (12/14)   \x1b[K\rremote: Compressing objects:  92% (13/14)   \x1b[K\rremote: Compressing objects: 100% (14/14)   \x1b[K\rremote: Compressing objects: 100% (14/14), done.\x1b[K\r\nremote: Total 21 (delta 13), reused 15 (delta 7), pack-reused 0\x1b[K\r\nUnpacking objects:   4% (1/21)   \rUnpacking objects:   9% (2/21) 

In [48]:
print(run_command(clients[0], '~/miniconda3/envs/largescale/bin/redis-cli set \'a\' 1'))
print(run_command(clients[0], '~/miniconda3/envs/largescale/bin/redis-cli get \'a\''))

('OK\r\n', '')
('"1"\r\n', '')


#### Create redis cluster

In [49]:
if args.n_nodes > 3:
    ip_str = ''
    for ip in private_ip_list:
        ip_str = ip_str +":6379 "
    stdin,out,err = run_command(client[0], 'cd redis-stable/src && ./redis-trib.rb create {ip_str}',['yes'])
    print(out)

#### Creating Arguments String

In [50]:
#Model parameters
batch_size = 128
epochs = 20
lr = 0.01
momentum = 0.5
log_interval = 50
num_processes = 2
nnet_arch = 'LeNet'
dataset = 'cifar10'
args.is_redis = True

In [51]:
#dataset can be 'MNIST' or 'cifar10'
#architecture can be 'LeNet' or 'ResNet'(still working on this)
#num_processes should be either 1 or 2
#batch size 128,256,512,1024,2048

arg_str = '--is-redis='+str(args.is_redis)+' --dataset='+dataset+' --nnet-arch='+nnet_arch+' --num-processes='+\
str(num_processes) + ' --batch-size='+str(batch_size) +' --lr='+str(lr)

print(arg_str)

--is-redis=True --dataset=cifar10 --nnet-arch=LeNet --num-processes=2 --batch-size=128 --lr=0.01


#### Chose the log file name

In [52]:
if args.is_redis:
    log_file = dataset+'-'+nnet_arch+'-'+str(batch_size)+'-'+str(num_processes)+'-'+ 'redis'
else:
    log_file = dataset+'-'+nnet_arch+'-'+str(batch_size)+'-'+str(num_processes)
print('Log file name: '+log_file)

Log file name: cifar10-LeNet-128-2-redis


In [59]:
out = run_command(clients[0],'w')[0].split(',')[1].strip().split(' ')[0]

4


In [60]:
count = 0
running = True
for sess in tsess:
    sess.run_command('source activate largescale')
    sess.run_command('python3 -u ~/vanilla-hogwild/main.py '+arg_str+ '2>&1 | tee '+log_file)
    sess.run_command('exit')
while(running):
    for client in clients:
        if int(run_command(clients[0],'w')[0].split(',')[1].strip().split(' ')[0]) > 3:
            continue
        else:
            count = count+1
        if count==len(clients):
            running = False
            break

for instance in instances:
    print(instance.terminate())
for ip in public_ip_list:
    ec2c.release_address(PublicIp=ip)