In [61]:
import torch
import numpy as np

CHECKPOINT_PATH = "latest_checkpoint.pth"

q = 12
scalse_factor = 2 ** q

try:
    checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
    gen_state = checkpoint["gen"] if "gen" in checkpoint else checkpoint
    # print("=> File berhasil dimuat.")
except FileNotFoundError:
    print("=> ERROR: File tidak ditemukan.")
    exit()

def get_scaled_weight(tensor_weight):
    w = tensor_weight
    # Shape Conv: (Out, In, K, K)
    # Fan In = In * K * K
    fan_in = w.shape[1] * w.shape[2] * w.shape[3]
    scale = np.sqrt(2 / fan_in)
    return w * scale

# Wadah untuk weight NumPy
weights_numpy = {}

# --- 1. EKSTRAKSI BLOK AWAL (4x4) ---
# print("\n--- Memuat Blok Awal (4x4) ---")

# Ambil Konstanta
weights_numpy['const_input'] = ((gen_state["starting_constant"].numpy())* scalse_factor).astype(np.int16)
flat_weight = weights_numpy['const_input'].flatten()
flat_weight.tofile('weights_const_input.bin')
# print(f"[OK] Constant Input: {weights_numpy['const_input'].shape}")

# Ambil Konvolusi (Perhatikan key-nya!)
weights_numpy['conv_4x4'] = (get_scaled_weight(gen_state["initial_conv.weight"].numpy())* scalse_factor).astype(np.int16)
weights_numpy['bias_4x4'] = (gen_state["initial_conv.bias"].numpy() * scalse_factor).astype(np.int16)


flat_weight = weights_numpy['conv_4x4'].flatten()
flat_weight.tofile('weights_conv_4x4.bin')

flat_weight = weights_numpy['bias_4x4'].flatten()
flat_weight.tofile('weights_bias_4x4.bin')
# print(f"[OK] Initial Conv: {weights_numpy['conv_4x4'].shape}")

# Ambil Noise
weights_numpy['noise1_4x4'] = (gen_state["initial_noise1.weight"].numpy() * scalse_factor).astype(np.int16)
weights_numpy['noise2_4x4'] = (gen_state["initial_noise2.weight"].numpy() * scalse_factor).astype(np.int16)

flat_weight = weights_numpy['noise1_4x4'].flatten()
flat_weight.tofile('weights_noise1_4x4.bin')
flat_weight = weights_numpy['noise2_4x4'].flatten()
flat_weight.tofile('weights_noise2_4x4.bin')

# --- 2. EKSTRAKSI BLOK PROGRESIF (8x8 ke atas) ---
# Gunakan logika loop seperti sebelumnya, tapi mulai dari index 0
num_blocks = 0
while f"prog_blocks.{num_blocks}.conv1.conv.weight" in gen_state:
    num_blocks += 1

# print(f"\n--- Memuat {num_blocks} Blok Progresif ---")


