In [1]:
WORKING_DIR = "/home/xavier/Documents/DAE_project"

# Get features and reconstructed movies

In [3]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import dnnlib
import legacy
import pickle
import shutil
import re
import functools
from training.networks_stylegan2 import SynthesisLayer, MinibatchStdLayer
from concurrent.futures import ThreadPoolExecutor

# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================
target = "WT"  # "Roy_training"
CONFIG = {
    # --- Core Paths ---
    "root_dir": f"{WORKING_DIR}/dataset/{target}/images",
    "out_dir": f"{WORKING_DIR}/tmp/{target}_features",
    "videos_dir": f"{WORKING_DIR}/videos/{target}",
    "reconstruction_dir": f"{WORKING_DIR}/reconstructions/{target}",
    "cropped_original_dir": f"{WORKING_DIR}/cropped_original/{target}",
    "model_pkl_path": f"{WORKING_DIR}/models/network-snapshot-001512.pkl",

    # --- Discovery & Parsing ---
    "run_id_regex": r'Run(\d+)',
    "scope_id_regex": r'Scope(\d+)',

    # --- Processing Settings ---
    "network_idx": 13,
    "batch_size": 16,
    "resize_by": 1.0,
    "resolution": 512,
    "frames_to_process": 1441,
    "num_workers": 8,

    # --- Output Settings ---
    "save_video": False,
    "smooth_video": True,
    "save_reconstructions": False,
    "save_cropped_original": True,
    "selected_frames": [0, 1440],
    "selected_feature_frames": [0, 1440],  # None for all, or a list like [0, 100, 1440]
    "reconstruction_seed": 42,
}

# Set environment variables for StyleGAN2
os.environ['CC'] = "/usr/bin/gcc-9"
os.environ['CXX'] = "/usr/bin/g++-9"


# ==============================================================================
# 2. HELPER FUNCTIONS
# ==============================================================================
def reset_noise_const(G, seed):
    """Resets the noise constants in the synthesis network with a fixed seed."""
    torch.manual_seed(seed)

    for block in G.synthesis.children():
        for layer in block.children():
            if layer.__class__.__name__ == "SynthesisLayer":
                resolution = layer.resolution
                with torch.no_grad():
                    layer.noise_const.copy_(torch.randn([resolution, resolution]))


def resize_crop(img_name, strain_dir, resize_by=1.0, resolution=512, brightness_norm=True, brightness_mean=107.2):
    """Loads, resizes, and crops a single image, ensuring it is grayscale."""
    img_path = os.path.join(strain_dir, img_name)
    if not os.path.exists(img_path): return None
    # Force load as grayscale to ensure 1 channel consistently.
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None: return None

    # Resize if needed
    if resize_by != 1.0:
        img_shape = img.shape
        resize_shape = (int(img_shape[1] * resize_by), int(img_shape[0] * resize_by))
        img = cv2.resize(img, resize_shape, cv2.INTER_LANCZOS4)

    # Crop the image from the center
    h, w = img.shape
    start_y = (h - resolution) // 2
    start_x = (w - resolution) // 2
    img = img[start_y:start_y + resolution, start_x:start_x + resolution]

    if brightness_norm:
        # Apply brightness normalization
        obj_v = np.mean(img)
        value = brightness_mean - obj_v
        # cv2.add handles clipping (keeping values between 0-255) automatically.
        img = cv2.add(img, value)
    return img


def smooth_array(arr, window_size=10):
    """Smooths a 1D or 2D numpy array using a moving average."""
    kernel = np.ones(window_size) / window_size
    smoothed_arr = np.zeros_like(arr)
    if arr.ndim == 1: return np.convolve(arr, kernel, mode='same')
    for i in range(arr.shape[1]): smoothed_arr[:, i] = np.convolve(arr[:, i], kernel, mode='same')
    return smoothed_arr


def get_final_frame_info(scope_dir):
    """Scans a directory and returns the filename and index of the single image with the highest frame number."""
    if not os.path.isdir(scope_dir): return None, None
    idx_name_dict = {}
    max_index = -1
    for img_name in os.listdir(scope_dir):
        try:
            img_idx = int(img_name.split("_")[-1].split(".")[0])
            idx_name_dict[img_idx] = img_name
            if img_idx > max_index:
                max_index = img_idx
        except (ValueError, IndexError):
            continue
    if max_index != -1:
        return max_index, idx_name_dict[max_index]
    return None, None


