In [1]:
import sys 
sys.path.append('../')

import VNCA_ST 

In [2]:
from VNCA_ST.Model.VNCA import VNCA, VNCA_paras

paras = VNCA_paras()
paras.h = 32
paras.w = 32
paras.n_channels = 3
paras.z_size = 256
paras.p_update = 1.0
paras.dmg_size = 16
paras.min_steps =64
paras.max_steps = 128

paras.batch_size =32

# update net
paras.nca_hid = 128
paras.n_mixtures = 1



# encoder net
paras.filter_size = 5
paras.pad = paras.filter_size // 2
paras.encoder_hid = 32


In [4]:
from torch import nn
from VNCA_ST.Model.base import Residual
from VNCA_ST.Model.distribution import DiscretizedMixtureLogitsDistribution

n_mixtures = paras.n_mixtures # 转为常量
def state_to_dist(state):
    return DiscretizedMixtureLogitsDistribution(n_mixtures, state[:, :n_mixtures * 10, :, :])

def create_encoder(paras):
    encoder = nn.Sequential(
        nn.Conv2d(paras.n_channels, paras.encoder_hid * 2 ** 0, paras.filter_size, padding=paras.pad), nn.ELU(),  # (bs, 32, h, w)
        nn.Conv2d(paras.encoder_hid * 2 ** 0, paras.encoder_hid * 2 ** 1, paras.filter_size, padding=paras.pad, stride=2), nn.ELU(),  # (bs, 64, h//2, w//2)
        nn.Conv2d(paras.encoder_hid * 2 ** 1, paras.encoder_hid * 2 ** 2, paras.filter_size, padding=paras.pad, stride=2), nn.ELU(),  # (bs, 128, h//4, w//4)
        nn.Conv2d(paras.encoder_hid * 2 ** 2, paras.encoder_hid * 2 ** 3, paras.filter_size, padding=paras.pad, stride=2), nn.ELU(),  # (bs, 256, h//8, w//8)
        nn.Conv2d(paras.encoder_hid * 2 ** 3, paras.encoder_hid * 2 ** 4, paras.filter_size, padding=paras.pad, stride=2), nn.ELU(),  # (bs, 512, h//16, w//16),
        nn.Flatten(),  # (bs, 512*h//16*w//16)
        nn.Linear(paras.encoder_hid * (2 ** 4) * paras.h // 16 * paras.w // 16, 2 * paras.z_size),
    )
    return encoder

def create_updateNet(paras):
    update_net = nn.Sequential(
        nn.Conv2d(paras.z_size, paras.nca_hid, 3, padding=1),
        Residual(
            nn.Conv2d(paras.nca_hid, paras.nca_hid, 1),
            nn.ELU(),
            nn.Conv2d(paras.nca_hid, paras.nca_hid, 1),
        ),
        Residual(
            nn.Conv2d(paras.nca_hid, paras.nca_hid, 1),
            nn.ELU(),
            nn.Conv2d(paras.nca_hid, paras.nca_hid, 1),
        ),
        Residual(
            nn.Conv2d(paras.nca_hid, paras.nca_hid, 1),
            nn.ELU(),
            nn.Conv2d(paras.nca_hid, paras.nca_hid, 1),
        ),
        Residual(
            nn.Conv2d(paras.nca_hid, paras.nca_hid, 1),
            nn.ELU(),
            nn.Conv2d(paras.nca_hid, paras.nca_hid, 1),
        ),
        nn.Conv2d(paras.nca_hid, paras.z_size, 1)
    )
    update_net[-1].weight.data.fill_(0.0)
    update_net[-1].bias.data.fill_(0.0)
    return update_net


encoder = create_encoder(paras)
update_net = create_updateNet(paras)

In [5]:
import os
from torchvision import datasets, transforms

data_dir = "/home/shi/WorkSpace/projects/scLLM_workspace/data/test/"
tp = transforms.Compose([transforms.Resize((paras.h, paras.w)), transforms.ToTensor()])
train_data, val_data, test_data = [datasets.CelebA(data_dir, 
                                                   split=split, 
                                                   download=True, 
                                                   transform=tp) for split in ["train", "valid", "test"]]



RuntimeError: The daily quota of the file img_align_celeba.zip is exceeded and it can't be downloaded. This is a limitation of Google Drive and can only be overcome by trying again later.