# Prototype hybrid embedding : data-parallel frequent categories and model- parallel infrequent categories

In [None]:
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

import matplotlib
plt.rcParams['figure.figsize'] = [10, 6]
font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 16}
matplotlib.rc('font', **font)

In [None]:
# utility functions
from copy import deepcopy

def flatten_data(data):

    # concatenate all iterations
    samples_data = np.concatenate([deepcopy(data[i][1]) for i in range(len(data))], axis=1)

    # data dimensions
    embedding_sizes = data[0][0]
    num_tables = samples_data.shape[0]
    num_samples = samples_data.shape[1]

    # 
    samples = np.zeros(num_tables * num_samples, dtype=np.int32)
    category_index_offset = 0
    for j in range(num_tables):
        for i in range(num_samples):
            samples[j*num_samples + i] =  category_index_offset + samples_data[j, i]
        category_index_offset += embedding_sizes[j]

    return samples

# Calibration - communication measurements

In [None]:
node_list = [2, 4, 8, 16]

# per gpu:
D_ar = np.array([0.25, 0.5, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] )* 1024*1024
T_ar = np.array([
    31.940742,
    36.368742,
    46.126742,
    77.696742,
    103.154742,
    124.293742,
    191.450742,
    331.715742,
    611.883742,
    1199.531742,
    2225.175742,
    4391.540742,
    8586.129742])


In [None]:
# 4 nodes
# The results of measuring the latencies for varying amount of data
# per node?

D_a2a_data = []
T_a2a_data = []

# 2 nodes:
T_a2a_data.append(
    np.array(
        [69.71, 69.98, 67.27, 68.65, 67.98, 
         68.32, 66.96, 68.03, 67.47, 68.88, 
         69.02, 69.39, 71.76, 75.59, 84.35, 
         115.3, 166.4 ,261, 450.7], dtype=np.float64))
D_a2a_data.append(
    np.array(
        [64, 128, 256, 512, 1024, 2048, 4096, 
         8192, 16384, 32768, 65536, 131072, 
         262144, 524288, 1048576, 2097152, 
         4194304, 8388608, 16777216], dtype=np.float64))

# 4 nodes:
D_a2a_data.append(np.array(
    [128, 256, 512, 1024, 2048, 4096, 8192, 16384,
     32768, 65536, 131072, 262144, 524288, 1048576, 
     2097152, 4194304, 8388608, 16777216 
                 ], dtype=np.float64))
# T_a2a = np.array([116, 101, 101, 112, 101, 99, 103, 100, 102, 101, 270, 109, 107, 117, 159, 230, 369, 690])
T_a2a_data.append(np.array(
    [116, 101, 101, 112, 101, 99, 103, 100,
     102, 101, 105, 109, 107, 117, 159, 230,
     369, 690], dtype=np.float64))

# 8 nodes:
# D_a2a_data.append(np.array(
#     [64, 128, 256, 512, 1024, 2048, 4096, 8192, 
#      16384, 32768, 65536, 131072, 262144, 524288, 
#      1048576, 2097152, 4194304, 8388608, 16777216], dtype=np.float64))
# T_a2a_data.append(np.array(
#     [0.14, 0.13, 212.2, 230.9, 201.9, 207.5, 190.4, 
#      193, 194.8, 187.7, 198.4, 392.8, 190.4, 190, 
#      212.5, 245.2, 376.4, 487.2, 858.1], dtype=np.float64))
D_a2a_data.append(np.array(
    [256, 512, 1024, 2048, 4096, 8192, 
     16384, 32768, 65536, 131072, 262144, 524288, 
     1048576, 2097152, 4194304, 8388608, 16777216], dtype=np.float64))
T_a2a_data.append(np.array(
    [212.2, 230.9, 201.9, 207.5, 190.4, 
     193, 194.8, 187.7, 198.4, 195, 190.4, 190, 
     212.5, 245.2, 376.4, 487.2, 858.1], dtype=np.float64))

