In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image, ImageFile, ImageFilter
import numpy as np
import random
from torchvision.transforms import functional as TF
import torch.nn.functional as F
from torchmetrics.image import StructuralSimilarityIndexMeasure
import lpips
from pytorch_msssim import SSIM
import matplotlib.pyplot as plt
import logging
from torch.cuda.amp import autocast, GradScaler

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Set up logging to file (optional)
logging.basicConfig(filename='training_image_log.txt', level=logging.INFO, format='%(message)s')

# Utility function to find common embryo IDs across directories
def get_common_embryo_ids(base_paths):
    sets_of_ids = []
    for path in base_paths:
        subfolders = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
        sets_of_ids.append(set(subfolders))
    common_ids = set.intersection(*sets_of_ids)
    return sorted(list(common_ids))

# Utility function to get frames for an embryo
def get_common_frames_for_embryo(embryo_id, base_paths, f0_path):
    f0_subfolder = os.path.join(f0_path, embryo_id)
    f0_files = sorted([f for f in os.listdir(f0_subfolder) 
                       if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    if not f0_files:
        return []

    common_frames = []
    for frame in f0_files:
        frame_exists_everywhere = True
        for path in base_paths:
            subfolder = os.path.join(path, embryo_id)
            if not os.path.exists(os.path.join(subfolder, frame)):
                frame_exists_everywhere = False
                break
        if frame_exists_everywhere:
            common_frames.append(frame)
    return common_frames

# Attention Block for U-Net
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# U-Net with Attention Gates
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        def CBR(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )

        self.enc1 = CBR(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = CBR(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = CBR(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = CBR(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.dec3 = CBR(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.dec2 = CBR(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.dec1 = CBR(128, 64)

        self.final = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))

        b = self.bottleneck(self.pool3(e3))

        d3 = self.up3(b)
        e3 = self.att3(d3, e3)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        e2 = self.att2(d2, e2)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        e1 = self.att1(d1, e1)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.final(d1)

# Dataset for Embryo Image Fusion (Handles all frames with sampling)
class EmbryoFocusStackDataset(Dataset):
    def __init__(self, base_paths, f0_path, embryo_ids, is_train=False, max_frames_per_embryo=10):
        self.base_paths = base_paths
        self.f0_path = f0_path
        self.is_train = is_train
        self.max_frames_per_embryo = max_frames_per_embryo
        self.preprocess = transforms.Compose([
            transforms.Resize((128, 128)),  # Reduced image size
            transforms.ToTensor()
        ])
        
        self.data_pairs = []
        for embryo_id in embryo_ids:
            frames = get_common_frames_for_embryo(embryo_id, base_paths, f0_path)
            if len(frames) > self.max_frames_per_embryo:
                frames = random.sample(frames, self.max_frames_per_embryo)
            for frame in frames:
                self.data_pairs.append((embryo_id, frame))
        
        print(f"Total samples (embryo-frame pairs): {len(self.data_pairs)}")

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        embryo_id, frame = self.data_pairs[idx]
        focal_pil_images = []

        for path in self.base_paths:
            embryo_subfolder = os.path.join(path, embryo_id)
            img_path = os.path.join(embryo_subfolder, frame)
            try:
                image = Image.open(img_path).convert('L')
                focal_pil_images.append(image)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                focal_pil_images.append(Image.new('L', (128, 128), 0))

        f0_subfolder = os.path.join(self.f0_path, embryo_id)
        f0_img_path = os.path.join(f0_subfolder, frame)
        try:
            f0_image = Image.open(f0_img_path).convert('L')
        except Exception as e:
            print(f"Error loading F0 image {f0_img_path}: {e}")
            f0_image = Image.new('L', (128, 128), 0)

        if self.is_train:
            angle = random.uniform(-30, 30)
            flip_h = random.random() < 0.5
            flip_v = random.random() < 0.5
        else:
            angle = 0
            flip_h = False
            flip_v = False

        transformed_images = []
        for img in focal_pil_images:
            img = TF.rotate(img, angle)
            if flip_h:
                img = TF.hflip(img)
            if flip_v:
                img = TF.vflip(img)
            img = self.preprocess(img)
            transformed_images.append(img)

        f0_image = TF.rotate(f0_image, angle)
        if flip_h:
            f0_image = TF.hflip(f0_image)
        if flip_v:
            f0_image = TF.vflip(f0_image)
        target_tensor = self.preprocess(f0_image)

        focal_tensors = transformed_images
        input_tensor = torch.cat(focal_tensors, dim=0)

        return input_tensor, target_tensor, (embryo_id, frame)

# Custom collate function
def custom_collate(batch):
    inputs = torch.stack([item[0] for item in batch])
    targets = torch.stack([item[1] for item in batch])
    embryo_frame_pairs = [item[2] for item in batch]
    return inputs, targets, embryo_frame_pairs

# Loss Functions
l1_loss_fn = nn.L1Loss()
ssim_loss_fn = SSIM(data_range=1.0, size_average=True, channel=1)
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0)

# Training Function with Optimizations
def train_model(model, train_loader, val_loader, base_paths, f0_path, num_epochs=1, device='cuda'):
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
    best_val_loss = float('inf')
    patience_counter = 0
    max_patience = 20
    scaler = GradScaler()

    ssim_loss_fn.to(device)
    ssim_metric.to(device)

    train_ssim_scores = []
    val_ssim_scores = []

    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        running_train_ssim = 0.0
        total_batches = len(train_loader)
        for batch_idx, (inputs, targets, embryo_frame_pairs) in enumerate(train_loader):
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}/{total_batches}")
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            with autocast():
                outputs = model(inputs)
                l1_loss = l1_loss_fn(outputs, targets)
                ssim_loss = 1 - ssim_loss_fn(outputs, targets)
                loss = 0.5 * l1_loss + 0.5 * ssim_loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_train_loss += loss.item() * inputs.size(0)

            ssim_value = ssim_metric(outputs, targets)
            running_train_ssim += ssim_value.item() * inputs.size(0)

        train_loss = running_train_loss / len(train_loader.dataset)
        train_ssim = running_train_ssim / len(train_loader.dataset)
        train_ssim_scores.append(train_ssim)

        model.eval()
        running_val_loss = 0.0
        running_val_ssim = 0.0
        with torch.no_grad():
            for inputs, targets, embryo_frame_pairs in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                with autocast():
                    outputs = model(inputs)
                    l1_loss = l1_loss_fn(outputs, targets)
                    ssim_loss = 1 - ssim_loss_fn(outputs, targets)
                    loss = 0.5 * l1_loss + 0.5 * ssim_loss
                running_val_loss += loss.item() * inputs.size(0)

                ssim_value = ssim_metric(outputs, targets)
                running_val_ssim += ssim_value.item() * inputs.size(0)

        val_loss = running_val_loss / len(val_loader.dataset)
        val_ssim = running_val_ssim / len(val_loader.dataset)
        val_ssim_scores.append(val_ssim)

        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train SSIM: {train_ssim:.4f}, Val SSIM: {val_ssim:.4f}")
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'embryo_unet_fusion.pth')
            print(f"  [*] Model saved at epoch {epoch+1}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= max_patience:
                print("Early stopping due to no improvement in validation loss.")
                break

    return train_ssim_scores, val_ssim_scores

# Test Function with Visualization
def test_single_embryo(model, image_paths, f0_path, transform, device='cuda'):
    model.eval()
    focal_tensors = []
    for path in image_paths:
        img = Image.open(path).convert('L')
        img_tensor = transform(img)
        focal_tensors.append(img_tensor)
    
    input_tensor = torch.cat(focal_tensors, dim=0).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
    output_image = output.squeeze(0).cpu()
    fused_pil = transforms.ToPILImage()(output_image)
    
    fused_pil = fused_pil.filter(ImageFilter.UnsharpMask(radius=2, percent=200, threshold=3))
    fused_pil = TF.adjust_contrast(fused_pil, contrast_factor=1.5)

    input_pil = Image.open(image_paths[0]).convert('L')
    input_pil = transform(input_pil).squeeze(0).cpu()
    input_pil = transforms.ToPILImage()(input_pil)

    embryo_id = os.path.basename(os.path.dirname(image_paths[0]))
    f0_subfolder = os.path.join(f0_path, embryo_id)
    frame = os.path.basename(image_paths[0])
    f0_img_path = os.path.join(f0_subfolder, frame)
    target_pil = Image.open(f0_img_path).convert('L')
    target_pil = transform(target_pil).squeeze(0).cpu()
    target_pil = transforms.ToPILImage()(target_pil)

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(input_pil, cmap='gray')
    axes[0].set_title("Input (F15)")
    axes[0].axis('off')
    axes[1].imshow(target_pil, cmap='gray')
    axes[1].set_title("Target (F0)")
    axes[1].axis('off')
    axes[2].imshow(fused_pil, cmap='gray')
    axes[2].set_title("Model Output")
    axes[2].axis('off')
    plt.savefig("comparison.png")
    plt.close()

    return fused_pil

# Main Execution
def main():
    base_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45"
    ]
    f0_path = r"C:\Projects\Embryo\Dataset\embryo_dataset"
    
    embryo_ids = get_common_embryo_ids(base_paths + [f0_path])
    print(f"Found {len(embryo_ids)} embryo IDs: {embryo_ids[:5]} ...")
    
    train_ratio = 0.8
    train_size = int(train_ratio * len(embryo_ids))
    train_indices = random.sample(range(len(embryo_ids)), train_size)
    val_indices = [i for i in range(len(embryo_ids)) if i not in train_indices]
    embryo_ids_train = [embryo_ids[i] for i in train_indices]
    embryo_ids_val = [embryo_ids[i] for i in val_indices]
    
    train_dataset = EmbryoFocusStackDataset(base_paths, f0_path, embryo_ids_train, is_train=True, max_frames_per_embryo=10)
    val_dataset = EmbryoFocusStackDataset(base_paths, f0_path, embryo_ids_val, is_train=False, max_frames_per_embryo=10)
    
    train_loader = DataLoader(train_dataset, batch_size=8, num_workers=4, shuffle=True, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=4, num_workers=4, shuffle=False, collate_fn=custom_collate)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    model = UNet(in_channels=6, out_channels=1).to(device)
    
    print("Starting training...")
    train_ssim_scores, val_ssim_scores = train_model(model, train_loader, val_loader, base_paths, f0_path, num_epochs=1, device=device)
    print("Training complete. Best model saved as 'embryo_unet_fusion.pth'.")

    print("\nSSIM Scores for Plotting:")
    print("Train SSIM Scores:", train_ssim_scores)
    print("Validation SSIM Scores:", val_ssim_scores)

    test_image_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30\AM716-7\D2013.07.01_S0867_I132_WELL7_RUN209.jpeg"
    ]
    
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])
    
    model.load_state_dict(torch.load('embryo_unet_fusion.pth', map_location=device))
    print("Loaded trained model weights for testing.")
    fused_image = test_single_embryo(model, test_image_paths, f0_path, transform, device)
    fused_image.save("fused_output.jpg")
    fused_image.show()
    print("Comparison plot saved as 'comparison.png'.")

if __name__ == "__main__":
    main()

Found 704 embryo IDs: ['AA83-7', 'AAL839-6', 'AB028-6', 'AB91-1', 'AC264-1'] ...
Total samples (embryo-frame pairs): 5630
Total samples (embryo-frame pairs): 1410
Using device: cuda
Starting training...


  scaler = GradScaler()
