Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inquiry About PirateNet Code Release Date #7

Open
liwenlin664477 opened this issue May 2, 2024 · 2 comments
Open

Inquiry About PirateNet Code Release Date #7

liwenlin664477 opened this issue May 2, 2024 · 2 comments

Comments

@liwenlin664477
Copy link

I am writing to inquire about the release date of the PirateNet code.

Thanks!!

@sifanexisted
Copy link
Collaborator

Sorry about this. We plan to release it ASAP and will definitely get back to you a more certrain answer next week.

In fact, implementing this arch should be quite easy. As a proof of concept, you can simply define a trainable parameter alpha (initalized by zero) for each block. I believe it should work and be better than plain MLP when the underlying PINN model is deep.

class Bottleneck(nn.Module):
    hidden_dim: int
    output_dim: int
    nonlinearity: float=0.0

    @nn.compact
    def __call__(self, x):
        identity = x

        x = nn.Dense(features=self.hidden_dim)(x)
        x = jnp.tanh(x)

        x = nn.Dense(features=self.hidden_dim)(x)
        x =  jnp.tanh(x)

        x = nn.Dense(features=self.hidden_dim)(x)
        x =  jnp.tanh(x)

        alpha = self.param("alpha", constant(self.nonlinearity), (1,))
        x = alpha * x + (1 - alpha) * identity

        return x

@liwenlin664477
Copy link
Author

Sorry about this. We plan to release it ASAP and will definitely get back to you a more certrain answer next week.

In fact, implementing this arch should be quite easy. As a proof of concept, you can simply define a trainable parameter alpha (initalized by zero) for each block. I believe it should work and be better than plain MLP when the underlying PINN model is deep.

class Bottleneck(nn.Module):
    hidden_dim: int
    output_dim: int
    nonlinearity: float=0.0

    @nn.compact
    def __call__(self, x):
        identity = x

        x = nn.Dense(features=self.hidden_dim)(x)
        x = jnp.tanh(x)

        x = nn.Dense(features=self.hidden_dim)(x)
        x =  jnp.tanh(x)

        x = nn.Dense(features=self.hidden_dim)(x)
        x =  jnp.tanh(x)

        alpha = self.param("alpha", constant(self.nonlinearity), (1,))
        x = alpha * x + (1 - alpha) * identity

        return x

Thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants