-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
gd_mf_weights.cc
136 lines (113 loc) · 4.65 KB
/
gd_mf_weights.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include <cstdio>
#include <fstream>
#include <iostream>
#include <string>
#include "config/cli_help_formatter.h"
#include "config/option_builder.h"
#include "config/option_group_definition.h"
#include "config/options_cli.h"
#include "crossplat_compat.h"
#include "parser.h"
#include "vw.h"
int main(int argc, char* argv[])
{
using std::cout;
using std::string;
bool help = false;
string infile;
string outdir;
string vwparams;
VW::config::options_cli opts(std::vector<std::string>(argv + 1, argv + argc));
VW::config::option_group_definition desc("GD MF Weights");
desc.add(VW::config::make_option("help", help).short_name("h").help("Produce help message"))
.add(VW::config::make_option("infile", infile).short_name("I").help("Input (in vw format) of weights to extract"))
.add(VW::config::make_option("outdir", outdir).short_name("O").help("Directory to write model files to"))
.add(VW::config::make_option("vwparams", vwparams)
.help("vw parameters for model instantiation (-i model.reg -t ..."));
opts.add_and_parse(desc);
// Return value is ignored as option reachability is not relevant here.
auto warnings = opts.check_unregistered();
_UNUSED(warnings);
if (help || infile.empty() || vwparams.empty())
{
cout << "Dumps weights for matrix factorization model (gd_mf)." << std::endl;
cout << "The constant will be written to <outdir>/constant." << std::endl;
cout << "Linear and quadratic weights corresponding to the input features will be " << std::endl;
cout << "written to <outdir>/<ns>.linear and <outdir>/<ns>.quadratic,respectively." << std::endl;
cout << std::endl;
VW::config::cli_help_formatter help_formatter;
std::cout << help_formatter.format_help(opts.get_all_option_group_definitions()) << std::endl;
cout << "Example usage:" << std::endl;
cout << " Extract weights for user 42 and item 7 under randomly initialized rank 10 model:" << std::endl;
cout << " echo '|u 42 |i 7' | ./gd_mf_weights -I /dev/stdin --vwparams '-q ui --rank 10'" << std::endl;
return 1;
}
// initialize model
VW::workspace* model = VW::initialize(vwparams);
model->audit = true;
string target("--rank ");
size_t loc = vwparams.find(target);
const char* location = vwparams.c_str() + loc + target.size();
size_t rank = atoi(location);
// global model params
std::vector<unsigned char> first_pair;
for (auto const& i : model->interactions)
{
if (i.size() == 2)
{
first_pair = i;
break;
}
}
if (first_pair.size() != 2)
{
cout << "Model doesn't include a quadratic interaction." << std::endl;
return 2;
}
unsigned char left_ns = first_pair[0];
unsigned char right_ns = first_pair[1];
dense_parameters& weights = model->weights.dense_weights;
FILE* file;
VW::file_open(&file, infile.c_str(), "r");
char* line = nullptr;
size_t len = 0;
ssize_t read;
// output files
std::ofstream constant((outdir + string("/") + string("constant")).c_str());
std::ofstream left_linear((outdir + string("/") + string(1, left_ns) + string(".linear")).c_str());
std::ofstream left_quadratic((outdir + string("/") + string(1, left_ns) + string(".quadratic")).c_str());
std::ofstream right_linear((outdir + string("/") + string(1, right_ns) + string(".linear")).c_str());
std::ofstream right_quadratic((outdir + string("/") + string(1, right_ns) + string(".quadratic")).c_str());
VW::example* ec = nullptr;
while ((read = getline(&line, &len, file)) != -1)
{
line[strlen(line) - 1] = 0; // chop
ec = VW::read_example(*model, line);
// write out features for left namespace
features& left = ec->feature_space[left_ns];
for (size_t i = 0; i < left.size(); ++i)
{
left_linear << left.space_names[i].name << '\t' << weights[left.indices[i]];
left_quadratic << left.space_names[i].name;
for (size_t k = 1; k <= rank; k++) { left_quadratic << '\t' << weights[(left.indices[i] + k)]; }
}
left_linear << std::endl;
left_quadratic << std::endl;
// write out features for right namespace
features& right = ec->feature_space[right_ns];
for (size_t i = 0; i < right.size(); ++i)
{
right_linear << right.space_names[i].name << '\t' << weights[right.indices[i]];
right_quadratic << right.space_names[i].name;
for (size_t k = 1; k <= rank; k++) { right_quadratic << '\t' << weights[(right.indices[i] + k + rank)]; }
}
right_linear << std::endl;
right_quadratic << std::endl;
VW::finish_example(*model, *ec);
}
// write constant
constant << weights[ec->feature_space[constant_namespace].indices[0]] << std::endl;
// clean up
VW::finish(*model);
fclose(file);
}