Skip to content
Yue Wu edited this page Aug 29, 2016 · 2 revisions

Design & Extension of the Library

The design principle is to keep the package simple, easy to read and extend. All codes follow the C++11 standard and need no external libraries. The reason to choose C++ is for feasibility and efficiency in handling large scale high dimensional data. The system is designed so that machine learning researchers can quickly implement a new algorithm with a new idea, and compare it with a large family of existing algorithms without spending much time and efforts in handling large scale data.

In general, SOL is written in a modular way, including PARIO(for PARallel IO, whose key component is DataIter), Loss, and Model. User can extend it by inheriting the base classes of these modules and implementing the corresponding interfaces; Thus, we hope that SOL is not only a machine learning tool, but also a comprehensive experimental platform for conducting online learning research. The following figure shows the framework of the system..

How to Add New Algorithms

A salient property of this library is that it provides a fairly easy-to-use testbed to facilitate online learning researchers to develop their new algorithms and conduct side-by-side comparisons with the state-of-the-art algorithms on various datasets with various loss functions with the minimal efforts. More specifically, adding a new algorithm has to address three major issues:

  • What is the condition for making an update? This is usually equivalent to defining a proper loss function (e.g., a hinge loss) such that an update occurs wherever the loss is nonzero (i.e., l_t>0).

  • How to perform the update on the classifier (i.e.,the weight vector) whenever the condition is satisfied? For example, Perceptron updates $w = w + y_t*x_t$;

  • Are there some parameters in your new algorithm? If so and you want your algorithm can be serialized to and deserialized from files, you need to do some parameter setting (SetParameter), model serialization (GetModelInfo, GetModelParam) and deserialization (SetModelParam).

Model

Model is about all the learning algorithms. There are several child base classes for different kind of algorithms.

Model

This is the base class for all algorithms. It implements the main test strategy of algorithms and serialization/deserialization functions. The interfaces include:

  • Constructor:

    Model(int class_num, const std::string& type);

    Here the parameter type indicates whether the algorithm is an online algorithm, stochastic, or batch algorithm for future extension purposes.

  • Parameter Setting[optional]:

    virtual void SetParameter(const string& name, const string&  value);

    Each algorithm may have its own parameters with different names. This interface allows the new algorithms to parse their parameters easily. The function will throw an invalid_argument exception if a wrong parameter is set.

  • Training Initialization[optional]:

    virtual void BeginTrain();

    Some algorithms may need to do some initialization before training. If so, the algorithm can overload this function and place the initialization code here.

  • Training Finalization[optional]:

    virtual void EndTrain();

    Some algorithms may need to do some finalization after training. If so, the algorithm can overload this function and place the finalization code here. For example, for sparse online learning algorithms, they need to truncate weights after all iterations if they used the lazy-update strategy.

  • GetModelInfo[optional]:

    virtual void GetModelInfo(Json::Value& root) const;

    This function is for the serialization of the model to model files. If th new algorithm contains some hyper-parameters, it should overload this function to serialize the hyper-parameter to a json object.

  • GetModelParameter[optional]:

    virtual void GetModelParam(ostream& os) const;

    This function is for the serialization of the model to model files. If th new algorithm contains some other parameters, it should overload this function to serialize the parameters to an output stream. For example, for second order online learning algorithm, they often have another matrix about the second order information. While the base class only know there exists a weight vector, the algorithms should overload this interface and serialize the second order matrix by themselves.

  • SetModelParameter[optional]:

    virtual void SetModelParam(istream& os);

    This function is for the deserialization of the model from model files (the inverse step of GetModelParam). Generally, if an algorithm overloads GetModelParam, it should also overload SetModelParam.

Users may be confused that why they do not need to overload SetModelInfo. The reason is that new algorithms have already provided the SetParameter interface for the setting of hyper-parameters. The base class will automatically call this function during deserialization.

Online Model

This is the base class for all online algorithms. It implements the main training strategy of online algorithms and defined some shared hyper-parameters or auxiliary parameters for online learning algorithms. The two interfaces that users may need to take care are:

  • Predict in training[optional]:

    label_t TrainPredict(const DataPoint& dp, float* predicts);

    dp is the input data sample. predicts is the output prediction, with equal length to number of classifiers. The purpose of this interface is that some online algorithms need to do some calculation before prediction in each iteration, like the second order perceptron (SOP) algorithm. In other cases, there is no need to overload this interface.

  • Update dimension[optional]:

    void update_dim(index_t dim);

    For online algorithms, data samples are processed one by one. So the model does not know the dimension of the whole dataset. When new data sample comes, this function is called to ensure that the dimension of the model will be updated. This function should be overloaded in the same case of GetModelParam and SetModleParam, where extra model parameters exist for algorithms.

Online Linear Model

This is the base class for all online linear algorithms. The only required interface here is the update function, which is the key algorithm for online linear algorithms.

  • Update function[\textbf{required}]:

    void Update(const DataPoint& dp, const float* predicts, float loss);

    dp is the input data sample. predicts is the prediction on each classes. loss is the output of loss function.

Loss Function

At the moment, we provide a base class (purely virtual class) for loss functions and four inherited classes(BoolLoss, HingeLoss, LogisticLoss, and SquareLoss) for binary classification, as well as their max-score and uniform loss functions for multi-class classification. The interfaces are:

  • Loss:

    virtual float loss(const DataPoint& dp, float* predicts, label_t predict_label, int cls_num);
    

    Get the loss for the current predictions. The first parameter is the data sample. The second is the prediction on each class. The third is the predicted class label, and the last one is the number of classes.

  • Gradient:

    virtual float gradient(const DataPoint& dp, float* predicts, label_t predict_label, float* gradient, int cls_num);
    

    Get the gradient of the loss function at the current data point. Note that we do not calculate the exact gradient here. The dimension of the fourth parameter \textbf{gradient} is the number of classifiers. To linear classification problems, the gradients on different features share a same part. Take Hinge Loss for example:

    l(w) = 1 - y * w * x
    

    The gradient is:

    l`(w) = -y * x
    

    As a result, we only calculate the shared term $-y$ for the gradients of different features for efficiency concern. Users need to multiply the correspondent feature $x[i]$ in the optimization algorithms.

DataIter

DataIter is in charge of loading and transferring data from disk to models efficiently. The two major functions are: adding and data reader and providing parsed data.

Extension of Data Reader

Adding data reader means DataIter should allow users to add a reader to parse the source data files. Generally, users just need to specify the following function of DataIter to add a reader:

int AddReader(const std::string& path, const std::string& dtype, int pass_num = 1);
  • path: path to the data file;

  • dtype: type of the data file ("svm", "bin", "csv");

  • pass_num: number of passes to go through the data.

By default, dtype supports "svm", "bin", and "csv". Extension of Data reader is to implement a reader for a new data type.

The new type of data reader should inherit from the base class DataReader, (in include/sol/pario/data_reader.h) and implement the interfaces as follows:

  • Open data file:

    virtual int Open(const std::string& path, const char* model=``r'');

    Open a data file to load data. The function returns Status_OK(0) if everything is ok.

  • Parse data:

    virtual int Next(DataPoint& dst_data);

    Parse a new data point from the input source. The parameter is the place to store the parsed data. The function returns Status_OK(0) if everything is ok.

  • Rewind[optional]:

    For some special data formats like csv, the first line is some meta data. In this case, Rewind should be inherited and overloaded.

    virtual void Rewind();
  • Close[optional]:

    If the new reader needs to allocate some resources, it should overload the Close function to avoid memory leak.

    virtual voidnt Close();