In [None]:
import csv
import matplotlib.pyplot as plt
import os

OLD_2D = 0
OLD_3D = 1
NEW_2D = 2
NEW_3D = 3

dirpath = "./results"
all_results = [{}, {}, {}, {}]
encoding_names = [
    "Original 2D Encoding",
    "Original 3D Encoding",
    "New 2D Encoding",
    "New 3D Encoding",
]

In [None]:
# Store a list of the models where the new encoding differed from the old encoding
invalid = []
valid = []
seen_contacts = {}

for filepath in os.listdir(dirpath):
    if not filepath.endswith(".csv"):
        continue
    with open(os.path.join(dirpath, filepath)) as f:
        # Ignore it if it's empty
        lines = f.readlines()
        if len(lines) == 1:
            continue

        model_name = filepath[:-8]
        old = "old" in filepath

        if "_2d" in filepath:
            result_type = OLD_2D if old else NEW_2D
        else:
            result_type = OLD_3D if old else NEW_3D
        
        rows = [list(map(float, line.split(","))) for line in lines[1:]]
        # print(model_name, rows[0])

        # Check if the new encoding is the same as the old encoding
        contacts = int(rows[0][2])
        if model_name not in seen_contacts:
            seen_contacts[model_name] = contacts
        else:
            if contacts != seen_contacts[model_name]:
                invalid.append(model_name)
            else:
                valid.append(model_name)

        for row in rows:
            length = row[0]
            time = float(row[1])

            
            # Add or increment the result
            results = all_results[result_type]
            if length not in results:
                results[length] = {
                    "total_time": time,
                    "count": 0,
                    "model": model_name
                }
            else:
                results[length]["total_time"] += time
                results[length]["count"] += 1

In [None]:
invalid.sort()
valid.sort()
print(invalid)
print(valid)

In [None]:
all_data = []

for results in all_results:
    data = []
    for key in results:
        result = results[key]
        if result["count"] == 0 or result["model"] in invalid:
            continue
        result["total_time"] /= result["count"]
        data.append((int(key), round(result["total_time"], 4)))
    all_data.append(data)

print("data:")
for data in all_data:
    data.sort(key=lambda x:x[0])
    print(data)

# Graph the data
for i in range(4):
    names = [x[0] for x in all_data[i]]
    values = [x[1] for x in all_data[i]]
    plt.plot(names, values, marker="o", label=encoding_names[i])

plt.grid(True, which="minor")
plt.xlabel("Sequence length")
plt.xticks(names)
plt.ylabel("Time(s)")
plt.yscale("log")
plt.legend()
plt.show()