<a href="https://colab.research.google.com/github/Jaeyoung-Choi/mnist_tutorial/blob/master/mnist_backpropagation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torchvision import datasets as dsets
import torchvision.transforms as transforms

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(42)
if device == 'cuda':
    torch.cuda.manual_seed_all(42)
    print(torch.cuda.current_device(), torch.cuda.device_count(), torch.cuda.get_device_name(0))

In [3]:
def sigmoid(X):
    return 1 / (1 + torch.exp(-X))

def softmax(X):
    return torch.exp(X) / torch.sum(torch.exp(X), dim = 0)

In [4]:
def dsigmoid(X):
    # return torch.exp(-X) / torch.pow(1 + torch.exp(-X), 2)
    return torch.mul(sigmoid(X), 1 - sigmoid(X))

def dsoftmax(X):
    #return torch.mul((torch.sum(torch.exp(X), dim = 1) - torch.exp(X)) / torch.pow(torch.sum(torch.exp(X), dim = 1), 2), torch.exp(X))
    return softmax(X) - torch.pow(softmax(X), 2)

In [5]:
lr = 0.001

In [6]:
class model:
    def __init__(self):
        self.W = [torch.normal(torch.zeros((784, 16)), torch.tensor([2])).to(device),
                  torch.normal(torch.zeros((16, 16)), torch.tensor([2])).to(device),
                  torch.normal(torch.zeros((16, 10)), torch.tensor([2])).to(device)]
        self.B = [torch.normal(torch.zeros((16)), torch.tensor(2)).to(device),
                  torch.normal(torch.zeros((16)), torch.tensor(2)).to(device),
                  torch.normal(torch.zeros((10)), torch.tensor(2)).to(device)]
    
    def forward(self, X):
        L1 = sigmoid(torch.matmul(X, self.W[0]) + self.B[0])
        L2 = sigmoid(torch.matmul(L1, self.W[1]) + self.B[1])
        Y = softmax(torch.matmul(L2, self.W[2]) + self.B[2])
        return Y

    def loss(self, X, Y):
        return torch.sum(torch.pow(self.forward(X) - Y, 2)) # / 원소 갯수(어케 구함?)
    
    def backward(self, X, Y):
        dW = [0, 0, 0]
        dB = [0, 0, 0]
        L1 = torch.matmul(X, self.W[0]) + self.B[0]
        L2 = torch.matmul(sigmoid(L1), self.W[1]) + self.B[1]
        FOR = torch.matmul(sigmoid(L2), self.W[2]) + self.B[2]

        dB[2] = torch.mul(dsoftmax(FOR), 2 * (softmax(FOR) - Y))
        dW[2] = torch.matmul(torch.transpose(torch.unsqueeze(sigmoid(L2), 0), 0, 1), torch.unsqueeze(dB[2], 0))

        dB[1] = torch.mul(torch.matmul(dB[2], torch.transpose(self.W[2], 0, 1)), dsigmoid(L2))
        dW[1] = torch.matmul(torch.transpose(torch.unsqueeze(sigmoid(L1), 0), 0, 1), torch.unsqueeze(dB[1], 0))

        dB[0] = torch.mul(torch.matmul(dB[1], torch.transpose(self.W[1], 0, 1)), dsigmoid(L1))
        dW[0] = torch.matmul(torch.transpose(torch.unsqueeze(X, 0), 0, 1), torch.unsqueeze(dB[0], 0))


        self.W[2] = self.W[2] + torch.mul(dW[2], -1 * lr)
        self.B[2] = self.B[2] + torch.mul(dB[2], -1 * lr)

        self.W[1] = self.W[1] + torch.mul(dW[1], -1 * lr)
        self.B[1] = self.B[1] + torch.mul(dB[1], -1 * lr)

        self.W[0] = self.W[0] + torch.mul(dW[0], -1 * lr)
        self.B[0] = self.B[0] + torch.mul(dB[0], -1 * lr)

