In [53]:
import torch
import torch.nn as nn
from torch.distributions import Beta, Uniform, Normal
import torch.optim as optim
# from scipy.stats import norm
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import numpy as np

In [81]:
# class SampleProbs(Function):
#     def __init__(self):
#         super(SampleProbs, self).__init__():
            
#     def forward(alpha, beta):
#         u = Uniform().sample()
#         beta_k = torch.tensor([0])
#         remaining_stick = 1
#         beta_list = []
#         while(beta_k>u):
#             beta_new = Beta(alpha, beta).sample()
#             beta_k = remaining_stick*beta_new
#             beta_list.append(beta_k)
#             remaining_stick = remaining_stick - beta_k
#         probs = torch.tensor(beta_list)
#         ctx.save_for_backward(probs)
#         return probs
    
#     def backward(output_grad):
#         probs = ctx.saved_tensors
#         input_grad_alpha = ().sum()
#         return input_grad_alpha, input_grad_beta

def norm(x, loc, scale):
    return (1/np.sqrt(2*np.pi*scale**2))*torch.exp(-(x-loc)**2/(2*scale**2))

def target(x):
    return 0.5*norm(x, loc=1, scale=0.5) + 0.5*norm(x, loc=-1, scale=0.5)

def KL(prob, theta):
    KL = (prob * torch.log(prob / target(theta))).sum()
    return KL


class SampleProbs(nn.Module):
    def __init__(self):
        super(SampleProbs, self).__init__()
            
    def forward(self, alpha, beta):
        u = Uniform(0,1).sample().cuda()
        beta_k = torch.tensor([0])
        remaining_stick = 1
        beta_list = []
        while(remaining_stick>1e-4):
            u0 = Uniform(0,1).sample().cuda()
            beta_new = (1-u0**(1/beta))**(1/alpha) #torch.sigmoid(beta + alpha)#
            beta_k = remaining_stick*beta_new
            beta_list.append(beta_k)
            remaining_stick = remaining_stick - beta_k
        beta_list.append(remaining_stick)
        probs = torch.tensor(beta_list).cuda()
        return probs
    
            
    
    
class Dirichlet(nn.Module):
    def __init__(self):
        super(Dirichlet, self).__init__()
        self.logalpha = nn.Parameter(torch.tensor([np.log(0.5)], device='cuda'))
        self.logbeta = nn.Parameter(torch.tensor([np.log(0.5)], device='cuda'))
        self.logmean = nn.Parameter(torch.tensor([np.log(1)], device='cuda'))
        
    def forward(self):
        alpha = torch.exp(self.logalpha)
        beta = torch.exp(self.logbeta)
        mean = torch.exp(self.logmean)
        
        prob = SampleProbs()(alpha, beta)
        noise = Normal(loc=0, scale=1).sample(prob.size()).cuda()
        theta = mean + noise
        out = KL(prob, theta)
        return out, prob
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Dirichlet().to(device)

In [82]:
optimizer = optim.SGD(model.parameters(), lr=0.001)
epoch = 10000
for i in range(epoch):
    loss, prob = model() #.to_device()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 1000 == 0:
        print('loss:',loss.data.cpu().numpy(), 'probs:', prob.data.cpu().numpy())


loss: 0.8297592401504517 probs: [  3.80693167e-01   6.17742956e-01   1.37684646e-03   1.20136021e-04
   6.68645880e-05]
loss: 1.7183347940444946 probs: [  2.91489065e-01   6.82435857e-05   1.48234174e-01   4.76091504e-01
   4.71678041e-02   3.63407992e-02   2.13393687e-05   3.22495216e-05
   5.53939200e-04   8.63918103e-07]
loss: 1.3264468908309937 probs: [  7.31804311e-01   4.39009480e-02   1.28426194e-01   5.64674512e-02
   2.27971487e-02   1.42250014e-02   2.09360936e-04   5.74741221e-04
   7.58637907e-04   5.50049299e-04   2.36548512e-04   4.96036664e-05]
loss: 0.7578646540641785 probs: [  1.00445310e-02   5.14874995e-01   3.94638240e-01   7.52717331e-02
   3.16626858e-04   3.41552868e-03   3.35713412e-04   2.16308717e-06
   1.04651775e-03   5.39668836e-05]
