In [None]:
# Loading the Pallas module, Numpy and Matplotlib
import pallas_trace as pallas
import numpy as np
import matplotlib.pyplot as plt


trace_name = "amg.4.16_trace"

In [None]:
%%time
# Downloading a trace from a trusted source and decompressing it
from urllib.request import urlretrieve
import tarfile
urlretrieve(f"http://stark2.int-evry.fr/traces/pallas_traces/ABI_{pallas.get_ABI()}/{trace_name}.tgz", f"{trace_name}.tgz")
with tarfile.open(f"{trace_name}.tgz") as tarfile:
    tarfile.extractall(".")


In [None]:
%%time
if not trace_name.endswith(".pallas"):
    trace_name += "/eztrace_log.pallas"

trace = pallas.open_trace(trace_name)
# Checking out general information about the trace
print(f"Trace is located in {trace.dir_name}/{trace.trace_name}")
# This loads the archives
print(f"Trace contains {len(trace.archives)} archives")
# This loads the threads
print(f"Trace contains {sum([len(a.threads) for a in trace.archives])} threads")


In [None]:
%%time
# Manually
def print_pallas_object(obj: pallas.Sequence| pallas.Loop | pallas.Event, index: int):
    match type(obj):
        case pallas.Sequence:
            print(f"{obj.timestamps[index]/1e9}\t{obj.guessName()}")
#        case pallas.Loop:
#            print(f"\tLoop {obj.id}")
        case pallas.Event:
            print(f"{obj.timestamps[index]/1e9}\t{obj.guessName()}")

def print_sequence(s: pallas.Sequence, counter: dict):
    for temp in s.content:
        if temp.id not in counter:
            counter[temp.id] = 0
        print_pallas_object(temp, counter[temp.id])
        counter[temp.id] +=1
        match type(temp):
            case pallas.Sequence:
                print_sequence(temp, counter)
            case pallas.Loop:
                for loop in range(temp.nb_iterations):
                    if temp.sequence.id not in counter:
                        counter[temp.sequence.id] = 0
                    print_pallas_object(temp.sequence, counter[temp.sequence.id])
                    print_sequence(temp.sequence, counter)
                    counter[temp.sequence.id] += 1

def print_thread(thread: pallas.Thread):
    counter = {}
    print_sequence(thread.sequences[0], counter)
                

print_thread(trace.archives[0].threads[0])
print("Done")

In [None]:
%%time
# Or using the built-in iterator
def print_thread(thread: pallas.Thread):
    for (token, index) in thread:
        print_pallas_object(token, index)

print_thread(trace.archives[0].threads[0])
print("Done ! It's definitely faster.")

In [None]:
%%time
# Creating a communication matrix
matrix = np.zeros((len(trace.archives), len(trace.archives)))

for sender, archive in enumerate(trace.archives):
    for thread in archive.threads:
        for event in thread.get_events_from_record(pallas.Record.MPI_ISEND):
            data = event.data
            matrix[sender][data['receiver']] += data['msgLength']

plt.matshow(matrix)

In [None]:
%%time
# Plotting an histogram to see the time distribution of a certain sequence
main_thread = trace.archives[0].threads[0]
for s in main_thread.sequences:
    print(f"{s.id}\t{s.guessName()}\t{s.min_duration / 1e9}\t{s.max_duration/1e9}\t{s.mean_duration/1e9}\t{s.n_iterations}")


selected_sequence = main_thread.sequences[20]
plt.hist(selected_sequence.durations.as_numpy_array() / 1e9)