# 16 nodes:
# D_a2a_data.append(np.array(
#     [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 
#      32768, 65536, 131072, 262144, 524288, 1048576, 
#      2097152, 4194304, 8388608, 16777216], dtype=np.float64
#      ))
# T_a2a_data.append(np.array(
#     [0.12, 0.12, 0.13, 445.7, 496.8, 387.9, 397.3, 400.4, 
#      403.8, 391.3, 408.6, 402.8, 390.1, 406.1, 408.2, 432.2, 
#      481.4, 1136.2, 1210.3], dtype=np.float64
#      ))
D_a2a_data.append(np.array(
    [512, 1024, 2048, 4096, 8192, 16384, 
     32768, 65536, 131072, 262144, 524288, 1048576, 
     2097152, 4194304, 8388608, 16777216], dtype=np.float64
     ))
T_a2a_data.append(np.array(
    [445.7, 496.8, 387.9, 397.3, 400.4, 
     403.8, 391.3, 408.6, 402.8, 390.1, 406.1, 408.2, 432.2, 
     481.4, 1136.2, 1210.3], dtype=np.float64
     ))

import matplotlib.pyplot as plt

B_IB = 200e9
B_AR = 130e9

for i, num_nodes in enumerate(node_list):
    D_a2a = D_a2a_data[i]
    T_a2a = T_a2a_data[i]

    T_sol_a2a = 8 * D_a2a * (num_nodes-1) / num_nodes / B_IB * 1e6
    T_sol_ar  = D_ar / B_AR * 1e6

    mark_a2a = T_sol_a2a > 10
    mask_ar = T_sol_ar > 10

    plt.title(f'calibration data {num_nodes} nodes')
    plt.plot(D_a2a, T_a2a, 'b.-', label='all-to-all latencies ($\mu s$)')
    plt.plot(D_a2a[mark_a2a], T_sol_a2a[mark_a2a], 'b--', label='all-to-all SOL')
    plt.plot(D_ar, T_ar, 'k.-', label='all reduce latencies ($\mu s$)')
    plt.plot(D_ar[mask_ar], T_sol_ar[mask_ar], 'k--', label='all-reduce SOL')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()
    plt.show()


# Initialize frequent categories

In [None]:
from copy import deepcopy
import numpy as np

def interpolate_T(D, T, bytes_in):
    # Calibration data of communication times may be noisy data:
    #    fit a straight line locally using a Guassion kernel and 
    #    evaluate the line at D=bytes_in

    # if zeros bytes, no communication will be performed and return t = 0
    epsilon = 0.5
    if bytes_in < epsilon or bytes_in < D[0] / 2:
        return 0.

    d_max = D[-1]
    # width of fit on logarithmic scale
    width_log = 2.

    num_points = D.size
    sigma = (bytes_in * width_log - bytes_in / width_log) / 2.
    kernel = np.exp(-(D-bytes_in)**2/(2*sigma*sigma))

    # Solve  
    #   kernel * ( a D + b ) = kernel * T
    # for a and b using linear regression.

    M = np.zeros((num_points, 2), dtype=np.float64)
    M[:,0] = kernel * D
    M[:,1] = kernel

    # linear regression using Moore-Penrose inverse 
    beta = np.dot(np.dot(np.linalg.inv(np.dot(M.T, M)), M.T), kernel * T)

    a = beta[0]
    b = beta[1]

    # evaluate interpolation and return interpolation
    t_comm_interpolated = a*bytes_in + b
    return t_comm_interpolated

