Skip to content

Commit

Permalink
fixed bugs in the minibatch code---may still not work
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Hoffman committed Dec 4, 2010
1 parent 4c4d630 commit 2281fee
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
50 changes: 41 additions & 9 deletions lda.cc
Expand Up @@ -53,6 +53,14 @@ float decayfunc3(double t, double old_t, double power_t)
return (old_t / t) * exp(0.5*power_t_plus_one*(-logt*logt + logoldt*logoldt));
}

float decayfunc4(double t, double old_t, double power_t)
{
if (power_t > 0.99)
return decayfunc3(t, old_t, power_t);
else
return decayfunc2(t, old_t, power_t);
}

void expdigammify(float* gamma)
{
float sum=0;
Expand Down Expand Up @@ -161,8 +169,14 @@ void merge_pair(v_array<index_triple>& source, v_array<index_triple>& dest)

for (index_triple* s=source.begin; s != source.end; s++)
{
while((old_index < limit) && (dest[old_index].f.weight_index < s->f.weight_index))
// fprintf(stderr, "moving dest to dest\n");
while((old_index < limit) && (dest[old_index].f.weight_index < s->f.weight_index)) {
// fprintf(stderr, "moving dest[%d] to dest[%d] --- %d to %d\n", old_index, new_index, dest[old_index].f.weight_index/global.stride, dest[new_index].f.weight_index/global.stride);
dest[new_index++] = dest[old_index++];
}
// for (index_triple* i = dest.begin; i != dest.begin + new_index; i++)
// fprintf(stderr, "%d \t %d\n", (i-dest.begin), i->f.weight_index/global.stride);
// fprintf(stderr, "moving s to dest[%d] --- %d to %d\n", new_index, s->f.weight_index/global.stride, dest[new_index].f.weight_index/global.stride);
dest[new_index++] = *s;
}
source.erase();
Expand All @@ -189,6 +203,17 @@ void merge_all()
}
}

void bubble_sort(feature* f0, feature*f1) {
for (; f0 < f1; f1--)
for (feature* f = f0+1; f < f1; f++)
if (f->weight_index < (f-1)->weight_index) {
feature temp = *f;
*f = *(f-1);
*(f-1) = temp;
}
}

// TODO: this doesn't work, because the features don't come in in order.
void merge_in(example* ec, size_t document)
{
size_t next_index = merge_set.index();
Expand All @@ -197,6 +222,7 @@ void merge_in(example* ec, size_t document)
for (size_t* i = ec->indices.begin; i != ec->indices.end; i++)
{
feature* f = ec->subsets[*i][0];
bubble_sort(f, ec->subsets[*i][1]);
for (; f != ec->subsets[*i][1]; f++)
{
index_triple temp = {document,*f};
Expand Down Expand Up @@ -250,16 +276,22 @@ void start_lda(gd_thread_params t)
}
else if (thread_done(0))
batch_size = d;
else
else
d--;
}

merge_all(); //Now merge_set[0] contains everything.

// fprintf(stderr, "merge_set[0].index() = %d\n", merge_set[0].index());
// for (int i = 0; i < merge_set[0].index(); i++) {
// fprintf(stderr, "merge_set[0][%d] = %d --- %d\n", i,
// merge_set[0][i].f.weight_index/global.stride, merge_set[0][i].document);
// }

example_t = (examples[0]->example_t - global.initial_t)/global.minibatch + global.initial_t;
eta = (global.eta*global.lda_D) /
(pow(example_t, t.vars->power_t)*batch_size);
minuseta = decayfunc3(example_t, example_t-1, power_t);
minuseta = decayfunc4(example_t, example_t-1, power_t);

float digammas[global.lda];
float additional = (float)(global.length()) * global.lda_rho;
Expand All @@ -273,8 +305,8 @@ void start_lda(gd_thread_params t)
continue;
last_weight_index = s->f.weight_index;
float* weights_for_w = &(weights[s->f.weight_index & global.thread_mask]);
float olddecay = decayfunc3(example_t-1, weights_for_w[global.lda], power_t);
float decay = decayfunc3(example_t, weights_for_w[global.lda], power_t);
float olddecay = decayfunc4(example_t-1, weights_for_w[global.lda], power_t);
float decay = decayfunc4(example_t, weights_for_w[global.lda], power_t);
float* u_for_w = weights_for_w + global.lda+1;

weights_for_w[global.lda] = example_t;
Expand All @@ -290,7 +322,7 @@ void start_lda(gd_thread_params t)
float v[batch_size*global.lda];

for (size_t d = 0; d < batch_size; d++)
lda_loop(&v[d*global.lda],weights,ec,t.vars->power_t);
lda_loop(&v[d*global.lda],weights,examples[d],t.vars->power_t);

for (index_triple* s = merge_set[0].begin; s != merge_set[0].end;)
{
Expand Down Expand Up @@ -324,14 +356,14 @@ void start_lda(gd_thread_params t)
for (size_t d = 0; d < batch_size; d++)
{
if (global.audit)
print_audit_features(reg, ec);
finish_example(ec);
print_audit_features(reg, examples[d]);
finish_example(examples[d]);
}
if (thread_done(0))
{
for (size_t i = 0; i < global.length(); i++) {
weight* weights_for_w = & (weights[i*global.stride]);
float decay = decayfunc3(example_t, weights_for_w[global.lda], power_t);
float decay = decayfunc4(example_t, weights_for_w[global.lda], power_t);
for (size_t k = 0; k < global.lda; k++)
weights_for_w[k] *= decay;
}
Expand Down
10 changes: 6 additions & 4 deletions parse_regressor.cc
Expand Up @@ -36,19 +36,21 @@ void initialize_regressor(regressor &r)
r.weight_vectors[i][j] = global.initial_weight;
if (global.random_weights)
for (size_t j = 0; j < length/num_threads; j++) {
r.weight_vectors[i][j] = -log(drand48());
r.weight_vectors[i][j] *= r.weight_vectors[i][j];
r.weight_vectors[i][j] *= r.weight_vectors[i][j];
r.weight_vectors[i][j] = drand48() - 0.5;
}
if (global.lda)
{
size_t stride = global.stride;

for (size_t j = 0; j < stride*length/num_threads; j+=stride)
{
for (size_t k = 0; k < global.lda; k++)
for (size_t k = 0; k < global.lda; k++) {
r.weight_vectors[i][j+k] = -log(drand48());
r.weight_vectors[i][j+k] *= r.weight_vectors[i][j+k];
r.weight_vectors[i][j+k] *= r.weight_vectors[i][j+k];
r.weight_vectors[i][j+k] *= (float)global.lda_D / (float)global.lda
/ global.length() * 200;
}
r.weight_vectors[i][j+global.lda] = global.initial_t;
}
}
Expand Down

0 comments on commit 2281fee

Please sign in to comment.