In [11]:
import os
import cv2
import torch
import numpy as np
from glob import glob
from tqdm import tqdm
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split

# ---- UNet Model ----
class UNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=1):
        super(UNet, self).__init__()
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            )
        self.enc1 = conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = conv_block(256, 512)

        self.up1 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec1 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec3 = conv_block(128, 64)
        self.out = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        d1 = self.dec1(torch.cat([self.up1(e4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d1), e2], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d2), e1], dim=1))
        return self.out(d3)

# ---- CLAHE ----
def apply_clahe(img, clip_limit=3.0, tile_grid_size=(8, 8)):
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    return clahe.apply(img)

# ---- Helper: Find Common Embryos ----
def get_common_embryo_ids(base_paths):
    embryo_sets = []
    for path in base_paths:
        folders = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))]
        embryo_sets.append(set(folders))
    return sorted(list(set.intersection(*embryo_sets)))

# ---- Custom Dataset ----
class EmbryoFocusStackDataset(Dataset):
    def __init__(self, base_paths, embryo_ids, transform=None, enhance_method='clahe', enhance_params={}):
        self.base_paths = base_paths
        self.embryo_ids = embryo_ids
        self.transform = transform
        self.enhance_method = enhance_method
        self.enhance_params = enhance_params

    def __len__(self):
        return len(self.embryo_ids)

    def __getitem__(self, idx):
        embryo_id = self.embryo_ids[idx]
        images = []
        for base_path in self.base_paths:
            img_dir = os.path.join(base_path, embryo_id)
            img_files = sorted(glob(os.path.join(img_dir, "*.jpeg")))
            if len(img_files) == 0:
                raise FileNotFoundError(f"No images found in {img_dir}")
            img = cv2.imread(img_files[0], cv2.IMREAD_GRAYSCALE)

            # Apply enhancement
            if self.enhance_method == 'clahe':
                img = apply_clahe(img, **self.enhance_params)

            if self.transform:
                img = self.transform(img)
            else:
                img = torch.tensor(img / 255.0, dtype=torch.float32).unsqueeze(0)

            images.append(img)

        stack = torch.cat(images, dim=0)  # Shape: [6, H, W]
        return stack

# ---- Train Function ----
def train(model, loader, device):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    model.train()
    for epoch in range(5):
        total_loss = 0
        for x in tqdm(loader, desc=f"Epoch {epoch+1}"):
            x = x.to(device)
            y = x.mean(dim=1, keepdim=True)  # pseudo-label
            pred = model(x)
            loss = criterion(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} Loss: {total_loss / len(loader):.4f}")

# ---- Inference Function ----
def inference(model, loader, device, output_folder):
    model.eval()
    os.makedirs(output_folder, exist_ok=True)
    with torch.no_grad():
        for idx, x in enumerate(loader):
            x = x.to(device)
            output = model(x).squeeze().cpu().numpy()
            os.makedirs(f"{output_folder}/sample_{idx}", exist_ok=True)
            for i in range(6):
                ahe_img = x[0][i].cpu().numpy() * 255
                cv2.imwrite(f"{output_folder}/sample_{idx}/ahe_f{i+1}.png", ahe_img.astype(np.uint8))
            cv2.imwrite(f"{output_folder}/sample_{idx}/fused_output.png", (output * 255).astype(np.uint8))

