In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.optim as optim
import matplotlib.pyplot as plt
import math
from torchvision import datasets, transforms
from torch.autograd import Variable
from PSO import PSO
from simple_cnn import simpleCNN
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [3]:
transform = transforms.Compose([
transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
data_train = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)

data_test = datasets.MNIST(root="./data/",
                           transform = transform,
                           train = False)

In [4]:
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size = 256,
                                                shuffle = True)

data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size = 256,
                                               shuffle = True)

In [5]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
model = LeNet()
model.to(device)

LeNet(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [7]:
def train_pso(model, train_data, test_data, epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = PSO(model)

    for epoch in range(epochs):
        # trainning
#         best_loss = float('INF')
        loss_history = np.zeros((5000, 1))
        for batch_idx, (inp, target) in enumerate(train_data):
            inp, target = inp.to(device), target.to(device)
            inp, target = Variable(inp), Variable(target)
            def _closure():
                out = model(inp)
                return criterion(out, target)
            out = model(inp)
            optimizer.step(_closure)
            loss = optimizer.best_loss
            loss_history[batch_idx] = loss
            if (batch_idx%100 == 0):
                print("%d Loss=%.3f" % (epoch, np.mean(loss_history[np.where(loss_history)])))
        # testing
        correct_cnt, ave_loss = 0, 0
        total_cnt = 0
        for batch_idx, (x, target) in enumerate(test_data):
            x, target = x.cuda(), target.cuda()
            x, target = Variable(x, volatile=True), Variable(target, volatile=True)
            out = model(x)
            loss = criterion(out, target)
            
            _, pred_label = torch.max(out.data, 1)
            total_cnt += x.data.size()[0]
            correct_cnt += (pred_label == target.data).sum()
            # smooth average

            if (batch_idx+1) == len(test_data):
                print ('==>>> epoch: {}, batch index: {}, acc: {:.3f}'.format(
                    epoch, batch_idx+1, correct_cnt * 1.0 / total_cnt))

In [None]:
train_pso(model, data_loader_train, data_loader_test, 1000)

0 Loss=36.983
0 Loss=6.852
0 Loss=5.960




==>>> epoch: 0, batch index: 40, acc: 0.086
1 Loss=4.701
1 Loss=4.701
1 Loss=4.701
==>>> epoch: 1, batch index: 40, acc: 0.195
2 Loss=4.665
2 Loss=4.665
2 Loss=4.665
==>>> epoch: 2, batch index: 40, acc: 0.207
3 Loss=4.665
3 Loss=4.665
3 Loss=4.665
==>>> epoch: 3, batch index: 40, acc: 0.175
4 Loss=4.665
4 Loss=4.665
4 Loss=4.665
==>>> epoch: 4, batch index: 40, acc: 0.218
5 Loss=4.665
5 Loss=4.665
5 Loss=4.665
==>>> epoch: 5, batch index: 40, acc: 0.103
6 Loss=4.665
6 Loss=4.665
6 Loss=4.665
==>>> epoch: 6, batch index: 40, acc: 0.184
7 Loss=4.665
7 Loss=4.665
7 Loss=4.665
==>>> epoch: 7, batch index: 40, acc: 0.145
8 Loss=4.665
8 Loss=4.574
8 Loss=4.548
==>>> epoch: 8, batch index: 40, acc: 0.223
9 Loss=4.522
9 Loss=4.522
9 Loss=4.522
==>>> epoch: 9, batch index: 40, acc: 0.204
10 Loss=4.522
10 Loss=4.522
10 Loss=4.522
==>>> epoch: 10, batch index: 40, acc: 0.224
11 Loss=4.522
11 Loss=4.522
11 Loss=4.522
==>>> epoch: 11, batch index: 40, acc: 0.226
12 Loss=4.522
12 Loss=4.522
12 Loss