loss: 0.7605620622634888 probs: [  6.78521633e-01   2.64677137e-01   1.47507126e-02   2.23997068e-02
   6.95942482e-03   9.81504470e-03   2.61476060e-04   2.01069145e-03
   2.03061689e-04   2.00859184e-04   1.77177659e-04   2.30

   5.76930579e-05   1.05195519e-04   1.94802269e-05]
loss: 0.7377040982246399 probs: [  8.43161047e-01   1.56001911e-01   2.57343872e-05   4.38497722e-04
   3.24799737e-04   4.80111048e-05]
loss: 1.1550740003585815 probs: [  3.01705658e-01   6.84049904e-01   6.26379624e-03   5.88814495e-03
   3.01659893e-05   4.35147464e-04   7.92527688e-04   7.68585247e-04
   6.60696533e-05]
loss: -0.1958269476890564 probs: [  4.23104107e-01   1.02660567e-01   1.58836409e-01   7.60154203e-02
   2.35421360e-01   1.07286491e-04   1.15199387e-03   1.73791114e-03
   9.42245650e-04   2.27220007e-05]
loss: 0.15299633145332336 probs: [  3.75589132e-01   9.98188630e-02   3.92048866e-01   1.75246410e-02
   5.71698602e-03   1.10229128e-03   1.06096908e-01   3.65347078e-04
   1.70952838e-03   2.74101039e-05]
loss: 1.023352861404419 probs: [  3.37085038e-01   5.14517367e-01   1.39348119e-01   2.32007634e-03
   5.50275575e-03   1.05217646e-03   1.53996123e-04   2.05003162e-05]
