In [1]:
import torch
from models.ae import IRVAE
from models import get_net 
from models.modules import (
    FC_vec,
    FC_image,
    IsotropicGaussian,
    ConvNet28,
    DeConvNet28,
    PreTrained_Model
)
from omegaconf import OmegaConf

In [24]:
def relaxed_distortion_measure(func, z, eta=0.2, metric='identity', create_graph=True):
    if metric == 'identity':
        bs = len(z)
        z_perm = z[torch.randperm(bs)]
        if eta is not None:
            alpha = (torch.rand(bs) * (1 + 2*eta) - eta).unsqueeze(1).to(z)
            z_augmented = alpha*z + (1-alpha)*z_perm
        else:
            z_augmented = z
        v = torch.randn(z.size()).to(z)
        Jv = torch.autograd.functional.jvp(func, z_augmented, v=v, create_graph=create_graph)[1]
        TrG = torch.sum(Jv.view(bs, -1)**2, dim=1).mean()
        JTJv = (torch.autograd.functional.vjp(func, z_augmented, v=Jv, create_graph=create_graph)[1]).view(bs, -1)
        TrG2 = torch.sum(JTJv**2, dim=1).mean()
        return TrG2/TrG**2

    elif isinstance(metric, PreTrained_Model):
        model = metric.class_name()
        model.load_state_dict(torch.load(metric.parameter_path))
        
        bs = len(z)
        z_perm = z[torch.randperm(bs)]
        if eta is not None:
            alpha = (torch.rand(bs) * (1 + 2*eta) - eta).unsqueeze(1).to(z)
            z_augmented = alpha*z + (1-alpha)*z_perm
        else:
            z_augmented = z

        v = torch.randn(z.size()).to(z)
        print(f"v.shape: {v.shape}")
        print(f"z_augmented.shape: {z_augmented.shape}")
        Jv = torch.autograd.functional.jvp(func, z_augmented, v=v, create_graph=create_graph)[1]
        print(f"Jv.shape: {Jv.size()}")
        HJv = torch.autograd.functional.jvp(model, func(z_augmented), v=Jv, create_graph=create_graph)[1]

        TrG = torch.sum(HJv.view(bs, -1)**2, dim=1).mean()

        print(f"HJv.shape: {HJv.shape}")
        HTHJv = (torch.autograd.functional.vjp(model, func(z_augmented), v=HJv, create_graph=create_graph)[1]).view(bs, -1)
        print(f"HTHJv.shape: {HTHJv.shape}")
        print(f"z_augmented.shape: {z_augmented.shape}")
        JTHTHJv = (torch.autograd.functional.vjp(func, z_augmented, v=HTHJv.view(bs, Jv.shape[1], Jv.shape[2], Jv.shape[3]), create_graph=create_graph)[1]).view(bs, -1)

        TrG2 = torch.sum(JTHTHJv**2, dim=1).mean()

        return TrG2/TrG**2

    else:
        raise NotImplementedError


SyntaxError: f-string: cannot use double starred expression here (2808115615.py, line 33)

In [25]:
cfg = OmegaConf.load('configs/mnist_irvae_z2_pretrain.yml')

if "model" in cfg:
    model_dict = cfg["model"]
elif "arch" in cfg:
    model_dict = cfg


x_dim = model_dict['x_dim']
z_dim = model_dict['z_dim']
arch = model_dict["arch"]

metric = PreTrained_Model("simple_linear", "models/saved_model/simple_linear.pt")
iso_reg = 1.0
encoder = get_net(in_dim=x_dim, out_dim=z_dim * 2, **model_dict["encoder"])
decoder = get_net(in_dim=z_dim, out_dim=x_dim, **model_dict["decoder"])
model = IRVAE(encoder, IsotropicGaussian(decoder), iso_reg=iso_reg, metric=metric)


In [26]:
print(model.encoder)
x = torch.rand(100, 1, 28, 28)
z = model.encoder(x)
print(z.shape)
z_sample = model.sample_latent(z)
print(z_sample.shape)
iso_loss = relaxed_distortion_measure(model.decode, z_sample, eta=0.2, metric=model.metric)

FC_image(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): ReLU(inplace=True)
    (8): Linear(in_features=256, out_features=4, bias=True)
  )
)
torch.Size([100, 4])
torch.Size([100, 2])
v.shape: torch.Size([100, 2])
z_augmented.shape: torch.Size([100, 2])
Jv.shape: torch.Size([100, 1, 28, 28])
HJv.shape: torch.Size([100, 10])
HTHJv.shape: torch.Size([100, 784])
z_augmented.shape: torch.Size([100, 2])


RuntimeError: shape '[100, 1]' is invalid for input of size 78400