In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.distributions as tdist
import numpy as np

# Set seed
torch.manual_seed(0)

# Where to add a new import
from torch.optim.lr_scheduler import StepLR

In [2]:
n_iters = 1000
n_sub_iters = 10
batch_size = 1 
num_atoms = 10

In [3]:
normal_dist = tdist.MultivariateNormal(torch.tensor([0.]), torch.tensor([[1.]]))

In [4]:
def l2_cost(x,y):
    return torch.norm(x-y)

In [5]:
def er_chi_unnorm(x,yj,gj, epsilon, cost_func=l2_cost):
    return torch.exp((-cost_func(x,yj)+gj)/epsilon)

In [6]:
def er_chi_j(x, y, g, epsilon, j):
    numerator = er_chi_unnorm(x,y[j],g[j], epsilon)
    print(numerator)
    denominator = torch.sum(er_chi_unnorm(x,y,g, epsilon))
    return numerator/ denominator

In [7]:
def er_chi(x, y, g, epsilon):
    chis = er_chi_unnorm(x,y,g, epsilon)
    normaliser = torch.sum(chis)
    return chis/ normaliser

In [8]:
# entropic reg c-transform
def er_ctran(x, g, y, epsilon, cost_func):
    return -epsilon * torch.log(torch.sum(torch.exp((-cost_func(x,y)+g)/epsilon))) + torch.sum(g)/ torch.tensor(num_atoms)

In [9]:
# init vectors
y = torch.tensor(np.arange(10)/10., requires_grad = True)
g = torch.tensor([np.random.random(size=num_atoms)], requires_grad = True)
epsilon = torch.tensor(0.1)

In [10]:
n_iters

1000

In [17]:
y

tensor([0.0429, 0.0429, 0.0429, 0.0429, 0.0429, 0.0429, 0.0429, 0.0429, 0.0429,
        0.0429], dtype=torch.float64, requires_grad=True)

## INSTANTIATE OPTIMIZER CLASS

In [11]:
learning_rate = 0.1
optimizer_atoms = torch.optim.SGD([y], lr=learning_rate, momentum=0.9, nesterov=True)
optimizer_map = torch.optim.SGD([g], lr=learning_rate, momentum=0.9, nesterov=True)

## INSTANTIATE STEP LEARNING SCHEDULER CLASS

In [12]:
# step_size: at how many multiples of epoch you decay
# new_lr = lr*gamma 

# gamma = decaying factor
scheduler_atoms = StepLR(optimizer_atoms, step_size=1, gamma=0.8)
scheduler_map = StepLR(optimizer_map, step_size=1, gamma=0.8)

In [13]:
for i in range(n_iters):
    x = normal_dist.sample()

    # Clear gradients w.r.t. parameters
    optimizer_map.zero_grad()

    # Get dual objective to maximise
    dual_objective = er_ctran(x, g, y, epsilon, l2_cost)
    map_loss = -dual_objective

    # Getting gradients w.r.t. parameters
    map_loss.backward()
    optimizer_map.step()
    
    # Updating parameters
    if i % 100 == 0:
        print(g)
        print(map_loss)
        scheduler_map.step()
        

tensor([[0.3479, 0.1961, 0.7294, 0.1560, 0.4886, 0.1317, 0.7441, 0.0317, 0.5805,
         0.2598]], dtype=torch.float64, requires_grad=True)
tensor(-3.0419, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)
tensor(-2.7394, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)
tensor(-2.7701, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)
tensor(-3.0077, dtype=torch.float64, grad_fn=<NegBackward>)
tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)
tensor(-1.1555, dtype=torch.float64, grad_fn=<NegBackwa

## TRAIN MODEL

### Min Max Optimization

In [14]:
for i in range(n_iters):
    for _ in range(n_sub_iters):
        # sample x
        x = normal_dist.sample()

        # Clear gradients w.r.t. parameters
        optimizer_map.zero_grad()
        
        # Get dual objective to maximise
        dual_objective = er_ctran(x, g, y, epsilon, l2_cost)
        map_loss = -dual_objective
        
        # Getting gradients w.r.t. parameters
        map_loss.backward()

        # Updating parameters
        optimizer_map.step()
        
    for _ in range(n_sub_iters):
        # sample x
        x = normal_dist.sample()

        # Clear gradients w.r.t. parameters
        optimizer_atoms.zero_grad()
        
        # Get loss objective to minimise
        atoms_loss = er_ctran(x, g, y, epsilon, l2_cost)
        
        # Getting gradients w.r.t. parameters
        atoms_loss.backward()

        # Updating parameters
        optimizer_atoms.step()
    
    # Updating parameters
    if i % 100 == 0:
        print(g)
        print(map_loss)
        # Decay Learning Rate
        scheduler_atoms.step()
        scheduler_map.step()
    
    print("Loss: {0}".format(map_loss))

tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)
tensor(-2.3242, dtype=torch.float64, grad_fn=<NegBackward>)
Loss: -2.3241907345162387
Loss: -3.1829023809139434
Loss: -2.6812323799178124
Loss: -3.008164645047274
Loss: -0.2701686462248475
Loss: -0.6533732384174038
Loss: -2.745527292332235
Loss: -4.068370050707714
Loss: -3.6785262578120967
Loss: -0.3519530951295929
Loss: -0.13929094718402488
Loss: -0.49393664815895477
Loss: -2.9691273010168313
Loss: -0.3681468157132289
Loss: -2.7036094815799285
Loss: -1.5924728043746585
Loss: -4.236500499143582
Loss: -2.4857191267215266
Loss: -4.054245175620577
Loss: 0.08852705174950731
Loss: -0.8619305474198184
Loss: -1.510816624054502
Loss: -0.45690866789379936
Loss: -2.227293540485137
Loss: -4.035310459354085
Loss: -0.38637046403357117
Loss: -3.1265627274282477
Loss: -0.9678124335972247
Loss: -1.1388815736926912
Loss: -0.28657569678159117
Loss: 0.17278693245011

