# IFIN WallerDataset Usage Template

This notebook is a general guide template for running IFIN with WallerDataset.
- Uses the same Waller dataset + PSF preprocessing path as the training code.
- Runs a simple forward pass example (no mandatory full training).
- Keep/add your own qualitative or quantitative results in the result cells below.

In [None]:
from pathlib import Path
import sys

def find_project_root(start: Path) -> Path:
    current = start.resolve()
    for candidate in [current, *current.parents]:
        if (candidate / 'src').exists() and (candidate / 'configs').exists():
            return candidate
    raise RuntimeError('Could not find project root containing src/ and configs/.')

PROJECT_ROOT = find_project_root(Path.cwd())
SRC_ROOT = PROJECT_ROOT / 'src'
if str(SRC_ROOT) not in sys.path:
    sys.path.insert(0, str(SRC_ROOT))

print('PROJECT_ROOT =', PROJECT_ROOT)
print('SRC_ROOT =', SRC_ROOT)

In [None]:
import torch
from torchvision import transforms

from config import load_config
from models.ifin import IFINNet, RB
from data.waller import WallerDataset
from runner import _build_psf

config = load_config(str(PROJECT_ROOT / 'configs' / 'default.yaml'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)

# Revise these paths if your dataset is elsewhere:
# config['data']['waller_path'] = '../../wallerlab/dataset'
# config['data']['psf_path'] = '../../wallerlab/dataset/psf.tiff'

config['data']['dataset'] = 'waller'
config['train']['num_workers'] = 0
config['eval']['num_workers'] = 0

In [None]:
transform_raw = transforms.Compose([transforms.ToTensor()])
transform_lab = transforms.Compose([transforms.ToTensor()])

dataset = WallerDataset(config['data']['waller_path'], train=False, transform_raw=transform_raw, transform_lab=transform_lab)
print('dataset length =', len(dataset))

meas_input, img_target = dataset[0]
print('sample measurement shape =', tuple(meas_input.shape))
print('sample target shape =', tuple(img_target.shape))

In [None]:
psf = _build_psf(config, device)
model = IFINNet(
    in_channels=config['model']['in_channels'],
    out_channels=config['model']['out_channels'],
    psf=psf,
    height=config['model']['height'],
    width=config['model']['width'],
    dim=config['model']['dim'],
    depth=config['model']['depth'],
    block_cls=RB,
    exchange=config['model']['exchange'],
    k=config['model']['k'],
    repeat=config['model']['repeat_psf'],
    random=config['model']['random_init_psf'],
).to(device)

model.eval()
with torch.no_grad():
    input_batch = meas_input.unsqueeze(0).to(device)
    img_inverse, img_forward, iso_out = model(input_batch)

print('img_inverse shape =', tuple(img_inverse.shape))
print('img_forward shape =', tuple(img_forward.shape))
print('iso_out shape =', tuple(iso_out.shape))

## Result Notes (Fill Manually)

- Date:
- Commit:
- Config overrides:
- Observation summary:
- Any known issues:

In [None]:
# Optional: add your own visualization or metric code below
# e.g., save sample outputs, compute PSNR/SSIM, etc.
pass