In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from skimage import exposure
from skimage.util import img_as_ubyte

# === Apply Adaptive Histogram Equalization ===
def apply_ahe(image):
    if isinstance(image, Image.Image):
        image = np.array(image)
    enhanced = exposure.equalize_adapthist(image, clip_limit=0.03)
    return Image.fromarray(img_as_ubyte(enhanced))

# === Simple U-Net ===
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=1):
        super(SimpleUNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, out_channels, 3, padding=1), nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# === Custom Dataset for 6 inputs per sample ===
class MyDataset(Dataset):
    def __init__(self, input_dirs, label_dirs, transform=None):
        self.input_dirs = input_dirs  # List of 3 input folders
        self.label_dirs = label_dirs  # List of 3 corresponding label folders
        self.transform = transform

        self.image_names = sorted(os.listdir(input_dirs[0]))  # Assume all folders aligned

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        input_tensors = []
        for input_dir in self.input_dirs:
            img_path = os.path.join(input_dir, self.image_names[idx])
            img = Image.open(img_path).convert("L")
            img = apply_ahe(img)
            input_tensors.append(transforms.ToTensor()(img))

        # Concatenate along channel dimension => (6, H, W)
        input_tensor = torch.cat(input_tensors, dim=0)

        # Use only one label folder for training
        label_path = os.path.join(self.label_dirs[0], self.image_names[idx])
        label_img = Image.open(label_path).convert("L")
        label_tensor = transforms.ToTensor()(label_img)

        return input_tensor, label_tensor

# === Train and Save Model ===
def train_and_save_model():
    input_dirs = [
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F15\embryo_dataset_F15",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F30\embryo_dataset_F30",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F45\embryo_dataset_F45"
    ]
    label_dirs = [
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F-15\embryo_dataset_F-15",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F-30\embryo_dataset_F-30",
        r"E:\EmbryoAnalysis\EmbryoAnalysis\Dataset\embryo_dataset_F-45\embryo_dataset_F-45"
    ]

    dataset = MyDataset(input_dirs, label_dirs)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleUNet().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(5):
        model.train()
        total_loss = 0.0
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Epoch {epoch+1}/5] Loss: {total_loss/len(dataloader):.4f}")

    torch.save(model.state_dict(), "trained_model.pth")
    print("[✓] Model saved as trained_model.pth")

# === Test the Saved Model ===
def test_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleUNet().to(device)
    model.load_state_dict(torch.load("trained_model.pth", map_location=device))
    model.eval()

    input_dirs = [
        "test/F15", "test/F30", "test/F45"
    ]
    test_names = sorted(os.listdir(input_dirs[0]))
    os.makedirs("output", exist_ok=True)

    for name in test_names:
        inputs = []
        for folder in input_dirs:
            img_path = os.path.join(folder, name)
            img = Image.open(img_path).convert("L")
            img = apply_ahe(img)
            inputs.append(transforms.ToTensor()(img))
        input_tensor = torch.cat(inputs, dim=0).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_tensor)

        out_img = output.squeeze().cpu().numpy() * 255
        out_img = Image.fromarray(out_img.astype(np.uint8))
        out_img.save(f"output/out_{name}")
        print(f"[✓] Saved: output/out_{name}")

# === Run ===
if __name__ == "__main__":
    train_and_save_model()
    test_model()

PermissionError: [Errno 13] Permission denied: 'E:\\EmbryoAnalysis\\EmbryoAnalysis\\Dataset\\embryo_dataset_F15\\embryo_dataset_F15\\RLFS800-2'