In [56]:
import pandas as pd
import numpy as np
from datetime import datetime as dt
from IPython.display import display
from bokeh.io import show, output_notebook
import glob
import re
import os.path
import math
import functools
import pickle
from bokeh.plotting import figure, ColumnDataSource
from bokeh.models import HoverTool, ranges
pd.options.mode.chained_assignment = None
output_notebook()

In [57]:
def log(msg): print(dt.now(), msg)

class parentree:
    def parensplit(self, s, sep = ","):
        ret = []
        cur = ""
        lvl = 0
        for char in s:
            if char == "(": lvl += 1
            elif char == ")": lvl -=1
            if char == sep and lvl == 0:
                ret.append(cur)
                cur = ""
            else: cur += char
        ret.append(cur)
        return ret
    def __str__(self):
        return self.n + (("(" + ", ".join([ str(item) for item in self.items ]) + ")") if len(self.items) > 0 else "")
    def debugstr(self, indent=""):
        ret = indent + self.n + "\n"
        for item in self.items: ret += item.debugstr(indent + "  ")
        return ret
    def __init__(self, s):
        match = re.match("\\s*([^()]*)(\\((.*)\\))?\\s*", s)
        self.n = match.group(1)
        self.items = [ parentree(item) for item in self.parensplit(match.group(3)) ] if match.group(3) else []
    def __getitem__(self, idx):
        return self.items[idx]
    def __len__(self):
        return len(self.items)

