In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageFile, ImageFilter
import numpy as np
import random
from torchvision.transforms import functional as TF
import torch.nn.functional as F

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Utility function to find common embryo IDs across directories
def get_common_embryo_ids(base_paths):
    sets_of_ids = []
    for path in base_paths:
        subfolders = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
        sets_of_ids.append(set(subfolders))
    common_ids = set.intersection(*sets_of_ids)
    return sorted(list(common_ids))

# Attention Block for U-Net
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# DoubleConv Block for U-Net
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

# U-Net with Attention Gates
class UNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=1):
        super(UNet, self).__init__()
        self.conv1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.att6 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.att7 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att8 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att9 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        c1 = self.conv1(x)  # 256x256
        p1 = self.pool1(c1)  # 128x128
        c2 = self.conv2(p1)  # 128x128
        p2 = self.pool2(c2)  # 64x64
        c3 = self.conv3(p2)  # 64x64
        p3 = self.pool3(c3)  # 32x32
        c4 = self.conv4(p3)  # 32x32
        p4 = self.pool4(c4)  # 16x16
        c5 = self.conv5(p4)  # 16x16

        up_6 = self.up6(c5)  # 16x16 -> 32x32
        att_c4 = self.att6(up_6, c4)
        merge6 = torch.cat([up_6, att_c4], dim=1)
        c6 = self.conv6(merge6)

        up_7 = self.up7(c6)  # 32x32 -> 64x64
        att_c3 = self.att7(up_7, c3)
        merge7 = torch.cat([up_7, att_c3], dim=1)
        c7 = self.conv7(merge7)

        up_8 = self.up8(c7)  # 64x64 -> 128x128
        att_c2 = self.att8(up_8, c2)
        merge8 = torch.cat([up_8, att_c2], dim=1)
        c8 = self.conv8(merge8)

        up_9 = self.up9(c8)  # 128x128 -> 256x256
        att_c1 = self.att9(up_9, c1)
        merge9 = torch.cat([up_9, att_c1], dim=1)
        c9 = self.conv9(merge9)

        output = self.conv10(c9)
        return torch.sigmoid(output)