def get_image_frame_list(scope_dir, max_index, total_frames):
    """Creates a list of (frame_index, image_name) tuples, filling in missing frames."""
    img_names = os.listdir(scope_dir)
    idx_name_dict = {int(name.split("_")[-1].split(".")[0]): name for name in img_names if
                     name.split("_")[-1].split(".")[0].isdigit()}
    if not total_frames or total_frames > max_index + 1: total_frames = max_index + 1
    idx_name_list = []
    for frame in range(total_frames):
        if frame in idx_name_dict:
            idx_name_list.append((frame, idx_name_dict[frame]))
        else:
            earlier_frame, later_frame = frame - 1, frame + 1
            while True:
                if earlier_frame in idx_name_dict:
                    idx_name_list.append((frame, idx_name_dict[earlier_frame]))
                    break
                elif later_frame in idx_name_dict:
                    idx_name_list.append((frame, idx_name_dict[later_frame]))
                    break
                earlier_frame -= 1
                later_frame += 1
                if earlier_frame < -10000:
                    idx_name_list.append((frame, "placeholder.jpg"))
                    break
    return np.array(idx_name_list)


# ==============================================================================
# 3. CORE LOGIC FUNCTIONS
# ==============================================================================
def discover_scopes(root_dir, run_regex, scope_regex):
    """Scans the root directory to find all experiment scopes to be processed."""
    print(f"Scanning for experiments in: {root_dir}")
    discovered_jobs = []
    if not os.path.isdir(root_dir): return discovered_jobs
    for strain_folder in os.listdir(root_dir):
        strain_path = os.path.join(root_dir, strain_folder)
        if not os.path.isdir(strain_path): continue
        run_match = re.search(run_regex, strain_folder)
        if not run_match: continue
        for scope_folder in os.listdir(strain_path):
            scope_path = os.path.join(strain_path, scope_folder)
            if not os.path.isdir(scope_path): continue
            scope_match = re.search(scope_regex, scope_folder)
            if not scope_match: continue
            job = {"run_id": int(run_match.group(1)), "scope_id": int(scope_match.group(1)), "strain": strain_folder,
                   "scope": scope_folder}
            discovered_jobs.append(job)
    print(f"-> Discovered {len(discovered_jobs)} potential scopes to process.")
    return discovered_jobs


def load_and_patch_model(pkl_path):
    """Loads the StyleGAN2 model, patching it if necessary for feature extraction."""
    patched_pkl_path = pkl_path.replace('.pkl', '-patched.pkl')
    if os.path.exists(patched_pkl_path):
        print(f"Loading already patched model from: {patched_pkl_path}")
        with dnnlib.util.open_url(patched_pkl_path) as fp: return legacy.load_network_pkl(fp)
    print(f"Loading original model from: {pkl_path}")
    with dnnlib.util.open_url(pkl_path) as fp:
        models = legacy.load_network_pkl(fp)
    E_ema = models['E_ema']
    # The patch ensures MinibatchStdLayer is present.
    if not hasattr(E_ema.b4.get_mu, 'mbstd') or not isinstance(E_ema.b4.get_mu.mbstd, MinibatchStdLayer):
        print("Model needs patching...")
        E_ema.b4.get_mu.mbstd = MinibatchStdLayer(group_size=4).eval()
        E_ema.b4.get_logvar.mbstd = MinibatchStdLayer(group_size=4).eval()
        with open(patched_pkl_path, 'wb') as f_out:
            pickle.dump(models, f_out)
        print(f"Saved patched model to: {patched_pkl_path}")
    else:
        print("Model already patched.")
        if not os.path.exists(patched_pkl_path):
            with open(patched_pkl_path, 'wb') as f_out: pickle.dump(models, f_out)
    return models


def generate_videos(config, run_id, scope_id, G, zs_numpy, processed_imgs, device):
    """Generates videos from original frames and synthesized frames."""
    video_scope_dir = os.path.join(config['videos_dir'], f"Run{run_id:04d}", f"Scope{scope_id:02d}")
    if os.path.isdir(video_scope_dir) and len(os.listdir(video_scope_dir)) > 0:
        print("  - Videos already exist. Skipping generation.")
        return

    os.makedirs(video_scope_dir, exist_ok=True)
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    resolution_tuple = (config['resolution'], config['resolution'])

    # --- Original Video ---
    video_path_orig = os.path.join(video_scope_dir, "original.avi")
    video_out_orig = cv2.VideoWriter(video_path_orig, fourcc, 24.0, resolution_tuple, isColor=False)
    for img in processed_imgs:
        if img is not None: video_out_orig.write(img)
    video_out_orig.release()

    # --- Generated Video ---
    video_path_gen = os.path.join(video_scope_dir, f"generated_net{config['network_idx']}.avi")
    video_out_gen = cv2.VideoWriter(video_path_gen, fourcc, 24.0, resolution_tuple, isColor=False)
    for i in range(0, len(zs_numpy), config['batch_size']):
        z_batch_numpy = zs_numpy[i:i + config['batch_size']]
        valid_rows = ~np.isnan(z_batch_numpy).any(axis=1)
        if not np.any(valid_rows): continue
        z_batch = torch.Tensor(z_batch_numpy[valid_rows]).to(device)
        synth_image = G(z_batch, None, noise_mode='const')
        synth_image = (synth_image + 1) * 127.5
        synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy()[:, :, :, 0]
        for frame_img in synth_image: video_out_gen.write(frame_img)
    video_out_gen.release()

    # --- Smoothed Generated Video ---
    if config['smooth_video']:
        video_path_smooth = os.path.join(video_scope_dir, f"generated_net{config['network_idx']}_smoothed.avi")
        video_out_smooth = cv2.VideoWriter(video_path_smooth, fourcc, 24.0, resolution_tuple, isColor=False)
        zs_smoothed = smooth_array(zs_numpy)
        for i in range(0, len(zs_smoothed), config['batch_size']):
            z_batch = torch.Tensor(zs_smoothed[i:i + config['batch_size']]).to(device)
            synth_image = G(z_batch, None, noise_mode='const')
            synth_image = (synth_image + 1) * 127.5
            synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy()[:, :, :, 0]
            for frame_img in synth_image: video_out_smooth.write(frame_img)
        video_out_smooth.release()
    print("  - Finished video generation.")