In [None]:
for i, num_nodes in enumerate(node_list):
    print(f"interpolation of calibration data for {num_nodes} nodes")

    D_a2a = D_a2a_data[i]
    T_a2a = T_a2a_data[i]
    
    N = 300
    data_a2a = np.linspace(D_a2a[0], D_a2a[-1], N)
    t_a2a_int = np.zeros(N)
    for i, d in enumerate(data_a2a):
        t_a2a_int[i] = interpolate_T(D_a2a, T_a2a, d)
    data_ar = np.linspace(D_ar[0], D_ar[-1], N)
    t_ar_int = np.zeros(N)
    for i, d in enumerate(data_ar):
        t_ar_int[i] = interpolate_T(D_ar, T_ar, d)

    plt.subplot(121)
    plt.plot(D_a2a, T_a2a, 'k.-', label='calibration all-to-all')
    plt.plot(data_a2a, t_a2a_int, label='interpolation all-to-all')

    plt.plot(D_ar, T_ar, 'b.-', label='calibration all-reduce')
    plt.plot(data_ar, t_ar_int, 'r', label='interpolation all reduce')

    plt.title('calibration data')
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('number of bytes per gpu')
    plt.ylabel('communication time ($\mu s$)')
    plt.legend()
    # plt.show()

    plt.subplot(122)

    # plt.plot(D_a2a, T_a2a, 'k.-', label='calibration all-to-all')
    plt.plot(data_a2a, t_a2a_int, label='interpolation all-to-all')

    # plt.plot(D_ar, T_ar, 'b.-', label='calibration all-reduce')
    plt.plot(data_ar, t_ar_int, 'r', label='interpolation all reduce')

    plt.title('interpolation of data')
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('number of bytes per gpu')
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
def initialize_frequent_categories(data, num_nodes, embedding_parameters, calibration_data):
    num_batches = len(data)
    batch_size = data[0][1].shape[1]
    embedding_sizes = data[0][0]
    num_tables = embedding_sizes.size

    #print('initializing frequent categories..')
    
    # get samples and category counts
    #print('flattening data..')
    samples = flatten_data(data)
    categories, counts = np.unique(samples, return_counts=True)

    #print('performing stats..')
    # sort counts and categories from most frequent to least frequent
    index_count_sort = np.argsort(counts)
    categories_sort = deepcopy(categories[index_count_sort])[::-1]
    counts_sort = deepcopy(counts[index_count_sort])[::-1]

    num_samples = num_tables * num_batches * batch_size
    ## plot stats:
    # plt.plot(np.cumsum(counts_sort) / num_samples*100, label='frequent categories')
    # plt.legend()
    # plt.show()

    # embedding parameters
    embedding_vec_size = embedding_parameters.embedding_vec_size
    data_element_size = embedding_parameters.data_element_size

    # calibration data
    D_a2a = calibration_data.D_a2a # unit : data in bytes per gpu total the all-to-all message size (# ranks x size)
    T_a2a = calibration_data.T_a2a # unit : time in microseconds
    D_ar = calibration_data.D_ar   # unit : data in bytes per gpu
    T_ar = calibration_data.T_ar   # unit : time in microseconds

    # some theoretical maxima
    B_a2a_max = 190e9
    B_ar_max = 230e9

    n_max = 0.1 * B_a2a_max / B_ar_max * num_nodes / (num_nodes-1)
    # node occupancy of the categories
    n_c = counts_sort / (num_batches * num_nodes)
    num_frequent_max = (np.argmax(n_c < n_max) + 1)

    #print('calculating communication times..')
    # calculate the communication times for all possible number of 
    # frequent categories up to num_frequent_max
    communication_time = np.zeros(num_frequent_max)
    comm_time_ar = np.zeros(num_frequent_max)
    comm_time_a2a = np.zeros(num_frequent_max)
    for num_frequent_categories in range(num_frequent_max):

        # calculate all-to-all bytes
        bytes_a2a = (num_samples - np.sum(counts_sort[:num_frequent_categories]) ) * embedding_vec_size * data_element_size
        bytes_a2a_gpu = bytes_a2a / (num_batches * num_nodes) / 8

        # calculate all-reduce bytes
        bytes_ar = num_frequent_categories * embedding_vec_size * data_element_size

        t_ar = interpolate_T(D_ar, T_ar, bytes_ar)
        t_a2a = 2*interpolate_T(D_a2a, T_a2a, bytes_a2a_gpu)

        # record data
        comm_time_ar[num_frequent_categories] = t_ar
        comm_time_a2a[num_frequent_categories] = t_a2a
        communication_time[num_frequent_categories] = t_ar + t_a2a

    num_frequent = int(np.argmin(communication_time) + 1)
    return num_frequent, categories_sort, counts_sort, communication_time, comm_time_ar, comm_time_a2a

