<a href="https://colab.research.google.com/github/Jaeyoung-Choi/Hardware/blob/master/mnist_backpropagation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
import torch
from torchvision import datasets as dsets
import torchvision.transforms as transforms

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(42)
if device == 'cuda':
    torch.cuda.manual_seed_all(42)
    print(torch.cuda.current_device(), torch.cuda.device_count(), torch.cuda.get_device_name(0))

In [17]:
def sigmoid(X):
    return 1 / (1 + torch.exp(-X))

def softmax(X):
    return torch.exp(X) / torch.unsqueeze(torch.sum(torch.exp(X), dim = 1), 1)

In [18]:
def dsigmoid(X):
    # return torch.exp(-X) / torch.pow(1 + torch.exp(-X), 2)
    return torch.mul(sigmoid(X), 1 - sigmoid(X))

def dsoftmax(X):
    #return torch.mul((torch.sum(torch.exp(X), dim = 1) - torch.exp(X)) / torch.pow(torch.sum(torch.exp(X), dim = 1), 2), torch.exp(X))
    return softmax(X) - torch.pow(softmax(X), 2)

In [19]:
# lr = 0.001
lr = 0.00001

In [20]:
class model:
    def __init__(self):
        self.W = [torch.normal(torch.zeros((784, 16)), torch.tensor([2])).to(device),
                  torch.normal(torch.zeros((16, 16)), torch.tensor([2])).to(device),
                  torch.normal(torch.zeros((16, 10)), torch.tensor([2])).to(device)]
        self.B = [torch.normal(torch.zeros((16)), torch.tensor(2)).to(device),
                  torch.normal(torch.zeros((16)), torch.tensor(2)).to(device),
                  torch.normal(torch.zeros((10)), torch.tensor(2)).to(device)]
    
    def forward(self, X):
        L1 = sigmoid(torch.matmul(X, self.W[0]) + self.B[0])
        L2 = sigmoid(torch.matmul(L1, self.W[1]) + self.B[1])
        Y = softmax(torch.matmul(L2, self.W[2]) + self.B[2])
        return Y

    def loss(self, X, Y):
        return torch.sum(torch.pow(self.forward(X) - Y, 2)) # / 원소 갯수(어케 구함?)
    
    def backward(self, X, Y):
        dW = [0, 0, 0]
        dB = [0, 0, 0]
        L1 = torch.matmul(X, self.W[0]) + self.B[0]
        L2 = torch.matmul(sigmoid(L1), self.W[1]) + self.B[1]
        FOR = torch.matmul(sigmoid(L2), self.W[2]) + self.B[2]

        dB[2] = torch.mul(dsoftmax(FOR), 2 * (softmax(FOR) - Y))
        dW[2] = torch.matmul(torch.transpose(torch.unsqueeze(sigmoid(L2), 1), 1, 2), torch.unsqueeze(dB[2], 1))

        dB[1] = torch.mul(torch.matmul(dB[2], torch.transpose(self.W[2], 0, 1)), dsigmoid(L2))
        dW[1] = torch.matmul(torch.transpose(torch.unsqueeze(sigmoid(L1), 1), 1, 2), torch.unsqueeze(dB[1], 1))

        dB[0] = torch.mul(torch.matmul(dB[1], torch.transpose(self.W[1], 0, 1)), dsigmoid(L1))
        dW[0] = torch.matmul(torch.transpose(torch.unsqueeze(X, 1), 1, 2), torch.unsqueeze(dB[0], 1))


        self.W[2] = self.W[2] + torch.mul(torch.sum(dW[2], dim = 0), -1 * lr)
        self.B[2] = self.B[2] + torch.mul(torch.sum(dB[2], dim = 0), -1 * lr)

        self.W[1] = self.W[1] + torch.mul(torch.sum(dW[1], dim = 0), -1 * lr)
        self.B[1] = self.B[1] + torch.mul(torch.sum(dB[1], dim = 0), -1 * lr)

        self.W[0] = self.W[0] + torch.mul(torch.sum(dW[0], dim = 0), -1 * lr)
        self.B[0] = self.B[0] + torch.mul(torch.sum(dB[0], dim = 0), -1 * lr)