class reader:
    eventcol = "event"

    def tscmp(self, a, b): # Compare two timestamps, treating None as infinitely in the future
        if a is None and b is None: return 0
        if b is None: return -1
        if a is None: return 1
        if a == b: return 0
        return -1 if a < b else 1
    def svccmp(self, tida, tidb, tsa, tsb): # Compare timestamps, but always keep timestamps in a given trace together
        if tida < tidb: return -1
        if tida > tidb: return 1
        return self.tscmp(tsa, tsb)
    def rpctree_service(self, event, idx): # Extract the host and port from an RPC: 0 for sender, 1 for receiver
        assert(event.n == "RPC")
        assert(idx in (0, 1))
        return tuple(event[idx].n.split("/")[-1].split(":"))
    def event_initiator(self, uuid, event): # Return the host initiating any event
        if event.n == "RPC": return self.rpctree_service(event, 0)
        else: return self.uuidmap[uuid]
    def service_name(self, host, name): # Defines the formatting of a service name
        return host[0] + " " + host[1] + " " + name
    def tree_service_name(self, subtree): # Get the name of one of the services involved in an RPC
        host = tuple(subtree.n.split("/")[-1].split(":"))
        return self.service_name(host, self.servicemap[host])
    def getstarts(self): # Generate start times for each UUID (difficult to do services)
        ret = { uuid: None for uuid in self.uuidmap.keys() }
        for i in range(len(self.xevents)):
            row = self.xevents.iloc[i]
            if ret[row["id"]] is None: ret[row["id"]] = row["time"]
        return ret
    def service_start(self, service): # Check the time of the first event corresponding to a service
        targetid = self.uuidrevmap[service]
        for i in range(len(self.xevents)):
            row = self.xevents.iloc[i]
            if row["id"] == targetid: return row["time"]
        #start = self.df[self.df.apply(lambda x: self.event_initiator(x["id"], x[self.eventcol]) == service, 1)].sort_values("time")["time"].head(1).tolist()
        #if len(start) > 0: return start[0]
        return None
    def filter_apply_events(self, event, filters): # Sequentially apply all filters to an event
        for f in filters:
            if not f.fevent(event): return False
        return True
    def filter_service_basic(self, host, filters): # Sequentially apply all filters to a service
        #if host[1] == "": return True # Retain unresolvable UUIDs
        if host not in self.servicemap: return False
        for f in filters:
            if not f.fservice(host[0], host[1], self.servicemap[host]): return False
        return True
    def filter_apply_services(self, uuid, event, filters): # Apply all service filters to an event based on whether it is an RPC
        if event.n == "RPC":
            return self.filter_service_basic(self.rpctree_service(event, 0), filters) and self.filter_service_basic(self.rpctree_service(event, 1), filters)
        else:
            if uuid not in self.uuidmap: return False
            return self.filter_service_basic(self.uuidmap[uuid], filters)
    def filter_apply_mutations(self, tree, filters): # Apply mutations from all filters to an event
        for f in filters: tree = f.mutate(tree)
        return tree
    def filter_one(self, df, filters):
        if len(df) == 0: return df
        df = df[df.apply(lambda x: self.filter_apply_services(x["id"], x[self.eventcol], filters), 1)] # Filter the services
        df = df[df[self.eventcol].apply(lambda x: self.filter_apply_events(x, filters))] # Filter the events
        df[self.eventcol] = df[self.eventcol].map(lambda x: self.filter_apply_mutations(x, filters)) # Apply mutations to the event
        df["name"] = df["id"].map(self.uuid2name)
        return df
    def uuid2name(self, uuid): # Get the service name corresponding to a UUID
        service = self.uuidmap[uuid]
        return self.service_name(service, self.servicemap[service])
    def host2name(self, host): # Get the service name corresponding to a SparkHost parentree
        service = tuple(host.n.split("/")[-1].split(":"))
        return self. service_name(service, self.servicemap[service])
    def makemaps(self):
        alluuids = self.df["id"].unique().tolist() # Get a list of all UUIDs
        self.uuidrevmap = { self.rpctree_service(row[self.eventcol], 1) : row["id"] for row in self.df[self.df[self.eventcol].apply(lambda x: x.n == "RPC")][["id", self.eventcol]].to_dict("records") } # Map from IP and port to trace UUID
        missing_uuids = [ uuid for uuid in alluuids if uuid not in self.uuidrevmap.values() ] # UUIDs that are not resolvable
        self.uuidrevmap.update({ (uuid, ""): uuid for uuid in missing_uuids }) # JVMs that have never received a message are unresolvable.  Thus, map their UUIDs to themselves so they can still be in the plot
        self.uuidmap = { val : key for key, val in self.uuidrevmap.items() } # Map indeterminately from trace UUID to IP and port
        self.servicemap = { tuple(event[1].n.split("/")[-1].split(":")): event[0].n for event in self.df[self.df[self.eventcol].apply(lambda x: x.n == "Service")][self.eventcol].drop_duplicates().tolist() } #{ self.rpctree_service(host, 1): host[1][0].n for host in self.df[self.df[self.eventcol].apply(lambda x: x.n == "RPC")][self.eventcol].drop_duplicates().tolist() } # Map from IP and port to service name
        self.servicemap.update({ (uuid, ""): "" for uuid in missing_uuids }) # Map unresolved UUIDs back to themselves to handle the case in the line after next
        self.df["name"] = self.df["id"].map(self.uuid2name)
    def __init__(self, sources):
        subtraces = []
        self.tracemap = {}
        for source in sources: # Load CSVs
            curtrace = pd.concat([ pd.read_csv(open(file), sep="\t") for file in glob.glob(source + "/*.tsv") if os.path.getsize(file) > 0 ])
            curtrace["traceid"] = len(subtraces)
            starttime = curtrace["time"].min() # Find the start time, so that all plots are normalized to start from time 0
            curtrace["time"] = pd.to_datetime(curtrace["time"] - starttime, unit="ms") # Add datetime column
            for uuid in curtrace["id"].unique().tolist():
                self.tracemap[uuid] = len(subtraces)
            subtraces.append(curtrace)
        self.df = pd.concat(subtraces)
        self.df[self.eventcol] = self.df["type"].map(lambda x: parentree(x)) # Add column of parsed case class trees
        self.makemaps()
        
        self.xrpcs = self.df[self.df[self.eventcol].apply(lambda x: x.n == "RPC")]
        self.xrpcs["src"] = self.xrpcs[self.eventcol].map(lambda x: self.tree_service_name(x[0]))
        self.xrpcs["dst"] = self.xrpcs[self.eventcol].map(lambda x: self.tree_service_name(x[1]))
        self.xevents = self.df[self.df[self.eventcol].apply(lambda x: x.n not in ["RPC", "ProcessStart", "ProcessEnd"])]
        self.xprocesses = self.df[self.df[self.eventcol].apply(lambda x: x.n in ["ProcessStart", "ProcessEnd"])]
        self.starttimes = self.getstarts()
        servicetmp = sorted([ ((service, name), self.tracemap[self.uuidrevmap[service]], self.starttimes[self.uuidrevmap[service]]) for (service, name) in self.servicemap.items() ], key=functools.cmp_to_key(lambda a, b: self.svccmp(a[1], b[1], a[2], b[2])))
        self.xservices = [ svc[0] for svc in servicetmp ]
    def filter(self, filters): # Apply filters
        self.servicemap = { key: val for key, val in self.servicemap.items() if self.filter_service_basic(key, filters) } # Refresh the service map
        self.uuidmap = { val: key for key, val in self.uuidrevmap.items() if self.filter_service_basic(key, filters) } # Refresh UUID map so UUIDs map to non-excluded services when possible
        self.df = self.filter_one(self.df, filters)
        self.xrpcs = self.filter_one(self.xrpcs, filters)
        self.xevents = self.filter_one(self.xevents, filters)
        self.xservices = [ svc for svc in self.xservices if svc[0] in self.servicemap.keys() ] # FIXME Update xservices based on new start times
        self.xprocesses = self.filter_one(self.xprocesses, filters)
    def resolved(self): # Get a table of the resolved services, hosts, and UUIDs
        return pd.DataFrame([ (self.tracemap[uuid], uuid, hostport[0], hostport[1], self.servicemap[hostport]) for (uuid, hostport) in self.uuidmap.items() ], columns=["T", "UUID", "Host", "Port", "Service"])
    def rpcs(self):
        ret = self.xrpcs
        ret["type"] = ret[self.eventcol].map(str)
        del ret[self.eventcol]
        return ret
    def events(self):
        ret = self.xevents
        ret["type"] = ret[self.eventcol].map(str)
        del ret[self.eventcol]
        return ret
    def processes(self): # Pull out processes and return as a dataframe
        ret = self.xprocesses
        if len(ret) == 0: return pd.DataFrame({"name": [], "start": [], "end": []})
        ret["pid"] = ret[self.eventcol].map(lambda x: x[0].n)
        ret["order"] = ret[self.eventcol].map(lambda x: 0 if x.n == "ProcessStart" else 1)
        ret = ret.pivot("pid", "order")
        ret = pd.DataFrame({"id": ret["id", 0], "start": ret["time", 0], "end": ret["time", 1], "type": ret["type", 0], "event": ret["event", 0]})
        ret["name"] = ret["id"].map(self.uuid2name)
        ret["type"] = ret[self.eventcol].map(lambda x: str(x[1]))
        del(ret[self.eventcol])
        return ret
    def services(self): # Get a list of the services involved in the trace
        return [ self.service_name(svc[0], svc[1]) for svc in self.xservices ]
    def timerange(self): # Get the start and end time of the trace.
        return (self.df["time"].min(), self.df["time"].max())
    def debug_display(self): # Dump internal state
        display(self.resolved())
        display(self.tracemap) # UUID -> trace ID
        display(self.uuidrevmap) # (IP, port) -> UUID
        display(self.uuidmap) # UUID -> (IP, port)
        display(self.servicemap) # (IP, port) -> service name
        display(self.starttimes) # UUID -> start time
        display(self.df)