In [None]:
# test code
class EmbeddingParameters:
    def __init__(self, embedding_vec_size=128, data_element_size=2):
        self.embedding_vec_size = embedding_vec_size
        self.data_element_size = data_element_size

class CalibrationData:
    def __init__(self, D_ar, T_ar, D_a2a, T_a2a):
        self.D_ar = D_ar
        self.T_ar = T_ar
        self.D_a2a = D_a2a
        self.T_a2a = T_a2a

import matplotlib
plt.rcParams['figure.figsize'] = [14, 8]
font = {'family' : 'normal',
    'weight' : 'normal',
    'size'   : 16}
matplotlib.rc('font', **font)

In [None]:
for i, num_nodes in enumerate(node_list):

    D_a2a = D_a2a_data[i]
    T_a2a = T_a2a_data[i]

    print()
    print()
    print(f'{color.BOLD}{color.GREEN}Hybrid embedding communication optimization on {num_nodes} nodes{color.END}')

    N = 300
    data_a2a = np.linspace(D_a2a[0], D_a2a[-1], N)
    t_a2a_int = np.zeros(N)
    for i, d in enumerate(data_a2a):
        t_a2a_int[i] = interpolate_T(D_a2a, T_a2a, d)
    data_ar = np.linspace(D_ar[0], D_ar[-1], N)
    t_ar_int = np.zeros(N)
    for i, d in enumerate(data_ar):
        t_ar_int[i] = interpolate_T(D_ar, T_ar, d)

    plt.subplot(121)
    plt.plot(D_a2a, T_a2a, 'k.-', label='calibration all-to-all')
    plt.plot(data_a2a, t_a2a_int, label='interpolation all-to-all')

    plt.plot(D_ar, T_ar, 'b.-', label='calibration all-reduce')
    plt.plot(data_ar, t_ar_int, 'r', label='interpolation all reduce')

    plt.title('calibration data')
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('number of bytes per gpu')
    plt.ylabel('communication time ($\mu s$)')
    plt.legend()
    # plt.show()

    embedding_vec_size = 128
    data_element_size = 2
    num_frequent_categories, frequent_categories, counts_frequent_categories, communication_time, comm_time_ar, comm_time_a2a \
        = initialize_frequent_categories(
        data, num_nodes,
        EmbeddingParameters(embedding_vec_size=embedding_vec_size, data_element_size=data_element_size),
        CalibrationData(D_ar, T_ar, D_a2a, T_a2a))
    num_frequent = num_frequent_categories
    num_frequent_max = communication_time.size
    t_min = communication_time[num_frequent-1]

    plt.subplot(122)
    plt.title('intitialization frequent categories - communication time')
    plt.text(num_frequent-1, t_min+100, f'comm time = {t_min:4.0f} microseconds')
    plt.plot(range(num_frequent_max), communication_time, 'k-', label=f'communication time, num_frequent = {num_frequent:3,d}')
    plt.plot(range(num_frequent_max), comm_time_ar, 'r--', label=f'communication time all-reduce')
    plt.plot(range(num_frequent_max), comm_time_a2a, 'b--', label=f'communication time all-to-all')
    plt.plot(range(num_frequent_max), t_min*np.ones(num_frequent_max), 'k--')
    plt.xlabel('number of frequent categories')
    plt.ylabel('communication time ($\mu s$)')
    plt.legend()
    plt.tight_layout()
    plt.show()

    counts = counts_frequent_categories
    num_batches = len(data)
    batch_size = data[0][1].shape[1]
    num_tables = data[0][1].shape[0]
    num_samples = batch_size * num_batches * num_tables
    percentage_samples_ar = np.sum(counts[:num_frequent_categories]) / num_samples * 100
    print(f'speedup hybrid model vs model-parallel : {color.BOLD}{communication_time[0] / communication_time[num_frequent_categories-1]:4.2f} X{color.END}')
    print()
    print(f'number of frequent categories = {color.BOLD}{color.BLUE}{num_frequent_categories:3,d}{color.END}')
    print(f'total communication time      = {color.BOLD}{color.BLUE}{communication_time[num_frequent_categories-1]:4.0f} microseconds {color.END}(vs {communication_time[0]:4.0f} microseconds)')
    print()
    print(f'samples covered by all-reduce{color.BOLD} (data-parallel){color.END}  = {percentage_samples_ar:2.1f} %')
    print(f'samples covered by all-to-all{color.BOLD} (model-parallel){color.END} = {100-percentage_samples_ar:2.1f} %')
    print()
    bytes_ar = num_frequent_categories * embedding_vec_size * data_element_size
    print(f'bytes per gpu into all-reduce : {int(bytes_ar):8,d} bytes per gpu')
    bytes_a2a = (num_samples - np.sum(counts[:num_frequent_categories]) ) * embedding_vec_size * data_element_size
    bytes_a2a_gpu = bytes_a2a / (num_batches * num_nodes) / 8
    print(f'bytes per gpu into all-to-all : {int(bytes_a2a):8,d} bytes per gpu, equivalent rank size = {int(bytes_ar/(num_nodes*8)):6,d} bytes')
    print(f'')
    print()
    t_ar = interpolate_T(D_ar, T_ar, bytes_ar)
    latency_ar = (t_ar*1e-6 - bytes_ar / 130e9)*1e6
    print(f'all-reduce communication time : {t_ar:3.1f} microseconds,\x1b[31m latency = {latency_ar:3.1f} microseconds\x1b[0m (assuming 130 GB/s algo bandwidth)')
    t_a2a = interpolate_T(D_a2a, T_a2a, bytes_a2a_gpu)
    latency_a2a = (t_a2a*1e-6 - bytes_a2a_gpu*(num_nodes-1)/num_nodes / 24e9)*1e6
    print(f'all-to-all communication time : {t_a2a:3.1f} microseconds,\x1b[31m latency = {latency_a2a:3.1f} microseconds\x1b[0m (assuming 24 GB/s NIC-IB bandwidth)')
    print()
    print(f'latency all-reduce + 2 x latency all-to-all = {color.BOLD}\x1b[31m {latency_ar + 2*latency_a2a:3.1f} microseconds latency total\x1b[0m on {num_nodes}{color.END} nodes')