# ---- Main ----
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet().to(device)

    # 6 focal-plane directories
    base_paths = [
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F15\embryo_dataset_F15",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F-15\embryo_dataset_F-15",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F30\embryo_dataset_F30",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F-30\embryo_dataset_F-30",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F45\embryo_dataset_F45",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F-45\embryo_dataset_F-45"
    ]

    embryo_ids = get_common_embryo_ids(base_paths)
    print(f"🧬 Common Embryo IDs: {len(embryo_ids)}")

    # Preprocessing
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    enhance_params = {
        'clip_limit': 3.0,
        'tile_grid_size': (8, 8)
    }

    # Dataset
    full_dataset = EmbryoFocusStackDataset(
        base_paths=base_paths,
        embryo_ids=embryo_ids,
        transform=transform,
        enhance_method='clahe',
        enhance_params=enhance_params
    )

    print(f"📊 Total Samples: {len(full_dataset)}")

    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, test_dataset = random_split(full_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    print("🚀 Training started...")
    train(model, train_loader, device)

    print("🔍 Inference started...")
    inference(model, test_loader, device, output_folder="inference_results")


🧬 Common Embryo IDs: 704
📊 Total Samples: 704
🚀 Training started...


Epoch 1: 100%|██████████| 563/563 [01:11<00:00,  7.92it/s]


Epoch 1 Loss: 0.0024


Epoch 2: 100%|██████████| 563/563 [00:50<00:00, 11.08it/s]


Epoch 2 Loss: 0.0000


Epoch 3: 100%|██████████| 563/563 [00:53<00:00, 10.53it/s]


Epoch 3 Loss: 0.0000


Epoch 4: 100%|██████████| 563/563 [00:54<00:00, 10.28it/s]


Epoch 4 Loss: 0.0000


Epoch 5: 100%|██████████| 563/563 [01:17<00:00,  7.31it/s]


Epoch 5 Loss: 0.0000
🔍 Inference started...


In [5]:
import os

dataset_path = "E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F15\embryo_dataset_F15"
embryo_dirs = [f for f in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, f))]
print("Number of embryo folders:", len(embryo_dirs))
print("First 5:", embryo_dirs[:5])

Number of embryo folders: 704
First 5: ['AA83-7', 'AAL839-6', 'AB028-6', 'AB91-1', 'AC264-1']


  dataset_path = "E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F15\embryo_dataset_F15"


In [None]:
import os
import glob

dataset_path = r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F15\embryo_dataset_F15"
embryo_dirs = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, f))]

# Check one folder
sample = embryo_dirs[0]
input_images = glob.glob(os.path.join(sample, "*input*.png"))
target_images = glob.glob(os.path.join(sample, "*target*.png"))

print("Sample folder:", sample)
print("Input images found:", input_images)
print("Target images found:", target_images)


Sample folder: E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F15\embryo_dataset_F15\AA83-7
Input images found: []
Target images found: []


In [7]:
from glob import glob
import os

folder = r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F15\embryo_dataset_F15\AA83-7"
files = glob(os.path.join(folder, "*.jpeg"))
print("Found files:", files)


Found files: ['E:\\EmbryoAnalysis\\EmbryoAnalysis\\Dataset\\embryo_dataset_F15\\embryo_dataset_F15\\AA83-7\\D2013.01.28_S0717_I132_WELL7_RUN1.jpeg', 'E:\\EmbryoAnalysis\\EmbryoAnalysis\\Dataset\\embryo_dataset_F15\\embryo_dataset_F15\\AA83-7\\D2013.01.28_S0717_I132_WELL7_RUN10.jpeg', 'E:\\EmbryoAnalysis\\EmbryoAnalysis\\Dataset\\embryo_dataset_F15\\embryo_dataset_F15\\AA83-7\\D2013.01.28_S0717_I132_WELL7_RUN100.jpeg', 'E:\\EmbryoAnalysis\\EmbryoAnalysis\\Dataset\\embryo_dataset_F15\\embryo_dataset_F15\\AA83-7\\D2013.01.28_S0717_I132_WELL7_RUN101.jpeg', 'E:\\EmbryoAnalysis\\EmbryoAnalysis\\Dataset\\embryo_dataset_F15\\embryo_dataset_F15\\AA83-7\\D2013.01.28_S0717_I132_WELL7_RUN102.jpeg', 'E:\\EmbryoAnalysis\\EmbryoAnalysis\\Dataset\\embryo_dataset_F15\\embryo_dataset_F15\\AA83-7\\D2013.01.28_S0717_I132_WELL7_RUN103.jpeg', 'E:\\EmbryoAnalysis\\EmbryoAnalysis\\Dataset\\embryo_dataset_F15\\embryo_dataset_F15\\AA83-7\\D2013.01.28_S0717_I132_WELL7_RUN104.jpeg', 'E:\\EmbryoAnalysis\\EmbryoAna

In [1]:
pip install opencv-python numpy torch torchvision matplotlib tqdm

Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
print(f"Total samples in dataset: {len(full_dataset)}")

Total samples in dataset: 0


In [None]:
import os
import cv2
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ---- UNet Model ----
class UNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=1):
        super(UNet, self).__init__()
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            )
        self.enc1 = conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = conv_block(256, 512)

        self.up1 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec1 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec3 = conv_block(128, 64)
        self.out = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        d1 = self.dec1(torch.cat([self.up1(e4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d1), e2], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d2), e1], dim=1))
        return self.out(d3)

