In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.models import vgg16
from pytorch_msssim import ms_ssim
from PIL import Image, ImageFile
import numpy as np
import pandas as pd
import random
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# ─────────────────────────────────────────────────────────────────────────────
# 1. Reproducibility
# ─────────────────────────────────────────────────────────────────────────────
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ─────────────────────────────────────────────────────────────────────────────
# 2. Utility: common embryo IDs
# ─────────────────────────────────────────────────────────────────────────────
def get_common_embryo_ids(base_paths):
    sets = [set(os.listdir(p)) for p in base_paths]
    common = set.intersection(*sets)
    return sorted(common)

# ─────────────────────────────────────────────────────────────────────────────
# 3. U-Net + Residual Head
# ─────────────────────────────────────────────────────────────────────────────
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch), nn.ReLU(True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch), nn.ReLU(True),
        )
    def forward(self,x): return self.net(x)

class UNet(nn.Module):
    def __init__(self, in_ch=6, out_ch=1):
        super().__init__()
        self.enc1 = DoubleConv(in_ch, 64)
        self.pool = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64,128)
        self.enc3 = DoubleConv(128,256)
        self.enc4 = DoubleConv(256,512)
        self.enc5 = DoubleConv(512,1024)
        self.up  = lambda ic, oc: nn.ConvTranspose2d(ic, oc, 2, 2)
        self.dec4 = DoubleConv(1024+512, 512)
        self.dec3 = DoubleConv(512+256, 256)
        self.dec2 = DoubleConv(256+128, 128)
        self.dec1 = DoubleConv(128+64, 64)
        self.outc= nn.Conv2d(64, out_ch, 1)

    def forward(self,x):
        c1 = self.enc1(x); p1 = self.pool(c1)
        c2 = self.enc2(p1); p2 = self.pool(c2)
        c3 = self.enc3(p2); p3 = self.pool(c3)
        c4 = self.enc4(p3); p4 = self.pool(c4)
        c5 = self.enc5(p4)
        u4= self.up(1024,512)(c5); d4= self.dec4(torch.cat([u4,c4],1))
        u3= self.up(512,256)(d4);  d3= self.dec3(torch.cat([u3,c3],1))
        u2= self.up(256,128)(d3);  d2= self.dec2(torch.cat([u2,c2],1))
        u1= self.up(128,64)(d2);   d1= self.dec1(torch.cat([u1,c1],1))
        return self.outc(d1)

class UNetResidual(UNet):
    def forward(self,x):
        # baseline: average focal stack
        base = x.mean(dim=1, keepdim=True)   # (N,1,H,W)
        res  = super().forward(x)            # (N,1,H,W) residual
        fused= torch.clamp(base + res, 0.0, 1.0)
        return fused

# ─────────────────────────────────────────────────────────────────────────────
# 4. Dataset
# ─────────────────────────────────────────────────────────────────────────────
class EmbryoT4Dataset(Dataset):
    def __init__(self, base_paths, target_path, phase_csv_dir,
                 embryo_ids, transform=None,
                 num_t4_embryos=50, num_other_embryos=0):
        assert len(base_paths)==6, "Need 6 focal dirs"
        self.base_paths, self.target_path = base_paths, target_path
        self.phase_dir, self.transform = phase_csv_dir, transform

        # find t4 windows
        t4s=[]
        for eid in embryo_ids:
            csv = os.path.join(phase_csv_dir, f"{eid}_phases.csv")
            if not os.path.exists(csv): continue
            df = pd.read_csv(csv, header=None, names=['phase','s','e'])
            r = df[df.phase=='t4']
            if not r.empty and r.s.iloc[0]<=r.e.iloc[0]:
                t4s.append((eid,r.s.iloc[0],r.e.iloc[0]))
        self.t4s = random.sample(t4s, min(num_t4_embryos,len(t4s)))
        t4ids = {eid for eid,_,_ in self.t4s}
        others= [eid for eid in embryo_ids if eid not in t4ids]
        self.others= random.sample(others, min(num_other_embryos,len(others)))

        # build samples
        self.samples=[]
        for eid,s,e in self.t4s:
            files = sorted(os.listdir(os.path.join(base_paths[0],eid)))
            for i,fn in enumerate(files,1):
                if s<=i<=e: self.samples.append((eid,fn))
        for eid in self.others:
            files = sorted(os.listdir(os.path.join(base_paths[0],eid)))
            pick = random.sample(files, min(2,len(files)))
            for fn in pick: self.samples.append((eid,fn))

        if not self.samples:
            raise RuntimeError("No samples found.")

    def __len__(self): return len(self.samples)

    def __getitem__(self,idx):
        eid,fn = self.samples[idx]
        # load 6 focals
        focals=[]
        for p in self.base_paths:
            img = Image.open(os.path.join(p,eid,fn)).convert('L')
            focals.append(np.array(img))
        x = np.stack(focals,axis=-1)  # H,W,6
        # load target F0
        tgt = Image.open(os.path.join(self.target_path, eid,fn)).convert('L')
        tgt = np.array(tgt)[...,None]  # H,W,1

        aug = self.transform(image=x, target=tgt)
        inp = aug['image'].float()     # 6,H,W
        out = aug['target'].float()    # 1,H,W
        return inp, out

# ─────────────────────────────────────────────────────────────────────────────
# 5. Losses: L1 + MS-SSIM + Perceptual
# ─────────────────────────────────────────────────────────────────────────────
# VGG feature extractor
def make_vgg(device):
    vgg = vgg16(pretrained=True).features[:9].to(device).eval()  # up to conv2_2
    for p in vgg.parameters(): p.requires_grad=False
    return vgg

def perceptual_loss(pred, tgt, vgg):
    # rep 3-channel
    p = vgg(pred.repeat(1,3,1,1))
    t = vgg(tgt.repeat(1,3,1,1))
    return F.l1_loss(p, t)

def combined_fusion_loss(pred, tgt, vgg):
    l1   = F.l1_loss(pred, tgt)
    ssim = 1 - ms_ssim(pred, tgt, data_range=1.0, size_average=True)
    perc = perceptual_loss(pred, tgt, vgg)
    return l1 + 0.5*ssim + 0.1*perc

# ─────────────────────────────────────────────────────────────────────────────
# 6. Training loop
# ─────────────────────────────────────────────────────────────────────────────
def train_model(model, train_loader, val_loader,
                epochs=50, device='cuda', lr=1e-4, use_amp=True):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     patience=3,
                                                     factor=0.5)
    scaler = GradScaler() if use_amp else None
    vgg    = make_vgg(device)

    best_val = float('inf')
    for ep in range(1, epochs+1):
        #---- train ----
        model.train()
        running = 0
        for x,y in train_loader:
            x,y = x.to(device), y.to(device)
            optimizer.zero_grad()
            if use_amp:
                with autocast():
                    pred = model(x)
                    loss = combined_fusion_loss(pred, y, vgg)
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                pred = model(x)
                loss = combined_fusion_loss(pred, y, vgg)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            running += loss.item() * x.size(0)
        train_loss = running / len(train_loader.dataset)

        #---- validate ----
        model.eval()
        running = 0
        with torch.no_grad():
            for x,y in val_loader:
                x,y = x.to(device), y.to(device)
                pred = model(x)
                running += combined_fusion_loss(pred, y, vgg).item() * x.size(0)
        val_loss = running / len(val_loader.dataset)
        scheduler.step(val_loss)

        print(f"Epoch {ep}/{epochs}  Train: {train_loss:.4f}  Val: {val_loss:.4f}")
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), 'embryo_unet_t4_residual.pth')
            print("  [*] Saved new best model")

# ─────────────────────────────────────────────────────────────────────────────
# 7. Main entry
# ─────────────────────────────────────────────────────────────────────────────
def main_train(seed=42):
    set_seed(seed)

    base_paths = [
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F45",
        r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45"
    ]
    target_path   = r"C:\Projects\Embryo\Dataset\embryo_dataset_F0"
    phase_csv_dir = r"C:\Projects\Embryo\Dataset\embryo_dataset_annotations"

    embryo_ids = get_common_embryo_ids(base_paths)
    print(f"Found {len(embryo_ids)} embryos")

    transform = A.Compose([
        A.RandomRotate90(),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.1, rotate_limit=15, p=0.5),
        A.Resize(256,256),
        A.Normalize(mean=[0.5]*6, std=[0.5]*6),
        ToTensorV2()
    ], additional_targets={'target': 'image'})

    dataset = EmbryoT4Dataset(
        base_paths, target_path, phase_csv_dir,
        embryo_ids, transform,
        num_t4_embryos=50, num_other_embryos=0
    )
    train_len = int(0.8*len(dataset))
    val_len   = len(dataset)-train_len
    train_ds, val_ds = random_split(dataset, [train_len, val_len])
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=4, pin_memory=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)
    model = UNetResidual(in_ch=6, out_ch=1).to(device)

    train_model(model, train_loader, val_loader,
                epochs=50, device=device, lr=1e-4, use_amp=True)
    print("Training complete.")

if __name__ == "__main__":
    main_train(seed=121)
    

Found 704 embryos


  original_init(self, **validated_kwargs)


Device: cuda


  scaler = GradScaler() if use_amp else None
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\parth/.cache\torch\hub\checkpoints\vgg16-397923af.pth
100.0%