# Initialize data-structures

In [None]:
# configure nodes and gpus

class Gpu:

    def __init__(self):
        self.frequent_categories = None
        self.category_frequent_index = None
        self.frequent_embedding_vectors = None
        self.frequent_partial_gradients = None
        self.category_location = None
        self.node = None

    def init_embedding_cache(self, num_frequent, embedding_vec_size):
        self.num_frequent = num_frequent
        self.frequent_embedding_vectors = np.zeros(num_frequent*embedding_vec_size, dtype=np.float32)
        self.frequent_partial_gradients = np.zeros(num_frequent*embedding_vec_size, dtype=np.float32)
        
class Node:

    def __init__(self, num_gpus):
        self.gpus = [Gpu() for i in range(num_gpus)]
        for i in range(num_gpus):
            self.gpus[i].gpu_id = i
            self.gpus[i].node = self # reference to this node

class Network:

    def __init__(self, nodes):
        self.nodes = nodes

    def all_reduce(self):
        pass

    def all_to_all(self):
        pass

In [None]:
# setup nodes, gpus and network:
i_node = 1 # 4 nodes
num_nodes = node_list[i_node]

nodes = [Node(8) for i in range(num_nodes)]
gpus = [gpu for node in nodes for gpu in node.gpus]
for i in range(num_nodes):
    nodes[i].node_id = i
network = Network(nodes)

for node in nodes:
    print(f"Node {node.node_id} with gpu's {[gpu.gpu_id for gpu in node.gpus]}, reporting for duty!")
for gpu in gpus:
    print(f"Gpu {gpu.gpu_id} on node {gpu.node.node_id}, reporting for duty!")

print(f"network with nodes {[node.node_id for node in network.nodes]} reporting for duty!")

