In [1]:


import caffe

from caffe import layers as L, params as P, to_proto

from caffe.proto import caffe_pb2

from caffe.coord_map import crop

import copy
from nbfinder import NotebookFinder
import sys
sys.meta_path.append(NotebookFinder())
from layer_util import *
from make_network import make_netcdf_network
import numpy as np
import h5py
from make_solver import make_solver
import argparse
from os.path import join
#from accuracy import BBox_Accuracy

importing Jupyter notebook from layer_util.ipynb
importing Jupyter notebook from make_network.ipynb
importing Jupyter notebook from network_architecture.ipynb
importing Jupyter notebook from make_solver.ipynb


In [2]:
def get_source_path_from_proto(proto_file,phase="TRAIN"):
    with open(proto_file, "r") as f:
        lines = f.readlines()
    for i,line in enumerate(lines):
        if "source: " in line and phase in lines[i-3]:
            source_path = line.split("source: ")[-1].replace('"', "").strip("\n")
            break
    return source_path
            
        

In [3]:
def get_num_examples(prototxt_file,phase="TRAIN", time_stride=1, examples_per_file=8):
    source_path = get_source_path_from_proto(prototxt_file)
    
    with open(source_path, 'r') as f:
        num_examples = len(f.readlines()) * (examples_per_file / time_stride)
    return num_examples

In [4]:
def make_train(cl_args):
    tr_netspec, val_netspec = make_netcdf_network(batch_size=cl_args["tr_batch_size"],
                                                      data_path=cl_args["data_dir"],
                                                      modes=cl_args["mode"], 
                                                      filters_scale=cl_args["filters_scale"], copies=cl_args["copies"] )

    if val_netspec:
        ns = [val_netspec,tr_netspec ]
    else:
        ns = [tr_netspec]
    tr_net_filepath = write_to_file(ns, filename=cl_args["train_proto_name"], basepath=cl_args["proto_basepath"])

#     val_netspec = make_netcdf_network(batch_size=cl_args["val_batch_size"],data_path=cl_args["data_dir"], mode="val" ,filters_scale=cl_args["filters_scale"])
#     val_net_filepath = write_to_file(val_netspec, filename=cl_args["val_proto_name"], basepath=cl_args["proto_basepath"] )

In [5]:
def make_deploy(cl_args):
    pass

In [9]:
def create_solver(cl_args):
    
    tr_net_path = join(cl_args["proto_basepath"], cl_args["train_proto_name"] ) 
    num_tr = get_num_examples(tr_net_path, phase="TRAIN")
    
    num_val = get_num_examples(tr_net_path, phase="TEST")
    
    spstr,solver_filename = make_solver(net_path=cl_args["proto_basepath"],
                                        base_lr=cl_args["lr"],
                                        solver_name = cl_args["solver_name"],
                                        train_file_name=cl_args["train_proto_name"],
                                        #test_net_path=tr_net_path, 
                                        tr_num_examples=num_tr, 
                                        test_num_examples=num_val,snapshot_path=cl_args["snapshot_path"],
                                        tr_batch_size=cl_args["tr_batch_size"], 
                                        test_batch_size=cl_args["val_batch_size"], snapshot_frequency=cl_args["snapshot_frequency"],
                                        print_every_iteration=num_tr)
    return solver_filename

In [11]:
if __name__ == "__main__":
    cl_args = {"lr": 0.00001,
           "num_epochs": 20, 
           "filters_scale": 1./8,
           "data_dir": "/project/projectdirs/dasrepo/gordon_bell/deep_learning/networks/climate/2d_semi_sup/extremely_small_dataset/", 
           #"save_dir": "/global/homes/r/racah/projects/climate-caffe/2d_semi_sup/notebooks/plots",
           "tr_batch_size":32,
           "val_batch_size":64,
           "proto_basepath":"/global/homes/r/racah/projects/climate-caffe/2d_semi_sup/notebooks/proto_files/",
           "train_proto_name":"trval_foo.prototxt",
#             "val_proto_name": "val_foo.prototxt",
            "solver_name":"solver_foo"}
    
    make_train(cl_args)
    solver_path = create_solver(cl_args)
    solver = caffe.SGDSolver(solver_path)
    solver.step(1)
   
    