for i in range(num_blocks):
    res = 4 * (2 ** (i + 1)) # 8, 16, 32...
    
    # Conv 1
    k1 = (gen_state[f"prog_blocks.{i}.conv1.conv.weight"].numpy())
    weights_numpy[f'conv1_{res}'] = (get_scaled_weight(k1)* scalse_factor).astype(np.int16)
    fixed_weight = (weights_numpy[f'conv1_{res}'] )
    flat_weight = fixed_weight.flatten()
    flat_weight.tofile(f'weights_conv1_{res}.bin')
    
    # Conv 2
    k2 = gen_state[f"prog_blocks.{i}.conv2.conv.weight"].numpy()
    weights_numpy[f'conv2_{res}'] = (get_scaled_weight(k2)* scalse_factor).astype(np.int16)

    fixed_weight = (weights_numpy[f'conv2_{res}'] )
    flat_weight = fixed_weight.flatten()
    flat_weight.tofile(f'weights_conv2_{res}.bin')
    
    # Noise
    n1 = gen_state[f"prog_blocks.{i}.inject_noise1.weight"].numpy()
    n2 = gen_state[f"prog_blocks.{i}.inject_noise2.weight"].numpy()
    weights_numpy[f'noise1_{res}'] = (n1 * scalse_factor).astype(np.int16)
    weights_numpy[f'noise2_{res}'] = (n2 * scalse_factor).astype(np.int16)

    fixed_weight = (weights_numpy[f'noise1_{res}'])
    flat_weight = fixed_weight.flatten()
    flat_weight.tofile(f'weights_noise1_{res}.bin')

    fixed_weight = (weights_numpy[f'noise2_{res}'])
    flat_weight = fixed_weight.flatten()
    flat_weight.tofile(f'weights_noise2_{res}.bin')

    # bias
    b1 = gen_state[f"prog_blocks.{i}.conv1.bias"].numpy()
    b2 = gen_state[f"prog_blocks.{i}.conv2.bias"].numpy()
    weights_numpy[f'bias1_{res}'] = (b1 * scalse_factor).astype(np.int16)
    weights_numpy[f'bias2_{res}'] = (b2 * scalse_factor).astype(np.int16)

    fixed_weight = (weights_numpy[f'bias1_{res}'])
    flat_weight = fixed_weight.flatten()
    flat_weight.tofile(f'weights_bias1_{res}.bin')

    fixed_weight = (weights_numpy[f'bias2_{res}'])
    flat_weight = fixed_weight.flatten()
    flat_weight.tofile(f'weights_bias2_{res}.bin')
    
    print(f"Block {res}x{res} loaded. Conv1: {k1.shape}, Conv2: {k2.shape}, {n1.dtype}")



Block 8x8 loaded. Conv1: (256, 256, 3, 3), Conv2: (256, 256, 3, 3), float32
Block 16x16 loaded. Conv1: (256, 256, 3, 3), Conv2: (256, 256, 3, 3), float32
Block 32x32 loaded. Conv1: (256, 256, 3, 3), Conv2: (256, 256, 3, 3), float32
Block 64x64 loaded. Conv1: (128, 256, 3, 3), Conv2: (128, 128, 3, 3), float32
Block 128x128 loaded. Conv1: (64, 128, 3, 3), Conv2: (64, 64, 3, 3), float32
Block 256x256 loaded. Conv1: (32, 64, 3, 3), Conv2: (32, 32, 3, 3), float32
Block 512x512 loaded. Conv1: (16, 32, 3, 3), Conv2: (16, 16, 3, 3), float32
Block 1024x1024 loaded. Conv1: (8, 16, 3, 3), Conv2: (8, 8, 3, 3), float32


[[[[ -58  143   48]
   [ 199 -210  129]
   [ -97  -16   52]]

  [[ -78  -58   21]
   [-298   74 -104]
   [  23  -62 -151]]

  [[  71  -22  -49]
   [-278   63  -21]
   [-216  144   48]]

  ...

  [[  84  191   12]
   [  58  -31    0]
   [  61  118 -126]]

  [[ -58   -7  -72]
   [ -73    5   91]
   [ 117  -43  123]]

  [[ -36   74 -237]
   [   8   58 -266]
   [ 122 -203   45]]]


 [[[ -31 -152  -92]
   [ -66   94 -106]
   [  70  -28   21]]

  [[ -31   -1   50]
   [ 279  169  -33]
   [ -31  -53  169]]

  [[  40  202  -87]
   [ -58  164   45]
   [-214  157  174]]

  ...

  [[  65 -128 -112]
   [ -52   35 -298]
   [  39  -23 -137]]

  [[-139  -79  -47]
   [  79  -43 -137]
   [   1 -187 -155]]

  [[-218   10  -10]
   [  26 -124   55]
   [ 336  111  -35]]]


 [[[  85   94  230]
   [-189  -35   83]
   [  11 -184   35]]

  [[  84  -51  -63]
   [-138 -108  -96]
   [ 214 -106   55]]

  [[-105 -116   72]
   [ -75  139  -92]
   [  62  -25   86]]

  ...

  [[ 308  -22  136]
   [ 130  -43    7]
   [ 