In [1]:
WORKING_DIR = "/home/xavier/Documents/DAE_project"

# Get selected frames of selected strains

In [4]:
import os
import cv2
import numpy as np
import pandas as pd

IMG_DIR = f"{WORKING_DIR}/dataset/Roy_training/images"
OUT_DIR = f"{WORKING_DIR}/images/figure1"
REFERENCE_CSV = f"{WORKING_DIR}/dataset/Roy_training/Caro_3d_9.7.22_2.20_new.xlsx"

# Specify strains and frames of interest
strains = [1622, 8615, 2232, 4398, 4299, 5208, 4408]
frames = [1, 361, 721, 1081, 1441]  # Example: specify your target frame numbers here


def resize_crop(img_path, resize_by=1.0, resolution=512, brightness_norm=True, brightness_mean=107):
    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(f"Image not found: {img_path}")

    if img.dtype != np.uint8:
        img = (img / 256).astype(np.uint8)

    original_shape = img.shape
    resized_shape = (int(original_shape[1] * resize_by), int(original_shape[0] * resize_by))

    if resize_by != 1.0:
        img = cv2.resize(img, resized_shape, interpolation=cv2.INTER_LANCZOS4)

    # Crop the center
    center_crop = img[
        (resized_shape[1] - resolution) // 2:(resized_shape[1] + resolution) // 2,
        (resized_shape[0] - resolution) // 2:(resized_shape[0] + resolution) // 2
    ]

    if brightness_norm:
        current_mean = np.mean(center_crop)
        delta = brightness_mean - current_mean
        center_crop = np.clip(cv2.add(center_crop, delta), 0, 255).astype(np.uint8)

    return center_crop


# Load reference data
reference_df = pd.read_excel(REFERENCE_CSV)

# Build image folder DataFrame
folder_data = []
for folder_name in os.listdir(IMG_DIR):
    if folder_name[-4:].isdigit():
        folder_data.append([int(folder_name[-4:]), folder_name])
folder_df = pd.DataFrame(folder_data, columns=["Run", "Folder"])

# Merge with reference information
image_df = folder_df.merge(reference_df, on="Run", how="left")
image_df["Mutant #"] = image_df["Mutant #"].astype(str).str.extract(r'(\d+)')

# Process each target strain
for strain in strains:
    strain = str(strain)
    subset = image_df[image_df["Mutant #"] == strain]
    if subset.empty:
        print(f"Strain {strain} not found in metadata. Skipping.")
        continue

    output_path = os.path.join(OUT_DIR, str(strain))
    os.makedirs(output_path, exist_ok=True)

    for _, experiment in subset.iterrows():
        folder_path = os.path.join(IMG_DIR, experiment["Folder"])
        if not os.path.isdir(folder_path):
            print(f"Folder not found: {folder_path}. Skipping.")
            continue

        for scope in os.listdir(folder_path):
            scope_path = os.path.join(folder_path, scope)
            if not os.path.isdir(scope_path):
                continue
            print(scope_path)
            all_filenames = os.listdir(scope_path)
            frame_dict = {
                int(fname[-8:-4]): fname for fname in all_filenames if fname[-8:-4].isdigit()
            }

            for frame in frames:
                search_frame = frame
                cnt = 1
                while search_frame not in frame_dict:
                    search_frame = frame + (-1) ** cnt * cnt
                    cnt += 1
                    if cnt > 50:
                        print(f"Could not find frame close to {frame} in {scope_path}")
                        break

                if search_frame in frame_dict:
                    filename = frame_dict[search_frame]
                    img_path = os.path.join(scope_path, filename)
                    try:
                        img = resize_crop(img_path)
                        save_path = os.path.join(output_path, f"{strain}_{scope}_{search_frame}.jpg")
                        cv2.imwrite(save_path, img)
                        print(f"Saved {save_path}")
                    except Exception as e:
                        print(f"Failed to process {img_path}: {e}")


/home/xavier/Documents/DAE_project/dataset/Roy_training/images/CS5_04_1622_1%agar_Run0582/Scope09
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope09_1.jpg
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope09_361.jpg
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope09_721.jpg
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope09_1081.jpg
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope09_1441.jpg
/home/xavier/Documents/DAE_project/dataset/Roy_training/images/CS1_33_A1622_1%agar_Run0271/Scope31
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope31_1.jpg
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope31_361.jpg
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope31_721.jpg
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_Scope31_1081.jpg
Saved /home/xavier/Documents/DAE_project/images/figure1/1622/1622_

# Construct videos

In [6]:
# === New Cell: For Video Generation (with Reconstruction) ===
# Ensure all variables from the first cell (IMG_DIR, REFERENCE_CSV, strains, resize_crop, image_df, etc.) are defined before running this cell.

import os
import cv2
import numpy as np
import torch
import dnnlib
import legacy  # Assuming this is from the StyleGAN/PGAN repo

# --- Environment and Model Setup ---
os.environ['CC'] = "/usr/bin/gcc-9"
os.environ['CXX'] = "/usr/bin/g++-9"

