In [121]:
#importer
import jsonschema
import json
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from ValidJson import VALID_JSON

DATASET_TRAIN_PATH = 'data/train'
DATASET_TEST_PATH  = 'data/test'

### Parsing configuration

In [122]:
def parse_config(path:str)->dict:
        with open(path, 'r') as file:
            try:
                config_schema = json.load(file)
                jsonschema.validate(config_schema, VALID_JSON)
                return config_schema
            except FileNotFoundError:
                print("Error! Configuration file is not exist.")
                exit(1)
            except jsonschema.exceptions.ValidationError as err:
                print(f"Error! Invalid format of configuration file: {err}")
                exit(1)

### Prepare data for learning

In [123]:
def prepare_learn_data(config)->DataLoader:
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.1307), (0.3081))]
    )#(0.1307), (0.3081)
    train_loader = DataLoader(
        dataset=datasets.MNIST(root=DATASET_TRAIN_PATH, train=True, transform=transform, download=True),
        batch_size=config['nn_config']['batch_size'], 
        shuffle=True)
    test_loader = DataLoader(
        dataset=datasets.MNIST(root=DATASET_TEST_PATH, train=False, transform=transform, download=True),
        batch_size=config['nn_config']['batch_size'],
        shuffle=False)
    return train_loader, test_loader 

In [124]:
### Model

In [125]:
from typing import List, Tuple

class NeuralNetwork(nn.Module):
    def __init__(self, hidden_arr:List[Tuple[int, int]], act_f:str) -> None:
        super().__init__()
        self._HIDDEN_ARR = hidden_arr
        self._activation = self._activation_decode(act_f)
        self._body = self._generate_architecture()
        
        
    def _activation_decode(self, act_f_str):
            if act_f_str == 'identity':
                return nn.Identity()
            elif act_f_str == 'sig':
                return nn.SiLU()
            else:
                return nn.ReLU()
    
    def _generate_architecture(self)->nn.ModuleList:
        hidden = nn.ModuleList()
        for hidden_layer in self._HIDDEN_ARR:
           hidden.append(nn.Linear(hidden_layer[0], hidden_layer[1]))
        return hidden

    def forward(self, x):
        x = x.view(-1, self._HIDDEN_ARR[0][0])  # Flatten the input
        l = len(self._body)-1
        for i, layer in enumerate(self._body):
            x = layer(x)
            if i != l:
                x = self._activation(x)
        return x

### Learning process

In [126]:
def train(device:torch.device, model:NeuralNetwork, config:dict, train_loader:DataLoader, criterion:nn.CrossEntropyLoss, optimizer:optim.Adam):
    model.train()
    for epoch in range(config['nn_config']['epochs']):
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device) 
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            #if batch_idx % 100 == 0:
            #    print(f'Epoch [{epoch+1}/{config['nn_config']['epochs']}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
    
        print(f'Epoch {epoch+1} finished with avg loss: {running_loss/len(train_loader):.4f}')

def test(device: torch.device, model:NeuralNetwork, test_loader:DataLoader):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        accuracy = 100 * correct / total
        print(f'Accuracy on test set: {accuracy:.2f}%')

def learning(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader = prepare_learn_data(config)
    model = NeuralNetwork(config['nn_config']['hidden_arr'], config['nn_config']['act_f']).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['nn_config']['learning_rate'])
    #return
    train(device, model, config, train_loader, criterion, optimizer)
    test(device, model, test_loader)

### Test 1

In [127]:
config = parse_config('config/config1.json')
learning(config)

Epoch 1 finished with avg loss: 0.2513
Epoch 2 finished with avg loss: 0.1052
Epoch 3 finished with avg loss: 0.0741
Epoch 4 finished with avg loss: 0.0594
Epoch 5 finished with avg loss: 0.0464
Epoch 6 finished with avg loss: 0.0369
Epoch 7 finished with avg loss: 0.0338
Epoch 8 finished with avg loss: 0.0281
Epoch 9 finished with avg loss: 0.0246
Epoch 10 finished with avg loss: 0.0226
Accuracy on test set: 97.89%


### Test 2
Use more deeper structure with decrise inputs cnt

In [128]:
config = parse_config('config/config2.json')
learning(config)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x256 and 512x128)

### Test 3
Activation function is sig

In [None]:
config = parse_config('config/config3.json')
learning(config)

### Test 4
Activation function is identity

In [None]:
config = parse_config('config/config4.json')
learning(config)

### Test 5
Use greater learning rate

In [None]:
config = parse_config('config/config5.json')
learning(config)

### Test 6
Use smaller learning rate

In [None]:
config = parse_config('config/config6.json')
learning(config)

### Test 7
Use smaller batches

In [None]:
config = parse_config('config/config7.json')
learning(config)

### Test 8
Use greater batches

In [None]:
config = parse_config('config/config8.json')
learning(config)

### Test 9
Use more epochs

In [None]:
config = parse_config('config/config9.json')
learning(config)