# Dataset for Embryo Image Fusion
class EmbryoFocusStackDataset(Dataset):
    def __init__(self, base_paths, embryo_ids, is_train=False):
        self.base_paths = base_paths
        self.embryo_ids = embryo_ids
        self.is_train = is_train
        self.preprocess = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.embryo_ids)

    def __getitem__(self, idx):
        embryo_id = self.embryo_ids[idx]
        focal_pil_images = []

        for path in self.base_paths:
            embryo_subfolder = os.path.join(path, embryo_id)
            image_files = sorted([f for f in os.listdir(embryo_subfolder) 
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            if not image_files:
                raise FileNotFoundError(f"No image found in {embryo_subfolder}")
            img_path = os.path.join(embryo_subfolder, image_files[0])
            try:
                image = Image.open(img_path).convert('L')
                focal_pil_images.append(image)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                focal_pil_images.append(Image.new('L', (256, 256), 0))

        if self.is_train:
            angle = random.uniform(-30, 30)
            flip_h = random.random() < 0.5
            flip_v = random.random() < 0.5
        else:
            angle = 0
            flip_h = False
            flip_v = False

        transformed_images = []
        for img in focal_pil_images:
            img = TF.rotate(img, angle)
            if flip_h:
                img = TF.hflip(img)
            if flip_v:
                img = TF.vflip(img)
            img = self.preprocess(img)
            transformed_images.append(img)

        focal_tensors = transformed_images
        input_tensor = torch.cat(focal_tensors, dim=0)  # [6, H, W]

        mean_target = torch.mean(torch.stack(focal_tensors), dim=0)  # [1, H, W]
        blurred = F.avg_pool2d(mean_target.unsqueeze(0), kernel_size=3, stride=1, padding=1).squeeze(0)
        sharp_target = mean_target + (mean_target - blurred) * 1.5  # Unsharp masking

        return input_tensor, sharp_target

# Loss Function
l1_loss_fn = nn.L1Loss()

# Training Function
def train_model(model, train_loader, val_loader, num_epochs=6, device='cuda'):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
    best_val_loss = float('inf')
    patience_counter = 0
    max_patience = 3  # Stricter early stopping

    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = l1_loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item() * inputs.size(0)

        train_loss = running_train_loss / len(train_loader.dataset)

        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = l1_loss_fn(outputs, targets)
                running_val_loss += loss.item() * inputs.size(0)

        val_loss = running_val_loss / len(val_loader.dataset)

        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'embryo_unet_fusion.pth')
            print(f"  [*] Model saved at epoch {epoch+1}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= max_patience:
                print("Early stopping due to no improvement in validation loss.")
                break

# Test Function with Enhanced Post-Processing
def test_single_embryo(model, image_paths, transform, device='cuda'):
    model.eval()
    focal_tensors = []
    for path in image_paths:
        img = Image.open(path).convert('L')
        img_tensor = transform(img)
        focal_tensors.append(img_tensor)
    
    input_tensor = torch.cat(focal_tensors, dim=0).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
    output_image = output.squeeze(0).cpu()
    fused_pil = transforms.ToPILImage()(output_image)
    
    # Refined unsharp masking
    fused_pil = fused_pil.filter(ImageFilter.UnsharpMask(radius=3, percent=300, threshold=1))
    fused_pil = TF.adjust_contrast(fused_pil, contrast_factor=1.5)
    return fused_pil

# Main Execution
def main():
    base_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45"
    ]
    
    embryo_ids = get_common_embryo_ids(base_paths)
    print(f"Found {len(embryo_ids)} embryo IDs: {embryo_ids[:5]} ...")
    
    random.seed(42)
    train_size = int(0.8 * len(embryo_ids))
    train_indices = random.sample(range(len(embryo_ids)), train_size)
    val_indices = [i for i in range(len(embryo_ids)) if i not in train_indices]
    embryo_ids_train = [embryo_ids[i] for i in train_indices]
    embryo_ids_val = [embryo_ids[i] for i in val_indices]
    
    train_dataset = EmbryoFocusStackDataset(base_paths, embryo_ids_train, is_train=True)
    val_dataset = EmbryoFocusStackDataset(base_paths, embryo_ids_val, is_train=False)
    
    train_loader = DataLoader(train_dataset, batch_size=16, num_workers=0, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, num_workers=0, shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    model = UNet(in_channels=6, out_channels=1).to(device)
    
    print("Starting training...")
    train_model(model, train_loader, val_loader, num_epochs=10, device=device)
    print("Training complete. Best model saved as 'embryo_unet_fusion.pth'.")

    test_image_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30\AB91-1\D2013.01.29_S0719_I132_WELL1_RUN169.jpeg"
    ]
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    
    model.load_state_dict(torch.load('embryo_unet_fusion.pth', map_location=device))
    print("Loaded trained model weights for testing.")
    fused_image = test_single_embryo(model, test_image_paths, transform, device)
    fused_image.save("fused_output.jpg")
    fused_image.show()

if __name__ == "__main__":
    main()

Found 704 embryo IDs: ['AA83-7', 'AAL839-6', 'AB028-6', 'AB91-1', 'AC264-1'] ...
Using device: cuda
Starting training...
Epoch [1/10] Train Loss: 0.0875, Val Loss: 0.0384
  [*] Model saved at epoch 1
Epoch [2/10] Train Loss: 0.0372, Val Loss: 0.0324
  [*] Model saved at epoch 2
Epoch [3/10] Train Loss: 0.0342, Val Loss: 0.0189
  [*] Model saved at epoch 3
Epoch [4/10] Train Loss: 0.0242, Val Loss: 0.0150
  [*] Model saved at epoch 4
Epoch [5/10] Train Loss: 0.0254, Val Loss: 0.0192
Epoch [6/10] Train Loss: 0.0225, Val Loss: 0.0182
Epoch [7/10] Train Loss: 0.0240, Val Loss: 0.0111
  [*] Model saved at epoch 7
Epoch [8/10] Train Loss: 0.0230, Val Loss: 0.0224
Epoch [9/10] Train Loss: 0.0221, Val Loss: 0.0138
Epoch [10/10] Train Loss: 0.0170, Val Loss: 0.0088
  [*] Model saved at epoch 10
Training complete. Best model saved as 'embryo_unet_fusion.pth'.
Loaded trained model weights for testing.
