In [22]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image, ImageFile
import numpy as np
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [10]:
# ------------------------------------------------
# 1) Utility function to find common embryo IDs
# ------------------------------------------------
def get_common_embryo_ids(base_paths):
    """
    Returns a sorted list of folder names (embryo IDs)
    that appear in *all* the given directories.
    """
    sets_of_ids = []
    for path in base_paths:
        subfolders = [
            d for d in os.listdir(path)
            if os.path.isdir(os.path.join(path, d))
        ]
        sets_of_ids.append(set(subfolders))

    common_ids = set.intersection(*sets_of_ids)
    return sorted(list(common_ids))

In [11]:
# ------------------------------------------------
# 2) Define the UNet building blocks
# ------------------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [12]:
class UNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=1):
        """
        U-Net model that takes 6-channel input (one for each focal plane)
        and outputs a single-channel fused image.
        """
        super(UNet, self).__init__()
        
        # Encoder
        self.conv1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.conv5 = DoubleConv(512, 1024)
        
        # Decoder
        self.up6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        
        # Output
        self.conv10 = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        
        # Bottleneck
        c5 = self.conv5(p4)
        
        # Decoder
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        
        # Output with Sigmoid (values in [0,1])
        output = self.conv10(c9)
        return torch.sigmoid(output)

In [23]:
# ------------------------------------------------
# 3) Create a Dataset that stacks 6 focal-plane images
# ------------------------------------------------
class EmbryoFocusStackDataset(Dataset):
    """
    Each item in this dataset is a (input_tensor, target_tensor) pair.
    input_tensor has shape (6, H, W), one channel per focal plane.
    """
    def __init__(self, base_paths, embryo_ids, transform=None):
        if len(base_paths) != 6:
            raise ValueError("We need exactly 6 focal-plane directories.")
        
        self.base_paths = base_paths
        self.embryo_ids = embryo_ids
        self.transform = transform

    def __len__(self):
        return len(self.embryo_ids)

    def __getitem__(self, idx):
        embryo_id = self.embryo_ids[idx]
        
        focal_images = []
        for path in self.base_paths:
            embryo_subfolder = os.path.join(path, embryo_id)
            image_files = sorted(
                [f for f in os.listdir(embryo_subfolder)
                 if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            )
            if not image_files:
                raise FileNotFoundError(f"No image found in {embryo_subfolder}")
            
            img_path = os.path.join(embryo_subfolder, image_files[0])
            
            # No try-except needed now because LOAD_TRUNCATED_IMAGES = True
            image = Image.open(img_path).convert('L')
            if self.transform:
                image = self.transform(image)
            
            focal_images.append(image)
        
        input_tensor = torch.cat(focal_images, dim=0)
        target = focal_images[2]  # dummy target
        return input_tensor, target




In [24]:
# ------------------------------------------------
# 4) Training function
# ------------------------------------------------
def train_model(model, train_loader, val_loader, num_epochs=5, device='cpu'):
    """
    Trains the U-Net using a simple BCELoss. 
    Saves the best model to 'embryo_unet.pth'.
    """
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_train_loss += loss.item() * inputs.size(0)
        
        train_loss = running_train_loss / len(train_loader.dataset)
        
        # Validation
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                running_val_loss += loss.item() * inputs.size(0)
        
        val_loss = running_val_loss / len(val_loader.dataset)
        
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'embryo_unet.pth')
            print(f"  [*] Model saved at epoch {epoch+1}")


In [25]:
# ------------------------------------------------
# 5) Main script
# ------------------------------------------------
def main():
    # Paths to your 6 focal-plane directories (edit as needed)
    base_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45"
    ]
    
    # Find embryo IDs that exist in all 6 directories
    embryo_ids = get_common_embryo_ids(base_paths)
    print(f"Found {len(embryo_ids)} embryo IDs: {embryo_ids[:5]} ...")
    
    # Define transforms (resize + ToTensor for example)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    
    # Create the dataset
    dataset = EmbryoFocusStackDataset(base_paths, embryo_ids, transform=transform)
    
    # Split into train/val
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
    
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize model
    model = UNet(in_channels=6, out_channels=1).to(device)
    
    # Train the model
    print("Starting training...")
    train_model(model, train_loader, val_loader, num_epochs=5, device=device)
    
    print("Training complete. Best model saved as 'embryo_unet.pth'.")

if __name__ == "__main__":
    main()

Found 704 embryo IDs: ['AA83-7', 'AAL839-6', 'AB028-6', 'AB91-1', 'AC264-1'] ...
Using device: cuda
Starting training...
Epoch [1/5] Train Loss: 0.4550, Val Loss: 0.4427
  [*] Model saved at epoch 1
Epoch [2/5] Train Loss: 0.4481, Val Loss: 0.4479
Epoch [3/5] Train Loss: 0.4482, Val Loss: 0.4416
  [*] Model saved at epoch 3
Epoch [4/5] Train Loss: 0.4480, Val Loss: 0.4441
Epoch [5/5] Train Loss: 0.4478, Val Loss: 0.4454
Training complete. Best model saved as 'embryo_unet.pth'.


In [32]:
import torch
from PIL import Image
from torchvision import transforms

def test_single_embryo(model, image_paths, transform, device='cpu'):
    """
    Loads 6 images from different focal planes, stacks them as 6-channel input,
    and returns the fused output image (as a PIL Image).
    
    Args:
        model (nn.Module): Trained UNet model with 6 input channels.
        image_paths (list[str]): List of 6 image file paths (one per focal plane).
        transform (callable): Same transform used in training (resize, ToTensor, etc.).
        device (str): 'cpu' or 'cuda'.
        
    Returns:
        PIL.Image: The fused output image (single-channel).
    """
    # Make sure the model is in eval mode
    model.eval()
    
    # Load and transform each focal-plane image
    focal_tensors = []
    for path in image_paths:
        # Open and convert to grayscale
        img = Image.open(path).convert('L')
        # Apply the same transform (e.g., resize, ToTensor)
        img_tensor = transform(img)
        focal_tensors.append(img_tensor)
    
    # Stack along channel dimension: shape (6, H, W)
    input_tensor = torch.cat(focal_tensors, dim=0)
    # Add batch dimension: shape (1, 6, H, W)
    input_tensor = input_tensor.unsqueeze(0).to(device)
    
    # Forward pass through the model
    with torch.no_grad():
        output = model(input_tensor)  # shape (1, 1, H, W)
    
    # Convert the output tensor back to a PIL image
    output_image = output.squeeze(0).cpu()  # shape (1, H, W)
    fused_pil = transforms.ToPILImage()(output_image)
    
    return fused_pil


In [35]:
def main_test():
    import os
    
    # 1) Paths to your 6 focal-plane images for one embryo
    test_image_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg"
    ]
    
    # 2) Define the same transform used in training
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    
    # 3) Load the trained model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(in_channels=6, out_channels=1).to(device)
    
    if os.path.exists('embryo_unet.pth'):
        model.load_state_dict(torch.load('embryo_unet.pth', map_location=device))
        print("Loaded trained model weights.")
    else:
        print("Warning: No trained model found. Using random weights.")
    
    # 4) Generate the fused image
    fused_image = test_single_embryo(model, test_image_paths, transform, device)
    
    # 5) Save or show the fused image
    fused_image.save("fused_output.jpg")
    fused_image.show()

if __name__ == "__main__":
    main_test()


Loaded trained model weights.
