# Inference 

This notebook allows you to do: 
- ✅ **Inference**: Inference refers to using a trained model to make predictions on new, unseen data. This is the final step after training and evaluation, where the model is applied to real-world data.  
   - ✅ **Define Data**: Define the path to a directory or image which you would like to have predictions for.  
   - ✅ **Choose Model**: Choose a trained model which you would like to use for making predictions.  
   - ✅ **Make Predictions**: Make predictions on the provided data using the selected model.

---  

## Setup and Imports

> *Execute the cell below to import external libraries, which simplify the implementation of the notebook.*

In [1]:
# import dependencies
from IPython.display import display
from deepEM.Utils import create_text_widget, print_info
from src.Inferencer import Inference

## 3.1. Define Data

In difference to other deep learning methods, a model needs to be trained for each tomogram you wish to reconstruct. This is the case due to the special nature of the approach: It is "overfitting" the model to the tilt series used for training. Hence, please make sure to execute `1_Development.ipynb` for each tomogram generation. 

Then, you will not need to provide data for inference. The tomogram will be based on the data provided for training in `1_Development.ipynb`.

> *Execute the cell below to visulize a text form to provdide the batch size for infernce.*

In [2]:
batch_widget = create_text_widget("Batch Size:", 64, "Please set the batch size for inference. Larger batch size can lead to faster computation but may lead to OOM (out of memory) errors.")

display(*batch_widget)


Text(value='64', description='Batch Size:', layout=Layout(width='1000px'), style=TextStyle(description_width='…

HTML(value='<b>Hint:</b> Please set the batch size for inference. Larger batch size can lead to faster computa…

> *Execute the cell below to set the Data Path accoring to your input in the text form above.*

In [3]:
batch_size = int(batch_widget[0].value)
print_info(f"Use batch size of {batch_size} for inference.")

[INFO]::Use batch size of 64 for inference.


## 3.2. Choose Model

Load the model which you'd like to use for the tomogram generation. Make sure this model was trained on the tilt series you'd wish to generate the tomogram for.

> *Execute the cell below to visulize a text form to provide the path to a trained model to do inference with.*

In [4]:
model_widget = create_text_widget("Model Path:","", "Enter the path to a pretrained model (i.e. logs/synthetic_2025-04-01_08-38-23/TrainingRun/checkpoints) which you'd like to use for inference.")
display(*model_widget)

Text(value='', description='Model Path:', layout=Layout(width='1000px'), style=TextStyle(description_width='in…

HTML(value="<b>Hint:</b> Enter the path to a pretrained model (i.e. logs/synthetic_2025-04-01_08-38-23/Trainin…

## 3.3. Make Prediction

For optimal results we recomment to adjust brightness and contrast of the resulting tomogram using [ImageJ](https://imagej.net/ij/). To do so, open the tomogram in ImageJ and press `Ctrl+Shift+C` (or go to `Image`>`Adjust`>`Brightness/Contrast`). This opens a small dialog which allows the adjustment of brightness and contrast.

![ImageJ](./images/brightness-contrast.jpg)

Press the `Auto` button or choose a manual setting of brightness and contrast to your liking.

> *Execute the cell below to generate the tomogram based on the data you specified earlier using the model you defined above. Results will be stored within the provided data folder. You can open the tomogram using [ImageJ](https://imagej.net/ij/)*

In [None]:
model_path = model_widget[0].value
inferencer = Inference(model_path, None, batch_size)
inferencer.inference()

[INFO]::Found model checkpoint at logs/final/synthetic_2025-04-02_06-27-45/TrainingRun/checkpoints/best_model.pth


  checkpoint = torch.load(self.model_path)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[INFO]::Will save results to ./data/synthetic/results-synthetic_2025-04-02_06-27-45/2025-04-07_10-30-49.


Generate Tomogram:   2%|▏         | 21877/1074219 [00:20<15:46, 1111.29it/s]