In [None]:
# Loading the Pallas module, Numpy and Matplotlib
import pallas_trace as pallas
import numpy as np
import matplotlib.pyplot as plt


trace_name = "amg.16.4_trace"

In [None]:
%%time
# Downloading a trace from a trusted source and decompressing it
from urllib.request import urlretrieve
import tarfile
urlretrieve(f"http://stark.int-evry.fr/traces/pallas_traces/ABI_{pallas.get_ABI()}/{trace_name}.tgz", f"{trace_name}.tgz")
with tarfile.open(f"{trace_name}.tgz") as tarfile:
    tarfile.extractall(".")


In [None]:
%%time
if not trace_name.endswith(".pallas"):
    trace_name += "/eztrace_log.pallas"

trace = pallas.open_trace(trace_name)
# Checking out general information about the trace
print(f"Trace is located in {trace.dir_name}/{trace.trace_name}")
# This loads the archives
print(f"Trace contains {len(trace.archives)} archives")
# This loads the threads
print(f"Trace contains {sum([len(a.threads) for a in trace.archives])} threads")


In [None]:
%%time
# Maybe you want to parse the trace ?
# This code right here shows how to parse all the events in a thread

# This is just a pretty print to show a timestamp for Sequences and Events
def print_pallas_object(obj: pallas.Sequence| pallas.Loop | pallas.Event, index: int):
    match type(obj):
        case pallas.Sequence:
            print(f"{obj.timestamps[index]/1e9}\t{obj.guessName()}")
        case pallas.Loop:
            print(f"\tLoop {obj.id}")
        case pallas.Event:
            print(f"{obj.timestamps[index]/1e9}\t{obj.guessName()}")

# Then you can iterate over a thread !
# Token is the token you're currently seeing, and index is the numerotation of that token
# ie you've "seen" that token index-th time before
def print_thread(thread: pallas.Thread):
    for (token, index) in thread:
        print_pallas_object(token, index)

print_thread(trace.archives[0].threads[0])


In [None]:
%%time
# Maybe you want to create a communication matrix ?
# This is an example that uses Pallas' API to parse the grammar
size = len(trace.archives)
matrix = np.zeros((size, size), dtype=np.uint64)

# Then get all messages received
for receiver, archive in enumerate(trace.archives):
    for thread in archive.threads:
        for event in thread.get_events_from_record([pallas.Record.MPI_IRECV, pallas.Record.MPI_RECV]):
            data = event.data
            matrix[data['sender']][receiver] += data['msgLength'] * event.nb_occurrences
plt.matshow(matrix)

In [None]:
%%time
# But there's actually a helper function that does this for us !
faster_matrix = pallas.get_communication_matrix(trace)
plt.matshow(faster_matrix)
# It's also faster ! ( 7 ms vs 30 ms )

In [None]:
%%time 
# Maybe you want to see a histogram of the communications over time ?
# It's quite easy to do !
# First get your timestamps
timestamps = np.linspace(trace.starting_timestamp, trace.ending_timestamp)
width = timestamps[1] - timestamps[0] 

# Then get your histogram and plot !
histogram_data = pallas.get_communication_over_time(trace, timestamps)

plt.bar(timestamps[:-1]/1e9, histogram_data / (1024 * 1024 ), width=width / 1e9)
plt.xlabel("Time (s)")
plt.ylabel("Data exchanged ( MiB )")
plt.show()

# Maybe you're interested in checking the number of messages ?
histogram_count = pallas.get_communication_over_time(trace, timestamps, count_messages=True)
plt.bar(timestamps[:-1]/1e9, histogram_count / 1e3, width=width / 1e9)
plt.xlabel("Time (s)")
plt.ylabel("Messages exchanged / 1k")
plt.show()

In [None]:
%%time
# Plotting an histogram to see the time distribution of a certain sequence
main_thread = trace.archives[0].threads[0]
for s in main_thread.sequences:
    print(f"{s.id}\t{s.guessName()}\t{s.min_duration / 1e9}\t{s.max_duration/1e9}\t{s.mean_duration/1e9}\t{s.n_iterations}")


selected_sequence = main_thread.sequences[20]
plt.hist(selected_sequence.durations.as_numpy_array() / 1e9)