In [1]:
import torch
import torchvision
from tqdm import tqdm

In [4]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0),(1))
])
dataset_train = torchvision.datasets.MNIST("./",train=True,download=True,transform=transform)
dataset_test = torchvision.datasets.MNIST("./",train=False,download=True,transform=transform)
train_data = torch.utils.data.DataLoader(dataset_train,batch_size=32,shuffle=True)
test_data = torch.utils.data.DataLoader(dataset_test,batch_size=32)

In [5]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fully_connect1 = torch.nn.Linear(784,512)
        self.fully_connect2 = torch.nn.Linear(512,256)
        self.fully_connect3 = torch.nn.Linear(256,10)
    
    def forward(self,x):
        x = x.reshape((x.shape[0],784))
        x = self.fully_connect1(x)
        x = torch.relu(x)
        x = torch.dropout(x,0.2,train=True)
        x = self.fully_connect2(x)
        x = torch.relu(x)
        x = torch.dropout(x,0.5,train=True)
        x = self.fully_connect3(x)
        x = torch.softmax(x,dim=1)

        return x

In [6]:
device = torch.device("cpu")
model = Model()
model = model.to(device)
model.train = True

In [7]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
loss_function = torch.nn.CrossEntropyLoss()

In [8]:
def accuracy(predict,label):
    _,predict_max = torch.max(predict,1)
    acc = torch.sum(predict_max==label,dtype=torch.float64) / len(predict)
    return acc

In [9]:
for epoch in range(20):
    train_loss = 0.0
    test_loss = 0.0
    train_acc = 0.0
    test_acc = 0.0
    for image,label in tqdm(train_data):
        image.to(device)
        label.to(device)
        optimizer.zero_grad()
        predict = model(image)
        loss = loss_function(predict,label)
        loss.backward()
        optimizer.step()
        
        train_loss +=loss
        train_acc += accuracy(predict,label)
        
    for image,label in test_data:
        image.to(device)
        label.to(device)
        predict = model(image)
        loss = loss_function(predict,label)
        
        test_loss += loss
        test_acc += accuracy(predict,label)
        
        
    train_total_loss = train_loss / len(train_data)
    train_total_acc = train_acc / len(train_data)
    test_total_loss = test_loss / len(test_data)
    test_total_acc = test_acc / len(test_data)
    
    print(f"Epoch: {epoch+1}, Accuracy: {train_total_acc}, Loss: {train_total_loss}")
    print(f"Epoch: {epoch+1}, Val_Accuracy: {test_total_acc}, Val_Loss: {test_total_loss}")

100%|██████████| 1875/1875 [00:20<00:00, 89.30it/s]


Epoch: 1, Accuracy: 0.8971, Loss: 1.571567416191101
Epoch: 1, Val_Accuracy: 0.9405950479233227, Val_Loss: 1.5227594375610352


100%|██████████| 1875/1875 [00:22<00:00, 83.35it/s]


Epoch: 2, Accuracy: 0.9473833333333334, Loss: 1.5148673057556152
Epoch: 2, Val_Accuracy: 0.9447883386581469, Val_Loss: 1.5163865089416504


100%|██████████| 1875/1875 [00:23<00:00, 80.98it/s]


Epoch: 3, Accuracy: 0.9544833333333334, Loss: 1.5067291259765625
Epoch: 3, Val_Accuracy: 0.9618610223642172, Val_Loss: 1.4993170499801636


100%|██████████| 1875/1875 [00:24<00:00, 76.93it/s]


Epoch: 4, Accuracy: 0.9599833333333333, Loss: 1.5014593601226807
Epoch: 4, Val_Accuracy: 0.9592651757188498, Val_Loss: 1.5012420415878296


100%|██████████| 1875/1875 [00:25<00:00, 72.20it/s]


Epoch: 5, Accuracy: 0.9611333333333333, Loss: 1.499843955039978
Epoch: 5, Val_Accuracy: 0.9564696485623003, Val_Loss: 1.5050262212753296


100%|██████████| 1875/1875 [00:27<00:00, 68.28it/s]


