In [1]:
import os
import time
import numpy as np
import datasets
from datasets import load_dataset, list_datasets, Dataset
import huggingface_hub
from matplotlib import pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

In [2]:
# start time 
t0 = time.time()

In [3]:
# context
torch.multiprocessing.set_start_method('spawn')

In [4]:
def transforms(data):
    data['data'] = data['image'].type(torch.FloatTensor)
    return data

In [5]:
# load dataset `mnist`
mnist_train = load_dataset(path='mnist', 
                           split='train', 
                           cache_dir="/eagle/projects/candle_aesp/siebenschuh/HF") #.cast_column("image", float)
mnist_test  = load_dataset('mnist', 
                           split='test', 
                           cache_dir="/eagle/projects/candle_aesp/siebenschuh/HF") #.cast_column("image", float)

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# float tensor
dset_train = torchvision.datasets.MNIST(root="/eagle/projects/candle_aesp/siebenschuh/PT", train=True, download=True)
dset_test  = torchvision.datasets.MNIST(root="/eagle/projects/candle_aesp/siebenschuh/PT", train=False, download=True)

# dsets
#mnist_train = mnist_train.with_format(type="torch", device=device)
#mnist_test  = mnist_test.with_format(type="torch", device=device)

# actual dsets
#dset_train = mnist_train.map(transforms, remove_columns=["image"], batched=False)
#dset_test = mnist_test.map(transforms, remove_columns=["image"], batched=False)

Found cached dataset mnist (/eagle/projects/candle_aesp/siebenschuh/HF/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)
Found cached dataset mnist (/eagle/projects/candle_aesp/siebenschuh/HF/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)


In [6]:
if(device.type!='cuda'):
    # Show a digit
    digit_id = 5
    plt.figure(figsize=(3,3))
    plt.imshow(mnist_train['image'][digit_id], cmap='Greys', interpolation='nearest')
    plt.title(f'Digit `{mnist_train["label"][digit_id]}`', fontsize=14)
    plt.axis('off')
    plt.show()

## Input Parameters

In [7]:
b_size = 128

## Model

In [8]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16 * 4 * 4, 128)
        self.fc2   = nn.Linear(128, 64)
        self.fc3   = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dim except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

In [9]:
transform = transforms.ToTensor()

# Custom collate function
def custom_collate(batch):
    # Convert PIL.Image.Image objects to tensors
    batch = [transform(image) for image in batch]
    # Return the batch
    return torch.stack(batch, dim=0)

AttributeError: 'function' object has no attribute 'ToTensor'

In [None]:
# data loader
trainloader = torch.utils.data.DataLoader(dset_train, 
                                          batch_size=b_size, 
                                          shuffle=True, 
                                          num_workers=1,
                                          collate_fn=custom_collate)

testloader  = torch.utils.data.DataLoader(dset_test, 
                                          batch_size=b_size, 
                                          shuffle=True, 
                                          num_workers=1,
                                          collate_fn=custom_collate)

classes=tuple([str(i) for i in range(10)])

In [None]:
# Model Instance
net = Net()
net = net.to(device)

# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [None]:
# Training
for epoch in range(6):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        #inputs, labels = data['image'], data['label']
        inputs, labels = data #['data'].reshape((-1,1,28,28)), data['label']

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

## Inference

In [None]:
def acc(inputs, labels, model):
    pred = (model(inputs).argmax(axis=1)).cpu()
    gt = labels.cpu()
    
    # difference
    diff = (pred - gt)
    diff[diff!=0]=1

    return (1. - diff.sum() / len(diff)) * 100.

In [None]:
net.eval()

accList = []
for i, data in enumerate(testloader, 0):
    inputs, labels = data['data'].reshape((-1,1,28,28)), data['label']
    
    accList.append(acc(inputs, labels, net))
    
print(f'Mean accuracy: {np.mean(accList):.2f}%')

In [None]:
# end
t1 = time.time()

print(f'Comp. time: {t1-t0:.2f}')