In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [None]:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

In [None]:
# input datas
train_dataset.data.shape, test_dataset.data.shape

# output datas
train_dataset.targets.shape, test_dataset.targets.shape

In [None]:
import matplotlib.pyplot as plt

In [None]:
plot_size = 4
for it_sam, _sample in enumerate(train_dataset.data[:10]):

    cur_idx = it_sam%plot_size+1

    plt.subplot(1, plot_size, cur_idx)

    _sample_target = train_dataset.targets[it_sam]
    plt.title(_sample_target.item())
    plt.imshow(_sample, 'gray')


    if cur_idx == plot_size:
        plt.show()

# model init

In [None]:
class Mnist_fcn(nn.Module):
    
    def __init__(self, model_input, model_output):
        super(Mnist_fcn, self).__init__()
        
        #28 x 28 = 784
        # layer
        self.fc = nn.Linear(model_input, 64) # learning weight, bias
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, model_output)
        
        #activation 
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        out = self.fc(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        
       # Cross Entropy Loss --> softmax 생략
        
        return out

In [None]:
model = Mnist_fcn(28*28, 1) # Cross entropy 1 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0] # MSE 2 2

In [None]:
print(model)

In [None]:
from torchsummary import summary

summary(model, input_size=((1, 784)))

# model learning 

In [None]:
learning_rate = 1e-2 #0.01

# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
device = 'mps'

In [None]:
epochs = 10

model = model.to(device)

for _epoch in range(epochs):
    for it_batch, (images, labels) in enumerate(train_loader):
        
        images = images.reshape([100, -1])
        images = images.to(device)
        
        # label
        labels = labels.reshape([100, -1])
        labels = labels.type(torch.float32)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        
        if (it_batch+1) % 100 == 0 :
            print(f'epoch{_epoch}, batch{it_batch}, loss{loss.item()}' )
            
            with torch.no_grad():
                model.eval()

                correct_ = 0
                total_ = 0

                for it_batch, (images, labels) in enumerate(test_loader):
                    images = images.reshape([100, -1])
                    images = images.to(device)
                    
                    labels = labels.reshape([100, -1])
                    labels = labels.type(torch.float32)
                    labels = labels.to(device)
                    outputs = model(images)
                    # pred=torch.argmax(outputs, axis=1)
                     

                    # total_ += pred.shape[0]
                    # correct_ += (labels==pred).sum()
                    total_ += outputs.shape[0]
                    correct_ += torch.sum(torch.round(outputs) == labels).item()


                acc = correct_/total_ * 100
                print(f'acc:{acc:.2f}, correct:{correct_}, total:{total_}')
                
                model.train()
        

# calc Accuracy

In [None]:
model = model.to('cpu')

In [None]:
with torch.no_grad():
    model.eval()
    
    correct_ = 0
    total_ = 0

    for it_batch, (images, labels) in enumerate(test_loader):
        images = images.reshape([100, -1])
        labels = labels.reshape([100, -1])
        outputs = model(images)
        total_ += outputs.shape[0]
        correct_ += torch.sum(torch.round(outputs) == labels).item()

#         pred=torch.argmax(outputs, axis=1)

#         total_ += pred.shape[0]
#         correct_ += (labels==pred).sum()
    
    
    acc = correct_/total_ * 100
    print(f'acc{acc:.2f}, correct{correct_}, total{total_}')

# predict display 

In [None]:
def display_torch_ret(inputs, targets, predicts, plot_size=4):
    inputs = inputs.permute(0, 2, 3, 1)

    for it_sam, _sample in enumerate(inputs):


        cur_idx = it_sam%plot_size+1
        plt.subplot(1, plot_size, cur_idx)

        _sample_target = targets[it_sam]
        _sample_predict = predicts[it_sam] 
        
        if _sample_predict != _sample_target :
            print('################# incorrect ##############')
        _title = f'gt:{_sample_target.item()}, p:{_sample_predict.item()}'
        plt.title(_title)
        plt.imshow(_sample, 'gray')


        if cur_idx == plot_size:
#             plt.figure(figsize=(12, 12))
            plt.show()

In [None]:
iter_test = iter(test_loader)

images, labels = next(iter_test)

In [None]:
outputs = model(images.reshape([100, -1]))

In [None]:
(torch.round(outputs) == labels.reshape([-1, 1]) ).sum()

In [None]:

# predicts = torch.argmax(outputs, axis=1)
predicts = torch.round(outputs)

In [None]:
images.shape

In [None]:
display_torch_ret(images, labels, predicts, 6)