# Train the network using the REDISAI db as an exchange place and debug the problems


In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.models as models
import torch.utils.data as tdata

import numpy as np
import redisai as rai

from dataclasses import dataclass
import pickle

# import the modules used in the program
import train_utils
import ml2rt

## Create the network

In [73]:
@dataclass
class TrainParams:
    ps_id: str
    N: int
    task: str
    func_id: int
    lr: float
    batch_size: int
    

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
def create_model(init: bool):
    """Creates the model used to train the network

    For this example we'll be using the simple model from the MNIST examples
    (https://github.com/pytorch/examples/blob/master/mnist/main.py)
    """

    def init_weights(m: nn.Module):
        """Initialize the weights of the network"""
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0.01)
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0.01)

    # Create the model and initialize the weights
    model = Net()

    # If the task is initializing the layers do so
    if init:
        print('Initializing layers...')
        model.apply(init_weights)

    return model

In [44]:
torch.cuda.is_available()

True

## Create the Dataloader

In [45]:

# TODO max document size is 16 MB, this could give us problems in the future
# when the datasets are so big, we should calculate the size (easy, and divide the dataset)
def split_dataset(X, Y, subsets):
    """Splits the X and Y in N different subsets"""
    X_split = np.split(X, subsets)
    Y_split = np.split(Y, subsets)
    
    return X_split, Y_split


def approx_size(a: np.array):
    """ approx size of float 32 array in MB"""
    return (32/8) * np.prod(a.shape) / 1e6



In [15]:
47*128, 16*128

(6016, 2048)

In [46]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_data = datasets.MNIST('./data', train=True, download=False, transform=transform)
val_data = datasets.MNIST('./data', train=False, download=False, transform=transform)

train_data.data, train_data.targets = train_data.data[:3000], train_data.targets[:3000]
val_data.data, val_data.targets = val_data.data[:2000], val_data.targets[:2000]

In [47]:

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=128)
len(train_loader), len(val_loader)

(24, 16)

# Define the train and test methods


In [48]:
def train(model: nn.Module, device,
          train_loader: tdata.DataLoader,
          optimizer: torch.optim.Optimizer, tensor_dict) -> float:
    """Loop used to train the network"""
    model.train()
    loss, tot = 0, 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)

        loss = F.nll_loss(output, target)
        tot += loss.item()
        loss.backward()

        # Here save the gradients to publish on the database
#         train_utils.update_tensor_dict(model, tensor_dict)
        optimizer.step()
        

        if batch_idx % 5 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                1, batch_idx * len(data), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.item()))

    return tot/len(train_loader)


def validate(model, device, val_loader: tdata.DataLoader) -> (float, float):
    """Loop used to validate the network"""
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(val_loader.dataset)

    accuracy = 100. * correct / len(val_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
    return accuracy, test_loss

## Main entrypoint of the code

In [82]:
from copy import deepcopy

params = TrainParams(ps_id='example', func_id=0, N =2, task='train', lr=0.01, batch_size=128)


torch.manual_seed(42) 
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')

# Create the model
model = create_model(init=True).to(device)

Initializing layers...


In [50]:
# Create the redis connection
addr = '192.168.99.101'
port = 31618
con = rai.Client(debug=True, host=addr, port=port)

### Train for a couple of epochs

In [84]:
%%time
# create the tensor dict
tdict = dict()

optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1,10):
    print('Epoch', epoch)
    train(model, device, train_loader, optimizer, tdict)
    validate(model, device, val_loader)


Epoch 1

Test set: Average loss: 0.6271, Accuracy: 1648/2000 (82%)

Epoch 2

Test set: Average loss: 0.4882, Accuracy: 1724/2000 (86%)

Epoch 3

Test set: Average loss: 0.3404, Accuracy: 1793/2000 (90%)

Epoch 4

Test set: Average loss: 0.3638, Accuracy: 1777/2000 (89%)

Epoch 5

Test set: Average loss: 0.3386, Accuracy: 1779/2000 (89%)

Epoch 6

Test set: Average loss: 0.2790, Accuracy: 1833/2000 (92%)

Epoch 7

Test set: Average loss: 0.2843, Accuracy: 1822/2000 (91%)

Epoch 8

Test set: Average loss: 0.2415, Accuracy: 1847/2000 (92%)

