In [2]:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

class CustomELBO(Trace_ELBO):
    def _get_trace(self, model, guide, *args, **kwargs):
        trace = super()._get_trace(model, guide, *args, **kwargs)

        # Access the sampled latent variables from the trace
        z_1 = trace.nodes["z1"]["value"]
        z_2 = trace.nodes["z2"]["value"]

        # Get the reconstructed images
        loc_x1 = model.decoder1(z_2)
        loc_x2 = model.decoder2(z_1)

        # Compute the reconstruction loss
        x = kwargs["obs"]
        recon_loss = -dist.Normal(loc_x1, 1.0).log_prob(x).sum(-1)
        recon_loss += -dist.Normal(loc_x2, 1.0).log_prob(x).sum(-1)

        # Compute the KL divergence
        kl_divergence = 0.5 * (z_1 ** 2).sum(-1) - 0.5 * (z_2 ** 2).sum(-1)

        # Compute the ELBO
        elbo = recon_loss - kl_divergence

        # Update the trace with the ELBO value
        trace.add_node("elbo", value=elbo.sum())

        return trace


class HVAE:
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, latent_dim):
        self.input_dim = input_dim
        self.hidden_dim1 = hidden_dim1
        self.hidden_dim2 = hidden_dim2
        self.latent_dim = latent_dim

        # 注册模型的参数
        # 注册模型的参数
        self.z1_loc = pyro.param("z1_loc", torch.zeros(latent_dim))
        self.z1_scale = pyro.param("z1_scale", torch.ones(latent_dim))
        self.z2_loc = pyro.param("z2_loc", torch.zeros(latent_dim))
        self.z2_scale = pyro.param("z2_scale", torch.ones(latent_dim))

        # Define encoder1
        self.encoder1 = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim1),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim1, latent_dim * 2)
        )

        # Define encoder2
        self.encoder2 = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim2, latent_dim * 2)
        )

        # Define decoder1
        self.decoder1 = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim2, input_dim)
        )

        # Define decoder2
        self.decoder2 = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim1),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim1, input_dim)
        )

    def model(self, x):
        pyro.module("hvae", self)

        with pyro.plate("data", x.shape[0]):
            # First encoding
            z1_loc, z1_scale = self.encoder1(x).chunk(2, dim=-1)
            z1 = pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))

            # Second encoding
            z2_loc, z2_scale = self.encoder2(z1).chunk(2, dim=-1)
            z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

            # First decoding
            loc_x1 = self.decoder1(z2)
            pyro.sample("obs1", dist.Normal(loc_x1, 1.0).to_event(1), obs=x)

            # Second decoding
            loc_x2 = self.decoder2(z1)
            pyro.sample("obs2", dist.Normal(loc_x2, 1.0).to_event(1), obs=x)

    def guide(self, x):
        with pyro.plate("data", x.shape[0]):
            # First encoding parameters
            z1_loc = pyro.param("z1_loc", torch.zeros(x.shape[0], self.latent_dim))
            z1_scale = pyro.param("z1_scale", torch.ones(x.shape[0], self.latent_dim))
            z1 = pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))

            # Second encoding parameters
            z2_loc = pyro.param("z2_loc", torch.zeros(x.shape[0], self.latent_dim))
            z2_scale = pyro.param("z2_scale", torch.ones(x.shape[0], self.latent_dim))
            z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

            return z1, z2

    def fit(self, x, num_epochs):
        optimizer = Adam({"lr": 0.01})
        elbo = CustomELBO()
        svi = SVI(self.model, self.guide, optimizer, loss=elbo)

        pyro.clear_param_store()
        for epoch in range(num_epochs):
            loss = svi.step(x)
            #loss = svi.step({"obs":x})
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss}")

# Create model instance
input_dim = 1000
hidden_dim1 = 256
hidden_dim2 = 128
latent_dim = 20

model = HVAE(input_dim, hidden_dim1, hidden_dim2, latent_dim)

# Generate sample data
x = torch.randn(100, input_dim)

# Train the model
model.fit(x, num_epochs=100)



AssertionError: module has no parameters