In [88]:
import torch
import time
import matplotlib.pyplot as plt
from mlrg.hmc import HMCSampler
from rgflow import RGLayer, RGPartition

In [89]:
class GaussianModel(torch.nn.Module):
    ''' phi4 model energy
        E = (1/2) sum_<ij> |x_i-x_j|^2 + (r/2) sum_i |x_i|^2

        Parameters:
            r :: real - (initial) value of
            '''
    def __init__(self, r=0.):
        super().__init__()
        self.r = torch.nn.Parameter(torch.tensor(r))

    def extra_repr(self):
        return f'r={self.r.item()}'

    def return_param(self):
        return self.r

    def clone(self):
        mdl = type(self)()
        mdl.load_state_dict(self.state_dict())
        return mdl

    def forward(self, x):
        energy = 0.
        for axis in range(2, x.dim()):
            dx2 = (x.roll(1,axis) - x).square().sum(1)
            energy = energy + dx2 / 2
        x2 = x.square().sum(1)
        energy = energy + self.r * x2 / 2
        energy = energy.view(energy.shape[:1]+(-1,)).sum(-1)
        return energy

In [71]:
device = 'cpu'
hmc = HMCSampler(GaussianModel(r=3.), [1,1,2])
x = hmc.sample(device, samples=10).detach()
print(x.dim())
print(GaussianModel(r=3.)(x))

3
tensor([4.3420e-01, 1.6956e-01, 3.8666e-01, 2.4564e-01, 5.3521e-01, 5.6790e-02,
        8.3578e-04, 5.8300e-02, 9.8736e-01, 1.7811e-01],
       grad_fn=<SumBackward1>)


In [25]:
x, z = RGPartition([2]).split(x)
print(RGPartition([2]).mask)
print(x)

tensor([False,  True])
tensor([[[0.4042]]])


In [95]:
class RGLearner(torch.nn.Module):
    def __init__(self, uv_model, ir_model, device, uv_shape, dim, base_dist='Normal', **kwargs):
        super().__init__()
        self.uv_model = uv_model.requires_grad_(False)
        # self.ir_model = uv_model.clone().requires_grad_(True)
        self.ir_model = ir_model.requires_grad_(True)
        self.rglayer = RGLayer(uv_shape, device, dim, **kwargs)
        ir_shape = self.rglayer.partitioner.out_shape
        self.ir_sampler = HMCSampler(self.ir_model, (1, dim)+torch.Size(ir_shape))
        self.base_dist = getattr(torch.distributions, base_dist)(0., 1.)
        self.ir_param = self.ir_model.return_param()
        self.uv_param = self.uv_model.return_param()
        self.inv_K = torch.tensor([[2 + self.uv_param, -2],[-2, 2 + self.uv_param]])
        self.K = torch.inverse(self.inv_K)

    def check_ir_param(self):
        return self.ir_param.requires_grad

    def sample(self, samples, device):
        with torch.no_grad():
            return self.rsample(samples, device)

    def rsample(self, samples, device):
        x_ir = self.ir_sampler.sample(device, samples=samples)
        print(x_ir.is_cuda)
        z = self.base_dist.rsample(x_ir.shape[:2]+self.rglayer.partitioner.res_shape).to(device)
        print(z.is_cuda)
        x_uv, *_ = self.rglayer.decode(x_ir, z)
        return x_uv

    def loss(self, samples, device, lk=0.01, lg=0.01, mode=None, **kwargs):
        x_ir = self.ir_sampler.sample(device, samples=samples, **kwargs)
        z = self.base_dist.rsample(x_ir.shape[:2]+self.rglayer.partitioner.res_shape).to(device)
        x_uv, logJ, Ek, Eg = self.rglayer.decode(x_ir, z, mode='jf_reg', **kwargs)
        diff = self.uv_model(x_uv) - 1/2 * torch.log(1 / self.ir_param) - self.ir_model(x_ir) - logJ + 1 / 2 * torch.log(torch.det(self.K)) # The global minimum would be 0.5
        Loss = diff + lk * Ek + lg * Eg      # original Loss
        Loss_ = Loss.mean().detach()         # take average and detach
        Loss = (1 + self.ir_model(x_ir).detach() + 1/2 * torch.log(1 / self.ir_param).detach() - self.ir_model(x_ir) - 1/2 * torch.log(1 / self.ir_param)) * (Loss-Loss_) # implement the reinforce algorithm. log(p(x_ir)) = -E(x_ir) + const
        Loss, diff, Ek, Eg = [val.mean() for val in (Loss, diff, Ek, Eg)]
        return Loss, diff, Ek, Eg

In [96]:
device = torch.device('cpu')
rgl = RGLearner(GaussianModel(r=1.).to(device), GaussianModel(r=5.).to(device), device, [2], 1, hdims=[8,8], hyper_dim=16)
optimizer = torch.optim.Adam(rgl.parameters(), lr=0.001)
print(rgl.check_ir_param())


True


In [98]:
for _ in range(100):
        optimizer.zero_grad()
        loss, *rest = rgl.loss(1000, device)
        loss.backward()
        optimizer.step()
        print(f'{loss.item()} '+' '.join(f'{r.item()}' for r in rest))

-1.4495849143258965e-07 2.938091278076172 0.4170213043689728 1.0147730112075806
2.0980834847250662e-08 2.075449228286743 0.3822191655635834 0.6212579607963562
9.536743306171047e-08 1.749548077583313 0.386790931224823 0.36216145753860474
7.247924571629483e-08 1.3561747074127197 0.3868293762207031 0.2308245450258255
1.3351440841802287e-08 1.105602502822876 0.3434905409812927 0.22334060072898865
2.574920721087892e-08 0.9526402950286865 0.3048913776874542 0.2738724648952484
1.0490417423625331e-08 0.8986373543739319 0.2831771671772003 0.36143797636032104
-1.907348590179936e-08 0.7896808981895447 0.34496837854385376 0.552113950252533
1.1444091896350983e-08 0.7041557431221008 0.44465509057044983 0.6498399376869202
-2.2888183792701966e-08 0.8220619559288025 0.5459113717079163 0.7846616506576538
3.814697446813398e-09 0.8994033336639404 0.5892844796180725 0.8097887635231018
3.2424928519958485e-08 0.7494833469390869 0.6217252016067505 0.9134032130241394
2.8610228852699038e-08 0.7644792199134827 0

In [94]:
rgl.ir_model

GaussianModel(r=4.96737813949585)