In [None]:
from src.modules import *
import pickle
import os

In [None]:
device = 'cuda:3'

In [None]:
diffAE = UNet_conditional(img_width=128, img_height=64, feat_num=3, device=device).to(device)
with open('models/edm-imagenet-64x64-cond-adm.pkl', 'rb') as f: # model downloaded from https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/
    edm = pickle.load(f)["ema"].model

In [None]:
def generate_indices(length, max_value):
    unique_values = torch.arange(max_value)
    num_repetitions = length - max_value
    repetitions = torch.randint(high=max_value, size=(num_repetitions,))
    indices = torch.cat((unique_values, repetitions))
    
    return indices

def load_weights(src_layer, dst_layer):
    indices = []
    if len(src_layer.weight.shape) != len(dst_layer.weight.shape):
        print("Source layer has ", len(src_layer.shape), " dimensions, but destination layer has ", len(dst_layer.shape), " dimensions.")
        return
    for idx in range(len(dst_layer.weight.shape)):
        if dst_layer.weight.shape[idx] > src_layer.weight.shape[idx]:
            indices.append(generate_indices(dst_layer.weight.shape[idx], src_layer.weight.shape[idx]))
        elif dst_layer.weight.shape[idx] < src_layer.weight.shape[idx]:
            # print(src_layer.weight.shape[idx])
            indices.append(torch.randperm(src_layer.weight.shape[idx])[:dst_layer.weight.shape[idx]])
        else:
            indices.append(torch.arange(0, src_layer.weight.shape[idx], dtype=int))
    extracted_weights = torch.nn.Parameter(src_layer.weight[indices[0]][:,indices[1]][:,:,indices[2]][:,:,:,indices[3]])
    dst_layer.weight = extracted_weights
    return dst_layer

In [None]:
# for 2 blocks
layers = [  diffAE.inc.double_conv[0],                            edm.enc["64x64_conv"],
            diffAE.inc.double_conv[3],                            edm.enc["64x64_block0"].conv0,
            diffAE.down1.maxpool_conv[1].double_conv[0],          edm.enc["64x64_block0"].conv1,
            diffAE.down1.maxpool_conv[1].double_conv[3],          edm.enc["64x64_block1"].conv0,
            diffAE.down1.maxpool_conv[2].double_conv[0],          edm.enc["64x64_block1"].conv1,
            diffAE.down1.maxpool_conv[2].double_conv[3],          edm.enc["64x64_block2"].conv0,
            diffAE.down2.maxpool_conv[1].double_conv[0],          edm.enc["64x64_block2"].conv1,
            diffAE.down2.maxpool_conv[1].double_conv[3],          edm.enc["32x32_down"].conv0,
            diffAE.down2.maxpool_conv[2].double_conv[0],          edm.enc["32x32_down"].conv1,  # dim 0 mismatch
            diffAE.down2.maxpool_conv[2].double_conv[3],          edm.enc["32x32_block0"].conv0, # dim 1 mismatch
            diffAE.bot1.double_conv[0],                           edm.dec["8x8_in0"].conv0,
            diffAE.bot1.double_conv[3],                           edm.dec["8x8_in0"].conv1,
            diffAE.bot2.double_conv[0],                           edm.dec["8x8_in1"].conv0,
            diffAE.bot2.double_conv[3],                           edm.dec["8x8_in1"].conv1, 
            diffAE.bot3.double_conv[0],                           edm.dec["8x8_block0"].conv0,
            diffAE.bot3.double_conv[3],                           edm.dec["8x8_block0"].conv1,
            diffAE.up1.conv[0].double_conv[0],                    edm.dec["64x64_block0"].conv0,
            diffAE.up1.conv[0].double_conv[3],                    edm.dec["64x64_block0"].conv1,
            diffAE.up1.conv[1].double_conv[0],                    edm.dec["64x64_block1"].conv0,
            diffAE.up1.conv[1].double_conv[3],                    edm.dec["64x64_block1"].conv1,
            diffAE.up2.conv[0].double_conv[0],                    edm.dec["64x64_block2"].conv0,
            diffAE.up2.conv[0].double_conv[3],                    edm.dec["64x64_block2"].conv1,
            diffAE.up2.conv[1].double_conv[0],                    edm.dec["64x64_block3"].conv0,
            diffAE.up2.conv[1].double_conv[3],                    edm.dec["64x64_block3"].conv1,
            diffAE.outc,                                          edm.out_conv,
]
src_layers = layers[1::2]
dst_layers = layers[::2]

