# DeepDeWedge Tutorial

This is a minimal example for how to apply DeepDeWedge.

In [None]:
import torch
import math
from dataset import setup_fitting_and_val_dataset
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from utils.mrctools import load_mrc_data
from matplotlib import pyplot as plt

from model import Unet3D
from utils.fitting import masked_loss, get_avg_model_input_mean_and_var
from utils.mrctools import save_mrc_data
from utils.visualization import plot_tomo_slices
import datetime
import shutil
from utils.missing_wedge import fft_3d
from utils.dataloader import MultiEpochsDataLoader
from torchsummary import summary
from refine_tomogram import refine_tomogram

## Download the tutorial dataset
We apply DeepDeWedge to the Wiener-Filter CTF corrected FBP reconstructon of tilt series 05 of EMPIAR-10045. First, we download the tutorial dataset which contains FBP reconstructions from the even, odd and full tilt series. We binned these reconstructions by a factor of 6 using average pooling, which results in a physical voxel size of 13.02 Angstroms. For CTF correction, we used the Wiener-like filter implemented in IsoNet (https://github.com/IsoNet-cryoET/IsoNet).

The following two lines of code download the data as .zip archive and unzip it. Upon successful completion, you will find a new subdirectory 


In [None]:
# download data with wget
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1_yuI2Xu2ISnuBKT3FS_9cqdAXvC2Shh8' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1_yuI2Xu2ISnuBKT3FS_9cqdAXvC2Shh8" -O tutorial_data.zip && rm -rf /tmp/cookies.txt
!unzip tutorial_data.zip

In [None]:
print(os.listdir("tutorial_data"))

In [None]:
tomo_full = load_mrc_data("./tutorial_data/IS002_291013_006_FullFBP_Bin6_Wiener.mrc")
plot_tomo_slices(tomo_full, figsize=(20, 10))
print(tomo_full.shape)

## Setup datasets for model fitting
We first load the dataset for model fitting. The function `setup_fitting_and_val_dataset` returns two torch datasets, one for model fitting and one for validation to prevent overfitting. The tomograms corresponding to the filepaths in `tomo0_files` are used to construct the model inputs, while the ones in `tomo1_files` are used to construct the targets. 

Both the fitting and the validation dataset return model inputs and targets with shape `subtomo_size x subtomo_size x subtomo_size`. These subtomograms are extracted from the tomograms using x, y and z direction strides specified in `extraction_strides`. To reduce RAM consumption during model fitting, all subtomograms are saved to the directory `save_subtomos_to`, which is created if it doe note exist.

The number of elements in the validation set is at most `validation_frac` times the number of total extracted subtomograms. It may also contain fewer subtomograms since we randomly sample the validation subtomograms such that they have no overlap with the ones used for model fitting. If the sampling procedure was unable to sample enough validation subtomograms, the function prints a warning.

In [None]:
shutil.rmtree("./subtomos/", ignore_errors=True)

fitting_dataset, val_dataset = setup_fitting_and_val_dataset(
    tomo0_files=["./tutorial_data/IS002_291013_006_EvenFBP_Bin6_Wiener.mrc"],
    tomo1_files=["./tutorial_data/IS002_291013_006_OddFBP_Bin6_Wiener.mrc"],
    subtomo_size=80,
    extraction_strides=[40, 40, 40],
    mw_angle=60,
    val_fraction=0.2,
    save_subtomos_to="./subtomos/",
)

print(f"Number of subtomograms for model fitting: {len(fitting_dataset)}")
print(f"Number of subtomos for validation: {len(val_dataset)}")

Both the fitting and the validation dataset return dictionries containing the following items:
* `model_input`: A model input $\tilde{\mathbf{v}}_\varphi^0$ with two missing wedges
* `model_target`: A model target $\tilde{\mathbf{v}}_\varphi^1$ with only one missing wedge
* `rot_mw_mask`: The rotated missing wedge mask $\mathbf{M}_\varphi$
* `mw_mask`: The original missing wedge mask $\mathbf{M}$. This mask is the same for all elements in the dataset.

**Note**: The rotation angles $\varphi$ in the training set are always random, and are re-sampled every time an item is queried. For the validation dataset, we only sample random rotation angles once and every item always has its fixed rotation.


Let's now have a look at the real and Fourier domain representation of some of the model inputs:

In [None]:
for k in range(3):
    item = fitting_dataset[k]
    model_input = item["model_input"]
    model_input -= model_input.mean()  
    plot_tomo_slices(item["model_input"])

We create a fitting and a validation dataloader which return batches of elements from the fitting and validation sets:

In [None]:
batch_size = 5
num_workers = 10

fitting_dataloader = torch.utils.data.DataLoader(dataset=fitting_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True)

## Fit the model

First, we setup a 3D U-Net. The architecture we implemented in `unet.py`is the same used in the official IsoNet implementation.
Before we can start model fitting, we calculate the average mean and variance of the model inputs in the fitting dataset. We use these values to normalize the input to the U-Net during model fitting.

In [None]:
unet_params = {
    "in_chans": 1,
    "out_chans": 1,
    "chans": 64,
    "num_pool_layers": 3,
    "drop_prob": 0.0,
}

avg_model_input_mean, avg_model_input_var = get_avg_model_input_mean_and_var(
    fitting_dataloader, 
    batches=3*len(fitting_dataloader),
    verbose=True
)
unet_params["normalization_loc"] = avg_model_input_mean
unet_params["normalization_scale"] = math.sqrt(avg_model_input_var)

For model fitting, we use the PyTorch lightning framework for convenience. Below, we define the class `LitUnet3D`. The class takes parameters for a U-Net and the `torch.optim.Adam` optimizer as input and can then be used to fit the U-Net. The important methods of this class are:
* `training_step`: This step handles passing the model inputs provided by the fitting dataloader through the model and calculating the loss. 
+ `validation_step`: In this step, we can implement any validation routine we like. For simplicity, we just calculate the loss on the validation set to monitor overfitting. Depending on which logger we use, we can also log plots of the model output.

In [None]:
class LitUnet3D(pl.LightningModule):
    def __init__(self, unet_params, adam_params):
        super().__init__()
        self.unet_params = unet_params
        self.adam_params = adam_params
        self.unet = Unet3D(**self.unet_params)
        self.save_hyperparameters()

    def forward(self, x):
        return self.unet(x.unsqueeze(1)).squeeze(1)  # unsqueeze to add channel dimension, squeeze to remove it

    def training_step(self, batch, batch_idx):
        model_output = self(batch["model_input"])  
        loss = masked_loss(
            model_output=model_output, 
            target=batch["model_target"], 
            rot_mw_mask=batch["rot_mw_mask"], 
            mw_mask=batch["mw_mask"]
        )
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        model_output = self(batch["model_input"])
        loss = masked_loss(
            model_output=model_output,
            target=batch["model_target"], 
            rot_mw_mask=batch["rot_mw_mask"], 
            mw_mask=batch["mw_mask"]
        )
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), **self.adam_params)
        return optimizer

In [None]:
unet = LitUnet3D(unet_params=unet_params, adam_params={"lr": 4e-4})

PyTorch Lightning's ``Trainer`` will do the heavy lifting for us. In this tutorial, we only specify the bare minimum of parameters such as the number of epochs for model fitting `max_epochs`, and the GPU used for fitting. 

In [None]:
trainer = pl.Trainer(
    max_epochs=50,  # you can fit for longer but 50 epochs should alredy yield a decent result
    accelerator="gpu",
    devices=[1],
    check_val_every_n_epoch=1,
    logger=pl.loggers.CSVLogger("csv_logs", name="tutorial"),  # the logger creates a folder "csv_logs" and saves all logs as csv files there
)
trainer.fit(unet, fitting_dataloader, val_dataloader)

## Refine the full FBP reconstruction

In [None]:
tomo_ref = refine_tomogram(
    tomo=tomo_full.cuda(1), 
    lightning_model=unet.cuda(1),
    subtomo_size=80, 
    extraction_strides=[40, 40, 40], 
    batch_size=10
)

plot_tomo_slices(tomo_ref.cpu(), figsize=(20, 10))