In [None]:
print(f'{color.BOLD}{color.GREEN}Hybrid embedding communication optimization on {num_nodes} nodes{color.END}')
print()
print(f'Setting up data structures for run on {num_nodes} nodes')

D_a2a = D_a2a_data[i_node]
T_a2a = T_a2a_data[i_node]

print(f'Initializing frequent categories..')
embedding_vec_size = 128
data_element_size = 2
num_frequent_categories, frequent_categories, counts_frequent_categories, communication_time, comm_time_ar, comm_time_a2a \
    = initialize_frequent_categories(
    data, num_nodes,
    EmbeddingParameters(embedding_vec_size=embedding_vec_size, data_element_size=data_element_size),
    CalibrationData(D_ar, T_ar, D_a2a, T_a2a))

In [None]:
num_frequent = num_frequent_categories
t_min = communication_time[num_frequent-1]
print(f"Number of frequent categories = {color.BOLD}{num_frequent:5,d}{color.END}, embedding communication time {color.BOLD}{t_min:4.0f}{color.END} microseconds")

## category_frequent_index

In [None]:
embedding_sizes = data[0][0]
num_tables = embedding_sizes.size
num_categories = np.sum(embedding_sizes)
print(f'Total number of categories : {num_categories:8,d}, category_frequent_index array size : {num_categories*4/(1024*1024):4.2f} MB')
category_frequent_index = num_categories * np.ones(num_categories, dtype=np.int32)
frequent_categories = frequent_categories[:num_frequent]

# initializing category_frequent_index :
category_frequent_index[frequent_categories] = np.array(range(num_frequent), dtype=np.int32)

# this array is identical on all gpu's :
for gpu in gpus:
    gpu.category_frequent_index = category_frequent_index

n_display = 20
print(f'{color.BOLD}{color.RED}category          |-> frequent category cache index{color.END}')
for category in range(n_display):
    frequent_category_cache_index = category_frequent_index[category]
    if frequent_category_cache_index < num_categories:
        print(f'category {color.BOLD}{category:3d}{color.END}      |->  cache index {color.BOLD}{color.BLUE}{frequent_category_cache_index:6,d}{color.END}')
    else:
        print(f'category {color.BOLD}{category:3d}{color.END}      |->  cache index    {color.BOLD}{color.RED}END{color.END}')

In [None]:
category_frequent_index.shape
np.sum(category_frequent_index < num_categories)

In [None]:
# initialize frequent_embedding_vectors
# initialize frequent_partial_gradients
for gpu in gpus:
    gpu.init_embedding_cache(num_frequent, embedding_vec_size)

In [None]:
%%time
# takes a lot of time, there are many infrequent categories! (15 minutes)

# category_location
num_infrequent = num_categories - num_frequent
category_location = num_categories * np.ones((num_categories,2), dtype=np.int32)
#locations_infrequent = [ [int(np.floor(i/8)),i%8] for i in range(num_infrequent) ]
infrequent_index = np.zeros(num_categories)
infrequent_index[category_frequent_index == num_categories] = range(num_infrequent)
for c in range(num_categories):
    if category_frequent_index[c] == num_categories:
        index = infrequent_index[c]
        category_location[c,:] = [int(np.floor(index/8)), index%8]

In [None]:
for gpu in gpus:
    gpu.category_location = category_location

In [None]:
n_display = 20
print(f'{color.BOLD}{color.RED}category          |->  category location {color.END}')
for category in range(n_display):
    location = category_location[category,:]
    if location[0] < num_categories:
        print(f'category {color.BOLD}{category:3d}{color.END}      |->  category location {color.BOLD}{color.GREEN}{location}{color.END}')
    else:
        print(f'category {color.BOLD}{category:3d}{color.END}      |->  category location   {color.BOLD}{color.RED}END{color.END}')

In [None]:
# multi-node: node_id, gpu_id
# single-node: gpu_id, category_model_index

# linear location index:
#
# multi-node: node_id * 8 + gpu_id
# single-node: gpu_id * max_model_size + category_model_index