In [21]:
mnist_train = dsets.MNIST(root='MNIST_data/', # 다운로드 경로 지정
                          train=True, # True를 지정하면 훈련 데이터로 다운로드
                          transform=transforms.ToTensor(), # 텐서로 변환
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/', # 다운로드 경로 지정
                         train=False, # False를 지정하면 테스트 데이터로 다운로드
                         transform=transforms.ToTensor(), # 텐서로 변환
                         download=True)

In [22]:
def one_hot_encoding(x):
    arg = [0] * 10
    arg[x] = 1
    return arg

In [23]:
mnist_test_X = []
for i in range(10000):
    mnist_test_X.append(mnist_test[i][0].reshape(1, 784))

mnist_test_X = torch.cat(mnist_test_X, dim = 0)
mnist_test_X = mnist_test_X.to(device)

In [24]:
mnist_train_X = []
for i in range(60000):
    mnist_train_X.append(mnist_train[i][0].reshape(1, 784))

mnist_train_X = torch.cat(mnist_train_X, dim = 0)
mnist_train_X = mnist_train_X.to(device)

In [25]:
mnist_test_Y = []
for i in range(10000):
    mnist_test_Y.append(one_hot_encoding(mnist_test[i][1]))
mnist_test_Y = torch.tensor(mnist_test_Y).to(device)

In [26]:
mnist_train_Y = []
for i in range(60000):
    mnist_train_Y.append(one_hot_encoding(mnist_train[i][1]))
mnist_train_Y = torch.tensor(mnist_train_Y).to(device)

In [27]:
M = model()

In [28]:
def hit():
    return torch.sum(torch.eq(torch.argmax(mnist_test_Y, dim = 1), torch.argmax(M.forward(mnist_test_X), dim = 1)))

In [32]:
comp = torch.tensor([9000]).to(device)
# for tr in range(100):
tr = 1
max_hit = curr_hit = hit()
while True:
    # for i in range(60000):
    #     #print(tr, i, M.loss(mnist_train_X[i], mnist_train_Y[i]).to(torch.device('cpu')), end = '\t')
    #     M.backward(mnist_train_X[i], mnist_train_Y[i])
    #     loss_sum += M.loss(mnist_train_X[i], mnist_train_Y[i])
    #     #print(M.loss(mnist_train_X[i], mnist_train_Y[i]).to(torch.device('cpu')))
    M.backward(mnist_train_X, mnist_train_Y)
    loss_sum = M.loss(mnist_train_X, mnist_train_Y)
    print(tr, curr_hit, loss_sum)
    tr += 1
    curr_hit = hit()
    if torch.gt(curr_hit, max_hit + 10):
        torch.save(M.W[0], '/content/drive/MyDrive/mnist_train/'+str(curr_hit.tolist() / 100)+'W0.pt')
        torch.save(M.W[1], '/content/drive/MyDrive/mnist_train/'+str(curr_hit.tolist() / 100)+'W1.pt')
        torch.save(M.W[2], '/content/drive/MyDrive/mnist_train/'+str(curr_hit.tolist() / 100)+'W2.pt')

        torch.save(M.B[0], '/content/drive/MyDrive/mnist_train/'+str(curr_hit.tolist() / 100)+'B0.pt')
        torch.save(M.B[1], '/content/drive/MyDrive/mnist_train/'+str(curr_hit.tolist() / 100)+'B1.pt')
        torch.save(M.B[2], '/content/drive/MyDrive/mnist_train/'+str(curr_hit.tolist() / 100)+'B2.pt')
        max_hit = curr_hit