In [None]:
# for 3 blocks
layers = [  diffAE.inc.double_conv[0],                            edm.enc["64x64_conv"],
            diffAE.inc.double_conv[3],                            edm.enc["64x64_block0"].conv0,
            diffAE.down1.maxpool_conv[1].double_conv[0],          edm.enc["64x64_block0"].conv1,
            diffAE.down1.maxpool_conv[1].double_conv[3],          edm.enc["64x64_block1"].conv0,
            diffAE.down1.maxpool_conv[2].double_conv[0],          edm.enc["64x64_block1"].conv1,
            diffAE.down1.maxpool_conv[2].double_conv[3],          edm.enc["64x64_block2"].conv0,
            diffAE.down2.maxpool_conv[1].double_conv[0],          edm.enc["64x64_block2"].conv1,
            diffAE.down2.maxpool_conv[1].double_conv[3],          edm.enc["32x32_down"].conv0,
            diffAE.down2.maxpool_conv[2].double_conv[0],          edm.enc["32x32_down"].conv1,  # dim 0 mismatch
            diffAE.down2.maxpool_conv[2].double_conv[3],          edm.enc["32x32_block0"].conv0, # dim 1 mismatch
            diffAE.down3.maxpool_conv[1].double_conv[0],          edm.enc["32x32_block0"].conv1,
            diffAE.down3.maxpool_conv[1].double_conv[3],          edm.enc["32x32_block1"].conv0,
            diffAE.down3.maxpool_conv[2].double_conv[0],          edm.enc["32x32_block1"].conv1,
            diffAE.down3.maxpool_conv[2].double_conv[3],          edm.enc["32x32_block2"].conv0,
            diffAE.bot1.double_conv[0],                           edm.dec["8x8_in0"].conv0,
            diffAE.bot1.double_conv[3],                           edm.dec["8x8_in0"].conv1,
            diffAE.bot2.double_conv[0],                           edm.dec["8x8_in1"].conv0,
            diffAE.bot2.double_conv[3],                           edm.dec["8x8_in1"].conv1, 
            diffAE.bot3.double_conv[0],                           edm.dec["8x8_block0"].conv0,
            diffAE.bot3.double_conv[3],                           edm.dec["8x8_block0"].conv1,
            diffAE.up1.conv[0].double_conv[0],                    edm.dec["32x32_block1"].conv0,
            diffAE.up1.conv[0].double_conv[3],                    edm.dec["32x32_block1"].conv1,
            diffAE.up1.conv[1].double_conv[0],                    edm.dec["32x32_block2"].conv0,
            diffAE.up1.conv[1].double_conv[3],                    edm.dec["32x32_block2"].conv1,
            diffAE.up2.conv[0].double_conv[0],                    edm.dec["32x32_block3"].conv0,
            diffAE.up2.conv[0].double_conv[3],                    edm.dec["32x32_block3"].conv1,
            diffAE.up2.conv[1].double_conv[0],                    edm.dec["64x64_up"].conv0,
            diffAE.up2.conv[1].double_conv[3],                    edm.dec["64x64_up"].conv1,
            diffAE.up3.conv[0].double_conv[0],                    edm.dec["64x64_block0"].conv0,
            diffAE.up3.conv[0].double_conv[3],                    edm.dec["64x64_block0"].conv1,
            diffAE.up3.conv[1].double_conv[0],                    edm.dec["64x64_block1"].conv0,
            diffAE.up3.conv[1].double_conv[3],                    edm.dec["64x64_block1"].conv1,
            diffAE.up4.conv[0].double_conv[0],                    edm.dec["64x64_block2"].conv0,
            diffAE.up4.conv[0].double_conv[3],                    edm.dec["64x64_block2"].conv1,
            diffAE.up4.conv[1].double_conv[0],                    edm.dec["64x64_block3"].conv0,
            diffAE.up4.conv[1].double_conv[3],                    edm.dec["64x64_block3"].conv1,
            diffAE.outc,                                          edm.out_conv,
]
src_layers = layers[1::2]
dst_layers = layers[::2]