In [7]:
mnist_train = dsets.MNIST(root='MNIST_data/', # 다운로드 경로 지정
                          train=True, # True를 지정하면 훈련 데이터로 다운로드
                          transform=transforms.ToTensor(), # 텐서로 변환
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/', # 다운로드 경로 지정
                         train=False, # False를 지정하면 테스트 데이터로 다운로드
                         transform=transforms.ToTensor(), # 텐서로 변환
                         download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [8]:
def one_hot_encoding(x):
    arg = [0] * 10
    arg[x] = 1
    return arg

In [9]:
mnist_test_X = []
for i in range(10000):
    mnist_test_X.append(mnist_test[i][0].reshape(1, 784))

mnist_test_X = torch.cat(mnist_test_X, dim = 0)
mnist_test_X = mnist_test_X.to(device)

In [10]:
mnist_train_X = []
for i in range(60000):
    mnist_train_X.append(mnist_train[i][0].reshape(1, 784))

mnist_train_X = torch.cat(mnist_train_X, dim = 0)
mnist_train_X = mnist_train_X.to(device)

In [11]:
mnist_test_Y = []
for i in range(10000):
    mnist_test_Y.append(one_hot_encoding(mnist_test[i][1]))
mnist_test_Y = torch.tensor(mnist_test_Y).to(device)

In [12]:
mnist_train_Y = []
for i in range(60000):
    mnist_train_Y.append(one_hot_encoding(mnist_train[i][1]))
mnist_train_Y = torch.tensor(mnist_train_Y).to(device)

In [13]:
M = model()

In [None]:
M.W[0] = torch.load('/content/drive/MyDrive/mnist_train/88.3W0')
M.W[1] = torch.load('/content/drive/MyDrive/mnist_train/88.3W1')
M.W[2] = torch.load('/content/drive/MyDrive/mnist_train/88.3W2')

M.B[0] = torch.load('/content/drive/MyDrive/mnist_train/88.3B0')
M.B[1] = torch.load('/content/drive/MyDrive/mnist_train/88.3B1')
M.B[2] = torch.load('/content/drive/MyDrive/mnist_train/88.3B2')

In [14]:
for tr in range(100):
    loss_sum = torch.tensor([0], dtype=torch.float).to(device)
    for i in range(60000):
        #print(tr, i, M.loss(mnist_train_X[i], mnist_train_Y[i]).to(torch.device('cpu')), end = '\t')
        M.backward(mnist_train_X[i], mnist_train_Y[i])
        loss_sum += M.loss(mnist_train_X[i], mnist_train_Y[i])
        #print(M.loss(mnist_train_X[i], mnist_train_Y[i]).to(torch.device('cpu')))
    print(tr, loss_sum)

0 tensor([70694.4062])
1 tensor([54909.1836])
2 tensor([53013.5977])
3 tensor([52208.5820])
4 tensor([51711.8828])
5 tensor([51315.1016])
6 tensor([50961.8086])
7 tensor([50633.8867])
8 tensor([50324.5625])
9 tensor([50018.0938])
10 tensor([49699.2461])
11 tensor([49353.9531])
12 tensor([48943.9922])
13 tensor([48407.7422])
14 tensor([47785.3594])
15 tensor([47158.8711])
16 tensor([46549.4961])
17 tensor([45947.6289])
18 tensor([45344.4062])
19 tensor([44733.5977])
20 tensor([44106.2500])
21 tensor([43438.1172])
22 tensor([42771.6914])
23 tensor([42145.9219])
24 tensor([41565.3555])
25 tensor([41032.4883])
26 tensor([40537.5625])
27 tensor([40076.])
28 tensor([39635.7578])
29 tensor([39216.1406])
30 tensor([38814.8398])
31 tensor([38432.9492])
32 tensor([38063.1836])
33 tensor([37703.3945])
34 tensor([37353.4844])
35 tensor([37013.5820])
36 tensor([36682.1758])
37 tensor([36356.3633])
38 tensor([36035.4727])
39 tensor([35717.9219])
40 tensor([35409.5820])
41 tensor([35109.5547])
42 ten

In [None]:
s_hit = 0
for i in range(10000):
    ans_y = torch.argmax(mnist_test_Y[i])
    exp_y = torch.argmax(M.forward(mnist_test_X[i]))
    print(ans_y, exp_y)
    if ans_y == exp_y:
        s_hit += 1

print(s_hit / 100)

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
tensor(2) tensor(2)
tensor(8) tensor(8)
tensor(7) tensor(7)
tensor(4) tensor(4)
tensor(4) tensor(4)
tensor(0) tensor(0)
tensor(9) tensor(9)
tensor(3) tensor(3)
tensor(9) tensor(7)
tensor(5) tensor(5)
tensor(2) tensor(7)
tensor(6) tensor(6)
tensor(9) tensor(9)
tensor(1) tensor(1)
tensor(8) tensor(6)
tensor(6) tensor(6)
tensor(3) tensor(3)
tensor(6) tensor(6)
tensor(6) tensor(2)
tensor(8) tensor(8)
tensor(7) tensor(7)
tensor(8) tensor(8)
tensor(0) tensor(0)
tensor(4) tensor(4)
tensor(4) tensor(4)
tensor(8) tensor(8)
tensor(0) tensor(0)
tensor(3) tensor(5)
tensor(0) tensor(0)
tensor(0) tensor(0)
tensor(1) tensor(1)
tensor(2) tensor(2)
tensor(2) tensor(2)
tensor(7) tensor(9)
tensor(3) tensor(3)
tensor(1) tensor(1)
tensor(4) tensor(4)
tensor(3) tensor(3)
tensor(6) tensor(6)
tensor(2) tensor(2)
tensor(7) tensor(7)
tensor(5) tensor(2)
tensor(8) tensor(8)
tensor(2) tensor(3)
tensor(0) tensor(0)
tensor(1) tensor(1)
tensor(1) tensor(1)
tensor(0) 

In [None]:
while True:
    for tr in range(100):
        for i in range(60000):
            #print(tr, i, M.loss(mnist_train_X[i], mnist_train_Y[i]).to(torch.device('cpu')), end = '\t')
            M.backward(mnist_train_X[i], mnist_train_Y[i])
            #print(M.loss(mnist_train_X[i], mnist_train_Y[i]).to(torch.device('cpu')))

    s_hit = 0
    for i in range(10000):
        ans_y = torch.argmax(mnist_test_Y[i])
        exp_y = torch.argmax(M.forward(mnist_test_X[i]))
        if ans_y == exp_y:
            s_hit += 1

    print(s_hit / 100)
    
    torch.save(M.W[0], '/content/drive/MyDrive/mnist_train/'+str(s_hit / 100)+'W0.pt')
    torch.save(M.W[1], '/content/drive/MyDrive/mnist_train/'+str(s_hit / 100)+'W1.pt')
    torch.save(M.W[2], '/content/drive/MyDrive/mnist_train/'+str(s_hit / 100)+'W2.pt')

    torch.save(M.B[0], '/content/drive/MyDrive/mnist_train/'+str(s_hit / 100)+'B0.pt')
    torch.save(M.B[1], '/content/drive/MyDrive/mnist_train/'+str(s_hit / 100)+'B1.pt')
    torch.save(M.B[2], '/content/drive/MyDrive/mnist_train/'+str(s_hit / 100)+'B2.pt')

88.83
88.99
89.21
89.3
89.24
89.28
89.28
89.16


KeyboardInterrupt: ignored

In [None]:
torch.save(torch.tensor([0]), '/content/drive/MyDrive/mnist_train/'+'test.pt')