In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
from datetime import datetime
import re

%matplotlib inline

In [6]:
def parse_results(file_name):

    re_time = r'.*Total time (.*) \(ms\)'
    re_client_comm = r'.*Client writing message size (.*)'
    re_server_comm = r'.*Server writing message size (.*)'
    
    client_comms = []
    server_comms = []
    
    
    runtime = None
    with open(file_name) as f:
        lines = [line.strip() for line in f.readlines()]
        for line in lines:
            #print(line)
            if re.match(re_time, line):
                match = re.match(re_time, line)
                runtime = float(match.groups()[0])
            elif re.match(re_client_comm, line):
                match = re.match(re_client_comm, line)
                client_comm = float(match.groups()[0])
                client_comms.append(client_comm)
            elif re.match(re_server_comm, line):
                match = re.match(re_server_comm, line)
                server_comm = float(match.groups()[0])
                server_comms.append(server_comm)
                
    if runtime is None:
        print('could not find runtime in ', file_name)
    assert(runtime is not None)
    
    return runtime, client_comms, server_comms


In [24]:
def make_cryptonets_relu_table():

    folder = './results/'
    omp1_complex_times = []
    omp1_times = []
    wan_times = []
    clientless24_times = []


    for i in range(1,11):
        # LAN with different threads
        #filename = folder + 'nt_32_' + str(i) + '.txt'
        filename = folder + 'wan_' + str(i) + '.txt'
        times, client_comms, server_comms = parse_results(filename)
        wan_times.append(times)

    for i in range(1,11):   
        omp1_complex_file = folder + 'omp1_complex_' + str(i) + '.txt'
        omp1_complex_time, client_comms, server_comms = parse_results(omp1_complex_file)
        omp1_complex_times.append(omp1_complex_time)

    for i in range(1,11):
        omp1_file = folder + 'omp1_' + str(i) + '.txt'
        omp1_time, client_comms, server_comms = parse_results(omp1_file)
        omp1_times.append(omp1_time)

    for i in range(1,11):
        clientless24_file = folder + 'clientless_nt24_' + str(i) + '.txt'
        clientless24_time, client_comms, server_comms = parse_results(clientless24_file)
        clientless24_times.append(clientless24_time)        

    def fmt_line(times, batch_size):
        times_sec = [time / 1000. for time in times]
        amortized_time_ms = np.round(np.mean(times) / batch_size, 2)

        ret = str(batch_size) + ' & ' +\
              str(amortized_time_ms) + ' & ' +\
              str(np.round(np.mean(times_sec), 2)) +\
            ' \pm ' + str(np.round(np.std(times_sec), 2)) +\
            ' \\tabularnewline'
        return ret
    
    
    table_row1 = '1 & \\xmark & localhost & ' + fmt_line(omp1_times, 1024) + '\n'
    table_row2 = '1 & \\cmark & localhost & ' + fmt_line(omp1_complex_times, 2048) + '\n'
    table_row3 = '24 & \\cmark & localhost & ' + fmt_line(clientless24_times, 2048) + '\n'
    table_row4 = '24 & \\cmark & LAN & ' + fmt_line(wan_times, 2048)
    
    table_str = table_row1 + table_row2 + table_row3 + table_row4
    print(table_str)
    
    # Throughput
    batch_size = 2048.    
    throughput = batch_size / np.mean(wan_times) * 1000.
    print('throughput', throughput)
    
    # Communication costs    
    _,  client_comms, server_comms = parse_results(folder + 'comm.txt')
    print('client_comms', client_comms)
    print('server_comms', server_comms)
    
    online_client_comms = client_comms[3:]
    online_server_comms = server_comms[2:4]
    print('online_client_comms', online_client_comms)
    print('online_server_comms', online_server_comms)
    
    
    def bytes_to_mb(n):
        return float(n) / (1000. * 1000.)
    
    total_client_comms = bytes_to_mb(sum(online_client_comms))
    total_server_comms = bytes_to_mb(sum(online_server_comms))
    
    total_comms = total_client_comms + total_server_comms
    total_comms_per_image = total_comms / batch_size
    
    print('client comms', total_client_comms, ' MB')
    print('server_comms', total_server_comms, ' MB')    
    print('total_comms', total_comms, 'MB')
    print('total_comms_per_image', total_comms_per_image, 'MB')    
    
make_cryptonets_relu_table()


1 & \xmark & localhost & 1024 & 2.72 & 2.79 \pm 0.06 \tabularnewline
1 & \cmark & localhost & 2048 & 1.44 & 2.94 \pm 0.04 \tabularnewline
24 & \cmark & localhost & 2048 & 0.24 & 0.5 \pm 0.04 \tabularnewline
24 & \cmark & LAN & 2048 & 0.34 & 0.69 \pm 0.04 \tabularnewline
throughput 2959.109955208785
client_comms [32868.0, 32916.0, 25747371.0, 27750672.0, 3284127.0]
server_comms [52.0, 35.0, 27750672.0, 3284127.0, 328437.0]
online_client_comms [27750672.0, 3284127.0]
online_server_comms [27750672.0, 3284127.0]
client comms 31.034799  MB
server_comms 31.034799  MB
total_comms 62.069598 MB
total_comms_per_image 0.0303074208984375 MB
