<a href="https://colab.research.google.com/github/Structurebiology-BNL/ResEM/blob/main/Colab_ResEM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Upload your half maps or using half maps from EMDB: {display-mode: "form"}
#@markdown - You will be first asked to mount the Google drive with your google account.
from google.colab import drive
drive.mount('/content/drive')
from google.colab import files
#@markdown - Upload could take a while depending on the size of the map
#@markdown  - Please upload one half map at a time
#@markdown  - Cancel upload to download EMDB map instead
half_map_1, half_map_2, uploaded_map = None, None, None
print("\nChoose the first half map to uplodad:")
half_map_1 = files.upload()
if half_map_1:
    uploaded_map = list(half_map_1.keys())
print("\nChoose the second half map to uplodad:")
half_map_2 = files.upload()
if half_map_2:
    uploaded_map += list(half_map_2.keys())
if uploaded_map:
    print("\n Uploading finished")
#@markdown - Alternatively, you can choose to download half-maps from EMDB instead. {display-mode: "form"}
#@markdown  - Make sure that the corresponding EMDB entry has half maps
EMDB_ID = 23274  #@param {type: "number"}
#@markdown After starting the session, make sure to change the `runtime type` in the `Runtime` menu to use GPU for faster performance.

In [30]:
#@title Install and import packages {display-mode: "form"}
#@markdown Run this cell and Colab will: 
#@markdown - Install necessary python packages
#@markdown - Download ResEM model
#@markdown - Preprocess the uploaded maps or download maps from EMDB
#%%capture --no-stderr --no-display
!pip install -q mrcfile torchio gemmi
!pip install -q torch torchvision -f https://download.pytorch.org/whl/cu116/torch_stable.html
print("Package installazation finished")
!git clone https://github.com/Structurebiology-BNL/ResEM
%cd ResEM
!wget -O model_weights.pt 'https://docs.google.com/uc?export=download&id=1hCaEbYxQV56JIpN2c2iJSiiKAgRu7TT6&confirm=t'
print("ResEM download finished")
import torch
import torchio
from tqdm import tqdm
import mrcfile
from utils.utils import download_half_maps
from models.map_splitter import reconstruct_maps, map_resample
from models.unet import UNetRes

if uploaded_map:
    file_path_1 = "/content/" + uploaded_map[0]
    if len(uploaded_map) == 1:
        print(
            "only one half map is uploaded, will proceed to enhance a single half map"
        )
        input_map = mrcfile.open(file_path_1, mode="r")
        raw_map, input_data, meta_data = map_resample(input_map)
    elif len(uploaded_map) == 2:
        file_path_2 = "/content/" + uploaded_map[1]
        half_map_1 = mrcfile.open(file_path_1, mode="r")
        half_map_2 = mrcfile.open(file_path_2, mode="r")
        raw_map, input_data, meta_data = map_resample(
            half_map_1, input_map_2=half_map_2
        )
    else:
        print("no valid half maps are provided")
else:
    assert isinstance(int(EMDB_ID), int), "EMDB ID must be an integer"
    download_succesful = download_half_maps(EMDB_ID)
    if download_succesful:
        print("Read downloaded map EMDB-{} and resample".format(EMDB_ID))
        input_map = mrcfile.open("averaged_map_{}.ccp4".format(EMDB_ID), mode="r")
        raw_map, input_data, meta_data = map_resample(input_map)
    else:
        print("No valid half maps found for EMDB-{}".format(EMDB_ID))
print("\nResampled map ready")

In [None]:
# @title Run prediction {display-mode: "form"}
# @markdown - The prediction should be done in less than a minute if using GPU
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model = UNetRes(n_blocks=2, act_mode="R")
checkpoint = torch.load("/content/ResEM/model_weights.pt", map_location="cpu")
model.load_state_dict(checkpoint)
model = model.to(device)

batch_size = 16
torch.backends.cudnn.benchmark = True
model.eval()
with torch.no_grad():
    y_pred = torch.tensor(())
    print("Start prediction...")
    for indx in tqdm(range(0, input_data.shape[0], batch_size)):
        x_partial = input_data[indx : indx + batch_size].unsqueeze(dim=1).to(device)
        y_pred_partial = model(x_partial)
        y_pred = torch.cat(
            (y_pred, y_pred_partial.squeeze(dim=1).detach().cpu()),
            dim=0,
        )
    original_shape = (
        int(meta_data.cella.x),
        int(meta_data.cella.y),
        int(meta_data.cella.z),
    )
    y_pred_recon = reconstruct_maps(
        y_pred.numpy(),
        original_shape,
    )
    with mrcfile.new("pred_map.mrc") as mrc:
        mrc.set_data(y_pred_recon)
        mrc.header.cella.x = meta_data.cella.x
        mrc.header.cella.y = meta_data.cella.y
        mrc.header.cella.z = meta_data.cella.z
        mrc.header.nxstart = meta_data.nxstart
        mrc.header.nystart = meta_data.nystart
        mrc.header.nzstart = meta_data.nzstart
    print("\nPrediction done!")

In [None]:
#@title Visualization {display-mode: "form"}
#@markdown - The upper rows are the raw map (resampled and normalized to [0, 1])
#@markdown - The lower rows are the enhanced map
raw_map = (raw_map - raw_map.min()) / (raw_map.max() - raw_map.min())
raw_map = torchio.ScalarImage(tensor=torch.from_numpy(raw_map).unsqueeze(dim=0))
raw_map.plot(
    figsize=(12, 9),
    radiological=False,
)
pred_map = torchio.ScalarImage(tensor=torch.from_numpy(y_pred_recon).unsqueeze(dim=0))
pred_map.plot(
    figsize=(12, 9),
    radiological=False,
)

In [None]:
#@title Download enhanced map {display-mode: "form"}
#@markdown Please wait for a while for the pop up window of the browser to show up 
files.download("/content/ResEM/pred_map.mrc")