In [None]:
%cd /home/ubuntu/Qwen-Image-Edit-Angles

In [None]:
import huggingface_hub 
from qwenimage.datamodels import QwenConfig
from qwenimage.foundation import QwenImageFoundationSaveInterm
from datasets import concatenate_datasets, load_dataset, interleave_datasets

In [None]:
repo_tree = huggingface_hub.list_repo_tree(
    "WeiChow/CrispEdit-2M",
    "data",
    repo_type="dataset",
)

all_paths = []
for i in repo_tree:
    all_paths.append(i.path)

parquet_prefixes = set()
for path in all_paths:
    if path.endswith('.parquet'):
        filename = path.split('/')[-1]
        if '_' in filename:
            prefix = filename.split('_')[0]
            parquet_prefixes.add(prefix)

print("Found parquet prefixes:", sorted(parquet_prefixes))


In [None]:
total_per = 10

EDIT_TYPES = [
    "color",
    "style",
    "replace",
    "remove",
    "add",
    "motion change",
    "background change",
]

In [None]:


all_edit_datasets = []
for edit_type in EDIT_TYPES:
    to_concat = []
    for ds_n in range(total_per):
        ds = load_dataset("parquet", data_files=f"/data/CrispEdit/{edit_type}_{ds_n:05d}.parquet", split="train")
        to_concat.append(ds)
    edit_type_concat = concatenate_datasets(to_concat)
    all_edit_datasets.append(edit_type_concat)

# consistent ordering for indexing, also allow extension
join_ds = interleave_datasets(all_edit_datasets)

In [None]:
from pathlib import Path


save_base_dir = Path("/data/regression_output")
save_base_dir.mkdir(exist_ok=True, parents=True)

In [None]:
foundation = QwenImageFoundationSaveInterm(QwenConfig())

In [None]:
import torch


for idx, input_data in enumerate(join_ds):

    output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
        image=[input_data["input_img"]],
        prompt=input_data["instruction"],
    ))

    torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")


In [None]:
output_dict = torch.load(save_base_dir/f"{idx:06d}.pt", weights_only=False)

In [None]:
output_dict.keys()

In [None]:
test_ind = 10

latents_i_start = output_dict[f"latents_{test_ind}_start"]
t_i = output_dict[f"t_{test_ind}"]
v_i = output_dict[f"noise_pred_{test_ind}"]

proj_out_i = latents_i_start - t_i * v_i

In [None]:
proj_out_i_1d = proj_out_i
proj_out_i_2d = foundation.unpack_latents(proj_out_i_1d, output_dict["height"] // 16, output_dict["width"] // 16, )
proj_out_i_pil = foundation.latents_to_pil(proj_out_i_2d)
proj_out_i_pil[0]

In [None]:
out_1d = output_dict["image_latents"]
out_2d = foundation.unpack_latents(out_1d, output_dict["height"] // 16, output_dict["width"] // 16, )
out_pil = foundation.latents_to_pil(out_2d)
# out_pil[0]
# join_ds[idx]["input_img"]
# join_ds[idx]["instruction"]

In [None]:
proj_out_i_pil[0].size