In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.distributions as tdist
import numpy as np

# Set seed
torch.manual_seed(0)

# Where to add a new import
from torch.optim.lr_scheduler import StepLR

In [2]:
n_iters = 1000
n_sub_iters = 10
batch_size = 1 
num_atoms = 10
te

In [3]:
normal_dist = tdist.MultivariateNormal(torch.tensor([0.]), torch.tensor([[1.]]))

In [4]:
def l2_cost(x,y):
    return torch.norm(x-y)

In [5]:
def er_chi_unnorm(x,yj,gj, epsilon, cost_func=l2_cost):
    return torch.exp((-cost_func(x,yj)+gj)/epsilon)

In [6]:
def er_chi_j(x, y, g, epsilon, j):
    numerator = er_chi_unnorm(x,y[j],g[j], epsilon)
    print(numerator)
    denominator = torch.sum(er_chi_unnorm(x,y,g, epsilon))
    return numerator/ denominator

In [7]:
def er_chi(x, y, g, epsilon):
    chis = er_chi_unnorm(x,y,g, epsilon)
    normaliser = torch.sum(chis)
    return chis/ normaliser

In [8]:
# entropic reg c-transform
def er_ctran(x, g, y, epsilon, cost_func):
    return -epsilon * torch.log(torch.sum(torch.exp((-cost_func(x,y)+g)/epsilon))) + torch.sum(g, )

In [9]:
# init vectors
y = torch.tensor(np.arange(10)/10., requires_grad = True)
g = torch.tensor([np.random.random(size=num_atoms)], requires_grad = True)
epsilon = torch.tensor(0.1)

In [10]:
n_iters

1000

## INSTANTIATE OPTIMIZER CLASS

In [11]:
learning_rate = 0.1
optimizer_atoms = torch.optim.SGD([y], lr=learning_rate, momentum=0.9, nesterov=True)
optimizer_map = torch.optim.SGD([g], lr=learning_rate, momentum=0.9, nesterov=True)

## INSTANTIATE STEP LEARNING SCHEDULER CLASS

In [12]:
# step_size: at how many multiples of epoch you decay
# new_lr = lr*gamma 

# gamma = decaying factor
scheduler_atoms = StepLR(optimizer_atoms, step_size=1, gamma=0.8)
scheduler_map = StepLR(optimizer_map, step_size=1, gamma=0.8)

In [13]:
for i in range(n_iters):
    x = normal_dist.sample()

    # Clear gradients w.r.t. parameters
    optimizer_map.zero_grad()

    # Get dual objective to maximise
    dual_objective = er_ctran(x, g, y, epsilon, l2_cost)
    map_loss = -dual_objective

    # Getting gradients w.r.t. parameters
    map_loss.backward()
    optimizer_map.step()
    
    # Updating parameters
    if i % 100 == 0:
        print(g)
        print(map_loss)
        scheduler_map.step()
        

tensor([[0.6144, 0.7008, 1.0332, 0.3547, 0.9850, 1.0443, 0.4550, 1.0159, 0.6922,
         0.4975]], dtype=torch.float64, requires_grad=True)
tensor(-8.2268, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[67.4907, 67.4906, 67.4906, 67.4907, 67.4906, 67.4906, 67.4907, 67.4906,
         67.4906, 67.4907]], dtype=torch.float64, requires_grad=True)
tensor(-603.6753, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       dtype=torch.float64, requires_grad=True)
tensor(nan, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       dtype=torch.float64, requires_grad=True)
tensor(nan, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       dtype=torch.float64, requires_grad=True)
tensor(nan, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       dtype=torch.float64, re

## TRAIN MODEL

### Min Max Optimization

In [14]:
for i in range(n_iters):
    for _ in range(n_sub_iters):
        # sample x
        x = normal_dist.sample()

        # Clear gradients w.r.t. parameters
        optimizer_map.zero_grad()
        
        # Get dual objective to maximise
        dual_objective = er_ctran(x, g, y, epsilon, l2_cost)
        map_loss = -dual_objective
        
        # Getting gradients w.r.t. parameters
        map_loss.backward()

        # Updating parameters
        optimizer_map.step()
        
    for _ in range(n_sub_iters):
        # sample x
        x = normal_dist.sample()

        # Clear gradients w.r.t. parameters
        optimizer_atoms.zero_grad()
        
        # Get loss objective to minimise
        atoms_loss = er_ctran(x, g, y, epsilon, l2_cost)
        
        # Getting gradients w.r.t. parameters
        atoms_loss.backward()

        # Updating parameters
        optimizer_atoms.step()
    
    # Updating parameters
    if i % 100 == 0:
        print(g)
        print(map_loss)
        # Decay Learning Rate
        scheduler_atoms.step()
        scheduler_map.step()
    
    print("Loss: {0}".format(map_loss))

tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       dtype=torch.float64, requires_grad=True)
tensor(nan, dtype=torch.float64, grad_fn=<NegBackward>)
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss:

Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
Loss: nan
tensor([[nan, nan, n

In [15]:
g

tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       dtype=torch.float64, requires_grad=True)

In [16]:
y

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=torch.float64,
       requires_grad=True)