In [None]:
import gemmi 
import matplotlib.pyplot as plt
import test_model as model
import numpy as np 
import os
import logging
import pandas as pd

In [None]:
def mtz_to_map(mtz: gemmi.Mtz, resolution_cutoff: float = 3.0) -> gemmi.FloatGrid:
    data = np.array(mtz, copy=False)
    mtz.set_data(data[mtz.make_d_array() >= resolution_cutoff])
    return mtz.transform_f_phi_to_map("FWT", "PHWT")

In [None]:
# Generate new maps with varying resolutions
mtz_file = "/home/jordan/dev/sugar_prediction_model/data/mtz/1d8g_phases.mtz"
output_dir = "/home/jordan/dev/sugar_prediction_model/data/tmp"
resolution_list = [
    1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0
]

mtz_obj = gemmi.read_mtz_file(mtz_file)

for resolution in resolution_list:
    map_obj = mtz_to_map(mtz=mtz_obj, resolution_cutoff=resolution)
    ccp4 = gemmi.Ccp4Map()
    ccp4.grid = map_obj
    ccp4.update_ccp4_header()

    output_path = os.path.join(output_dir, f"{resolution}.map")
    ccp4.write_ccp4_map(output_path)

In [None]:
# Calculate score for each map 

map_file_dir = "/home/jordan/dev/sugar_prediction_model/data/tmp"

model_path = "/home/jordan/dev/sugar_prediction_model/models/base_1.5A_model_1.best.hdf5"
pdb_code = "1d8g" 

logging.basicConfig(
        level=logging.DEBUG, format="%(asctime)s %(levelname)s - %(message)s"
)

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

for map_file in os.scandir(map_file_dir):
    resolution = '.'.join(map_file.name.split(".")[:-1])
    test = model.TestModel(model_dir=model_path, use_cache=False)
    test.make_prediction(map_path=map_file.path, pdb_code=pdb_code)
    test.save_score("/home/jordan/dev/sugar_prediction_model/results/1d8g_res_test", suffix=resolution)

In [None]:
res_test_dir = "/home/jordan/dev/sugar_prediction_model/results/1d8g_res_test/base_1.5A_model_1.best.hdf5"

resolutions = []
positives = []


for file in os.scandir(res_test_dir):
    df = pd.read_csv(file.path)
    
    res = float('.'.join(file.name.split("_")[-1].split(".")[:-1]))
    positive = df["Positive"].values[0]
    false_negative = df["FalseNegative"].values[0]

    resolutions.append(res)
    positives.append(100* (positive)/(positive+false_negative))

plt.bar(resolutions, positives, width=0.3, color='orange')
plt.xlabel("Resolution / A")
plt.ylabel("Correct points located / %")