In [None]:
import numpy as np
from pathlib import Path
from noise2inverse import tiffs, noise, fig
from noise2inverse.datasets import (
    TiffDataset,
    Noise2InverseDataset,
)
import tomosipo as ts
import tomopy
import tifffile
from msd_pytorch import MSDRegressionModel
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
# Parameters
train_dir = Path("reconstructions")
output_dir = Path("weights")

num_splits = 4
strategy = "X:1"
epochs = 100
batch_size = 16
multi_gpu = True
# Scale pixel intensities during training such that its values roughly occupy the range [0,1].
# This improves convergence.
data_scaling = 200

In [None]:
datasets = [TiffDataset(train_dir / f"{j}/*.tif") for j in range(num_splits)]
train_ds = Noise2InverseDataset(*datasets, strategy=strategy)

train_ds.num_slices, train_ds.num_splits

In [None]:
fig.plot_imgs(
    input=train_ds[0][0].detach().squeeze(),
    target=train_ds[0][1].detach().squeeze(),
    vmin=0,
    vmax=0.008,
)

In [None]:
# Dataloader and network:
dl = DataLoader(train_ds, batch_size, shuffle=True,)
model = MSDRegressionModel(1, 1, 100, 1, parallel=multi_gpu)

In [None]:
output_dir.mkdir(exist_ok=True)

In [None]:
# The dataset contains multiple input-target pairs for each slice. 
# Therefore, we divide by the number of splits to obtain the effective number of epochs.
train_epochs = max(epochs // num_splits, 1)

# training loop
for epoch in range(train_epochs):
    # Train
    for (inp, tgt) in tqdm(dl):
        inp = inp.cuda(non_blocking=True) * data_scaling
        tgt = tgt.cuda(non_blocking=True) * data_scaling

        # Do training step with masking
        model.output = model.net(inp)
        model.loss = model.criterion(model.output, tgt)
        model.optimizer.zero_grad()
        model.loss.backward()
        model.optimizer.step()

    # Save network 
    model.save(output_dir / f"weights_epoch_{epoch}.torch", epoch)
    
model.save(output_dir / "weights.torch", epoch)        