In [51]:
%run '../catalog_common.ipynb'

In [52]:
#!pip3 install torchvision

In [53]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import binascii
from datetime import datetime

In [54]:
def search_onrl_models(catalog, name):
    obj_list = list_onrl(catalog)
    obj_list_2 = []
    for item in obj_list:
        if item["metadata"] is not None:
            obj_list_2.append(item)
#     obj_list_2 = [item for item in obj_list if item["metadata"] is not None]
    item_list = [item for item in obj_list_2 if item["metadata"]["name"].strip() == name.strip()]
    if len(item_list):
        match = item_list[0]
        if match is not None: 
            print("Model found in the Catalog with the following metadata")
            print(match["metadata"])
            return True, match["_id"], match["metadata"]
    return False, None, None

def slac_list_onrl_models(obj_list):
    mdata_list = []
    for item in obj_list:
        if item["metadata"] is not None:
            mdata_list.append(item["metadata"])
#     mdata_list =[item["metadata"] for item in obj_list if item["metadata"] is not None]
    names=[item["name"] for item in mdata_list]
    return names

def list_onrl(catalog): 
    onrl_scope = "edu.onrl.slac_collab"
    onrl_type = "model"
    text_entries = catalog.list_all()
#     item_list = my_parse(text_entries)
#     obj_list =[json.loads(item) for item in item_list]
    obj_list = json.loads(text_entries)
    return obj_list

In [55]:
base_dir = os.getcwd()
data_dir = os.path.join(base_dir, 'onrl', 'data')
temp_file = os.path.join(base_dir, 'onrl', 'temp.pt')
batch_size=128
learning_rate = 0.01
momentum = 0.5
device = "cpu"

In [56]:
class BasicModel(nn.Module):
    def __init__(self):
        super(BasicModel, self).__init__()
        #input channel 1, output channel 10
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5, stride=1)
        #input channel 10, output channel 20
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5, stride=1)
        #dropout layer
        self.conv2_drop = nn.Dropout2d()
        #fully connected layer
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.conv2_drop(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        x = x.view(-1, 320)
        x = self.fc1(x)
        x = F.relu(x)
        x = F.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=0)


def get_training_model(device="cpu"):
    model = BasicModel().to(device)
    return model 
    
    
def train_model(model, training_data, log_interval = 100):
    train_loader = DataLoader(training_data, shuffle=True, batch_size=batch_size)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        # print('data size =', data.shape)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            epoch=batch_idx // log_interval
            print('Train Step: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


            
def get_serialized_model(model):
    model_scripted = torch.jit.script(model)
    model_scripted.save(temp_file)
    return get_file_contents(temp_file)
        