# Multi-Cell TFM + U-Net Pipeline
This notebook loops over multiple cell folders to prepare a dataset for U-Net training.

In [None]:
import os
import numpy as np
import tifffile
import torch
import csv
from tfm import TFM_Image_registration, TFM_displacement_tools, TFM_tools
from utils import data_processing, UNeXt
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim


In [None]:
# Folder where each subfolder contains one cell's data
base_input_dir = "input_data"
output_dir = "example_dataset"
os.makedirs(output_dir, exist_ok=True)


## Step 1–2: Process Each Cell and Stack `.npy`

In [None]:
npy_filenames = []

for folder in sorted(os.listdir(base_input_dir)):
    cell_path = os.path.join(base_input_dir, folder)
    if not os.path.isdir(cell_path): continue

    print(f"Processing {folder}...")

    # TFM steps
    TFM_Image_registration.TFM_Image_registration(cell_path)
    TFM_displacement_tools.TFM_optical_flow(cell_path)
    TFM_tools.TFM_calculation(cell_path)
    TFM_tools.cellmask_threshold(cell_path)

    # Read all tif files
    fx = tifffile.imread(os.path.join(cell_path, "fx_0.tif"))
    fy = tifffile.imread(os.path.join(cell_path, "fy_0.tif"))
    ux = tifffile.imread(os.path.join(cell_path, "disp_u_0.tif"))
    uy = tifffile.imread(os.path.join(cell_path, "disp_v_0.tif"))
    mask = tifffile.imread(os.path.join(cell_path, "cellmask.tif"))
    forcemask = tifffile.imread(os.path.join(cell_path, "forcemask.tif"))
    zyxin = tifffile.imread(os.path.join(cell_path, "zyxin.tif"))
    actin = tifffile.imread(os.path.join(cell_path, "actin.tif"))

    stack = np.stack([ux, uy, fx, fy, mask, forcemask, zyxin, actin], axis=0)
    npy_name = f"{folder}.npy"
    np.save(os.path.join(output_dir, npy_name), stack)
    npy_filenames.append(npy_name)


## Step 3: Create `dataset.csv`

In [None]:
csv_path = os.path.join(output_dir, "dataset.csv")
with open(csv_path, "w", newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["filename"])
    for name in npy_filenames:
        writer.writerow([name])
print("CSV created at:", csv_path)


## Step 4: Train U-Net on All Cells

In [None]:
dataset = data_processing.CellDataset(output_dir, csv_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNeXt.UNeXt().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

epochs = 3
model.train()
for epoch in range(epochs):
    running_loss = 0.0
    for batch in dataloader:
        x = batch["image"].to(device, dtype=torch.float32)
        y = batch["label"].to(device, dtype=torch.float32)

        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss:.4f}")