def process_scope_features_and_videos(job, config, models):
    """
    Main processing function for a single scope.
    Checks for existing .npz files, extracts features, and generates outputs.
    Can process all frames or a selected subset of frames for feature extraction.
    """
    run_id, scope_id = job['run_id'], job['scope_id']
    out_dir = os.path.join(config['out_dir'], f"Run{run_id:04d}", f"Scope{scope_id:02d}")
    os.makedirs(out_dir, exist_ok=True)
    scope_dir = os.path.join(config['root_dir'], job['strain'], job['scope'])
    npz_path = os.path.join(out_dir, "features.npz")

    max_frame_idx, _ = get_final_frame_info(scope_dir)
    if max_frame_idx is None:
        print(f"  - ERROR: No valid image files found in {scope_dir}")
        return False

    idx_name_list = get_image_frame_list(scope_dir, max_frame_idx, config['frames_to_process'])
    total_frames = len(idx_name_list)

    # --- Load existing data or initialize new arrays ---
    zs, logvars, wss = None, None, None
    if os.path.exists(npz_path):
        try:
            data = np.load(npz_path)
            if data['z'].shape[0] == total_frames:
                print("  - Found existing features.npz. Checking for completeness.")
                zs, logvars, wss = data['z'], data['var'], data['w']
            else:
                print("  - Existing NPZ has mismatched frame count. Reprocessing.")
        except Exception as e:
            print(f"  - Could not load or parse existing NPZ file: {e}. Reprocessing.")

    if zs is None:
        print("  - Initializing new feature arrays.")
        zs = np.ones([total_frames, models['G_ema'].z_dim], dtype=np.float32) * np.nan
        logvars = np.ones([total_frames, models['G_ema'].z_dim], dtype=np.float32) * np.nan
        wss = np.ones([total_frames, models['G_ema'].num_ws, models['G_ema'].w_dim], dtype=np.float32) * np.nan

    # --- Load all necessary images in parallel ---
    print("  - Loading and preprocessing image frames...")
    resize_crop_partial = functools.partial(resize_crop, strain_dir=scope_dir, resize_by=config['resize_by'],
                                            resolution=config['resolution'])
    with ThreadPoolExecutor(max_workers=config['num_workers']) as executor:
        processed_imgs = list(executor.map(resize_crop_partial, idx_name_list[:, 1]))

    # --- Run feature extraction ---
    device = torch.device('cuda')
    G = models['G_ema'].eval().to(device)

    # Determine which frames to process for feature extraction
    potential_indices = config.get('selected_feature_frames') or range(total_frames)
    frames_to_process_indices = sorted([
        idx for idx in potential_indices if idx < total_frames and np.isnan(zs[idx, 0])
    ])

    if not frames_to_process_indices:
        print("  - No new frames to process for feature extraction.")
    else:
        if config.get('selected_feature_frames'):
            print(f"  - Targeted processing for {len(frames_to_process_indices)} selected frames.")
        else:
            print(f"  - Processing {len(frames_to_process_indices)} remaining frames.")

        E = models['E_ema'].eval().to(device)

        for i in range(0, len(frames_to_process_indices), config['batch_size']):
            batch_indices = frames_to_process_indices[i:i + config['batch_size']]

            img_batch_list = [processed_imgs[idx] for idx in batch_indices if
                              idx < len(processed_imgs) and processed_imgs[idx] is not None]
            valid_batch_indices = [idx for idx in batch_indices if
                                   idx < len(processed_imgs) and processed_imgs[idx] is not None]

            if not img_batch_list:
                continue

            img_batch_numpy = np.array(img_batch_list)
            img_batch = torch.from_numpy(img_batch_numpy).to(device).to(torch.float32) / 127.5 - 1
            img_batch = img_batch.unsqueeze(1)

            z, logvar = E.mu_var(img_batch, None)
            ws = G.mapping(z, None)

            zs[valid_batch_indices] = z.cpu().numpy()
            logvars[valid_batch_indices] = logvar.cpu().numpy()
            wss[valid_batch_indices] = ws.cpu().numpy()

        np.savez(npz_path, w=wss, z=zs, var=logvars)
        print(f"  - Saved/updated features to: {npz_path}")

    # --- Generate videos, reconstructions, etc. ---
    if config['save_video']:
        generate_videos(config, run_id, scope_id, G, zs, processed_imgs, device)

    if config['save_reconstructions']:
        save_reconstructions(config, run_id, scope_id, G, zs, device)

    if config['save_cropped_original']:
        save_cropped_originals(config, run_id, scope_id, processed_imgs)

    return True