class displayfilter:
    def fevent(event): return True
    def fservice(host, port, name): return True
    def mutate(tree): return tree

def seqplot(trace):
    rpcs = trace.rpcs()
    events = trace.events()
    processes = trace.processes()
    services = trace.services()
    timerange = trace.timerange()

    hover = HoverTool()
    hover.tooltips = "<div style='max-width: 400px; word-wrap: wrap-all'>@type</div>"
    p = figure(y_axis_type="datetime", x_range=services, tools=["ypan", "ywheel_zoom", hover, "reset"], active_scroll="ywheel_zoom")
    p.segment(y0="start", y1="end", x0="name", x1="name", source=ColumnDataSource(processes), line_width=4, color="lime", alpha=0.6)
    p.triangle("name", "end", source=ColumnDataSource(processes), size=12, color="green")
    p.inverted_triangle("name", "start", source=ColumnDataSource(processes), size=8, color="lime")
    p.circle("src", "time", size=8, source=ColumnDataSource(rpcs), color="blue")
    p.segment(y0="time", y1="time", x0="src", x1="dst", source=ColumnDataSource(rpcs), color="blue")
    p.circle("name", "time", size=8, source=ColumnDataSource(events), color="red")
    p.y_range = ranges.Range1d(timerange[1], timerange[0])
    p.xaxis.major_label_orientation = math.pi/6
    p.sizing_mode = "scale_width"
    p.height = 400
    return p

class stat:
    def name(): return "<default>"
    def extract(events): return pd.Series()

def filt(events, function): return events[events.apply(function, axis=1)]

def timedelta(events, partitioner):
    ret = {}
    e2 = events.copy()
    e2["part"] = e2.apply(partitioner, axis=1)
    e2 = e2[e2.apply(lambda x: x is not None, axis = 1)]
    for part in e2["part"].unique().tolist():
        times = e2[e2["part"] == part]["time"].tolist()
        if len(times) != 2: continue
        ret[part] = (times[1] - times[0]).total_seconds()
    return pd.Series(list(ret.values()), list(ret.keys()))

