In [None]:
import matplotlib.pyplot as plt
import os

OLD_2D = 0
OLD_3D = 1
NEW_2D = 2
NEW_3D = 3

dirpath = "./results"
all_results = [{}, {}, {}, {}]
encoding_names = [
    "Original 2D Encoding",
    "Original 3D Encoding",
    "New 2D Encoding",
    "New 3D Encoding",
]

In [None]:
# Store a list of the models where the new encoding differed from the old encoding
invalid = []
valid = []
seen_contacts = {}

for filepath in os.listdir(dirpath):
    if not filepath.endswith(".csv"):
        continue
    with open(os.path.join(dirpath, filepath)) as f:
        # Ignore it if it's empty
        lines = f.readlines()
        if len(lines) == 1:
            continue

        model_name = filepath[:-8]
        old = "old" in filepath

        if "_2d" in filepath:
            result_type = OLD_2D if old else NEW_2D
        else:
            result_type = OLD_3D if old else NEW_3D
        
        rows = [list(map(float, line.split(","))) for line in lines[1:]]
        # print(model_name, rows[0])

        # Check if the new encoding is the same as the old encoding
        contacts = int(rows[0][2])
        if model_name not in seen_contacts:
            seen_contacts[model_name] = contacts
        else:
            if contacts != seen_contacts[model_name]:
                invalid.append(model_name)
            else:
                valid.append(model_name)

        for row in rows:
            length = row[0]
            time = float(row[1])

            
            # Add or increment the result
            results = all_results[result_type]
            if length not in results:
                results[length] = []
            results[length].append({
                "model_name": model_name,
                "time": time
            })

In [None]:
# Remove the dimension from the filename
invalid_sequences = list(set(x[:-3] for x in invalid))
if len(invalid_sequences) != len(invalid) / 2:
    print("There were some sequences that were correct for one dimension but not the other")


def get_sequence_from_inputs(filename: str) -> str:
    with open(f"input/{filename}") as f:
        return f.readline()[:-1]


def get_contacts_from_results(filename: str) -> str:
    with open(f"results/{filename}") as f:
        f.readline()
        results = f.readline()
        return int(results.split(",")[-1])

    
# Print out the invalid sequences
for filename in sorted(invalid_sequences):
    print(filename)
    sequence = get_sequence_from_inputs(filename)
    print(f"{filename}: {sequence}")
    for encoding in ["2d_old", "2d_new", "3d_old", "3d_new"]:
        contacts = get_contacts_from_results(f"{filename}_{encoding}.csv")
        print(f"{encoding} {contacts = }")
    

# print("Invalid sequences: ")
# invalid.sort()
# print(invalid)

# print("Valid sequences: ")
# valid.sort()
# print(valid)

In [None]:
all_data = []

# Loop through all the encoding types
for results in all_results:
    data = []

    # Go through all the lengths
    for length, models in results.items():
        count = 0
        total_time = 0

        # Go through all the models of that length and find the average time
        for model in models:
            if model["model_name"] in invalid:
                continue
            count += 1
            total_time += model["time"]
        
        if count == 0:
            continue
        total_time /= count

        data.append((int(length), round(total_time, 2)))

    all_data.append(data)

# Sort the the data of each encoding type by the length of the sequences
for i, data in enumerate(all_data):
    data.sort(key=lambda x:x[0])
    print(encoding_names[i])
    print(data)
    print()

# Graph the data
for i in range(4):
    names = [x[0] for x in all_data[i]]
    values = [x[1] for x in all_data[i]]
    plt.plot(names, values, marker="o", label=encoding_names[i])

plt.grid(True, which="minor")
plt.xlabel("Sequence length")
plt.xticks(names)
plt.ylabel("Time(s)")
plt.yscale("log")
plt.legend()
plt.show()