In [30]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from tqdm import tqdm

In [3]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 12896980.93it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 209362.66it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3821721.54it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 10010327.77it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw






In [33]:
def check_accuracy(y_pred,y):
    _,predpos=y_pred.max(1)
    num_samples=len(y)
    num_correct=(predpos==y).sum()
    return (num_correct/num_samples)*100

In [4]:
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

In [5]:
class BasicANN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_block = nn.Sequential(
            nn.Linear(28*28,512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Linear(256,10))
        
    def forward(self,x):
        x = self.flatten(x)
        x = self.linear_block(x)
        return x

In [18]:
# Hyperparams
learning_rate = 0.001
batch_size = 64
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using",device)

Using cuda


In [11]:
model = BasicANN()
print(model)

BasicANN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_block): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
)


In [12]:
loss_fun = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(),learning_rate)

In [35]:
def train_loop(model,dataloader,loss_fun,optimizer):
    size = len(dataloader.dataset)
    model.train()
    model.to(device)
    for epoch in range(num_epochs):
        losses = []
        accuracies = []
        loop = tqdm(enumerate(dataloader), total=len(dataloader), leave=True)
        for batch,(x,y) in loop:
            # put on cuda
            x = x.to(device)
            y = y.to(device)
    
            # forward pass
            y_pred = model(x)
            
            # calculate loss & accuracy
            loss = loss_fun(y_pred,y)
            losses.append(loss.detach().item())
            
            accuracy = check_accuracy(y_pred.detach(),y)
            accuracies.append(accuracy.item())
            
            # zero out prior gradients
            optimizer.zero_grad()
            
            # # backprop
            loss.backward()
            
            # update weights
            optimizer.step()
            
            # Update TQDM progress bar
            loop.set_description(f"Epoch [{epoch}/{num_epochs}] ")
            loop.set_postfix(loss=loss.detach().item(), accuracy=accuracy.item())

In [36]:
train_loop(model,train_dataloader,loss_fun,optim)

Epoch [0/10] : 100%|██████████| 938/938 [00:10<00:00, 89.23it/s, accuracy=93.8, loss=0.188] 
Epoch [1/10] : 100%|██████████| 938/938 [00:10<00:00, 88.81it/s, accuracy=96.9, loss=0.155] 
Epoch [2/10] : 100%|██████████| 938/938 [00:10<00:00, 88.78it/s, accuracy=93.8, loss=0.164] 
Epoch [3/10] : 100%|██████████| 938/938 [00:10<00:00, 89.42it/s, accuracy=96.9, loss=0.127] 
Epoch [4/10] : 100%|██████████| 938/938 [00:10<00:00, 87.78it/s, accuracy=90.6, loss=0.209] 
Epoch [5/10] : 100%|██████████| 938/938 [00:10<00:00, 89.17it/s, accuracy=90.6, loss=0.129] 
Epoch [6/10] : 100%|██████████| 938/938 [00:10<00:00, 89.48it/s, accuracy=96.9, loss=0.122] 
Epoch [7/10] : 100%|██████████| 938/938 [00:10<00:00, 88.60it/s, accuracy=100, loss=0.106]  
Epoch [8/10] : 100%|██████████| 938/938 [00:10<00:00, 88.24it/s, accuracy=96.9, loss=0.081] 
Epoch [9/10] : 100%|██████████| 938/938 [00:10<00:00, 89.28it/s, accuracy=100, loss=0.085]  


In [44]:
def test_loop(model,dataloader,loss_fun):
    model.eval()
    model.to(device)
    losses = []
    samples,correct = 0,0
    loop = tqdm(enumerate(dataloader), total=len(dataloader), leave=True)
    with torch.no_grad():
        for batch,(x,y) in loop:
            # put on cuda
            x = x.to(device)
            y = y.to(device)

            # forward pass
            y_pred = model(x)
            
            # caclulate test loss
            loss = loss_fun(y_pred,y)
            losses.append(loss.item())

            # accuracy over entire dataset
            _,predpos=y_pred.max(1)
            samples+=len(y)
            correct+=(predpos==y).sum().item()
            
            # Update TQDM progress bar
            loop.set_postfix(loss=loss.item())

    print("Final Accuracy = ",100 * (correct/samples))

In [45]:
test_loop(model,test_dataloader,loss_fun)

100%|██████████| 157/157 [00:01<00:00, 110.97it/s, loss=0.0236]

Final Accuracy =  88.78