loss: 1.6588726043701172 probs: [  1.

loss: nan probs: [  6.12311773e-02   9.34390724e-01   3.15779401e-03   7.08043983e-04
   1.84410717e-04   1.12061520e-04   4.42958844e-06   1.92352512e-04
   1.89880375e-05]
loss: nan probs: [  7.69727290e-01   2.97963619e-02   2.38371883e-02   1.74578384e-01
   9.68593406e-04   3.75940945e-05   4.08995140e-04   1.49562238e-05
   5.42953203e-04   8.76788981e-05]
loss: nan probs: [  1.60029754e-01   8.32552850e-01   6.15975587e-03   8.62475834e-04
   1.00085172e-05   1.63161203e-05   2.89592310e-04   7.92321516e-05]
loss: nan probs: [  9.98913527e-01   4.79702838e-04   4.47136757e-04   4.34767899e-05
   8.02796640e-05   3.58774196e-05]
loss: nan probs: [  5.69214225e-01   1.10404141e-01   3.18932652e-01   6.66252396e-04
   7.68003578e-04   1.47329411e-05]
loss: nan probs: [  1.18054509e-01   8.79527628e-01   9.60025180e-04   4.90800885e-04
   7.02867983e-04   3.18233360e-05   2.13020059e-04   1.93249143e-05]
loss: nan probs: [  8.74625981e-01   1.25250861e-01   1.31233319e-05   1.050668

loss: nan probs: [  7.54418015e-01   4.60726134e-02   6.36063367e-02   8.50369874e-03
   1.06201105e-01   1.63270645e-02   4.38987045e-03   3.22770065e-04
   1.25526058e-04   3.29894974e-05]
loss: nan probs: [  7.59765804e-01   7.81132653e-02   1.58632606e-01   1.12719683e-03
   2.33243359e-03   2.87014991e-05]
loss: nan probs: [  9.49110508e-01   1.00968078e-04   5.06575853e-02   1.19666547e-04
   1.12736816e-05]
loss: nan probs: [  9.01440561e-01   4.99777645e-02   4.53753509e-02   1.10390130e-03
   1.80811423e-03   2.24818868e-04   6.94894261e-05]
loss: nan probs: [  2.32107434e-02   3.57082370e-03   5.71974516e-01   3.77682418e-01
   1.35566562e-03   1.08290194e-02   6.81804493e-03   1.06497664e-05
   9.58111923e-05   4.02290933e-03   4.16413503e-04   1.29946566e-05]
loss: nan probs: [  9.02913213e-01   6.64510950e-02   1.80061944e-02   3.85400112e-04
   1.22440737e-02   2.42143869e-08]
loss: nan probs: [  1.42561361e-01   7.12139648e-04   4.28793877e-01   3.00614387e-01
   2.51047

loss: nan probs: [  4.56432849e-02   9.40552711e-01   1.11266144e-03   1.21842604e-03
   5.98500192e-04   1.04095861e-02   3.99223849e-04   6.56206685e-05]
loss: nan probs: [  4.23699543e-02   8.00150573e-01   1.52467817e-01   2.01893272e-03
   5.82488065e-05   2.63927085e-03   1.01145606e-05   7.05807324e-05
   1.30748174e-06   1.51742497e-04   6.14502205e-05]
loss: nan probs: [  7.20534474e-02   9.18534279e-01   7.50692794e-03   1.53952278e-03
   1.99346992e-04   6.42368541e-05   9.83397695e-05   3.91432695e-06]
loss: nan probs: [  6.06401339e-02   3.80581915e-01   5.58772326e-01   5.60283661e-06]
loss: nan probs: [  9.30453777e-01   6.94175288e-02   3.63567324e-07   2.70694755e-09
   6.90515953e-05   5.92760116e-05]
loss: nan probs: [  8.46197903e-01   5.09371012e-02   3.77578125e-03   7.97980726e-02
   1.79868657e-02   1.25720596e-03   4.70684608e-05]
loss: nan probs: [  9.95392799e-01   3.06920521e-03   1.40239578e-03   1.03671315e-04
   3.19283208e-05]
loss: nan probs: [  7.54564

loss: nan probs: [  7.14186728e-01   1.48643479e-01   8.52660239e-02   3.56772952e-02
   9.42569633e-04   5.66254370e-03   7.60342227e-03   7.83640426e-04
   5.38603403e-04   5.97311475e-04   9.83836944e-05]
loss: nan probs: [  3.25578940e-03   9.54026759e-01   1.80457793e-02   6.45835837e-03
   1.82131883e-02   1.30385160e-07]
loss: nan probs: [  9.63975728e-01   1.92982182e-02   8.04285737e-05   5.87146683e-03
   1.07738804e-02   2.77534127e-07]
loss: nan probs: [  7.69175738e-02   1.54798233e-03   7.32719481e-01   8.23353678e-02
   8.40158090e-02   1.33322123e-02   4.97804815e-03   3.46399564e-03
   6.57835044e-04   3.16700898e-05]
loss: nan probs: [  9.85943079e-01   8.84081237e-03   4.02645953e-03   3.68067267e-05
   8.53552541e-04   1.81249125e-04   5.73488023e-06   1.11133246e-04
   1.17262971e-06]
loss: nan probs: [  2.39340603e-01   7.21216321e-01   1.00671016e-02   1.44202234e-02
   7.44900061e-03   1.59643765e-04   6.97324611e-03   3.00568121e-04
   7.32922344e-05]
loss: nan

loss: nan probs: [  6.73052251e-01   7.09413588e-02   2.45854005e-01   2.88151973e-03
   5.32658072e-04   4.17717639e-03   2.30332674e-03   1.37432566e-04
   1.15156916e-04   5.11469261e-06]
loss: nan probs: [  2.03251496e-01   7.34381318e-01   6.07050471e-02   1.06146885e-03
   4.08912747e-04   1.00014113e-04   9.17580401e-05]
loss: nan probs: [  9.61748779e-01   1.81640685e-03   6.52800547e-03   5.37630776e-03
   4.49117180e-03   1.85949281e-02   4.17601201e-04   9.90052824e-04
   3.67490575e-05]
loss: nan probs: [  7.56392956e-01   5.26299118e-05   8.52037128e-03   1.78916706e-03
   1.07555926e-01   2.75924839e-02   4.68522497e-03   1.33369234e-04
   5.97649701e-02   3.34252790e-02   8.76300037e-05]
loss: nan probs: [  9.88835275e-01   1.82432740e-03   7.49885384e-03   1.84153265e-03
   1.05937943e-08]
loss: nan probs: [  9.72360075e-01   1.56869926e-02   4.86303918e-07   2.42663478e-03
   7.85322301e-03   4.31754049e-10   7.33908440e-04   8.85643996e-04
   5.30363177e-05]
loss: nan

loss: nan probs: [  8.61400506e-04   8.29696894e-01   5.72517030e-02   1.13525717e-02
   2.40090452e-02   4.75171469e-02   2.36099996e-02   5.70101058e-03
   2.18395144e-07]
loss: nan probs: [  4.71104473e-01   2.52616405e-01   8.44851695e-03   2.24381313e-01
   4.28019017e-02   6.21851868e-05   1.89858765e-05   5.27096039e-04
   3.90989590e-05]
loss: nan probs: [  5.28828204e-01   3.22667286e-02   2.88595527e-01   1.12774551e-01
   3.49589367e-03   2.61287205e-03   1.27615007e-02   9.18493606e-03
   1.42650644e-03   3.53110861e-03   2.05961452e-03   1.73211272e-03
   9.81529083e-05   6.31181465e-04   1.10088149e-06]
loss: nan probs: [  7.64666200e-01   4.30035442e-02   1.91066086e-01   1.17355725e-03
   9.06125642e-05]
loss: nan probs: [  4.20196086e-01   1.82506703e-02   2.53910050e-02   1.23874269e-01
   1.52512312e-01   1.60169885e-01   3.00370646e-03   8.70734006e-02
   9.81355668e-04   3.18762730e-04   4.69195889e-03   5.46488503e-04
   2.99011730e-03   5.02914190e-08]
loss: nan 

loss: nan probs: [  9.06562924e-01   9.27776396e-02   2.13586609e-04   1.00226491e-04
   2.92922650e-05   6.59960278e-05   9.46441924e-06   2.04251643e-04
   3.66185268e-05]
loss: nan probs: [  7.44250000e-01   2.54940391e-01   5.89184056e-04   1.27324354e-04
   9.31014802e-05]
loss: nan probs: [  4.14593309e-01   3.98981154e-01   1.67610466e-01   3.48092220e-03
   1.51559748e-02   9.33247647e-05   8.48186173e-05]
loss: nan probs: [  9.19268847e-01   5.12732305e-02   1.78588089e-02   9.43369605e-03
   1.71228906e-03   2.76265910e-05   1.60444863e-04   2.54204031e-04
   1.08535751e-05]
loss: nan probs: [  1.59973010e-01   5.22700131e-01   3.06475371e-01   1.02982146e-03
   9.70621873e-03   6.48505520e-05   5.05822245e-05]
loss: nan probs: [  3.40317525e-02   9.57325459e-01   2.35583005e-03   2.07379204e-03
   3.58815282e-03   5.16147702e-05   5.86961069e-05   4.41808777e-04
   7.28986342e-05]
loss: nan probs: [  8.40442896e-01   1.58399731e-01   5.44982497e-04   4.77082358e-04
   1.1902

loss: nan probs: [  4.88687634e-01   3.60112544e-03   5.09409793e-02   1.49775624e-01
   9.95669421e-03   6.64043650e-02   1.13874651e-03   3.64226638e-03
   1.66822657e-01   4.02503349e-02   1.49617279e-02   2.05890465e-04
   3.46626411e-03   1.00045276e-04   4.56294074e-05]
loss: nan probs: [  2.51199119e-02   4.18755319e-03   4.68178779e-01   4.61751312e-01
   1.11746257e-02   2.35198755e-02   1.96799659e-03   3.36637977e-03
   2.53662089e-04   4.29062813e-04   5.08511148e-05]
loss: nan probs: [  9.93432105e-03   1.67881280e-01   5.30956566e-01   1.76429063e-01
   4.97430041e-02   6.29822165e-02   2.00684345e-03   6.67499844e-05]
loss: nan probs: [  7.81235874e-01   1.02667980e-01   9.43649113e-02   2.02823058e-02
   2.17821333e-04   1.03818486e-03   9.57396842e-05   9.71834379e-05]
loss: nan probs: [  2.55099297e-01   4.06483054e-01   2.24457458e-01   1.17551330e-02
   9.97919291e-02   5.25144220e-04   1.35081017e-03   2.62782822e-04
   9.54013686e-08   6.45736145e-06   8.02986033e

loss: nan probs: [  9.85913515e-01   1.30425282e-02   1.04201341e-03   1.94332097e-06]
loss: nan probs: [  6.35551929e-01   4.04425710e-02   1.77832007e-01   1.12699429e-02
   9.92718339e-03   5.45518957e-02   3.24708894e-02   3.66731621e-02
   1.19802612e-03   8.23747832e-05]
loss: nan probs: [  8.84725809e-01   1.79343782e-02   1.72945356e-03   2.96609811e-02
   1.10597834e-02   4.48598973e-02   1.00041535e-02   2.55461782e-05]
loss: nan probs: [  4.98599261e-01   4.64915782e-01   9.90395434e-03   2.13096347e-02
   2.28808681e-03   2.97963456e-03   3.61655839e-06]
loss: nan probs: [  1.75160468e-01   6.15916610e-01   1.54512554e-01   2.28300443e-04
   1.07906829e-03   1.32449577e-02   1.97907002e-03   2.91936733e-02
   1.86740537e-03   4.64574713e-03   2.17214110e-03   4.88944352e-09]
loss: nan probs: [  5.50413281e-02   9.38335478e-01   3.71440081e-03   3.39007238e-05
   2.76224525e-03   1.10243411e-04   2.41844828e-06]
loss: nan probs: [  8.32683265e-01   1.62074938e-01   1.7093892

loss: nan probs: [  3.26639861e-01   5.36064625e-01   3.54087278e-02   5.62274130e-03
   6.14381917e-02   3.37151177e-02   8.59906795e-05   5.66782954e-04
   3.62608116e-04   9.53210401e-05]
loss: nan probs: [  7.96786666e-01   1.48202792e-01   3.22689116e-02   2.27414146e-02
   2.16066837e-07]
loss: nan probs: [  9.28099360e-03   1.06143323e-03   5.24625182e-03   3.88775865e-04
   3.60354602e-01   1.21935301e-01   1.42794579e-01   3.58924389e-01
   1.36792660e-05]
loss: nan probs: [  8.02021444e-01   4.98624370e-02   2.08142344e-02   1.27085283e-01
   6.47004126e-05   7.52031483e-05   7.66848243e-05]
loss: nan probs: [  8.61952841e-01   1.30049497e-01   6.44150656e-04   4.75050649e-03
   3.75238538e-04   4.62536009e-05   2.46143827e-06   1.09763663e-04
   5.12207218e-04   1.52946811e-03   2.76122009e-05]
loss: nan probs: [  4.30298805e-01   1.61159337e-02   1.93212822e-01   1.60250083e-01
   7.99751505e-02   8.43083113e-02   3.48749310e-02   8.80272710e-04
   8.37355619e-05]
loss: nan

loss: nan probs: [  7.72104084e-01   9.94284451e-02   7.70682171e-02   5.12332730e-02
   1.61389878e-04   4.59043076e-06]
loss: nan probs: [  6.57464206e-01   1.51173100e-02   1.50844023e-01   1.71909109e-01
   1.14790811e-04   6.78986893e-04   1.25379057e-03   1.02677767e-03
   1.57721550e-03   1.37836905e-05]
loss: nan probs: [  6.15684927e-01   1.10590316e-01   2.00608358e-01   7.30230063e-02
   9.34004784e-05]
loss: nan probs: [  9.89312887e-01   5.60298096e-03   2.78280443e-03   2.12001894e-03
   6.58635763e-05   1.34186075e-06   1.05605250e-04   8.49779462e-06]
loss: nan probs: [  9.53964829e-01   4.01273519e-02   4.40240186e-03   1.29286409e-03
   2.00906667e-04   1.16460578e-05]
loss: nan probs: [  7.56339133e-01   2.20975056e-01   1.41026825e-02   8.09405185e-03
   4.88601101e-04   4.75498382e-07]
loss: nan probs: [  9.89022255e-01   1.40945963e-03   5.82567991e-05   4.98903792e-05
   9.44305863e-03   1.70795247e-05]
loss: nan probs: [  2.50724912e-01   7.45092273e-01   3.8853

loss: nan probs: [  6.68562353e-02   2.17856750e-01   7.10452974e-01   4.61601373e-03
   4.45415717e-05   1.55533024e-04   1.79079652e-05]
loss: nan probs: [  3.15224864e-02   9.67284977e-01   5.43165021e-04   2.26807242e-04
   4.07268206e-04   1.52696448e-05]
loss: nan probs: [  1.07115038e-01   8.91904950e-01   1.10804052e-04   4.57069109e-04
   1.54393711e-04   2.11503357e-04   4.62493626e-05]
loss: nan probs: [  3.23570579e-01   3.67521495e-01   8.46999325e-03   2.94667900e-01
   2.80846516e-03   2.36262195e-03   8.63957466e-05   1.74556539e-04
   2.00115159e-04   2.60841352e-05   5.06893548e-06   4.20861761e-05
   6.46040207e-05]
loss: nan probs: [  3.47820312e-01   6.52179480e-01   2.38418579e-07]
loss: nan probs: [  2.33570278e-01   7.61312127e-01   5.10850130e-03   9.09389928e-06]
loss: nan probs: [  5.26435316e-01   3.45025817e-03   2.88085550e-01   1.17451414e-01
   5.40128201e-02   1.04839206e-02   8.07344913e-05]
loss: nan probs: [  8.44198287e-01   1.57629550e-02   1.22898

loss: nan probs: [  8.49339724e-01   1.31401688e-01   1.91627089e-02   9.58796591e-05]
loss: nan probs: [  8.58480692e-01   2.31458675e-02   5.50301000e-02   2.11236626e-03
   1.52583839e-02   4.56804819e-02   3.56701166e-05   8.42468871e-05
   1.65506382e-04   6.68407301e-06]
loss: nan probs: [  7.19590425e-01   8.54123011e-03   1.95686077e-03   3.35124061e-02
   1.71031594e-01   4.70659584e-02   1.15082702e-02   3.73794086e-04
   6.36136159e-03   5.81056811e-05]
loss: nan probs: [  8.14790785e-01   1.79006413e-01   2.20638746e-03   5.76280116e-04
   4.69807448e-04   2.67230510e-03   1.27597095e-05   8.52338144e-06
   1.09162676e-04   1.47570056e-04   6.09725248e-09]
loss: nan probs: [  2.82365326e-02   1.91657040e-02   3.84247936e-02   2.12420017e-01
   3.53471518e-01   3.01597655e-01   4.31737900e-02   2.90256320e-03
   2.19962472e-04   2.82947818e-04   4.26272521e-07   9.25220811e-05
   1.16360898e-05]
loss: nan probs: [  9.97833014e-01   8.16892309e-04   1.24236417e-03   7.1330228

loss: nan probs: [  7.34991252e-01   1.62244469e-01   2.83625852e-02   6.38647825e-02
   1.04955025e-02   4.14066017e-05]
loss: nan probs: [  9.77982223e-01   2.07003113e-02   7.99680769e-04   5.14284824e-04
   3.50008486e-06]
loss: nan probs: [  7.94389606e-01   1.89042762e-01   1.88679274e-04   1.29580554e-02
   1.77094992e-03   5.71911572e-04   1.02657080e-03   5.14663989e-05]
loss: nan probs: [  9.35706019e-01   4.98849433e-04   3.62357572e-02   2.49708034e-02
   5.03649680e-05   1.85611891e-03   2.75279075e-04   1.18988370e-04
   2.62902118e-04   2.49203877e-05]
loss: nan probs: [  9.99626458e-01   3.33653763e-04   3.98885459e-05]
loss: nan probs: [  9.67913628e-01   2.88924370e-02   8.28336633e-05   1.01268946e-04
   9.21409577e-04   7.30218453e-05   6.01353124e-04   7.77468551e-04
   2.21815004e-04   2.37871296e-04   4.26408442e-05   6.00732201e-05
   7.41793192e-05]
loss: nan probs: [  5.54704249e-01   1.43757218e-03   9.65673700e-02   3.41024727e-01
   2.23000930e-03   6.95094

loss: nan probs: [  5.82922459e-01   2.05287039e-01   1.26476035e-01   7.83648863e-02
   3.55689926e-03   2.71321018e-03   6.33433927e-04   4.60378360e-05]
loss: nan probs: [  9.85910296e-01   1.11949339e-04   1.36012211e-03   1.00887846e-02
   2.43783533e-03   9.10118688e-05]
loss: nan probs: [  3.85760844e-01   5.94794154e-01   1.59796271e-02   2.21403006e-05
   3.35388980e-03   8.93448014e-05]
loss: nan probs: [  3.25095728e-02   6.40129298e-02   3.20787309e-04   9.66693312e-02
   4.84641865e-02   2.90763348e-01   2.52093613e-01   1.15101375e-01
   7.33963922e-02   2.65712198e-02   9.72542912e-05]
loss: nan probs: [  9.34435949e-02   2.21959725e-02   8.58195961e-01   6.35359716e-03
   1.82565842e-02   1.01169811e-04   3.21853440e-04   9.64539184e-04
   1.01139485e-04   6.55879194e-05]
loss: nan probs: [  5.66759147e-03   3.64591507e-03   2.07277760e-01   2.09628418e-01
   2.67082565e-02   3.62018347e-01   6.73394799e-02   9.07957926e-03
   1.08509116e-01   4.71154893e-10   7.7231947

loss: nan probs: [  7.88930833e-01   3.96835897e-03   1.98142812e-01   7.47689139e-03
   3.98564298e-04   5.48246026e-04   5.04646916e-04   2.96483631e-05]
loss: nan probs: [  6.38375819e-01   6.79615960e-02   2.08252713e-01   1.55083546e-02
   4.42210101e-02   2.56064460e-02   7.40550458e-05]
loss: nan probs: [  2.60092854e-01   5.40693820e-01   1.55789210e-02   4.00478346e-03
   7.28897890e-03   2.18247436e-03   2.61372491e-03   5.39402403e-02
   9.42457020e-02   3.53717059e-03   1.44532705e-02   1.09902001e-03
   1.44431287e-05   2.27840210e-04   2.67713331e-05]
loss: nan probs: [  7.99432918e-02   3.92457247e-01   1.73739597e-01   9.89935100e-02
   1.88761458e-01   3.96328010e-02   2.16504652e-02   2.81794416e-03
   1.97268650e-06   9.14364064e-04   5.05614269e-04   7.97617322e-05
   1.11509602e-04   3.87265871e-04   3.17531521e-06]
loss: nan probs: [  9.71660495e-01   1.71488598e-02   4.89731738e-03   1.26756515e-06
   6.26326073e-03   2.87997536e-05]
loss: nan probs: [  6.8492019

loss: nan probs: [  9.47287738e-01   4.32566181e-02   5.57741476e-03   1.81286805e-03
   1.42546651e-05   1.59244530e-03   2.34307518e-04   8.00268972e-05
   6.93181073e-06   1.37365394e-04   2.92784534e-08]
loss: nan probs: [  7.36474812e-01   4.56385054e-02   3.63621935e-02   3.92246321e-02
   7.46374577e-02   8.43531266e-03   5.70415407e-02   9.76694748e-04
   1.65918918e-05   5.13774343e-04   7.45139914e-06   2.20501406e-05
   2.96491780e-04   3.01792519e-04   5.07027726e-05]
loss: nan probs: [  8.65980268e-01   1.27888486e-01   6.02409709e-03   6.31941930e-06
   5.73784055e-05   4.34517715e-05]
loss: nan probs: [  9.87511516e-01   4.11995780e-03   7.83970021e-03   2.15044420e-05
   5.04263502e-04   3.05840513e-06]
loss: nan probs: [  5.92293382e-01   1.64259583e-01   6.75434768e-02   1.68043122e-01
   7.47500965e-03   7.14441194e-06   1.71946231e-04   2.05640754e-04
   6.95989002e-07]
loss: nan probs: [  3.76281768e-01   5.89017034e-01   2.25491710e-02   8.59586708e-03
   6.510301

KeyboardInterrupt: 

In [6]:
torch.tensor([torch.tensor([1]),torch.tensor([2])])

tensor([1, 2])

In [7]:
torch.cuda.is_available()

True