Epoch 9

Test set: Average loss: 0.2434, Accuracy: 1855/2000 (93%)

Wall time: 50.5 s


In [70]:
con.set('model', pickle.dumps(model))

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



True

In [71]:
del Net

## Save the model in the database

In [74]:
s = pickle.loads(con.get('model'))

GET model


In [75]:
s

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [66]:
con.modelset('test-model', 'torch', 'cpu', m)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



ResponseError: [enforce fail at inline_container.cc:208] . file not found: archive/constants.pkl frame #0: c10::ThrowEnforceNotMet(char const*, int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, void const*) + 0x67 (0x7f4653cd7787 in /usr/lib/redis/modules/backends/redisai_torch/lib/libc10.so) frame #1: caffe2::serialize::PyTorchStreamReader::getRecordID(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xd6 (0x7f4646c4b376 in /usr/lib/redis/modules/backends/redisai_torch/lib/libtorch_cpu.so) frame #2: caffe2::serialize::PyTorchStreamReader::getRecord(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x38 (0x7f4646c4c018 in /usr/lib/redis/modules/backends/redisai_torch/lib/libtorch_cpu.so) frame #3: torch::jit::readArchiveAndTensors(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<std::function<c10::StrongTypePtr (c10::QualifiedName const&)> >, c10::optional<std::function<c10::intrusive_ptr<c10::ivalue::Object, c10::detail::intrusive_target_default_null_type<c10::ivalue::Object> > (c10::StrongTypePtr, c10::IValue)> >, c10::optional<c10::Device>, caffe2::serialize::PyTorchStreamReader&) + 0xda (0x7f4647ccf3aa in /usr/lib/redis/modules/backends/redisai_torch/lib/libtorch_cpu.so) frame #4: <unknown function> + 0x2f3bc9d (0x7f4647ccfc9d in /usr/lib/redis/modules/backends/redisai_torch/lib/libtorch_cpu.so) frame #5: <unknown function> + 0x2f3e26f (0x7f4647cd226f in /usr/lib/redis/modules/backends/redisai_torch/lib/libtorch_cpu.so) frame #6: torch::jit::load(std::unique_ptr<caffe2::serialize::ReadAdapterInterface, std::default_delete<caffe2::serialize::ReadAdapterInterface> >, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x179 (0x7f4647cd2bf9 in /usr/lib/redis/modules/backends/redisai_torch/lib/libtorch_cpu.so) frame #7: torch::jit::load(std::istream&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x75 (0x7f4647cd33f5 in /usr/lib/redis/modules/backends/redisai_torch/lib/libtorch_cpu.so) frame #8: torchLoadModel + 0x215 (0x7f465c91b425 in /usr/lib/redis/modules/backends/redisai_torch/redisai_torch.so) frame #9: RAI_ModelCreateTorch + 0x8a (0x7f465c9141ea in /usr/lib/redis/modules/backends/redisai_torch/redisai_torch.so) frame #10: RAI_ModelCreate + 0x16d (0x7f465c94ac8d in /usr/lib/redis/modules/redisai.so) frame #11: RedisAI_ModelSet_RedisCommand + 0x6ea (0x7f465c943f3a in /usr/lib/redis/modules/redisai.so) frame #12: RedisModuleCommandDispatcher + 0x54 (0x5608c664f114 in redis-server *:6379) frame #13: call + 0x9d (0x5608c65daffd in redis-server *:6379) frame #14: processCommand + 0x33f (0x5608c65db78f in redis-server *:6379) frame #15: processCommandAndResetClient + 0x10 (0x5608c65e9480 in redis-server *:6379) frame #16: processInputBuffer + 0x18f (0x5608c65edacf in redis-server *:6379) frame #17: <unknown function> + 0xd5fac (0x5608c666afac in redis-server *:6379) frame #18: aeProcessEvents + 0x2e7 (0x5608c65d4bd7 in redis-server *:6379) frame #19: aeMain + 0x1d (0x5608c65d4f1d in redis-server *:6379) frame #20: main + 0x4c9 (0x5608c65d17c9 in redis-server *:6379) frame #21: __libc_start_main + 0xeb (0x7f465c98909b in /lib/x86_64-linux-gnu/libc.so.6) frame #22: _start + 0x2a (0x5608c65d1a5a in redis-server *:6379) 