Loss: -7.53960933169394
Loss: -4.065350281871105
Loss: -0.7183228610120682
Loss: -4.600109803962276
tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)
tensor(-0.9725, dtype=torch.float64, grad_fn=<NegBackward>)
Loss: -0.9725284857619614
Loss: -0.7730489533053886
Loss: -1.4549103035036015
Loss: -3.112540383138763
Loss: -2.524271553838425
Loss: -3.385123074610495
Loss: 0.1596676132997224
Loss: -4.442177838104133
Loss: -1.215587121964959
Loss: -1.2929480769289405
Loss: -4.357722775258153
Loss: -2.6211567214356504
Loss: -2.4356640241301295
Loss: -2.9390015366081155
Loss: -1.6579007013259404
Loss: -5.883294454214101
Loss: -1.9713992081779588
Loss: -4.681642510141437
Loss: -2.73976785923219
Loss: -4.083067758424608
Loss: -1.2720282837356855
Loss: -0.6353079548069025
Loss: -1.5777035429827324
Loss: -0.8870256956988924
Loss: -1.8701039945469573
Loss: -1.3800524371886982
Loss: -0.44345495034442
Loss: -2.19

Loss: -3.244263778587728
Loss: -0.6668867519357036
tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)
tensor(-0.1598, dtype=torch.float64, grad_fn=<NegBackward>)
Loss: -0.1597559240144422
Loss: 0.010310677296564819
Loss: -0.5814815321642939
Loss: -1.167081142022176
Loss: -2.0135327143975044
Loss: -3.0866075948370617
Loss: -2.3316812292904987
Loss: 0.08859219379785066
Loss: -3.1416699659094807
Loss: -5.392264153554474
Loss: -0.02348908571788566
Loss: -3.335533760572375
Loss: -1.7075077687512146
Loss: -0.2738634605285787
Loss: -0.9059362497188681
Loss: -0.7880567372301239
Loss: -1.7263764090114073
Loss: -1.2486435276205452
Loss: -3.649316628772133
Loss: -2.661795294598919
Loss: -1.1322859428398602
Loss: -1.7254924756994774
Loss: -1.7550654449972183
Loss: -4.92407532707834
Loss: -1.8275289480789152
Loss: -1.2683374342661844
Loss: -4.179324611155601
Loss: -2.4248406853521804
Loss: -2.2405292902669007


Loss: -0.06067297599146315
Loss: 0.09764337039690724
Loss: -5.8345717944611515
Loss: -1.311474050722934
Loss: -1.8012178005820625
Loss: -4.702261358180888
tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)
tensor(-4.3519, dtype=torch.float64, grad_fn=<NegBackward>)
Loss: -4.351886718309626
Loss: -0.053066775157621926
Loss: 0.06659161194951013
Loss: -1.0929101926556004
Loss: -1.5217839337624137
Loss: -2.10566118762961
Loss: -2.6883806141191275
Loss: -0.25481276223498417
Loss: -1.2133888507977322
Loss: -1.7347111573311762
Loss: -2.3203510546434805
Loss: -1.2403028736608661
Loss: -0.20974991805660462
Loss: -4.195536294927454
Loss: -1.5398165150366268
Loss: -0.7891494346169807
Loss: -0.5687430806928695
Loss: -5.832119874273461
Loss: -4.365469838513782
Loss: -0.6143241150391301
Loss: -0.716679881060843
Loss: -2.1952278670052237
Loss: -3.181291023142376
Loss: 0.02691723335683105
Loss: -0.677520443927491

In [15]:
g

tensor([[0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666, 0.3666,
         0.3666]], dtype=torch.float64, requires_grad=True)

In [16]:
y

tensor([0.0429, 0.0429, 0.0429, 0.0429, 0.0429, 0.0429, 0.0429, 0.0429, 0.0429,
        0.0429], dtype=torch.float64, requires_grad=True)