# sample 3 in network 5: category 7 stored in model (gpu) 8

In [None]:
# HybridEmbedding: 
#
#   frequent_categories_
#   category_frequent_index_
#   category_location_
#
#   frequent_embedding_;
#   infrequent_embeddin_;

# Forward: sample |-> batch embedding vectors in mlp (concatenation)
#
#    data-parallel (all-reduce), model-parallel (all-to-all)
#

#
# sample : 26 fields categorical feature => category
#          embedding vector category 0, embedding vector category 1, embedding vector category 2, ...

# EmbeddingFrequent ( data-parallel, all-reduce )
#   
#   frequent_embedding_vectors
#   frequent_partial_gradients
#
#   stores the frequent embedding 
#   update: reduces locally the frequent gradients into the frequent_partial_gradients array
#      frequent sample indices: update_sgd(sample_indices_frequent, samples_frequent_index, gradients_samples)
#
#   all-reduce : frequent_partial_gradients
#   update the frequent_embedding_vectors
#

# EmbeddingInfrequent => 
#
#   store the infrequent categories
#
#   infrequent_embedding_vectors
#
#   
#   # local batch, global batch
#   
#   all-to-all forward
#      send buffer: 
#         (I) entire batch |-> mark samples' categories that is placed here
#         create list of indices for entire batch, of samples' embedding vector to send
#         create offset per destination
#      receive buffer: 
#         (II) local mlp - batch |-> sort by source ( doesn't need to be sort )
#
#   all-to-all backward
#      send buffer: (II)
#      receive buffer: (I)
#
#   update infrequent embedding vectors
#      (I) |-> categories |-> where stored? samples_infrequent_index
#      update_sgd(sample_indices_infrequent, samples_infrequent_index, infrequent_embedding_vectors)


# Index calculations

## Embedding

In [None]:
iteration = 0

num_gpus = len(gpus)

samples = flatten_data([data[iteration]])
for gpu in gpus:
    gpu.samples = samples

In [None]:
samples = flatten_data(data)

def get_node_gpu(node_id, gpu_id):
    # not efficient, but that's not the point here! :P
    node = None
    gpu = None
    for node_ in nodes:
        if node_.node_id == node_id:
            node = node_
            break
    for gpu_ in node.gpus:
        if gpu_.gpu_id == gpu_id):
            gpu = gpu_
            break
    return node, gpu

def cub_DeviceSelect(gpu, samples, node_id, gpu_id):
    samples_category_location = gpu.category_location[samples,:]
    samples_mask = (location_samples[:,0]==node_id)*(location_samples[:,1]==gpu_id)
    samples_filter = deepcopy(samples[samples_mask])
    return samples_filter

# indices on the model side: each gpu 
def calculate_model_indices(samples, node_id, gpu_id):
    node, gpu = get_node_gpu(node_id, gpu_id)

    def get_network(samples, num_gpus, i_sample):
        num_samples = samples.size

        # i_sample = 

    sample_model_indices = cub_DeviceSelect(gpu, gpu.samples, node_id, gpu_id) # 
    network_offset_model_indices = np.zeros(num_gpus, dtype=np.int32)
    # cuda kernel:
    for i in range(sample_model_indices.size):
        if i == 0 or get_network(sample_model_indices[i]) != get_network(sample_model_indices[i-1]):
            offset = sample_model_indices[i]

    # selected_infrequent_samples[gid] / mlp_network_size # => mlp-network id for this sample
    # selected_infrequent_samples[gid] / mlp_network_size

    # sample_indices_network = np.zeros(samples, )

In [None]:
# embedding model forward:

node_id = 0
gpu_id = 1

samples = flatten_data([data[0]])
location_samples = category_location[samples,:]
sample_lin_location_index = np.zeros(samples.size, dtype=np.int32)
for i, category in enumerate(samples):
    if category_location[category,0] < num_categories:
        lin_location_index = category_location[category,0]*8 + category_location[category,1]
        sample_lin_location_index[i] = lin_location_index
    else:
        sample_lin_location_index[i] = num_categories
