In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image, ImageFile, ImageFilter
import numpy as np
import random
from torchvision.transforms import functional as TF
import torch.nn.functional as F
from torchmetrics.image import StructuralSimilarityIndexMeasure  # For SSIM metric
import lpips  # For predefined perceptual loss
from pytorch_msssim import SSIM  # For SSIM loss

In [5]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Utility function to find common embryo IDs across directories
def get_common_embryo_ids(base_paths):
    sets_of_ids = []
    for path in base_paths:
        subfolders = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
        sets_of_ids.append(set(subfolders))
    common_ids = set.intersection(*sets_of_ids)
    return sorted(list(common_ids))

In [6]:
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

In [7]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        def CBR(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )

        self.enc1 = CBR(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = CBR(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = CBR(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = CBR(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.dec3 = CBR(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.dec2 = CBR(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.dec1 = CBR(128, 64)

        self.final = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))

        b = self.bottleneck(self.pool3(e3))

        d3 = self.up3(b)
        e3 = self.att3(d3, e3)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        e2 = self.att2(d2, e2)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        e1 = self.att1(d1, e1)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.final(d1)

In [8]:
class EmbryoFocusStackDataset(Dataset):
    def __init__(self, base_paths, f0_path, embryo_ids, is_train=False):
        self.base_paths = base_paths
        self.f0_path = f0_path
        self.embryo_ids = embryo_ids
        self.is_train = is_train
        self.preprocess = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.embryo_ids)

    def __getitem__(self, idx):
        embryo_id = self.embryo_ids[idx]
        focal_pil_images = []

        for path in self.base_paths:
            embryo_subfolder = os.path.join(path, embryo_id)
            image_files = sorted([f for f in os.listdir(embryo_subfolder) 
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            if not image_files:
                raise FileNotFoundError(f"No image found in {embryo_subfolder}")
            img_path = os.path.join(embryo_subfolder, image_files[0])
            try:
                image = Image.open(img_path).convert('L')
                focal_pil_images.append(image)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                focal_pil_images.append(Image.new('L', (256, 256), 0))

        f0_subfolder = os.path.join(self.f0_path, embryo_id)
        f0_image_files = sorted([f for f in os.listdir(f0_subfolder) 
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        if not f0_image_files:
            raise FileNotFoundError(f"No F0 image found in {f0_subfolder}")
        f0_img_path = os.path.join(f0_subfolder, f0_image_files[0])
        try:
            f0_image = Image.open(f0_img_path).convert('L')
        except Exception as e:
            print(f"Error loading F0 image {f0_img_path}: {e}")
            f0_image = Image.new('L', (256, 256), 0)

        if self.is_train:
            angle = random.uniform(-30, 30)
            flip_h = random.random() < 0.5
            flip_v = random.random() < 0.5
        else:
            angle = 0
            flip_h = False
            flip_v = False

        transformed_images = []
        for img in focal_pil_images:
            img = TF.rotate(img, angle)
            if flip_h:
                img = TF.hflip(img)
            if flip_v:
                img = TF.vflip(img)
            img = self.preprocess(img)
            transformed_images.append(img)

        f0_image = TF.rotate(f0_image, angle)
        if flip_h:
            f0_image = TF.hflip(f0_image)
        if flip_v:
            f0_image = TF.vflip(f0_image)
        target_tensor = self.preprocess(f0_image)

        focal_tensors = transformed_images
        input_tensor = torch.cat(focal_tensors, dim=0)

        return input_tensor, target_tensor

In [9]:
l1_loss_fn = nn.L1Loss()  # Predefined L1 Loss
perceptual_loss_fn = lpips.LPIPS(net='vgg')  # Predefined Perceptual Loss using LPIPS
ssim_loss_fn = SSIM(data_range=1.0, size_average=True, channel=1)  # Predefined SSIM Loss
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0)  # SSIM metric for evaluation

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: c:\Projects\Embryo\embryo_env\Lib\site-packages\lpips\weights\v0.1\vgg.pth


In [10]:
def train_model(model, train_loader, val_loader, num_epochs=100, device='cuda'):
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
    best_val_loss = float('inf')
    patience_counter = 0
    max_patience = 20

    # Move loss functions to device
    perceptual_loss_fn.to(device)
    ssim_loss_fn.to(device)
    ssim_metric.to(device)

    # Lists to store SSIM scores for plotting
    train_ssim_scores = []
    val_ssim_scores = []

    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        running_train_ssim = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

            # Compute losses
            l1_loss = l1_loss_fn(outputs, targets)
            perceptual_loss = perceptual_loss_fn(outputs, targets).mean()  # LPIPS returns a tensor
            ssim_loss = 1 - ssim_loss_fn(outputs, targets)  # SSIM loss (1 - SSIM to convert to loss)
            loss = 0.3 * l1_loss + 0.3 * perceptual_loss + 0.4 * ssim_loss  # Weighted combination

            loss.backward()
            optimizer.step()
            running_train_loss += loss.item() * inputs.size(0)

            # Compute SSIM metric
            ssim_value = ssim_metric(outputs, targets)
            running_train_ssim += ssim_value.item() * inputs.size(0)

        train_loss = running_train_loss / len(train_loader.dataset)
        train_ssim = running_train_ssim / len(train_loader.dataset)
        train_ssim_scores.append(train_ssim)

        model.eval()
        running_val_loss = 0.0
        running_val_ssim = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)

                # Compute losses
                l1_loss = l1_loss_fn(outputs, targets)
                perceptual_loss = perceptual_loss_fn(outputs, targets).mean()
                ssim_loss = 1 - ssim_loss_fn(outputs, targets)
                loss = 0.3 * l1_loss + 0.3 * perceptual_loss + 0.4 * ssim_loss
                running_val_loss += loss.item() * inputs.size(0)

                # Compute SSIM metric
                ssim_value = ssim_metric(outputs, targets)
                running_val_ssim += ssim_value.item() * inputs.size(0)

        val_loss = running_val_loss / len(val_loader.dataset)
        val_ssim = running_val_ssim / len(val_loader.dataset)
        val_ssim_scores.append(val_ssim)

        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train SSIM: {train_ssim:.4f}, Val SSIM: {val_ssim:.4f}")
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'embryo_unet_fusion.pth')
            print(f"  [*] Model saved at epoch {epoch+1}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= max_patience:
                print("Early stopping due to no improvement in validation loss.")
                break

    return train_ssim_scores, val_ssim_scores

In [23]:
def test_single_embryo(model, image_paths, transform, device='cuda'):
    model.eval()
    focal_tensors = []
    for path in image_paths:
        img = Image.open(path).convert('L')
        img_tensor = transform(img)
        focal_tensors.append(img_tensor)
    
    input_tensor = torch.cat(focal_tensors, dim=0).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
    output_image = output.squeeze(0).cpu()
    fused_pil = transforms.ToPILImage()(output_image)
    
    fused_pil = fused_pil.filter(ImageFilter.UnsharpMask(radius=2, percent=200, threshold=3))
    fused_pil = TF.adjust_contrast(fused_pil, contrast_factor=1.5)
    return fused_pil


In [14]:
def main():
    base_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45"
    ]
    f0_path = r"C:\Projects\Embryo\Dataset\embryo_dataset"
    
    embryo_ids = get_common_embryo_ids(base_paths + [f0_path])
    print(f"Found {len(embryo_ids)} embryo IDs: {embryo_ids[:5]} ...")
    
    train_ratio = 0.8
    train_size = int(train_ratio * len(embryo_ids))
    train_indices = random.sample(range(len(embryo_ids)), train_size)
    val_indices = [i for i in range(len(embryo_ids)) if i not in train_indices]
    embryo_ids_train = [embryo_ids[i] for i in train_indices]
    embryo_ids_val = [embryo_ids[i] for i in val_indices]
    
    train_dataset = EmbryoFocusStackDataset(base_paths, f0_path, embryo_ids_train, is_train=True)
    val_dataset = EmbryoFocusStackDataset(base_paths, f0_path, embryo_ids_val, is_train=False)
    
    train_loader = DataLoader(train_dataset, batch_size=8, num_workers=0, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, num_workers=0, shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    model = UNet(in_channels=6, out_channels=1).to(device)
    
    print("Starting training...")
    train_ssim_scores, val_ssim_scores = train_model(model, train_loader, val_loader, num_epochs=100, device=device)
    print("Training complete. Best model saved as 'embryo_unet_fusion.pth'.")

    # Print SSIM scores for plotting
    print("\nSSIM Scores for Plotting:")
    print("Train SSIM Scores:", train_ssim_scores)
    print("Validation SSIM Scores:", val_ssim_scores)

    

In [15]:
if __name__ == "__main__":
    main()

Found 704 embryo IDs: ['AA83-7', 'AAL839-6', 'AB028-6', 'AB91-1', 'AC264-1'] ...
Using device: cuda
Starting training...
Epoch [1/100] Train Loss: 0.3514, Val Loss: 0.2197, Train SSIM: 0.4308, Val SSIM: 0.6166
  [*] Model saved at epoch 1
Epoch [2/100] Train Loss: 0.2058, Val Loss: 0.2210, Train SSIM: 0.6260, Val SSIM: 0.6283
Epoch [3/100] Train Loss: 0.1709, Val Loss: 0.1619, Train SSIM: 0.6827, Val SSIM: 0.7007
  [*] Model saved at epoch 3
Epoch [4/100] Train Loss: 0.1596, Val Loss: 0.1567, Train SSIM: 0.6978, Val SSIM: 0.7004
  [*] Model saved at epoch 4
Epoch [5/100] Train Loss: 0.1527, Val Loss: 0.1059, Train SSIM: 0.7074, Val SSIM: 0.8305
  [*] Model saved at epoch 5
Epoch [6/100] Train Loss: 0.1159, Val Loss: 0.0953, Train SSIM: 0.8064, Val SSIM: 0.8520
  [*] Model saved at epoch 6
Epoch [7/100] Train Loss: 0.1258, Val Loss: 0.1086, Train SSIM: 0.7744, Val SSIM: 0.8124
Epoch [8/100] Train Loss: 0.1171, Val Loss: 0.1291, Train SSIM: 0.7977, Val SSIM: 0.7579
Epoch [9/100] Train Lo

In [25]:
test_image_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
]
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=6, out_channels=1).to(device)
model.load_state_dict(torch.load('embryo_unet_fusion.pth', map_location=device))
print("Loaded trained model weights for testing.")
fused_image = test_single_embryo(model, test_image_paths, transform, device)
fused_image.save("fused_output.jpg")
fused_image.show()

Loaded trained model weights for testing.
