In [None]:
import pynn
from config import config

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets

class MNISTDataLoader:
    def __init__(self, type: str, batch_size: int, num_workers: int=1, transform: object=None):
        """
        Initialize MNIST data loader.
        Params:
            batch_size : (type int) batch size of data loader.
            num_workers : (type int) number of workers to use for data loader.
            transform : (type object) transform to apply to the dataset.
        """
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform
        # Type check the type
        self.type = type.lower()

        # Check if the type is valid
        if self.type != 'train' and self.type != 'test':
            raise ValueError(f"Invalid type: {self.type}. Expected 'train' or 'test'.")
        
        # Get the dataset
        self.dataset = datasets.MNIST(f'./data/{self.type}', train = self.type == 'train', download=True, transform=self.transform)
        
        # Create the data loader
        self.dataloader = DataLoader(dataset=self.dataset, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def __len__(self):
        """
        Get the length of the data loader.
        Returns:
            len: (type int) length of the data loader.
        """
        return len(self.dataloader)
        
    def get_dataloader(self):
        """
        Get the data loader.
        Returns:
            dataloader: (type torch.utils.data.DataLoader) data loader.
        """
        return self.dataloader

In [None]:
from torchvision import transforms

transform = []
transform.append(transforms.ToTensor())
transform.append(transforms.Normalize((0.1307,), (0.3081,)))

transform = transforms.Compose(transform)

In [None]:
dataloaders = {
    'train': MNISTDataLoader(type='train', batch_size=config['batch size'], num_workers=config['num workers'], transform=transform).dataloader,
    'test': MNISTDataLoader(type='test', batch_size=config['batch size'], num_workers=config['num workers'], transform=transform).dataloader,
}

In [None]:
nn = pynn.NeuralNetwork(name='Test NN')
l1 = pynn.Linear(in_features=784, out_features=512, bias=True, initialization='random', name='Linear 1')
a1 = pynn.Sigmoid(name='Sigmoid 1')
l2 = pynn.Linear(in_features=512, out_features=128, bias=True, initialization='random', name='Linear 2')
a2 = pynn.Sigmoid(name='Sigmoid 2')
l3 = pynn.Linear(in_features=128, out_features=10, bias=True, initialization='random', name='Linear 3')
a3 = pynn.Sigmoid(name='Sigmoid 3')
nn.add(block_name='input', layer=l1)
nn.add(block_name='input', layer=a1)
nn.add(block_name='hidden', layer=l2)
nn.add(block_name='hidden', layer=a2)
nn.add(block_name='output', layer=l3)
nn.add(block_name='output', layer=a3)
print(nn.summary())

In [None]:
loss = pynn.MSE(name='MSE')

In [None]:
learner = pynn.Learner(name='Learner')

In [None]:
learner.train(model=nn, train_set=dataloaders['train'], val_set=dataloaders['test'], epochs=config['epochs'], L=loss, lr=config['learning rate'])

In [None]:
learner.test(model=nn, test_set=dataloaders['test'])