1 tensor(9059) tensor(6886.8560)
2 tensor(9059) tensor(6886.8569)
3 tensor(9059) tensor(6886.8457)
4 tensor(9059) tensor(6886.8335)
5 tensor(9059) tensor(6886.8086)
6 tensor(9059) tensor(6886.8081)
7 tensor(9059) tensor(6886.7910)
8 tensor(9059) tensor(6886.7847)
9 tensor(9059) tensor(6886.7637)
10 tensor(9059) tensor(6886.7856)
11 tensor(9059) tensor(6886.7646)
12 tensor(9059) tensor(6886.7559)
13 tensor(9059) tensor(6886.7446)
14 tensor(9059) tensor(6886.7236)
15 tensor(9059) tensor(6886.7026)
16 tensor(9059) tensor(6886.7070)
17 tensor(9059) tensor(6886.7017)
18 tensor(9059) tensor(6886.6738)
19 tensor(9059) tensor(6886.6499)
20 tensor(9059) tensor(6886.6567)
21 tensor(9059) tensor(6886.6685)
22 tensor(9059) tensor(6886.6665)
23 tensor(9059) tensor(6886.6592)
24 tensor(9059) tensor(6886.6440)
25 tensor(9059) tensor(6886.6338)
26 tensor(9059) tensor(6886.6436)
27 tensor(9059) tensor(6886.6484)
28 tensor(9059) tensor(6886.6313)
29 tensor(9059) tensor(6886.6270)
30 tensor(9059) tensor(

KeyboardInterrupt: ignored

In [None]:
torch.save(M.W[0], '/content/drive/MyDrive/mnist_train/'+str(hit().tolist() / 100)+'W0.pt')
torch.save(M.W[1], '/content/drive/MyDrive/mnist_train/'+str(hit().tolist() / 100)+'W1.pt')
torch.save(M.W[2], '/content/drive/MyDrive/mnist_train/'+str(hit().tolist() / 100)+'W2.pt')

torch.save(M.B[0], '/content/drive/MyDrive/mnist_train/'+str(hit().tolist() / 100)+'B0.pt')
torch.save(M.B[1], '/content/drive/MyDrive/mnist_train/'+str(hit().tolist() / 100)+'B1.pt')
torch.save(M.B[2], '/content/drive/MyDrive/mnist_train/'+str(hit().tolist() / 100)+'B2.pt')

In [None]:
torch.save(torch.tensor([0]), '/content/drive/MyDrive/mnist_train/'+'test.pt')

In [30]:
if torch.cuda.is_available():
    M.W[0] = torch.load('/content/drive/MyDrive/mnist_train/90.59W0.pt')
    M.W[1] = torch.load('/content/drive/MyDrive/mnist_train/90.59W1.pt')
    M.W[2] = torch.load('/content/drive/MyDrive/mnist_train/90.59W2.pt')
    
    M.B[0] = torch.load('/content/drive/MyDrive/mnist_train/90.59B0.pt'
    M.B[1] = torch.load('/content/drive/MyDrive/mnist_train/90.59B1.pt'
    M.B[2] = torch.load('/content/drive/MyDrive/mnist_train/90.59B2.pt'
else:
    M.W[0] = torch.load('/content/drive/MyDrive/mnist_train/90.59W0.pt', map_location=torch.device('cpu'))
    M.W[1] = torch.load('/content/drive/MyDrive/mnist_train/90.59W1.pt', map_location=torch.device('cpu'))
    M.W[2] = torch.load('/content/drive/MyDrive/mnist_train/90.59W2.pt', map_location=torch.device('cpu   '))

    M.B[0] = torch.load('/content/drive/MyDrive/mnist_train/90.59B0.pt', map_location=torch.device('cpu'))
    M.B[1] = torch.load('/content/drive/MyDrive/mnist_train/90.59B1.pt', map_location=torch.device('cpu'))
    M.B[2] = torch.load('/content/drive/MyDrive/mnist_train/90.59B2.pt', map_location=torch.device('cpu'))