# Denoising-recontruction Autoencoder (DRACO) Denoising Visualizer
In this demo, we will show how to denoise using pretrained DRACO model.

### Setup

In [None]:
from pathlib import Path
import sys

import matplotlib.pyplot as plt
import numpy as np
from omegaconf import DictConfig
from PIL import Image
import torch
import torchvision.transforms.v2 as v2
from process import preprocess

sys.path.append(str(Path.cwd().parent))

from draco.configuration import CfgNode
from draco.model import (
    build_model,
    load_pretrained
)

In [None]:
ImageType = np.ndarray | Image.Image

def set_one_image(image: ImageType, title: str = "") -> None:
    if isinstance(image, Image.Image):
        image = np.array(image)
    image = (image-image.min()) / (image.max()-image.min())

    plt.imshow(image,cmap='gray')
    plt.title(title)
    plt.axis("off")

def show_one_image(image: ImageType, title: str = "") -> None:
    plt.figure(figsize=(5, 5))
    set_one_image(image, title)
    plt.show()

def show_denoising(origin: ImageType, denoised: ImageType) -> None:
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    set_one_image(origin, "Origin")

    plt.subplot(1, 3, 2)
    set_one_image(denoised, "Denoised")

    plt.subplot(1, 3, 3)
    difference = np.asarray(origin) - np.asarray(denoised)[:,:,0]
    set_one_image(difference, "Difference")

    plt.show()

def set_seed(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

### Visualizer
`DRACODenoiser` performs denoising on one micrograph at a time and then plots the result. To run inference on batched inputs or save the denoised micrograph, you can modify the `inference` function within `DRACODenoiser`

In [None]:
class DRACODenoiser(object):
    def __init__(self,
        cfg: DictConfig,
        ckpt_path: Path,
    ) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.transform = self.build_transform()
        self.model = build_model(cfg).to(self.device).eval()
        self.model = load_pretrained(self.model, ckpt_path, self.device)
        self.patch_size = cfg.MODEL.PATCH_SIZE

    def patchify(self, image: torch.Tensor) -> torch.Tensor:
        B, C, H, W = image.shape
        P = self.patch_size
        if H % P != 0 or W % P != 0:
            image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0)

        patches = image.unfold(2, P, P).unfold(3, P, P)
        patches = patches.permute(0, 2, 3, 4, 5, 1)
        patches = patches.reshape(B, -1, P * P * C)
        return patches

    def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor:
        B = patches.shape[0]
        P = self.patch_size

        images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1)
        images = images.permute(0, 5, 1, 3, 2, 4)
        images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P)
        images = images[..., :H, :W]
        return images

    @classmethod
    def build_transform(cls) -> v2.Compose:
        return v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True)
        ])

    @torch.inference_mode()
    def inference(self, image: Image.Image) -> None:
        W, H = image.size

        x = self.transform(image).unsqueeze(0).to(self.device)
        y = self.model(x)

        x = self.patchify(x).detach().cpu().numpy()
        denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy()

        show_denoising(image, denoised)
        

### Build Visualizer
To build the visualizer, provide the model parameter `.yaml` file and the corresponding checkpoint. By default, the model used is `DRACO-base`. To switch to `DRACO-large`, simply change the `VIT_SCALE` parameter in `denoise.yaml` to `large`. Note that `large` model could require a graphic card with more than 16GB display memories when inferencing.

In [None]:
cfg = CfgNode.load_yaml_with_base(Path("denoise.yaml"))
CfgNode.merge_with_dotlist(cfg, [])
ckpt_path = Path("CHECKPOINT_PATH")
visualizer = DRACODenoiser(cfg, ckpt_path)

### Load Data
The network input should be normalized micrographs. By default, our data is in `.h5` format. In our customized `.h5 `data format, the mean and standard deviation of the micrograph are pre-calculated and stored in the header, allowing direct normalization of the data. For raw `.mrc` files, we have also implemented an input processing function for you.

In [None]:
# h5 file 
import h5py
img_path = "H5_FILE_PATH"
with h5py.File(img_path, 'r') as hdf5_file:
    full_micrograph = hdf5_file["micrograph"]
    mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean()
    std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std()
    img = (hdf5_file["micrograph"][:].astype(np.float32) - mean) / std


In [None]:
# mrc file
import mrcfile as mrc
img_path = "YOUR_MRC_FILE_PATH"
with mrc.open(img_path, permissive=True) as m:
    img = m.data.copy().astype(np.float32)
    img, mean, std = preprocess(img)
    img = (img - mean) / std

### Inference
Set the mask ratio to 0 during inference. Since inference on a full-resolution bin 1 micrograph can be time-consuming, a 1024×1024 crop is applied in this demo, though you may input the entire micrograph if desired.

Note that the input micrograph’s height and width must be multiples of the patch size, which is 16 in our models. If the micrograph dimensions do not meet this requirement, apply padding or cropping as needed.

In [None]:
set_seed(0)
visualizer.inference(Image.fromarray(img[1024:2048,1024:2048]))