In [None]:
import os
import shutil
import glob

%cd /kaggle/working
if os.path.exists("benchmarking-generative-models-for-domain-adaptation"):
    shutil.rmtree("benchmarking-generative-models-for-domain-adaptation")

!git clone --recursive https://github.com/AntonioRosano/benchmarking-generative-models-for-domain-adaptation.git
repo_path = "/kaggle/working/benchmarking-generative-models-for-domain-adaptation/models/munit"
%cd {repo_path}

!pip install -q tensorboardX

In [None]:
files_to_fix = ['train.py', 'test.py', 'utils.py', 'trainer.py', 'networks.py']

for fname in files_to_fix:
    if os.path.exists(fname):
        with open(fname, 'r') as f:
            content = f.read()
        
        new_content = content
        new_content = new_content.replace('.data[0]', '.item()')
        new_content = new_content.replace('yaml.load(f)', 'yaml.load(f, Loader=yaml.FullLoader)')
        new_content = new_content.replace('yaml.load(stream)', 'yaml.load(stream, Loader=yaml.FullLoader)')
        new_content = new_content.replace('from torch.utils.serialization import load_lua', '# from torch.utils.serialization import load_lua')
        
        if content != new_content:
            with open(fname, 'w') as f:
                f.write(new_content)
            print(f" -> Patch applicata a {fname}")

In [None]:
config_file = "configs/ego_kaggle_train.yaml"

munit_config = f"""
# MUNIT Config EGO-CH Kaggle
# logger options
image_save_iter: 1000
image_display_iter: 1000
display_size: 4
snapshot_save_iter: 5000
log_iter: 100

# optimization options
max_iter: 75000               
batch_size: 4
weight_decay: 0.0001
beta1: 0.5                    
beta2: 0.999
init: kaiming
lr: 0.0001
lr_policy: step
step_size: 40000              
gamma: 0.5
gan_w: 1
recon_x_w: 10
recon_s_w: 1
recon_c_w: 1
recon_x_cyc_w: 0
vgg_w: 0

# model options
gen:
  dim: 64
  mlp_dim: 256
  style_dim: 8
  activ: relu
  n_downsample: 2
  n_res: 4
  pad_type: reflect
dis:
  dim: 64
  norm: none
  activ: lrelu
  n_layer: 4
  gan_type: lsgan
  num_scales: 3
  pad_type: reflect

# data options
input_dim_a: 3
input_dim_b: 3
num_workers: 4
new_size: 256                 
crop_image_height: 256        
crop_image_width: 256         
data_root: ""
"""

os.makedirs("configs", exist_ok=True)
with open(config_file, "w") as f:
    f.write(munit_config)

print(f"Configurazione salvata in: {config_file}")

In [None]:
import os
import glob
import shutil

CHECKPOINT_PATH = "/kaggle/input/datasets/marcogionfriddo/gen-00030000/gen_00030000.pt"
DIR_TRAIN = "/kaggle/input/datasets/marcogionfriddo/ego-ch-obj-seg/EGO-CH-OBJ-SEG/EGO-CH-OBJ-SEG/real/train/frames"

TEMP_OUT = "/kaggle/working/temp_munit"
FINAL_TRAIN_OUT = "/kaggle/working/synthetic_train"

os.makedirs(TEMP_OUT, exist_ok=True)
os.makedirs(FINAL_TRAIN_OUT, exist_ok=True)

immagini = glob.glob(os.path.join(DIR_TRAIN, "**/*.jpg"), recursive=True)
immagini.sort()

print(f"\n--- Avvio conversione TRAIN SET ---")
print(f"Trovate {len(immagini)} immagini JPG...")

#immagini = immagini[:5]

for idx, img_path in enumerate(immagini):
    nome_file = os.path.basename(img_path)
    nome_senza_estensione = os.path.splitext(nome_file)[0]
    
    if idx % 50 == 0 or idx == len(immagini)-1:
        print(f"[{idx+1}/{len(immagini)}] Traduzione di: {nome_file}...")
    
    for f in glob.glob(os.path.join(TEMP_OUT, "*")):
        os.remove(f)

    cmd = f'python test.py --config configs/ego_kaggle_train.yaml --input "{img_path}" --output_folder "{TEMP_OUT}" --checkpoint "{CHECKPOINT_PATH}" --a2b 1 --num_style 1'
    os.system(cmd)
    
    file_generati = glob.glob(os.path.join(TEMP_OUT, "output*.jpg"))
    final_file = os.path.join(FINAL_TRAIN_OUT, f"{nome_senza_estensione}.jpg")
    
    if file_generati:
        shutil.move(file_generati[0], final_file)
    else:
        print(f"ERRORE: Nessun output trovato per {nome_file}")

print("\nTraduzione del dataset completata con successo.")

In [None]:
print("Compressione del Train Set in corso...")
shutil.make_archive("/kaggle/working/synthetic_train_results", 'zip', FINAL_TRAIN_OUT)
print("ZIP pronto.")