In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# import matplotlib.pyplot as plt

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


# HyperParameters

In [None]:
input_size=28*28
hidden_size =100
num_classes =10
epochs=2
batch_size=64
learning_rate=0.01

# MNIST

In [None]:
train_dataset=torchvision.datasets.MNIST(root='../data',train=True,
                                        transform=transforms.ToTensor(),download=True)
test_dataset=torchvision.datasets.MNIST(root='../data',train=False,
                                        transform=transforms.ToTensor())

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,
                                          shuffle=False)

In [None]:
examples=iter(train_loader)
samples,labels=examples._next_data()
print(samples.shape,labels.shape)

torch.Size([64, 1, 28, 28]) torch.Size([64])


In [None]:
# for i in range(6):
#     plt.subplot(2,3,i+1)
#     plt.imshow(samples[i][0],cmap='gray')
# plt.show()

# Model

In [None]:
class NeuralNet(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.cnn = nn.Conv2d(
            in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2
        )
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(16 * 14 * 14, hidden_size)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.cnn(x)
        out = self.relu(out)
        out = self.pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out


In [None]:
model=NeuralNet(hidden_size,num_classes)
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

In [None]:
n_total_steps=len(train_loader)
for epoch in range(epochs):
    for i,(images,labels) in enumerate(train_loader):
        images=images.to(device)
        labels=labels.to(device)
        
        outputs=model(images)
        loss= criterion(outputs,labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1)%64==0:
            print(f'epoch {epoch+1}/{epochs},step {i+1}/{n_total_steps},loss={loss.item():.4f}')
        

epoch 1/2,step 64/938,loss=0.4350
epoch 1/2,step 128/938,loss=0.2257
epoch 1/2,step 192/938,loss=0.1110
epoch 1/2,step 256/938,loss=0.1116
epoch 1/2,step 320/938,loss=0.0909
epoch 1/2,step 384/938,loss=0.1317
epoch 1/2,step 448/938,loss=0.2119
epoch 1/2,step 512/938,loss=0.2762
epoch 1/2,step 576/938,loss=0.1751
epoch 1/2,step 640/938,loss=0.0461
epoch 1/2,step 704/938,loss=0.0939
epoch 1/2,step 768/938,loss=0.1976
epoch 1/2,step 832/938,loss=0.1375
epoch 1/2,step 896/938,loss=0.1476
epoch 2/2,step 64/938,loss=0.0534
epoch 2/2,step 128/938,loss=0.0338
epoch 2/2,step 192/938,loss=0.1704
epoch 2/2,step 256/938,loss=0.0867
epoch 2/2,step 320/938,loss=0.0456
epoch 2/2,step 384/938,loss=0.1649
epoch 2/2,step 448/938,loss=0.0759
epoch 2/2,step 512/938,loss=0.0758
epoch 2/2,step 576/938,loss=0.1092
epoch 2/2,step 640/938,loss=0.1272
epoch 2/2,step 704/938,loss=0.2066
epoch 2/2,step 768/938,loss=0.1963
epoch 2/2,step 832/938,loss=0.0627
epoch 2/2,step 896/938,loss=0.2198


# Test

In [None]:
with torch.no_grad():
    n_correct=0
    n_samples=0
    for images,labels in test_loader:
        images=images.to(device)
        labels=labels.to(device)
        outputs=model(images)
        
        #value,index
        _,prediction=torch.max(outputs,1)
        n_samples += labels.shape[0]
        n_correct += (prediction==labels).sum().item()
        
    acc=100*n_correct / n_samples
    print(f'accuracy={acc}')

accuracy=96.03
