In [1]:
import torch
import time
import matplotlib.pyplot as plt
from mlrg.hmc import HMCSampler
from rgflow import RGLayer, RGPartition

In [20]:
class GaussianModel(torch.nn.Module):
    ''' phi4 model energy
        E = (1/2) sum_<ij> |x_i-x_j|^2 + (r/2) sum_i |x_i|^2

        Parameters:
            r :: real - (initial) value of
            '''
    def __init__(self, r=0.):
        super().__init__()
        self.r = torch.nn.Parameter(torch.tensor(r))

    def extra_repr(self):
        return f'r={self.r.item()}'

    def return_param(self):
        return torch.exp(self.r)

    def clone(self):
        mdl = type(self)()
        mdl.load_state_dict(self.state_dict())
        return mdl

    def forward(self, x):
        energy = 0.
        for axis in range(1, x.dim()):
            dx2 = (x.roll(1,axis) - x).square().sum(1)
            energy = energy + dx2 / 2
        x2 = x.square().sum(1)
        energy = energy + torch.exp(self.r) * x2 / 2
        energy = energy.view(energy.shape[:1]+(-1,)).sum(-1)
        return energy

In [127]:
device = 'cpu'
hmc = HMCSampler(GaussianModel(r=1.), [1,1])
x_ir = hmc.sample(device, samples=10000)
E = GaussianModel(r=0.91)(x_ir).mean()
print(E)

tensor(0.4624, grad_fn=<MeanBackward0>)


In [122]:
class LinearRGLearner(torch.nn.Module):
    def __init__(self, uv_model, ir_model, device, base_dist='Normal'):
        super().__init__()
        self.uv_model = uv_model.requires_grad_(False)
        # self.ir_model = uv_model.clone().requires_grad_(True)
        self.ir_model = ir_model.requires_grad_(True)
        self.ir_sampler = HMCSampler(self.ir_model, [1, 1])
        self.base_dist = getattr(torch.distributions, base_dist)(0., 1.)
        self.uv_param = self.uv_model.return_param()
        self.mat = torch.nn.Parameter(torch.randn(2, 2, requires_grad=True)).to(device)
        self.inv_K = torch.tensor([[2 + self.uv_param, -2],[-2, 2 + self.uv_param]])
        self.K = torch.inverse(self.inv_K)


    def get_mat(self):
        return self.mat

    def get_ir_model(self):
        return self.ir_model

    def sample(self, samples, device):
        with torch.no_grad():
            return self.rsample(samples, device)

    def rsample(self, samples, device):
        x_ir = self.ir_sampler.sample(device, samples=samples)
        z = self.base_dist.rsample([samples, 1]).to(device)
        return x_ir, z

    def loss(self, samples, device, **kwargs):
        ir_param = self.ir_model.return_param()
        x_ir = self.ir_sampler.sample(device, samples=samples, **kwargs)
        z = self.base_dist.rsample([samples, 1]).to(device)
        X_ir = torch.cat((x_ir, z), dim=1)
        # X_ir.requires_grad_(True)
        x_uv = torch.matmul(X_ir, torch.transpose(self.mat, 0, 1))
        diff = self.uv_model(x_uv) - 1/2 * torch.log(1 / ir_param) - self.ir_model(x_ir) - torch.log(torch.abs(torch.det(self.mat))) + 1 / 2 * torch.log(torch.det(self.K))
        Loss = diff    # original Loss
        Loss = Loss.mean()
        return Loss, self.ir_model(x_ir).mean()

    def exact_loss(self, device):
        ir_param = self.ir_model.return_param()
        G = torch.zeros((2, 2))
        G[0, 0] = 1 / ir_param
        G[1, 1] = 1
        M = torch.matmul(self.mat, G)
        M = torch.matmul(self.inv_K, M)
        Loss = 1 / 2 * torch.trace(torch.matmul(torch.transpose(self.mat, 0, 1), M)) + 1 / 2 * torch.log(torch.det(self.K)) - 1 / 2 * torch.log(torch.det(G)) - torch.log(torch.abs(torch.det(self.mat))) - 1
        return Loss

In [138]:
device = 'cpu'
rgl = LinearRGLearner(GaussianModel(r=1.).to(device), GaussianModel(r=1.).to(device), device)
optimizer = torch.optim.Adam(rgl.parameters(), lr=0.005)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
print(sum(p.numel() for p in rgl.parameters() if p.requires_grad))
print(rgl.get_mat())

5
Parameter containing:
tensor([[ 0.1675,  0.3877],
        [ 0.8321, -1.1968]], requires_grad=True)


In [140]:
for _ in range(1000):
        optimizer.zero_grad()
        # loss, *rest = rgl.loss(10000, device)
        loss = rgl.exact_loss(device)
        loss.backward(retain_graph=True)
        optimizer.step()
        # scheduler.step()
        print(f'{loss.item()}')
        # print(f'{loss.item()} '+' '.join(f'{r.item()}' for r in rest))

0.02224147319793701
0.02154099941253662
0.020860791206359863
0.020200252532958984
0.01955890655517578
0.01893627643585205
0.01833200454711914
0.017745256423950195
0.017175912857055664
0.016623258590698242
0.016086935997009277
0.015566349029541016
0.015061378479003906
0.014571428298950195
0.01409614086151123
0.01363515853881836
0.013187766075134277
0.012754201889038086
0.012333393096923828
0.011925339698791504
0.011530041694641113
0.0111464262008667
0.010774731636047363
0.010414361953735352
0.010065317153930664
0.009726643562316895
0.009398698806762695
0.009080886840820312
0.008772850036621094
0.008474588394165039
0.008185386657714844
0.00790548324584961
0.007634520530700684
0.007371783256530762
0.007117509841918945
0.006871342658996582
0.00663304328918457
0.006402254104614258
0.006179213523864746
0.005962967872619629
0.005753874778747559
0.005551457405090332
0.0053555965423583984
0.005166292190551758
0.004983067512512207
0.004805803298950195
0.004634261131286621
0.004468679428100586
0.

In [141]:
print(rgl.get_ir_model)

<bound method LinearRGLearner.get_ir_model of LinearRGLearner(
  (uv_model): GaussianModel(r=1.0)
  (ir_model): GaussianModel(r=0.44882655143737793)
)>


In [142]:
print(rgl.get_mat())

Parameter containing:
tensor([[ 0.6098, -0.1449],
        [ 0.0943, -0.5027]], requires_grad=True)


In [144]:
mat = rgl.get_mat()
inv_k = torch.tensor([[2 + torch.exp(torch.tensor(1.)), -2],[-2, 2 + torch.exp(torch.tensor(1.))]])
M = torch.matmul(inv_k, mat)
M = torch.matmul(torch.transpose(mat, 0, 1), M)
print(M)

tensor([[ 1.5665e+00, -1.3411e-07],
        [-1.1921e-07,  1.0000e+00]], grad_fn=<MmBackward0>)


In [208]:
test = rgl.get_ir_model()
hmc = HMCSampler(test, [1,1])
x = hmc.sample(device, samples=3).detach()
E = test(x)
print(x)
print(x.roll(1,1))

tensor([[-0.4934],
        [-2.3520],
        [-4.9539]])
tensor([[-0.4934],
        [-2.3520],
        [-4.9539]])
