In [None]:
import os
import re
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from skimage.io import imread
from skimage.transform import resize
from skimage.filters import threshold_otsu
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
from segmentation_models_pytorch import Unet
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm

# Regex to extract frame/run number from filename, e.g. '..._RUN12.jpeg'
frame_re = re.compile(r"RUN(\d+)", re.IGNORECASE)

def extract_frame_number(filename):
    m = frame_re.search(filename)
    return int(m.group(1)) if m else None

# Utility: get common embryo IDs
def get_common_embryo_ids(annotations_path, image_dirs):
    # CSV IDs
    csv_ids = [fname.replace('_phases.csv', '')
               for fname in os.listdir(annotations_path)
               if fname.endswith('_phases.csv')]
    csv_set = set(csv_ids)
    # Dir IDs
    dir_sets = []
    for d in image_dirs:
        if os.path.isdir(d):
            subs = [dn for dn in os.listdir(d)
                    if os.path.isdir(os.path.join(d, dn))]
            dir_sets.append(set(subs))
    common_dirs = set.intersection(*dir_sets) if dir_sets else set()
    return sorted(csv_set & common_dirs)

# Paths & config
base = r'C:\Projects\Embryo\Dataset'
annotations_path = os.path.join(base, 'embryo_dataset_annotations')
gt_path          = os.path.join(base, 'embryo_dataset')
focal_dirs = {
    'F15':  os.path.join(base, 'embryo_dataset_F15'),
    'F-15': os.path.join(base, 'embryo_dataset_F-15'),
    'F30':  os.path.join(base, 'embryo_dataset_F30'),
    'F-30': os.path.join(base, 'embryo_dataset_F-30'),
    'F45':  os.path.join(base, 'embryo_dataset_F45'),
    'F-45': os.path.join(base, 'embryo_dataset_F-45')
}
focal_planes = list(focal_dirs.keys())
H, W = 256, 256
# Gather IDs
dirs = [gt_path] + list(focal_dirs.values())
ids = get_common_embryo_ids(annotations_path, dirs)
if not ids:
    raise ValueError("No common embryo IDs found.")
train_ids, val_ids = train_test_split(ids, test_size=0.2, random_state=42)

class EmbryoT4Dataset(Dataset):
    """
    Loads all T4-phase frames per embryo; builds 6-channel stack and GT mask.
    """
    def __init__(self, embryo_ids, annotations_path, focal_dirs, gt_path, transform=None):
        self.transform = transform
        self.samples = []  # list of (eid, frame_number, filename)
        self.focal_dirs = focal_dirs
        self.gt_path = gt_path
        # Build file list for each embryo
        for eid in embryo_ids:
            # read phase CSV
            csv = os.path.join(annotations_path, f"{eid}_phases.csv")
            if not os.path.exists(csv):
                continue
            df = pd.read_csv(csv, names=['phase','start','end'])
            t4 = df[df['phase']=='t4']
            if t4.empty:
                continue
            s, e = int(t4['start'].iloc[0]), int(t4['end'].iloc[0])
            # list filenames in one focal dir to iterate runs
            sample_dir = focal_dirs[focal_planes[0]]
            files = [f for f in os.listdir(os.path.join(sample_dir, eid))
                     if f.lower().endswith(('.jpg','.jpeg','.png'))]
            for fname in files:
                fr = extract_frame_number(fname)
                if fr and s <= fr <= e:
                    self.samples.append((eid, fr, fname))
        if not self.samples:
            raise ValueError("No T4 samples found.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        eid, fr, fname = self.samples[idx]
        # build stack: use the same fname across all focal dirs
        stack = []
        for fp in focal_planes:
            p = os.path.join(self.focal_dirs[fp], eid, fname)
            if not os.path.exists(p): raise FileNotFoundError(p)
            im = imread(p, as_gray=True)
            im = resize(im, (H, W), preserve_range=True)/255.0
            stack.append(im)
        inp = np.stack(stack, -1)
        # GT mask: find corresponding GT filename by frame number
        gt_files = os.listdir(os.path.join(self.gt_path, eid))
        gt_fname = next((g for g in gt_files if extract_frame_number(g)==fr), None)
        if gt_fname is None:
            raise FileNotFoundError(f"GT for frame {fr} not found for {eid}")
        g = imread(os.path.join(self.gt_path, eid, gt_fname), as_gray=True)
        g = resize(g, (H, W), preserve_range=True)
        mask = (g>threshold_otsu(g)).astype(np.float32)
        # augment
        if self.transform:
            aug = self.transform(image=inp, mask=mask)
            x = aug['image']
            y = aug['mask'].unsqueeze(0)
        else:
            x = torch.tensor(inp).permute(2,0,1).float()
            y = torch.tensor(mask).unsqueeze(0)
        return x, y

# Transforms
train_tf = A.Compose([A.HorizontalFlip(0.5), A.RandomRotate90(0.5),
                      A.ShiftScaleRotate(0.1,0.1,10,0.5), ToTensorV2()])
val_tf   = A.Compose([ToTensorV2()])

# Datasets & loaders
train_ds = EmbryoT4Dataset(train_ids, annotations_path, focal_dirs, gt_path, transform=train_tf)
val_ds   = EmbryoT4Dataset(val_ids,   annotations_path, focal_dirs, gt_path, transform=val_tf)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=0)

# Model & training setup
model = Unet('resnet34', in_channels=len(focal_planes), classes=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
opt = Adam(model.parameters(), lr=1e-4)
crit = BCEWithLogitsLoss()

def dice(pred, tgt, smooth=1e-6):
    p = pred.view(-1); t = tgt.view(-1)
    i = (p*t).sum()
    return (2*i+smooth)/(p.sum()+t.sum()+smooth)

# Train loop
best_loss = float('inf'); patience=10; wait=0
for epoch in range(50):
    model.train(); tloss=0
    for x,y in tqdm(train_loader, desc=f"Train {epoch+1}"):
        x,y = x.to(device), y.to(device)
        opt.zero_grad(); out = model(x)
        loss = crit(out, y); loss.backward(); opt.step()
        tloss += loss.item()
    tloss /= len(train_loader)
    model.eval(); vloss=0; vdice=0
    with torch.no_grad():
        for x,y in val_loader:
            x,y = x.to(device), y.to(device)
            out = model(x)
            vloss += crit(out,y).item()
            vdice += dice((torch.sigmoid(out)>0.5).float(), y).item()
    vloss/=len(val_loader); vdice/=len(val_loader)
    print(f"Epoch {epoch+1}: Train {tloss:.4f} | Val {vloss:.4f} | Dice {vdice:.4f}")
    if vloss<best_loss:
        best_loss=vloss; torch.save(model.state_dict(),'best_model.pth'); wait=0
    else:
        wait+=1
        if wait>=patience:
            print("Early stopping."); break

# Load best
model.load_state_dict(torch.load('best_model.pth'))
print("Done, saved best_model.pth")


ValueError: 1 validation error for InitSchema
interpolation
  Input should be 0, 1, 2, 3 or 4 [type=literal_error, input_value=0.5, input_type=float]
    For further information visit https://errors.pydantic.dev/2.10/v/literal_error