## Imports

In [1]:
import torch

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim

## Create Fully Connected Network

In [3]:
class NN(nn.Module):
    def __init__(self, input_size, num_classes): ## 28*28 image size
        super(NN,self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, num_classes)
        
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        

In [4]:
modelx = NN(784,10)
x = torch.rand(64,784)
print(modelx(x).shape)


torch.Size([64, 10])


## set device

In [5]:
device = torch.device('mps')
# device = ('mps' if torch.mps.is_available() else "cpu")

## Hyperparameters

In [6]:
input_size = 784
num_classes = 10
lr = 0.001
epochs =2
batch_size = 64

In [7]:
train_dataset = datasets.MNIST(root = 'datasets/', train = True, transform = transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle=True)

In [8]:
test_dataset = datasets.MNIST(root = 'datasets/', train = False, transform = transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset = test_dataset, batch_size=batch_size, shuffle=True)

## Initialize Network

In [9]:
model = NN(input_size, num_classes = num_classes).to(device)

In [10]:
model.parameters

<bound method Module.parameters of NN(
  (fc1): Linear(in_features=784, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)>

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

## Train Network

In [12]:
for epoch in range(epochs):
    for idx ,(data,target) in enumerate(train_loader):
        data = data.to(device=device)
        target = target.to(device=device)
        
    print(data.shape)

torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])


In [13]:
data.reshape(data.shape[0],-1)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')

In [14]:
for  epoch in range(epochs):
    for idx,(data,target) in enumerate(train_loader):
        data = data.to(device = device)
        target = target.to(device=device)
        
        data = data.reshape(data.shape[0],-1)
        scores =model(data)
        loss= criterion(scores,target)
        
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()

In [15]:

print((model.fc1.weight).shape)

torch.Size([50, 784])


In [23]:
def check_accuracy(loader,model):
    num_correct = 0
    num_samples = 0
    
    model.eval()
    
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            x = x.reshape(x.shape[0],-1)
            scores = model(x)
            _,predictions = scores.max(1) 
            num_correct = num_correct +  (predictions == y).sum()
            num_samples = num_samples + predictions.size(0)
            
            print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples) * 100:.2f}")
    

In [26]:
(scores.shape)

torch.Size([32, 10])

In [24]:
check_accuracy(train_loader, model)

Got 62/64 with accuracy 96.88
Got 121/128 with accuracy 94.53
Got 184/192 with accuracy 95.83
Got 245/256 with accuracy 95.70
Got 307/320 with accuracy 95.94
Got 369/384 with accuracy 96.09
Got 428/448 with accuracy 95.54
Got 488/512 with accuracy 95.31
Got 548/576 with accuracy 95.14
Got 611/640 with accuracy 95.47
Got 670/704 with accuracy 95.17
Got 733/768 with accuracy 95.44
Got 795/832 with accuracy 95.55
Got 857/896 with accuracy 95.65
Got 917/960 with accuracy 95.52
Got 978/1024 with accuracy 95.51
Got 1040/1088 with accuracy 95.59
Got 1098/1152 with accuracy 95.31
Got 1160/1216 with accuracy 95.39
Got 1221/1280 with accuracy 95.39
Got 1282/1344 with accuracy 95.39
Got 1345/1408 with accuracy 95.53
Got 1406/1472 with accuracy 95.52
Got 1466/1536 with accuracy 95.44
Got 1528/1600 with accuracy 95.50
Got 1591/1664 with accuracy 95.61
Got 1654/1728 with accuracy 95.72
Got 1711/1792 with accuracy 95.48
Got 1772/1856 with accuracy 95.47
Got 1836/1920 with accuracy 95.62
Got 1896/1984