In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch import nn


In [2]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                       download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:18<00:00, 1392236.14it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 72282.31it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:02<00:00, 1942042.10it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 12723793.16it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [None]:
class RBM():
    def __init__(self, nv, nh):
        self.W = torch.randn(nh, nv)
        self.a = torch.randn(1, nh)
        self.b = torch.randn(1, nv)
        
    def sample_h(self, x):
        wx = torch.mm(x, self.W.t())
        activation = wx + self.a.expand_as(wx)
        p_h_given_v = torch.sigmoid(activation)
        return p_h_given_v, torch.bernoulli(p_h_given_v)
    
    def sample_v(self, y):
        wy = torch.mm(y, self.W)
        activation = wy + self.b.expand_as(wy)
        p_v_given_h = torch.sigmoid(activation)
        return p_v_given_h, torch.bernoulli(p_v_given_h)
    
    def train(self, v0, vk, ph0, phk):
        self.W += (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t()
        self.b += torch.sum((v0-vk), 0)
        self.a += torch.sum((ph0-phk), 0)

In [3]:
class DBN(nn.Module):
    def __init__(self, num_visible, num_hidden1, num_hidden2, num_classes):
        super(DBN, self).__init__()
        self.fc1 = RBM(num_visible, num_hidden1)
        self.fc2 = RBM(num_hidden1, num_hidden2)
        self.fc3 = RBM(num_hidden2, num_classes)
        self.activation = nn.Sigmoid()
        
    def forward(self, x):
        _,x = self.fc1.sample_h(x)
        _,x = self.fc1.sample_v(x)
        
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


In [4]:
num_visible = 784
num_hidden1 = 500
num_hidden2 = 250
num_classes = 10

dbn = DBN(num_visible, num_hidden1, num_hidden2, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(dbn.parameters(), lr=0.001)


In [5]:
num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = Variable(inputs.view(-1, num_visible)), Variable(labels)
        
        optimizer.zero_grad()
        
        outputs = dbn(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0


[1,   100] loss: 1.248
[1,   200] loss: 0.662
[1,   300] loss: 0.548
[1,   400] loss: 0.522
[1,   500] loss: 0.485
[1,   600] loss: 0.463
[1,   700] loss: 0.441
[1,   800] loss: 0.434
[1,   900] loss: 0.433
[2,   100] loss: 0.393
[2,   200] loss: 0.405
[2,   300] loss: 0.373
[2,   400] loss: 0.393
[2,   500] loss: 0.385
[2,   600] loss: 0.386
[2,   700] loss: 0.386
[2,   800] loss: 0.360
[2,   900] loss: 0.373
[3,   100] loss: 0.355
[3,   200] loss: 0.328
[3,   300] loss: 0.349
[3,   400] loss: 0.349
[3,   500] loss: 0.348
[3,   600] loss: 0.339
[3,   700] loss: 0.352
[3,   800] loss: 0.338
[3,   900] loss: 0.320
[4,   100] loss: 0.329
[4,   200] loss: 0.322
[4,   300] loss: 0.306
[4,   400] loss: 0.328
[4,   500] loss: 0.315
[4,   600] loss: 0.293
[4,   700] loss: 0.320
[4,   800] loss: 0.309
[4,   900] loss: 0.315
[5,   100] loss: 0.274
[5,   200] loss: 0.298
[5,   300] loss: 0.286
[5,   400] loss: 0.301
[5,   500] loss: 0.300
[5,   600] loss: 0.294
[5,   700] loss: 0.296
[5,   800] 

In [6]:
correct = 0
total = 0
for data in testloader:
    images, labels = data
    images, labels = Variable(images.view(-1, num_visible)), Variable(labels)
    outputs = dbn(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Accuracy of the DBN classifier on the 10000 test images: %d %%' % (
    100 * correct / total))


Accuracy of the DBN classifier on the 10000 test images: 88 %