# ---- AHE Preprocessing ----
def apply_clahe(img):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return clahe.apply(img)

# ---- Custom Dataset ----
class EmbryoDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.samples = self._group_samples()

    def _group_samples(self):
        # Search all .jpeg files recursively
        files = sorted(glob(os.path.join(self.folder_path, "**", "*.jpeg"), recursive=True))
        embryo_dict = {}

        for f in files:
            base = os.path.basename(f)
            parts = base.split("_")
            if len(parts) >= 3:
                embryo_id = "_".join(parts[:2])  # e.g., D2013.01.28_S0717
                embryo_dict.setdefault(embryo_id, []).append(f)

        # Only keep samples with exactly 6 frames
        grouped = [sorted(imgs) for imgs in embryo_dict.values() if len(imgs) == 6]
        return grouped

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        paths = self.samples[idx]
        imgs = []
        for path in paths:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            img = apply_clahe(img)
            img = img / 255.0
            imgs.append(torch.tensor(img, dtype=torch.float32))
        stacked = torch.stack(imgs)  # [6, H, W]
        return stacked

# ---- Train Function ----
def train(model, loader, device):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    model.train()
    for epoch in range(5):
        total_loss = 0
        for x in tqdm(loader, desc=f"Epoch {epoch+1}"):
            x = x.to(device)
            y = x.mean(dim=1, keepdim=True)  # pseudo-label = average of inputs
            pred = model(x)
            loss = criterion(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} Loss: {total_loss / len(loader):.4f}")

# ---- Inference Function ----
def inference(model, loader, device, output_folder):
    model.eval()
    os.makedirs(output_folder, exist_ok=True)
    with torch.no_grad():
        for idx, x in enumerate(loader):
            x = x.to(device)
            output = model(x).squeeze().cpu().numpy()
            os.makedirs(f"{output_folder}/sample_{idx}", exist_ok=True)
            for i in range(6):
                ahe_img = x[0][i].cpu().numpy() * 255
                cv2.imwrite(f"{output_folder}/sample_{idx}/ahe_f{i+1}.png", ahe_img.astype(np.uint8))
            cv2.imwrite(f"{output_folder}/sample_{idx}/fused_output.png", (output * 255).astype(np.uint8))

# ---- Main ----
if __name__ == "__main__":
    base_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45"
    ]
    
    # Find embryo IDs that exist in all 6 directories
    embryo_ids = get_common_embryo_ids(base_paths)
    print(f"Found {len(embryo_ids)} embryo IDs: {embryo_ids[:5]} ...")
    
    # Define transforms (resize + ToTensor for example)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    
    # Enhancement parameters
    enhance_params = {
        'clip_limit': 3.0,      # For CLAHE
        'tile_grid_size': (8, 8),  # For CLAHE
        'kernel_size': 32       # For AHE
    }
    
    # Create the dataset with enhancement
    dataset = EmbryoFocusStackDataset(
        base_paths, 
        embryo_ids, 
        transform=transform,
        enhance_method='clahe',  # or 'ahe'
        enhance_params=enhance_params
    )
    
    # Split into train/val
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
    
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize model
    model = UNet(in_channels=6, out_channels=1).to(device)
    
    # Train the model
    print("Starting training with adaptive histogram equalization...")
    train_model(model, train_loader, val_loader, num_epochs=50, device=device)
    
    print("Training complete. Best model saved as 'embryo_unet_ahe.pth'.")
    print("🚀 Training started...")
    train(model, train_loader, device)

    print("🧪 Inference started...")
    inference(model, test_loader, device, output_folder="inference_results")