model_lin_index = node_id * 8 + gpu_id
samples_mask = (sample_lin_location_index == model_lin_index)
samples_mask
np.sum(samples_mask)

In [None]:
data[0][1]

# Forward send

In [None]:
# calculate 

# Forward receive

# Backward reduce

# Backward send


# Backward receive

In [None]:
nn = 16
55*1024*26*128*2/(nn*nn*64)/4.4

# Read 15 batches of 64k samples

In [None]:
import os
import numpy as np

def read_variable(lines, indx):
    line_i = lines[indx]
    line_split = line_i.split()
    variable_name = line_split[0]
    num_data = int(line_split[1])
    if num_data == 1:
        offset = 0
        if len(line_split) == 3:
            data = np.int64(line_split[2])
            offset = 0
        else:
            offset = 1
            data = np.int64(lines[indx+1])
        return variable_name, data, indx + 1 + offset
    else:
        values = np.zeros(num_data, dtype=np.int64)
        for i in range(num_data):
            values[i] = np.int64(lines[indx+1+i])
        return variable_name, values, indx + 1 + num_data

def read_dlrm_data(folder_name):
    data = {}

    file_names = os.listdir(folder_name)
    for file_name in file_names:
        print(file_name)
        split_list = file_name.split("_")
        iteration = int(split_list[2])
        gpu_id = int(split_list[4].split(".")[0])

        # parse file
        fobj = open(os.path.join(folder_name, file_name), "r")
        lines = fobj.readlines()
        indx = 0
        _, num_samples, indx = read_variable(lines, indx)

        _, slot_num, indx = read_variable(lines, indx)
        _, size_embedding, indx = read_variable(lines, indx)
        size_embeddings = size_embedding.astype(np.int64)
        _, categories_raw, indx = read_variable(lines, indx)

        categories = np.zeros( (slot_num, num_samples), dtype=np.int64 )
        if slot_num > 1:
            for i in range(categories_raw.size):
                offset = 0
                if slot_num > 1:
                    offset = np.sum(size_embeddings[:i%slot_num])
                #print(f"i mod 2 : {i%2}, i/2 : {np.int(np.floor(i / 2))}")
                categories[(i%slot_num), int(np.floor(i / 2))] = categories_raw[i] - offset
        else:
            categories = categories_raw.reshape((1, categories_raw.size))

        data[(iteration, gpu_id)] = (num_samples, size_embeddings, categories)

    return data

In [None]:
data_folder = os.path.join("/mnt/c/Users/dabel/Documents/mlperf/data/")
data = read_dlrm_data(data_folder)

In [None]:
iterations = {}
# concatenate embeddings per iteration
for iteration, gpu in data:
    if not iteration in iterations:
        iterations[iteration] = [gpu]
    else:
        iterations[iteration].append(gpu)

iteration_numbers = [num for num in iterations]
iteration_numbers = np.sort(iteration_numbers)
#print(len(iteration_numbers), iteration_numbers)

samples_iteration = []
for iteration in iteration_numbers:
    i_table = 0
    embedding_sizes = []
    data_iteration = np.zeros( (26, 65536), dtype=np.int64 )
    for gpu in range(16):
        size_embeddings_gpu = data[(iteration, gpu)][1]
        num_tables_gpu = size_embeddings_gpu.size
        for i_table_gpu in range(num_tables_gpu):
            data_iteration[i_table,:] = data[(iteration, gpu)][2][i_table_gpu,:]
            if num_tables_gpu > 1:
                embedding_sizes.append(size_embeddings_gpu[i_table_gpu])
            else:
                embedding_sizes.append(size_embeddings_gpu)
            i_table += 1
    samples_iteration.append( (np.array(embedding_sizes), data_iteration) )

In [None]:
len(samples_iteration)

In [None]:
samples_iteration[0][1].shape

In [None]:
from copy import deepcopy
samples = np.concatenate([deepcopy(samples_iteration[i][1]) for i in range(len(samples_iteration))], axis=1)

In [None]:
samples

In [None]:
samples.shape

In [None]:
data = samples_iteration