def save_cropped_originals(config, run_id, scope_id, processed_imgs):
    """Saves the cropped original images for specified frames."""
    cropped_original_dir = os.path.join(config['cropped_original_dir'], f"Run{run_id:04d}", f"Scope{scope_id:02d}")
    os.makedirs(cropped_original_dir, exist_ok=True)

    print(f"  - Saving cropped original frames: {config['selected_frames']}")

    for frame_idx in config['selected_frames']:
        if frame_idx < len(processed_imgs) and processed_imgs[frame_idx] is not None:
            original_img = processed_imgs[frame_idx]
            original_img_path = os.path.join(cropped_original_dir, f"{frame_idx}.png")
            cv2.imwrite(original_img_path, original_img)


def save_reconstructions(config, run_id, scope_id, G, zs_numpy, device):
    """Saves reconstructed images for specified frames."""
    recon_dir = os.path.join(config['reconstruction_dir'], f"Run{run_id:04d}", f"Scope{scope_id:02d}")
    os.makedirs(recon_dir, exist_ok=True)

    print(f"  - Saving reconstructions for frames: {config['selected_frames']}")

    reset_noise_const(G, config['reconstruction_seed'])

    for frame_idx in config['selected_frames']:
        if frame_idx < len(zs_numpy):
            z = torch.Tensor(zs_numpy[frame_idx]).to(device).unsqueeze(0)
            if not torch.isnan(z).any():
                synth_image = G(z, None, noise_mode='const')
                synth_image = (synth_image + 1) * 127.5
                synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy()[0, :, :, 0]
                recon_img_path = os.path.join(recon_dir, f"{frame_idx}_seed_{config['reconstruction_seed']}.png")
                cv2.imwrite(recon_img_path, synth_image)


# ==============================================================================
# 4. MAIN EXECUTION
# ==============================================================================
def main():
    """Main orchestrator to discover and process all scopes."""
    print("Process starts!")

    jobs = discover_scopes(CONFIG['root_dir'], CONFIG['run_id_regex'], CONFIG['scope_id_regex'])
    if not jobs:
        print("No scopes found to process.")
        return

    models = load_and_patch_model(CONFIG['model_pkl_path'])

    print(f"\nReady to process {len(jobs)} jobs.")

    for job in jobs:
        run_id, scope_id = job['run_id'], job['scope_id']
        print(f"--- Processing Run {run_id:04d}, Scope {scope_id:02d} ---")

        success = process_scope_features_and_videos(job, CONFIG, models)
        if not success:
            print(f"  - ERROR: Processing failed for Run {run_id}/{scope_id}. Skipping to next job.")
            continue


if __name__ == '__main__':
    main()



Process starts!
Scanning for experiments in: /home/xavier/Documents/DAE_project/dataset/WT/images
-> Discovered 366 potential scopes to process.
Loading already patched model from: /home/xavier/Documents/DAE_project/models/network-snapshot-001512-patched.pkl

Ready to process 366 jobs.
--- Processing Run 0195, Scope 35 ---
  - Found existing features.npz. Checking for completeness.
  - Loading and preprocessing image frames...
  - No new frames to process for feature extraction.
  - Saving cropped original frames: [0, 1440]
--- Processing Run 0195, Scope 12 ---
  - Found existing features.npz. Checking for completeness.
  - Loading and preprocessing image frames...
  - No new frames to process for feature extraction.
  - Saving cropped original frames: [0, 1440]
--- Processing Run 0195, Scope 11 ---
  - Found existing features.npz. Checking for completeness.
  - Loading and preprocessing image frames...
  - No new frames to process for feature extraction.
  - Saving cropped original fr