In [88]:
import torch
import time
import matplotlib.pyplot as plt
from mlrg.hmc import HMCSampler
from rgflow import RGLayer, RGPartition

In [89]:
class GaussianModel(torch.nn.Module):
    ''' phi4 model energy
        E = (1/2) sum_<ij> |x_i-x_j|^2 + (r/2) sum_i |x_i|^2

        Parameters:
            r :: real - (initial) value of
            '''
    def __init__(self, r=0.):
        super().__init__()
        self.r = torch.nn.Parameter(torch.tensor(r))

    def extra_repr(self):
        return f'r={self.r.item()}'

    def return_param(self):
        return self.r

    def clone(self):
        mdl = type(self)()
        mdl.load_state_dict(self.state_dict())
        return mdl

    def forward(self, x):
        energy = 0.
        for axis in range(2, x.dim()):
            dx2 = (x.roll(1,axis) - x).square().sum(1)
            energy = energy + dx2 / 2
        x2 = x.square().sum(1)
        energy = energy + self.r * x2 / 2
        energy = energy.view(energy.shape[:1]+(-1,)).sum(-1)
        return energy

In [71]:
device = 'cpu'
hmc = HMCSampler(GaussianModel(r=3.), [1,1,2])
x = hmc.sample(device, samples=10).detach()
print(x.dim())
print(GaussianModel(r=3.)(x))

3
tensor([4.3420e-01, 1.6956e-01, 3.8666e-01, 2.4564e-01, 5.3521e-01, 5.6790e-02,
        8.3578e-04, 5.8300e-02, 9.8736e-01, 1.7811e-01],
       grad_fn=<SumBackward1>)


In [25]:
x, z = RGPartition([2]).split(x)
print(RGPartition([2]).mask)
print(x)

tensor([False,  True])
tensor([[[0.4042]]])


In [90]:
class RGLearner(torch.nn.Module):
    def __init__(self, uv_model, ir_model, device, uv_shape, dim, base_dist='Normal', **kwargs):
        super().__init__()
        self.uv_model = uv_model.requires_grad_(False)
        # self.ir_model = uv_model.clone().requires_grad_(True)
        self.ir_model = ir_model.requires_grad_(True)
        self.rglayer = RGLayer(uv_shape, device, dim, **kwargs)
        ir_shape = self.rglayer.partitioner.out_shape
        self.ir_sampler = HMCSampler(self.ir_model, (1, dim)+torch.Size(ir_shape))
        self.base_dist = getattr(torch.distributions, base_dist)(0., 1.)
        self.ir_param = self.ir_model.return_param()
        self.uv_param = self.uv_model.return_param()
        self.inv_K = torch.tensor([[2 + self.uv_param, -2],[-2, 2 + self.uv_param]])
        self.K = torch.inverse(self.inv_K)

    def check_ir_param(self):
        return self.ir_param.requires_grad

    def sample(self, samples, device):
        with torch.no_grad():
            return self.rsample(samples, device)

    def rsample(self, samples, device):
        x_ir = self.ir_sampler.sample(device, samples=samples)
        print(x_ir.is_cuda)
        z = self.base_dist.rsample(x_ir.shape[:2]+self.rglayer.partitioner.res_shape).to(device)
        print(z.is_cuda)
        x_uv, *_ = self.rglayer.decode(x_ir, z)
        return x_uv

    def loss(self, samples, device, lk=0.01, lg=0.01, mode=None, **kwargs):
        x_ir = self.ir_sampler.sample(device, samples=samples, **kwargs)
        z = self.base_dist.rsample(x_ir.shape[:2]+self.rglayer.partitioner.res_shape).to(device)
        x_uv, logJ, Ek, Eg = self.rglayer.decode(x_ir, z, mode='jf_reg', **kwargs)
        diff = self.uv_model(x_uv) - 1/2 * torch.log(1 / self.ir_param) - self.ir_model(x_ir) - logJ + 1 / 2 * torch.log(torch.det(self.K)) # The global minimum would be 0.5
        Loss = diff + lk * Ek + lg * Eg      # original Loss
        Loss_ = Loss.mean().detach()         # take average and detach
        Loss = (1 + self.ir_model(x_ir).detach() - self.ir_model(x_ir)) * (Loss-Loss_) # implement the reinforce algorithm. log(p(x_ir)) = -E(x_ir) + const
        Loss, diff, Ek, Eg = [val.mean() for val in (Loss, diff, Ek, Eg)]
        return Loss, diff, Ek, Eg

In [91]:
device = torch.device('cpu')
rgl = RGLearner(GaussianModel(r=1.).to(device), GaussianModel(r=5.).to(device), device, [2], 1, hdims=[8,8], hyper_dim=16)
optimizer = torch.optim.Adam(rgl.parameters(), lr=0.001)
print(rgl.check_ir_param())


True


In [93]:
for _ in range(100):
        optimizer.zero_grad()
        loss, *rest = rgl.loss(1000, device)
        loss.backward()
        optimizer.step()
        print(f'{loss.item()} '+' '.join(f'{r.item()}' for r in rest))

-3.6239622858147413e-08 1.6502635478973389 1.005913496017456 0.755322277545929
2.479553273815327e-08 1.3593195676803589 0.724684476852417 0.5482970476150513
1.23977656585339e-07 1.2210677862167358 0.5099406838417053 0.5050588250160217
-3.6239622858147413e-08 1.1664806604385376 0.4050995409488678 0.437779039144516
-9.536743617033494e-10 1.0510467290878296 0.3214738070964813 0.38977086544036865
3.910064805268121e-08 0.9596055746078491 0.2588021755218506 0.4024229943752289
-1.907348723406699e-09 0.7425617575645447 0.20915080606937408 0.3447747826576233
-9.53674295089968e-09 0.7503708600997925 0.1861545741558075 0.390712708234787
1.525878978725359e-08 0.7415147423744202 0.18060027062892914 0.4217604994773865
-4.482269133632144e-08 0.6209588646888733 0.18910381197929382 0.3638084828853607
3.814697446813398e-09 0.7445825934410095 0.23300479352474213 0.37127551436424255
-2.1457672971791908e-08 0.6792638301849365 0.271903395652771 0.41504931449890137
1.9550324026340604e-08 0.6365292072296143 0

In [94]:
rgl.ir_model

GaussianModel(r=4.96737813949585)