network_pkl = f"{WORKING_DIR}/models/network-snapshot-003024-patched.pkl"
RECONSTRUCTION_SEED = 42  # Fixed seed for reproducible noise

# --- Video Parameters ---
VIDEO_OUT_DIR = f"{WORKING_DIR}/videos/movie1"
RECON_VIDEO_OUT_DIR = f"{WORKING_DIR}/videos/movie1_reconstructed"  # New output dir for recon videos
start_frame = 1  # Starting frame for the video
end_frame = 1441  # Ending frame for the video
fps = 20  # Frames per second (FPS) for the output video


# ---------------------

# --- Helper Function for Model ---
def reset_noise_const(G, seed):
    """Resets the noise constants in the synthesis network with a fixed seed."""
    torch.manual_seed(seed)
    for block in G.synthesis.children():
        for layer in block.children():
            if layer.__class__.__name__ == "SynthesisLayer":
                resolution = layer.resolution
                with torch.no_grad():
                    layer.noise_const.copy_(torch.randn([resolution, resolution]))


# --- Load Model ---
print("Loading reconstruction model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
    with dnnlib.util.open_url(network_pkl) as fp:
        models = legacy.load_network_pkl(fp)
        E = models['E_ema'].to(device)
        G = models['G_ema'].to(device)

    # Set noise constants
    reset_noise_const(G, RECONSTRUCTION_SEED)
    print(f"Model {network_pkl} loaded to {device}.")
except Exception as e:
    print(f"FATAL: Could not load model. Error: {e}")
    # If the model fails, we might not want to continue.
    # For this example, we'll set models to None and skip recon.
    E, G = None, None

# Ensure the video output directories exist
os.makedirs(VIDEO_OUT_DIR, exist_ok=True)
os.makedirs(RECON_VIDEO_OUT_DIR, exist_ok=True)
print(f"Original videos will be saved to: {VIDEO_OUT_DIR}")
print(f"Reconstructed videos will be saved to: {RECON_VIDEO_OUT_DIR}")

# 1. Iterate over each target strain
for strain in strains:
    strain = str(strain)
    subset = image_df[image_df["Mutant #"] == strain]
    if subset.empty:
        print(f"Strain {strain} not found in metadata. Skipping.")
        continue

    print(f"--- Processing Strain: {strain} ---")

    # 2. Iterate over all experiments for this strain
    for _, experiment in subset.iterrows():
        folder_path = os.path.join(IMG_DIR, experiment["Folder"])
        run = experiment["Run"]
        if not os.path.isdir(folder_path):
            print(f"Folder not found: {folder_path}. Skipping.")
            continue

        # 3. Iterate over each scope in the experiment (e.g., s1, s2, BF, PH)
        for scope in os.listdir(folder_path):
            scope_path = os.path.join(folder_path, scope)
            if not os.path.isdir(scope_path):
                continue

            print(f"  Processing Run: {run}, Scope: {scope}")

            # 4. Find all frames within this scope
            try:
                all_filenames = os.listdir(scope_path)
                frame_dict = {
                    int(fname[-8:-4]): fname for fname in all_filenames if fname[-8:-4].isdigit()
                }
            except Exception as e:
                print(f"    Could not read files in {scope_path}: {e}")
                continue

            if not frame_dict:
                print(f"    No valid frame files found in {scope_path}.")
                continue

            video_writer = None
            recon_video_writer = None  # New writer for the reconstruction video
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # MP4 codec

            # Paths for original video
            video_filename = f"{strain}_{run}_{scope}_{start_frame}_{end_frame}.mp4"
            video_save_path = os.path.join(VIDEO_OUT_DIR, video_filename)
            is_color_video = None  # Is the video color or grayscale?

            # Paths for reconstruction video
            recon_video_filename = f"{strain}_{run}_{scope}_{start_frame}_{end_frame}_recon.mp4"
            recon_video_save_path = os.path.join(RECON_VIDEO_OUT_DIR, recon_video_filename)

            resolution = (512, 512)  # Resolution from resize_crop

            # 5. Iterate over each frame in the specified range
            for frame_num in range(start_frame, end_frame + 1):
                # Use the same logic as your original script to find the closest frame
                search_frame = frame_num
                cnt = 1
                while search_frame not in frame_dict:
                    search_frame = frame_num + (-1) ** cnt * cnt
                    cnt += 1
                    if cnt > 50:  # Give up if no frame is found within a 50-frame range
                        break

                if search_frame not in frame_dict:
                    # print(f"    Frame close to {frame_num} not found.") # Optional: Uncomment for detailed logging
                    continue

                # 6. Load and process the image
                filename = frame_dict[search_frame]
                img_path = os.path.join(scope_path, filename)

                try:
                    img = resize_crop(img_path)
                except Exception as e:
                    print(f"    Failed to process frame {img_path}: {e}")
                    continue

                # 7. Initialize Original VideoWriter (if needed)
                if video_writer is None:
                    img_shape = img.shape
                    if len(img_shape) == 3:
                        is_color_video = True
                    elif len(img_shape) == 2:
                        is_color_video = False
                    else:
                        print(f"    Abnormal image shape {img_shape} for {img_path}. Skipping this scope.")
                        break  # Stop processing this scope

                    video_writer = cv2.VideoWriter(video_save_path, fourcc, fps, resolution, is_color_video)

                    if not video_writer.isOpened():
                        print(f"    Could not open VideoWriter for saving to {video_save_path}")
                        break  # Stop processing this scope
                    print(f"    Creating video: {video_filename} (isColor={is_color_video})")

                # 8. Write to Original Video (with color correction if needed)
                current_is_color = (len(img.shape) == 3)
                if current_is_color != is_color_video:
                    # print(f"    Frame {search_frame} ({filename}) color type mismatch.Attempting conversion.")
                    try:
                        if is_color_video:  # Video needs color, but frame is grayscale
                            img_for_video = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
                        else:  # Video needs grayscale, but frame is color
                            img_for_video = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                    except Exception as e:
                        print(f"    Conversion failed: {e}. Skipping this frame.")
                        continue
                else:
                    img_for_video = img  # No conversion needed

                video_writer.write(img_for_video)

                # 9. --- Start Reconstruction ---
                if G is not None and E is not None:
                    try:
                        # 9a. Prepare model input (must be grayscale)
                        if len(img.shape) == 3:  # If original is color
                            model_input_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                        else:
                            model_input_img = img.copy()  # Already grayscale

                        # 9b. Run inference (batch size 1)
                        img_batch_numpy = np.array([model_input_img])
                        img_batch = torch.from_numpy(img_batch_numpy).to(device).to(torch.float32) / 127.5 - 1
                        img_batch = img_batch.unsqueeze(1)  # Shape: (1, 1, 512, 512)

                        with torch.no_grad():
                            z, logvar = E.mu_var(img_batch, None)
                            ws = G.mapping(z, None)
                            synth_image_tensor = G.synthesis(ws, noise_mode='const')

                        # 9c. Post-process the output image
                        synth_image_tensor = (synth_image_tensor + 1) * 127.5
                        # Output processing snippet implies Grayscale output [0,:,:,0]
                        synth_image = \
                            synth_image_tensor.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy()[
                                0, :, :, 0]

                        # 9d. Initialize Reconstruction VideoWriter (if needed)
                        if recon_video_writer is None:
                            is_color_recon = False  # Model output is grayscale
                            recon_video_writer = cv2.VideoWriter(recon_video_save_path, fourcc, fps, resolution,
                                                                 is_color_recon)
                            if not recon_video_writer.isOpened():
                                print(f"    Could not open Recon VideoWriter for {recon_video_save_path}")
                            else:
                                print(f"    Creating recon video: {recon_video_filename}")

                        # 9e. Write to Reconstruction Video
                        if recon_video_writer is not None and recon_video_writer.isOpened():
                            recon_video_writer.write(synth_image)

                    except Exception as e:
                        print(f"    Failed during reconstruction for frame {search_frame}: {e}")
                        # If recon fails, stop writing to this recon video to avoid corruption
                        if recon_video_writer is not None:
                            recon_video_writer.release()
                            recon_video_writer = None  # Prevent further writes
                            print(f"    Aborting reconstruction video: {recon_video_filename}")
                # --- End Reconstruction ---

            # 10. After the loop, release the video_writers
            if video_writer is not None:
                video_writer.release()
                print(f"    Finished video: {video_filename}")
            else:
                print(f"    Could not create video for {scope} (perhaps no frames were found).")

            if recon_video_writer is not None:
                recon_video_writer.release()
                print(f"    Finished recon video: {recon_video_filename}")
            elif G is not None and E is not None:
                print(f"    Could not create recon video for {scope} (perhaps no frames found or error occurred).")

print("--- Video generation process complete ---")



Loading reconstruction model...
Model /home/xavier/Documents/DAE_project/models/network-snapshot-003024-patched.pkl loaded to cuda.
Original videos will be saved to: /home/xavier/Documents/DAE_project/videos/movie1
Reconstructed videos will be saved to: /home/xavier/Documents/DAE_project/videos/movie1_reconstructed
--- Processing Strain: 1622 ---
  Processing Run: 582, Scope: Scope09
    Creating video: 1622_582_Scope09_1_1441.mp4 (isColor=False)
Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "upfirdn2d_plugin"... Done.
    Creating recon video: 1622_582_Scope09_1_1441_recon.mp4
    Finished video: 1622_582_Scope09_1_1441.mp4
    Finished recon video: 1622_582_Scope09_1_1441_recon.mp4
  Processing Run: 271, Scope: Scope31
    Creating video: 1622_271_Scope31_1_1441.mp4 (isColor=False)
    Creating recon video: 1622_271_Scope31_1_1441_recon.mp4
    Finished video: 1622_271_Scope31_1_1441.mp4
    Finished recon video: 1622_271_Scope31_1_1441_recon.mp4
  Pr