In [1]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image
%matplotlib inline
import matplotlib.pyplot as plt
import random as rd


In [2]:
batch_size = 64
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data',
    train=True,
    download = True,
    transform = transforms.Compose(
        [transforms.ToTensor()])
     ),
     batch_size=batch_size
)

test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data',
    train=False,
    transform=transforms.Compose(
    [transforms.ToTensor()])
    ),
    batch_size=batch_size)

In [6]:
class RBM(nn.Module):
     def __init__(self,
                 n_vis=784,
                 n_hin=1000):
          super(RBM, self).__init__()
          self.W = nn.Parameter(torch.randn(n_hin,n_vis)*0.01)
          self.v_bias = nn.Parameter(torch.zeros(n_vis))
          self.h_bias = nn.Parameter(torch.zeros(n_hin))
      
     def sample_from_p(self, p : torch.Tensor):
          return torch.relu(torch.sign(p - Variable(torch.rand(p.size()))))
      
     def v_to_h(self, v : torch.Tensor):
          p_h = torch.sigmoid(F.linear(v,self.W,self.h_bias))
          sample_h = self.sample_from_p(p_h)
          return p_h,sample_h
      
     def h_to_v(self, h : torch.Tensor):
          p_v = torch.sigmoid(F.linear(h,self.W.t(),self.v_bias))
          sample_v = self.sample_from_p(p_v)
          return p_v,sample_v
          
     def forward(self, v : torch.Tensor):
          pre_h1, h1 = self.v_to_h(v)
          pre_v_,v_ = self.h_to_v(h1)          
          return v,v_
      
     def free_energy(self, v : torch.Tensor):
          vbias_term = v.mv(self.v_bias)
          wx_b = F.linear(v,self.W,self.h_bias)
          hidden_term = wx_b.exp().add(1).log().sum(1)
          return (-hidden_term - vbias_term).mean()

     def fit(self, train_loader, epochs=10, lr=0.01):
          opti = optim.SGD(self.parameters(), lr=lr)
          for epoch in range(epochs):
               train_loss = 0
               for batch_idx, (data, target) in enumerate(train_loader):
                    data = Variable(data.view(-1,784))
                    v0,vk = self(data)
                    loss = self.free_energy(v0) - self.free_energy(vk)
                    train_loss += loss.item()
                    opti.zero_grad()
                    loss.backward()
                    opti.step()
               print('Epoch : {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))
          return 


In [8]:
rbm = RBM()
rbm.fit(train_loader, epochs=15, lr=0.1)

In [None]:
for data,_ in train_loader:
    v = Variable(data.view(-1,784))
    v = v.bernoulli()
    print(v.shape)
    _,v = rbm(v)
    b = v.data.numpy().reshape(64,28,28)[0]
    print(b)
    plt.imshow(b)
    # plt.plot(b)
    # plt.show()

torch.Size([64, 784])
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0.]

: 

: 