In [37]:
# Lab 10 MNIST and softmax
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pylab as plt

In [38]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# for reproducibility
torch.manual_seed(1)
if device == 'cuda':
    torch.cuda.manual_seed_all(1)
    
# parameters
learning_rate = 0.01
training_epochs = 10
batch_size = 32

In [39]:
# MNIST dataset
mnist_train = dsets.MNIST(root='MNIST_data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/',
                         train=False,
                         transform=transforms.ToTensor(),
                         download=True)

In [40]:
 # dataset loader
train_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)

test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          drop_last=True)

In [41]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 32, bias =True)
        self.fc2 = nn.Linear(32, 32, bias =True)
        self.fc3 = nn.Linear(32, 10, bias =True)
        self.bn1 = torch.nn.BatchNorm1d(32)
        self.bn2 = torch.nn.BatchNorm1d(32)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn1(self.fc2(x)))
        x = self.fc3(x)
        return x

print("init model done")

init model done


In [53]:
# define cost/loss & optimizer
model = Net().to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)    # Softmax is internally computed.
optimizer = torch.optim.Adam(bn_model.parameters(), lr=learning_rate)

In [58]:
#Define Train function and Test function to validate.

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.view(-1,28 * 28).to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = (criterion(output,target))
        loss.backward()
        optimizer.step()
        #print(output.shape,target.shape)
    print("epoch : ", epoch)
    print("training loss : ",loss.item())
        
    return data,target,output

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.view(-1,28 * 28).to(device), target.to(device)
            output = model(data)            
            test_loss += (criterion(output, target)).item()
    test_loss /= len(test_loader.dataset)/(batch_size)
    print("test loss : ",test_loss)
    return (data, target ,output)     

In [60]:
# Train and Test the model and save it.
print("training")
for epoch in range(1, training_epochs+1):
    train(model, device, train_loader, optimizer, epoch)
DATA, TARGET, OUTPUT = test(model, device, test_loader)


training
epoch :  1
training loss :  2.4162983894348145
epoch :  2
training loss :  2.3375144004821777
epoch :  3
training loss :  2.397766590118408
epoch :  4
training loss :  2.375211000442505
epoch :  5
training loss :  2.383662223815918
epoch :  6
training loss :  2.3326337337493896
epoch :  7
training loss :  2.3297290802001953
epoch :  8
training loss :  2.354098081588745
epoch :  9
training loss :  2.2777340412139893
epoch :  10
training loss :  2.3781213760375977
test loss :  2.47136481552124


In [61]:
for tup in list(zip(TARGET,max(OUTPUT)):
    print(tup)

(tensor(0), tensor([-0.4924, -0.1519, -0.1132,  0.4359, -0.8443, -0.1157,  0.6434, -0.2933,
         0.5668, -0.9996]))
(tensor(6), tensor([ 0.1797,  0.1022, -0.0864,  0.4608, -0.4559,  0.1060,  0.4205, -0.4455,
         0.0742, -0.6775]))
(tensor(2), tensor([-0.1841,  0.6403,  1.3885,  0.1246, -1.1655,  0.0736,  0.3090, -0.9560,
        -0.3105, -0.9588]))
(tensor(1), tensor([-0.0498,  0.2157,  0.1627,  0.5169, -0.4822,  0.1588,  0.2144, -0.3291,
         0.1007, -0.4106]))
(tensor(1), tensor([-0.1003,  0.3250,  0.2894,  0.4009, -0.5822,  0.1804,  0.1346, -0.3168,
         0.1032, -0.2565]))
(tensor(7), tensor([-0.1031,  0.3368, -0.0236,  0.3962, -0.8828,  0.3711,  0.9947, -0.3711,
         0.4466, -0.5609]))
(tensor(7), tensor([ 0.0988,  0.3227,  0.0365,  0.7360, -0.9248,  0.2196,  0.6376, -0.5244,
         0.4530, -0.7183]))
(tensor(8), tensor([-0.2513,  0.3625,  0.2211,  0.2090, -0.8944,  0.1603,  0.3615, -0.5286,
         0.2263, -0.3525]))
(tensor(4), tensor([ 0.0486,  0.5827,  0