class statcol:
    def name(): return "<default>"
    def calc(vals): return None

def calcstats(trace, stats, cols):
    data = []
    for traceid in trace.df["traceid"].unique().tolist():
        subtrace = trace.df[trace.df["traceid"] == traceid]
        extracted = [ s.extract(subtrace) for s in stats ]
        data.extend([ [col.name(), traceid] + [ col.calc(s) for s in extracted ] for col in cols ])
    ret = pd.DataFrame(data, columns=["stat", "traceid"] + [ s.name() for s in stats ])
    ret = ret.set_index(["traceid", "stat"]).unstack("traceid").transpose().swaplevel().sort_index()
    ret.columns.name = None
    ret.index = pd.MultiIndex(levels=ret.index.levels, labels=ret.index.labels, names=[None, None])
    ret = ret[[ c.name() for c in cols ]]
    return ret

In [67]:
# FILTERS

class remove_cruft(displayfilter):
    def fevent(event):
        if event.n == "RPC":
            if event[2].n == "HeartbeatResponse": return False
            if event[2].n == "RequestMessage":
                if event[2][2].n == "Heartbeat": return False
        if event.n == "BMMUpdate": return False # Seems to fire on all block put/get/delete requests
        if event.n == "TrackerRegisterShuffle": return False # Duplicates RegisterShuffle
        if event.n == "Service": return False # Used to resolve service names
        return True
    def fservice(host, port, name):
        if name == "": return False # Unresolved services
        if name == "driverPropsFetcher": return False
        return True

class clean_rpcs(displayfilter):
    def mutate(tree):
        if tree.n == "RPC":
            tree = tree[2]
            if tree.n == "RequestMessage": tree = tree[2]
        return tree

class only_tasks(displayfilter):
    pass

class remove_events(displayfilter):
    def fevent(event):
        if event.n == "RPC": return True
        return False

class remove_rpcs(displayfilter):
    def fevent(event): return not remove_events.fevent(event)

class events_only_block(displayfilter):
    def fevent(event):
        if event.n == "RPC": return True
        if event.n in ["TrackerRegisterShuffle", "RegisterShuffle", "UnregisterShuffle", "BlockFetch", "BlockUpload",
            "GetBlock", "GetBlockData", "PutBlock", "DeleteBlock", "FreeBlock", "BMMRegister", "BMMUpdate",
            "BMMRemoveBlock", "BMMRemoveRDD", "BMMRemoveShuffle", "BMMRemoveBroadcast"]: return True
        return False

class only_management(displayfilter):
    def fevent(event):
        if event.n == "RPC":
            if event[2].n.startswith("Register") or event[2].n.startswith("Stop"): return True
        if event.n in [
            "DebugMessage", "JVMStart", "MainStart", "MainEnd", "SpawnExecutor", "StartYarnClient",
            "SubmittedApplication", "SubmitTaskSet", "SubmittedTaskSet", "ExecutorDone",
            "DagSchedulerEvent"
        ]: return True
        if event.n in ["ProcessStart", "ProcessEnd"]:
            if event[1].n in [
                "DebugProcess", "JVMStart", "CreateSparkContext", "CreateSparkEnv", "YarnAllocate",
                "FetchDriverProps"
            ]: return True
        return False

# COLUMNS

class stat_count(statcol):
    def name(): return "Count"
    def calc(vals): return len(vals)

class stat_min(statcol):
    def name(): return "Min"
    def calc(vals): return vals.min() if len(vals) > 0 else None

class stat_max(statcol):
    def name(): return "Max"
    def calc(vals): return vals.max() if len(vals) > 0 else None

class stat_mean(statcol):
    def name(): return "Average"
    def calc(vals): return vals.mean() if len(vals) > 0 else None

class stat_median(statcol):
    def name(): return "50%"
    def calc(vals): return vals.median() if len(vals) > 0 else None

class stat_25p(statcol):
    def name(): return "25%"
    def calc(vals): return vals.quantile(0.25) if len(vals) > 0 else None

class stat_75p(statcol):
    def name(): return "75%"
    def calc(vals): return vals.quantile(0.75) if len(vals) > 0 else None

class stat_argmin(statcol):
    def name(): return "Min at"
    def calc(vals): return vals.argmin() if len(vals) > 0 else None

class stat_argmax(statcol):
    def name(): return "Max at"
    def calc(vals): return vals.argmax() if len(vals) > 0 else None

