In [2]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from skimage.io import imread
from skimage.transform import resize
from skimage.filters import threshold_otsu
import albumentations as A
from albumentations.pytorch import ToTensorV2
from segmentation_models_pytorch import Unet
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Utility function to find common embryo IDs
def get_common_embryo_ids(base_paths):
    """
    Returns a sorted list of folder names (embryo IDs)
    that appear in *all* the given directories.
    """
    sets_of_ids = []
    for path in base_paths:
        subfolders = [
            d for d in os.listdir(path)
            if os.path.isdir(os.path.join(path, d))
        ]
        sets_of_ids.append(set(subfolders))
    common_ids = set.intersection(*sets_of_ids)
    return sorted(list(common_ids))

# Manual Dice coefficient implementation
def dice_coefficient(pred, target, smooth=1e-6):
    pred = pred.contiguous().view(-1)
    target = target.contiguous().view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

# Dataset paths
base_path = 'Dataset'
annotations_path = os.path.join(base_path, r"C:\Projects\Embryo\Dataset\embryo_dataset_annotations")
gt_path = os.path.join(base_path, r"C:\Projects\Embryo\Dataset")
focal_planes = ['F15', 'F30', 'F45', 'F-15', 'F-30', 'F-45']
focal_paths = {fp: os.path.join(base_path, f'C:\Projects\Embryo\Dataset\embryo_dataset_{fp}') for fp in focal_planes}

# Image size
H, W = 256, 256
N_channels = len(focal_planes)  # 6 focal planes

# Get common embryo IDs across all directories
all_paths = [annotations_path, gt_path] + list(focal_paths.values())
embryo_ids = get_common_embryo_ids(all_paths)
print(f'Found {len(embryo_ids)} common embryo IDs')

# Split into training and validation sets
train_ids, val_ids = train_test_split(embryo_ids, test_size=0.2, random_state=42)

# Custom Dataset class
class EmbryoDataset(Dataset):
    def __init__(self, embryo_ids, annotations_path, focal_paths, gt_path, transform=None):
        self.embryo_ids = embryo_ids
        self.annotations_path = annotations_path
        self.focal_paths = focal_paths
        self.gt_path = gt_path
        self.transform = transform

    def __len__(self):
        return len(self.embryo_ids)

    def __getitem__(self, idx):
        embryo_id = self.embryo_ids[idx]
        # Read annotation to get t4 frame
        csv_file = os.path.join(self.annotations_path, f'{embryo_id}.csv')
        df = pd.read_csv(csv_file)
        t4_frame = int(df['t4'].iloc[0])  # Use t4 frame
        # Load 6 focal plane images for t4 frame
        input_stack = []
        for fp in focal_planes:
            img_file = os.path.join(self.focal_paths[fp], f'{embryo_id}_frame_{t4_frame:03d}.png')
            if not os.path.exists(img_file):
                raise FileNotFoundError(f'Image {img_file} not found')
            img = imread(img_file, as_gray=True)
            img = resize(img, (H, W), preserve_range=True) / 255.0
            input_stack.append(img)
        input_stack = np.stack(input_stack, axis=-1)  # Shape: (H, W, 6)
        # Load ground truth image and convert to binary mask
        gt_file = os.path.join(self.gt_path, f'{embryo_id}_gt.png')
        gt_img = imread(gt_file, as_gray=True)
        gt_img = resize(gt_img, (H, W), preserve_range=True)
        thresh = threshold_otsu(gt_img)
        gt_mask = (gt_img > thresh).astype(np.float32)  # Shape: (H, W)
        # Apply transformations
        if self.transform:
            augmented = self.transform(image=input_stack, mask=gt_mask)
            input_stack = augmented['image']
            gt_mask = augmented['mask']
        else:
            input_stack = torch.tensor(input_stack).permute(2, 0, 1)  # (6, H, W)
            gt_mask = torch.tensor(gt_mask).unsqueeze(0)  # (1, H, W)
        return input_stack, gt_mask

# Define transformations
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=10, p=0.5),
    ToTensorV2(),
])

val_transform = A.Compose([
    ToTensorV2(),
])

# Create datasets and data loaders
train_dataset = EmbryoDataset(train_ids, annotations_path, focal_paths, gt_path, transform=train_transform)
val_dataset = EmbryoDataset(val_ids, annotations_path, focal_paths, gt_path, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

# Define U-Net model
model = Unet(encoder_name='resnet34', in_channels=N_channels, classes=1)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Loss and optimizer
criterion = BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 50
best_val_loss = float('inf')
patience = 10
trigger_times = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for inputs, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0.0
    val_dice = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            preds = (outputs > 0.5).float()
            val_dice += dice_coefficient(preds, targets).item()
    val_loss /= len(val_loader)
    val_dice /= len(val_loader)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}')

    # Early stopping and model checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        trigger_times = 0
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print('Early stopping triggered')
            break

# Load best model
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
print('Training completed. Best model saved as best_model.pth')

  focal_paths = {fp: os.path.join(base_path, f'C:\Projects\Embryo\Dataset\embryo_dataset_{fp}') for fp in focal_planes}
  focal_paths = {fp: os.path.join(base_path, f'C:\Projects\Embryo\Dataset\embryo_dataset_{fp}') for fp in focal_planes}


Found 0 common embryo IDs


ValueError: With n_samples=0, test_size=0.2 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.