In [None]:
# for 4 blocks
layers = [  diffAE.inc.double_conv[0],                            edm.enc["64x64_conv"],
            diffAE.inc.double_conv[3],                            edm.enc["64x64_block0"].conv0,
            diffAE.down1.maxpool_conv[1].double_conv[0],          edm.enc["64x64_block0"].conv1,
            diffAE.down1.maxpool_conv[1].double_conv[3],          edm.enc["64x64_block1"].conv0,
            diffAE.down1.maxpool_conv[2].double_conv[0],          edm.enc["64x64_block1"].conv1,
            diffAE.down1.maxpool_conv[2].double_conv[3],          edm.enc["64x64_block2"].conv0,
            diffAE.down2.maxpool_conv[1].double_conv[0],          edm.enc["64x64_block2"].conv1,
            diffAE.down2.maxpool_conv[1].double_conv[3],          edm.enc["32x32_down"].conv0,
            diffAE.down2.maxpool_conv[2].double_conv[0],          edm.enc["32x32_down"].conv1,  # dim 0 mismatch
            diffAE.down2.maxpool_conv[2].double_conv[3],          edm.enc["32x32_block0"].conv0, # dim 1 mismatch
            diffAE.down3.maxpool_conv[1].double_conv[0],          edm.enc["32x32_block0"].conv1,
            diffAE.down3.maxpool_conv[1].double_conv[3],          edm.enc["32x32_block1"].conv0,
            diffAE.down3.maxpool_conv[2].double_conv[0],          edm.enc["32x32_block1"].conv1,
            diffAE.down3.maxpool_conv[2].double_conv[3],          edm.enc["32x32_block2"].conv0,
            diffAE.down4.maxpool_conv[1].double_conv[0],          edm.enc["32x32_block2"].conv1,
            diffAE.down4.maxpool_conv[1].double_conv[3],          edm.enc["16x16_down"].conv0,
            diffAE.down4.maxpool_conv[2].double_conv[0],          edm.enc["16x16_down"].conv1,
            diffAE.down4.maxpool_conv[2].double_conv[3],          edm.enc["16x16_block0"].conv0,
            diffAE.bot1.double_conv[0],                           edm.dec["8x8_in0"].conv0,
            diffAE.bot1.double_conv[3],                           edm.dec["8x8_in0"].conv1,
            diffAE.bot2.double_conv[0],                           edm.dec["8x8_in1"].conv0,
            diffAE.bot2.double_conv[3],                           edm.dec["8x8_in1"].conv1, 
            diffAE.bot3.double_conv[0],                           edm.dec["8x8_block0"].conv0,
            diffAE.bot3.double_conv[3],                           edm.dec["8x8_block0"].conv1,
            diffAE.up1.conv[0].double_conv[0],                    edm.dec["32x32_block3"].conv0,
            diffAE.up1.conv[0].double_conv[3],                    edm.dec["32x32_block3"].conv1,
            diffAE.up1.conv[1].double_conv[0],                    edm.dec["64x64_up"].conv0,
            diffAE.up1.conv[1].double_conv[3],                    edm.dec["64x64_up"].conv1,
            diffAE.up2.conv[0].double_conv[0],                    edm.dec["64x64_block0"].conv0,
            diffAE.up2.conv[0].double_conv[3],                    edm.dec["64x64_block0"].conv1,
            diffAE.up2.conv[1].double_conv[0],                    edm.dec["64x64_block1"].conv0,
            diffAE.up2.conv[1].double_conv[3],                    edm.dec["64x64_block1"].conv1,
            diffAE.up3.conv[0].double_conv[0],                    edm.dec["64x64_block2"].conv0,
            diffAE.up3.conv[0].double_conv[3],                    edm.dec["64x64_block2"].conv1,
            diffAE.up3.conv[1].double_conv[0],                    edm.dec["64x64_block3"].conv0,
            diffAE.up3.conv[1].double_conv[3],                    edm.dec["64x64_block3"].conv1,
            diffAE.outc,                                          edm.out_conv,
]
src_layers = layers[1::2]
dst_layers = layers[::2]

In [None]:
for src_layer, dst_layer in zip(src_layers, dst_layers):
    load_weights(src_layer, dst_layer)

In [None]:
torch.save(diffAE.state_dict(), os.path.join("models", "transfered.pt"))