In [1]:
# -------------------------------------------------------------
# üì¶ INSTALL REQUIRED DEPENDENCIES
# -------------------------------------------------------------
!pip install ninja imageio tqdm einops albumentations
!pip install numpy==1.26.4 --force-reinstall --no-cache-dir

Collecting numpy (from imageio)
  Downloading numpy-2.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m62.0/62.0 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
INFO: pip is looking at multiple versions of mkl-fft to determine which version is compatible with other requirements. This could take a while.
Collecting mkl_fft (from numpy->imageio)
  Downloading mkl_fft-2.1.1-0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (7.3 kB)
  Downloading mkl_fft-2.0.0-22-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (7.1 kB)
INFO: pip is looking at multiple versions of mkl-random to determine which version is compatible with other requirements. This could take a while.
Collecting mkl_random (from numpy->imageio)
  Downloading mkl_random-1.3.0-0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
  Downloading mkl_

In [2]:
# -------------------------------------------------------------
# üì• CLONE NVIDIA EDM2 REPOSITORY
# -------------------------------------------------------------
!git clone https://github.com/NVlabs/edm2.git
!ls edm2

Cloning into 'edm2'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 60 (delta 13), reused 10 (delta 10), pack-reused 33 (from 1)[K
Receiving objects: 100% (60/60), 1.27 MiB | 10.77 MiB/s, done.
Resolving deltas: 100% (24/24), done.
calculate_metrics.py  Dockerfile	  README.md		train_edm2.py
count_flops.py	      docs		  reconstruct_phema.py	training
dataset_tool.py       generate_images.py  torch_utils
dnnlib		      LICENSE.txt	  toy_example.py


In [3]:
# -------------------------------------------------------------
# üìÅ SET PATHS TO YOUR CELEBA64 DATASET
# -------------------------------------------------------------
DATA_ROOT = "/kaggle/input/celeva-64x64-dataset/celeba64"
IMG_DIR   = f"{DATA_ROOT}/train"
ATTR_CSV  = f"{DATA_ROOT}/list_attr_celeba.csv"

# Show sample files
!ls $IMG_DIR | head


000001.jpg
000002.jpg
000003.jpg
000004.jpg
000005.jpg
000006.jpg
000007.jpg
000008.jpg
000009.jpg
000010.jpg
ls: write error: Broken pipe


In [4]:
# -------------------------------------------------------------
# üß± BUILD ATTRIBUTE LABELS FROM list_attr_celeba.csv
# -------------------------------------------------------------
import os
import numpy as np
import pandas as pd

# Load attribute CSV
attr_df = pd.read_csv(ATTR_CSV)

# Extract image names and attributes
image_ids_attr = attr_df.iloc[:, 0].values        # filenames
attr_values = attr_df.iloc[:, 1:].values.astype(np.float32)  # 40 columns

# Convert CelebA -1/+1 ‚Üí 0/1
attr_values = (attr_values + 1) / 2.0

# Load image filenames from the folder
image_files = sorted([f for f in os.listdir(IMG_DIR) if f.lower().endswith(".jpg")])

# Map: filename ‚Üí row index in CSV
idx_map = {img_id: i for i, img_id in enumerate(image_ids_attr)}

# Collect aligned labels
labels_list = [attr_values[idx_map[fname]] for fname in image_files]

labels = np.stack(labels_list, axis=0)

labels.shape


(162770, 40)

In [5]:
# -------------------------------------------------------------
# üíæ SAVE LABEL FILES FOR DATASET
# -------------------------------------------------------------
SAVE_DIR = "/kaggle/working/celeba64_labels"

os.makedirs(SAVE_DIR, exist_ok=True)

np.save(f"{SAVE_DIR}/train_labels.npy", labels)
np.save(f"{SAVE_DIR}/train_files.npy", np.array(image_files))

print("Saved:", os.listdir(SAVE_DIR))


Saved: ['train_labels.npy', 'train_files.npy']


In [6]:
%%writefile edm2/training/dataset.py
import os
import numpy as np
from PIL import Image
import torch

class ImageFolderDataset(torch.utils.data.Dataset):
    """
    Minimal EDM2-compatible dataset:
      - loads images from `path`
      - if `use_labels=True`, loads CelebA attributes from
        `<parent_of_path>/list_attr_celeba.csv`
      - exposes:
          self.num_channels
          self.has_labels
          self.label_dim
    """

    def __init__(self, path, resolution=None, use_labels=True, **kwargs):
        super().__init__()
        self.path = path
        self.use_labels = use_labels

        # -----------------------------------------------------
        # 1) Collect image filenames
        # -----------------------------------------------------
        self._image_fnames = sorted(
            f for f in os.listdir(path)
            if f.lower().endswith((".jpg", ".png"))
        )
        if len(self._image_fnames) == 0:
            raise RuntimeError(f"No images found in {path}")

        # -----------------------------------------------------
        # 2) Basic image info (num_channels, resolution)
        # -----------------------------------------------------
        first_img_path = os.path.join(self.path, self._image_fnames[0])
        first_img = Image.open(first_img_path).convert("RGB")
        w, h = first_img.size
        self.num_channels = 3
        # EDM2 only uses img_resolution as a single int, and 64x64 is square.
        self.resolution = h  # or w; they‚Äôre equal for CelebA-64

        # -----------------------------------------------------
        # 3) Load CelebA attribute labels if available
        # -----------------------------------------------------
        self.labels = None
        self.label_dim = 0
        self.has_labels = False

        if self.use_labels:
            base = os.path.dirname(self.path)  # parent of `train`
            attr_csv = os.path.join(base, "list_attr_celeba.csv")

            if os.path.exists(attr_csv):
                print(f">> Loading CelebA attributes from: {attr_csv}")
                import pandas as pd

                attr_df = pd.read_csv(attr_csv)

                # Column 0: image_id, Columns 1‚Äì40: attributes
                image_ids_attr = attr_df.iloc[:, 0].values
                attr_values = attr_df.iloc[:, 1:].values.astype(np.float32)

                # Convert CelebA -1/+1 ‚Üí 0/1
                attr_values = (attr_values + 1.0) / 2.0

                # Map filename ‚Üí row index
                idx_map = {img_id: i for i, img_id in enumerate(image_ids_attr)}

                labels_list = []
                missing = 0
                for fname in self._image_fnames:
                    if fname in idx_map:
                        labels_list.append(attr_values[idx_map[fname]])
                    else:
                        # Fallback: all zeros if somehow missing
                        missing += 1
                        labels_list.append(np.zeros(attr_values.shape[1], dtype=np.float32))

                if missing > 0:
                    print(f">> Warning: {missing} filenames not found in CSV; filled with zeros.")

                labels = np.stack(labels_list, axis=0)  # [N, 40]
                self.labels = labels.astype(np.float32)
                self.label_dim = self.labels.shape[1]
                self.has_labels = True

                print(f">> Loaded labels with shape: {self.labels.shape}")
            else:
                print(f">> Attribute CSV not found at {attr_csv}. Running unconditional.")
        else:
            print(">> use_labels=False ‚Üí Unconditional mode.")

    def __len__(self):
        return len(self._image_fnames)

    def __getitem__(self, idx):
        fname = self._image_fnames[idx]
        path = os.path.join(self.path, fname)

        # -----------------------------------------------------
        # 4) Return raw pixels as uint8 [C,H,W]
        #    EDM2 encoders expect uint8 in [0,255]
        # -----------------------------------------------------
        img = Image.open(path).convert("RGB")
        img = np.asarray(img, dtype=np.uint8)
        img = torch.from_numpy(img).permute(2, 0, 1)  # [H,W,C] ‚Üí [C,H,W], uint8

        # -----------------------------------------------------
        # 5) Return label vector
        # -----------------------------------------------------
        if self.labels is not None:
            label = torch.tensor(self.labels[idx], dtype=torch.float32)
        else:
            # Empty tensor if no labels
            label = torch.zeros(0, dtype=torch.float32)

        return img, label


Overwriting edm2/training/dataset.py


In [7]:
!torchrun --standalone --nproc_per_node=2 edm2/train_edm2.py \
    --outdir=/kaggle/working/training-runs/celeba64-cond-karras-rho \
    --data=/kaggle/input/celeva-64x64-dataset/celeba64/train \
    --cond=True \
    --preset=edm2-img64-xs \
    --batch=64 \
    --batch-gpu=32 \
    --duration=2Mi \
    --status=16Ki \
    --snapshot=512Ki \
    --checkpoint=0 \
    --seed=0

W1125 19:48:38.474000 73 torch/distributed/run.py:792] 
W1125 19:48:38.474000 73 torch/distributed/run.py:792] *****************************************
W1125 19:48:38.474000 73 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1125 19:48:38.474000 73 torch/distributed/run.py:792] *****************************************
[W1125 19:48:38.906150515 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1125 19:48:38.906796740 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1125 19:48:40.591616721 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1125 19:48:40.592487457 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1125 19:48:40.596328117 socket.