-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
gd_mf_weights.cc
134 lines (113 loc) · 4.31 KB
/
gd_mf_weights.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <stdio.h>
#include "../vowpalwabbit/parser.h"
#include "../vowpalwabbit/vw.h"
#include <fstream>
#include <iostream>
#include <string.h>
#include <boost/program_options.hpp>
namespace po = boost::program_options;
int main(int argc, char *argv[])
{
using std::cout;
using std::string;
string infile;
string outdir(".");
string vwparams;
po::variables_map vm;
po::options_description desc("Allowed options");
desc.add_options()
("help,h", "produce help message")
("infile,I", po::value<string>(&infile), "input (in vw format) of weights to extract")
("outdir,O", po::value<string>(&outdir), "directory to write model files to (default: .)")
("vwparams", po::value<string>(&vwparams), "vw parameters for model instantiation (-i model.reg -t ...")
;
try
{ po::store(po::parse_command_line(argc, argv, desc), vm);
po::notify(vm);
}
catch(std::exception & e)
{cout << std::endl << argv[0] << ": " << e.what() << std::endl << std::endl << desc << std::endl;
exit(2);
}
if (vm.count("help") || infile.empty() || vwparams.empty())
{ cout << "Dumps weights for matrix factorization model (gd_mf)." << std::endl;
cout << "The constant will be written to <outdir>/constant." << std::endl;
cout << "Linear and quadratic weights corresponding to the input features will be " << std::endl;
cout << "written to <outdir>/<ns>.linear and <outdir>/<ns>.quadratic,respectively." << std::endl;
cout << std::endl;
cout << desc << "\n";
cout << "Example usage:" << std::endl;
cout << " Extract weights for user 42 and item 7 under randomly initialized rank 10 model:" << std::endl;
cout << " echo '|u 42 |i 7' | ./gd_mf_weights -I /dev/stdin --vwparams '-q ui --rank 10'" << std::endl;
return 1;
}
// initialize model
vw* model = VW::initialize(vwparams);
model->audit = true;
string target("--rank ");
size_t loc = vwparams.find(target);
const char* location = vwparams.c_str()+loc+target.size();
size_t rank = atoi(location);
// global model params
std::vector<unsigned char> first_pair;
for (auto const& i : model->interactions)
{
if(i.size() == 2)
{
first_pair = i;
break;
}
}
if(first_pair.size() != 2)
{
cout << "Model doesn't include a quadratic interaction." << std::endl;
return 2;
}
unsigned char left_ns = first_pair[0];
unsigned char right_ns = first_pair[1];
dense_parameters& weights = model->weights.dense_weights;
FILE* file;
VW::file_open(&file, infile.c_str(), "r");
char* line = NULL;
size_t len = 0;
ssize_t read;
// output files
std::ofstream constant((outdir + string("/") + string("constant")).c_str()),
left_linear((outdir + string("/") + string(1, left_ns) + string(".linear")).c_str()),
left_quadratic((outdir + string("/") + string(1, left_ns) + string(".quadratic")).c_str()),
right_linear((outdir + string("/") + string(1, right_ns) + string(".linear")).c_str()),
right_quadratic((outdir + string("/") + string(1, right_ns) + string(".quadratic")).c_str());
example *ec = NULL;
while ((read = getline(&line, &len, file)) != -1)
{ line[strlen(line)-1] = 0; // chop
ec = VW::read_example(*model, line);
// write out features for left namespace
features& left = ec->feature_space[left_ns];
for (size_t i = 0; i < left.size(); ++i)
{
left_linear << left.space_names[i].second << '\t' << weights[left.indicies[i]];
left_quadratic << left.space_names[i].second;
for (size_t k = 1; k <= rank; k++)
left_quadratic << '\t' << weights[(left.indicies[i] + k)];
}
left_linear << std::endl;
left_quadratic << std::endl;
// write out features for right namespace
features& right = ec->feature_space[right_ns];
for (size_t i = 0; i < right.size(); ++i)
{
right_linear << right.space_names[i].second << '\t' << weights[right.indicies[i]];
right_quadratic << right.space_names[i].second;
for (size_t k = 1; k <= rank; k++)
right_quadratic << '\t' << weights[(right.indicies[i] + k + rank)];
}
right_linear << std::endl;
right_quadratic << std::endl;
VW::finish_example(*model, *ec);
}
// write constant
constant << weights[ec->feature_space[constant_namespace].indicies[0]] << std::endl;
// clean up
VW::finish(*model);
fclose(file);
}