In [9]:
import torch
import torchvision
import torchvision.transforms as transforms
from rbm import RBM
import numpy as np

In [2]:
transform = transforms.Compose([transforms.ToTensor()])
batch_size = 128

trainset = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
bin_train_x = []
bin_train_y = []
for i, (x,y) in enumerate(trainset):
    x[x>0.5] = 1
    x[x<0.5] = 0
    bin_train_x.append(x)
    bin_train_y.append(torch.tensor(y))
bin_train_data = torch.utils.data.TensorDataset(torch.stack(bin_train_x), torch.stack(bin_train_y))
trainloader = torch.utils.data.DataLoader(bin_train_data, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
bin_test_x = []
bin_test_y = []
for i, (x,y) in enumerate(testset):
    x[x>0.5] = 1
    x[x<0.5] = 0
    bin_test_x.append(x)
    bin_test_y.append(torch.tensor(y))
bin_test_data = torch.utils.data.TensorDataset(torch.stack(bin_test_x), torch.stack(bin_test_y))
testloader = torch.utils.data.DataLoader(bin_test_data, batch_size=batch_size, shuffle=False, num_workers=2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
rmb = torch.load("./models/bin_k2_500_15iters").to(device)

In [11]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(500, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net().to(device)

In [12]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [19]:
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    loss_ = []
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()
        
        # reshape inputs
        inputs = inputs.view(-1, 784)
        
        # send to gpu
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # convert to h
        rmb.eval()
        h = rmb.visible_to_hidden(inputs)

        # forward + backward + optimize
        outputs = net(h)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        loss_.append(loss.item())
        
    print(f"epoch {epoch+1} loss: {np.mean(loss_)}")
    
print('Finished Training')

epoch 1 loss: 0.21042371469774226
epoch 2 loss: 0.20574578444268912
epoch 3 loss: 0.2035207254831979
epoch 4 loss: 0.19948548121429455
epoch 5 loss: 0.19869840396111454
epoch 6 loss: 0.19774184441134365
epoch 7 loss: 0.19253657533447627
epoch 8 loss: 0.1916420421620676
epoch 9 loss: 0.18825506997197422
epoch 10 loss: 0.18781734496227967
epoch 11 loss: 0.18527372430827319
epoch 12 loss: 0.18301729788022764
epoch 13 loss: 0.18151527335013407
epoch 14 loss: 0.1789666433244753
epoch 15 loss: 0.17663834402873826
epoch 16 loss: 0.17473675270896477
epoch 17 loss: 0.1744324513462815
epoch 18 loss: 0.1727280708105325
epoch 19 loss: 0.1706992977543045
epoch 20 loss: 0.16791899568204688
Finished Training


In [20]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        
        # reshape inputs
        inputs = inputs.view(-1, 784)
        
        # send to gpu
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # convert to h
        rmb.eval()
        h = rmb.visible_to_hidden(inputs)
        
        outputs = net(h)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {float(correct) / float(total)}')

Accuracy of the network on the 10000 test images: 0.9467
