In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

In [4]:
#create fully connected NN
class NN(nn.Module):
    
    def __init__(self,input_size,num_classes):
        super(NN,self).__init__()
        self.fc1=nn.Linear(input_size,50)
        self.fc2=nn.Linear(50,num_classes)
    
    def forward(self,x):
        x=F.relu(self.fc1(x))
        x=self.fc2(x)
        return x
    
#set device
device='mps' if torch.backends.mps.is_available() else 'cpu'
# device='cpu'
#hyperparameter

input_size=784
num_classes=10
learning_rate=0.001
batch_size=64
num_epoch=2

#load data

train_dataset=datasets.MNIST(root='datasets/',train=True,transform=transforms.ToTensor(),download=True)
train_loader=DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)

test_dataset=datasets.MNIST(root='datasets/',train=False,transform=transforms.ToTensor(),download=True)
test_loader=DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)


#init network
model=NN(input_size=input_size,num_classes=num_classes).to(device)

#loss and optimizer
loss_fn=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=learning_rate)


#train network
for epoch in range(num_epoch):
    for batch_idx, (data,targets) in enumerate(train_loader):
        #get data to gpu if avilable
        data=data.to(device)
        targets=targets.to(device)
        
        #get to correct shape
        data=data.reshape(data.shape[0],-1)
        
        #forward
        scores=model(data)
        
        loss=loss_fn(scores,targets)
        
        #backword
        optimizer.zero_grad()
        loss.backward()
        
        #GD or adam step
        optimizer.step()

#check accuracy on both train and test set
def check_accuracy(loader,model):
    
    if loader.dataset.train:
        print("checking accuracy on train dataset")
    else:
        print("checking accuravy on test dataset")
    
    num_correct=0
    num_samples=0
    model.eval()
    
    with torch.no_grad():
        for x,y in loader:
            x=x.to(device)
            y=y.to(device)
            x=x.reshape(x.shape[0],-1)
            
            scores=model(x)
            _,pred=scores.max(1)
            num_correct+=(pred==y).sum()
            num_samples+=pred.size(0)
        print("Accuracy is :",round(float(num_correct)/float(num_samples)*100,2))
            
    model.train()
    return 

check_accuracy(train_loader,model)   
check_accuracy(test_loader,model) 

checking accuracy on train dataset
Accuracy is : 95.34
checking accuravy on test dataset
Accuracy is : 94.76
Execution time: 1.8723421096801758 seconds