# STATISTICS

class data_jvmstart(stat):
    def name(): return "JVM start time"
    def extract(events):
        def partition(row):
            if row["event"].n not in ["ProcessStart", "ProcessEnd"] or row["event"][1].n != "JVMStart": return None
            return trace.uuid2name(row["id"]) # BAD global variable `trace`
        return timedelta(events, partition)

class data_execlife(stat):
    def name(): return "Executor lifetime"
    def extract(events):
        def partition(row):
            if "sparkExecutor" not in row["name"]: return None
            if not ((row["event"].n == "MainEnd") or (row["event"].n == "ProcessEnd" and row["event"][1].n == "JVMStart")): return None
            return trace.uuid2name(row["id"]) # BAD global variable `trace`
        return timedelta(events, partition)

class data_nrpcs(stat):
    def name(): return "RPCs sent"
    def extract(events):
        ret = events[events.apply(lambda x: x["event"].n == "RPC", axis=1)]
        ret["sender"] = ret["event"].apply(lambda x: trace.host2name(x[0])) # BAD global variable `trace`
        return ret.groupby("sender").count()["id"]

class data_tasklife(stat):
    def name(): return "Task duration"
    def extract(events):
        def partition(row):
            ev = row["event"]
            if ev.n != "DagSchedulerEvent": return None
            if ev[0].n not in ["BeginEvent", "CompletionEvent"]: return None
            return (ev[0][0][0].n, ev[0][0][1].n)
        return timedelta(events, partition)

class data_execs(stat):
    def name(): return "Executors started"
    def extract(events):
        return filt(events, lambda x: x["event"].n == "DagSchedulerEvent" and x["event"][0].n == "ExecutorAdded")

class data_jobs(stat):
    def name(): return "Jobs"
    def extract(events):
        return filt(events, lambda x: x["event"].n == "DagSchedulerEvent" and x["event"][0].n == "JobSubmitted")

class data_tasks(stat):
    def name(): return "Tasks"
    def extract(events):
        return filt(events, lambda x: x["event"].n == "DagSchedulerEvent" and x["event"][0].n == "BeginEvent")

class data_blockupdates(stat):
    def name(): return "Block updates"
    def extract(events):
        return filt(events, lambda x: x["event"].n == "BlockOperation" and x["event"][0].n == "updateBlockInfo")

In [73]:
base = "/tmp/spark-trace"
#base = "/home/matt/code/spark-tracing/runs/remote"
infiles = [base]
checkpoint = base + ".pkl"
refresh = True

if refresh:
    log("Reading input")
    trace = reader(infiles)
    log("Applying preliminary filters")
    trace.filter([remove_cruft])
    log("Calculating statistics")
    stats = [calcstats(trace,
        [data_execs, data_jobs, data_tasks, data_blockupdates],
        [stat_count]),
    calcstats(trace,
        [data_jvmstart, data_nrpcs, data_tasklife],
        [stat_count, stat_min, stat_25p, stat_median, stat_75p, stat_max, stat_argmin, stat_argmax])]
    log("Writing output")
    with open(checkpoint, "wb") as outf:
        pickle.dump(trace, outf)
        pickle.dump(stats, outf)
else:
    log("Reading checkpoint")
    with open(checkpoint, "rb") as inf:
        trace = pickle.load(inf)
        stats = pickle.load(inf)
log("Applying filters")
trace.filter([])
trace.filter([clean_rpcs])
log("Displaying results")
#trace.debug_display()
for s in stats: display(s)
show(seqplot(trace))

2017-09-05 13:53:13.123213 Reading input
2017-09-05 13:53:13.166381 Applying preliminary filters
2017-09-05 13:53:13.208254 Calculating statistics
2017-09-05 13:53:13.348501 Writing output
2017-09-05 13:53:13.354795 Applying filters
2017-09-05 13:53:13.434647 Displaying results


Unnamed: 0,Unnamed: 1,Count
0,Executors started,0
0,Jobs,0
0,Tasks,0
0,Block updates,0


Unnamed: 0,Unnamed: 1,Count,Min,25%,50%,75%,Max,Min at,Max at
0,JVM start time,3,0.47,0.5465,0.623,0.738,0.853,10.50.108.70 40477 sparkDriver,10.50.108.70 52484 sparkExecutor
0,RPCs sent,3,6.0,20.5,35.0,42.0,49.0,10.50.108.70 52476 sparkYarnAM,10.50.108.70 52484 sparkExecutor
0,Task duration,0,,,,,,,