Epoch: 6, Accuracy: 0.9619833333333333, Loss: 1.4991505146026611
Epoch: 6, Val_Accuracy: 0.9622603833865815, Val_Loss: 1.4988242387771606


100%|██████████| 1875/1875 [00:28<00:00, 65.20it/s]


Epoch: 7, Accuracy: 0.9642833333333334, Loss: 1.496816873550415
Epoch: 7, Val_Accuracy: 0.9610623003194888, Val_Loss: 1.49973726272583


100%|██████████| 1875/1875 [00:29<00:00, 64.23it/s]


Epoch: 8, Accuracy: 0.9653833333333334, Loss: 1.4956958293914795
Epoch: 8, Val_Accuracy: 0.9562699680511182, Val_Loss: 1.5052043199539185


100%|██████████| 1875/1875 [00:30<00:00, 60.69it/s]


Epoch: 9, Accuracy: 0.9643166666666667, Loss: 1.496751070022583
Epoch: 9, Val_Accuracy: 0.9638578274760383, Val_Loss: 1.4971567392349243


100%|██████████| 1875/1875 [00:30<00:00, 62.03it/s]


Epoch: 10, Accuracy: 0.9661666666666666, Loss: 1.4950542449951172
Epoch: 10, Val_Accuracy: 0.9548722044728435, Val_Loss: 1.506162166595459


100%|██████████| 1875/1875 [00:31<00:00, 59.94it/s]


Epoch: 11, Accuracy: 0.9664, Loss: 1.4947634935379028
Epoch: 11, Val_Accuracy: 0.9599640575079872, Val_Loss: 1.5012686252593994


100%|██████████| 1875/1875 [00:30<00:00, 61.57it/s]


Epoch: 12, Accuracy: 0.9662833333333334, Loss: 1.4949889183044434
Epoch: 12, Val_Accuracy: 0.9625599041533547, Val_Loss: 1.4984110593795776


100%|██████████| 1875/1875 [00:30<00:00, 61.01it/s]


Epoch: 13, Accuracy: 0.9649166666666666, Loss: 1.4961552619934082
Epoch: 13, Val_Accuracy: 0.9600638977635783, Val_Loss: 1.5006816387176514


100%|██████████| 1875/1875 [00:29<00:00, 62.72it/s]


Epoch: 14, Accuracy: 0.9664166666666667, Loss: 1.4946033954620361
Epoch: 14, Val_Accuracy: 0.9646565495207667, Val_Loss: 1.4964838027954102


100%|██████████| 1875/1875 [00:30<00:00, 61.20it/s]


Epoch: 15, Accuracy: 0.9675333333333334, Loss: 1.49355947971344
Epoch: 15, Val_Accuracy: 0.9590654952076677, Val_Loss: 1.5018682479858398


100%|██████████| 1875/1875 [00:29<00:00, 64.12it/s]


Epoch: 16, Accuracy: 0.9660666666666666, Loss: 1.4951894283294678
Epoch: 16, Val_Accuracy: 0.9628594249201278, Val_Loss: 1.4978243112564087


100%|██████████| 1875/1875 [00:29<00:00, 62.64it/s]


Epoch: 17, Accuracy: 0.9676333333333333, Loss: 1.4934922456741333
Epoch: 17, Val_Accuracy: 0.9605630990415336, Val_Loss: 1.500481128692627


100%|██████████| 1875/1875 [00:29<00:00, 63.63it/s]


Epoch: 18, Accuracy: 0.9658833333333333, Loss: 1.4950751066207886
Epoch: 18, Val_Accuracy: 0.9569688498402555, Val_Loss: 1.5036283731460571


100%|██████████| 1875/1875 [00:29<00:00, 63.18it/s]


Epoch: 19, Accuracy: 0.96705, Loss: 1.4940341711044312
Epoch: 19, Val_Accuracy: 0.9618610223642172, Val_Loss: 1.4991676807403564


100%|██████████| 1875/1875 [00:29<00:00, 64.58it/s]


Epoch: 20, Accuracy: 0.9660333333333333, Loss: 1.4950788021087646
Epoch: 20, Val_Accuracy: 0.9645567092651757, Val_Loss: 1.4963780641555786


In [10]:
torch.save(model.state_dict(),"mnist.pth")