Skip to content
Fetching contributors…
Cannot retrieve contributors at this time
executable file 193 lines (163 sloc) 5.06 KB
#!/usr/bin/perl -w
#
# Generate & display convergence charts from vw progress outputs
# Requires R to generate the charts
#
use Getopt::Std;
use vars qw ($opt_d $opt_x $opt_y $opt_t
$opt_w $opt_h $opt_q $opt_Q
$opt_o
);
my $TmpImgFile = '/tmp/vw-convergence.png';
my $LossStr;
sub usage(@) {
print STDERR @_, "\n" if (@_);
die "Usage: $0 [options] [input_files...]
If input_file is missing, input is read from stdin.
Assuming inputs are one or more 'vw' progress reports
Requires R to generate the chart
Options:
-q Convert from squared-loss X to abs-loss (apply sqrt(X))
-Q Convert from squared-loss X to (exp(sqrt(X)) - 1.0) * 100
-o<IMG> Output chart to <IMG> file
-x<XL> Use <XL> as X-axis label in chart
-y<YL> Use <YL> as Y-axis label in chart
-t<T> Use <T> as title of chart
-w<W> set image width in pixels (default 800)
-h<W> set image height in pixels (default 600)
-o <image_file> is optional, if not given, will create a temp-file:
$TmpImgFile
and display it.
";
}
#
# App-specific transformation from squared-error of the log to percent.
#
sub sqlosslog2pct($) {
my $arg = shift;
# return 0 unless (defined $arg);
(exp(sqrt($arg)) - 1.0) * 100;
}
sub transform_loss($) {
my $loss = shift;
return sqlosslog2pct($loss) if ($opt_Q);
return sqrt($loss) if ($opt_q);
$loss;
}
#
# input: vw progress (can concatenate multiple) from STDIN or file
# output: a list of vectors of avg loss values
#
sub average_loss_arrays() {
my @avg_losses = ();
my @avg_loss = ();
while (<>) {
# progress lines
if (/^([0-9.]+)/) {
push(@avg_loss, transform_loss($1));
next;
}
# summary line
if (/^average loss\s*=\s*([0-9.]+)/) {
push(@avg_loss, transform_loss($1));
push(@avg_losses, [ @avg_loss ]);
@avg_loss = ();
next;
}
}
# if summary line wasn't included, and we have avg_loss data,
# use whatever we have
if (@avg_loss) {
push(@avg_losses, [ @avg_loss ]);
@avg_loss = ();
}
usage("Couldn't identify 'vw' progress report(s) in input")
unless (@avg_losses);
@avg_losses;
}
sub do_plot($;$) {
my ($loss_vec_ref, $imgfile) = @_;
my $set_r_device_line =
($imgfile =~ /\.e?ps$/) ?
"postscript(file='$imgfile', width=$opt_w, height=$opt_h)"
: ($imgfile =~ /\.jpg$/) ?
"jpeg(file='$imgfile', width=$opt_w, height=$opt_h)"
: # default: if you don't have 'Cairo', just use 'png' instead
"library(graphics); autoload('CairoPNG', 'Cairo');\n" .
"if (exists('Cairo', mode='function')) {\n" .
"CairoPNG(file='$imgfile', width=$opt_w, height=$opt_h)\n" .
"} else {\n" .
"png(file='$imgfile', width=$opt_w, height=$opt_h)\n" .
"}"
;
my $R_input = "$set_r_device_line;\n";
my $lineno = 0;
my @colors = (4, 2, 1, 3, 5, 6, 7);
my $ncols = @colors;
my @colorlist = ();
my @pchs = ();
foreach my $lossref (@$loss_vec_ref) {
my $R_losses_array = sprintf('c(%s)', join(',', @$lossref));
my $color = $colors[$lineno % $ncols];
$R_input .= "loss = $R_losses_array;\n";
$R_input .=
($lineno == 0) ?
"plot(loss, pch=20, t='o', col=$color, lwd=2, cex=1.25,\n" .
"\tlab=c(10,20,7), xlab='$opt_x', ylab='$opt_y',\n" .
"\tmain='$opt_t', panel.first=grid(col='gray63'));\n"
:
"lines(loss, pch=20, t='o', col=$color, lwd=2, cex=1.25);\n";
$lineno++;
push(@colorlist, $color);
push(@namelist, sprintf("%.4f", $lossref->[-1]));
push(@pchs, 20);
}
$R_input .= sprintf(
"legend('topright', legend=c(%s), col=c(%s), pch=c(%s), %s);\n",
join(',', @namelist),
join(',', @colorlist),
join(',', @pchs),
"inset=0.01, title='final mean $LossStr', pt.cex=2, lwd=2.5"
);
open(RIN, "|R --no-save --no-init-file");
print RIN $R_input;
close RIN;
}
sub get_args() {
$0 =~ s{.*/}{};
getopts('dx:y:t:w:h:qQo:');
$LossStr = ($opt_Q) ? '%loss' : 'loss';
$opt_x = 'vw progress iteration' unless (defined $opt_x);
$opt_y = "mean $LossStr"
unless (defined $opt_y);
$opt_t = "online training $opt_y convergence" unless (defined $opt_t);
$opt_w = 800 unless (defined $opt_w);
$opt_h = 600 unless (defined $opt_h);
$opt_o = $TmpImgFile
unless (defined $opt_o);
unlink($TmpImgFile) if (-e $TmpImgFile);
if (-e $opt_o && $opt_o ne $TmpImgFile) {
die "$0: image file '$opt_o' already exists: avoiding overwrite\n";
}
}
sub display($) {
my $imgfile = $_[0];
die "$0: display($imgfile): $! - must be a bug\n"
unless (-e $imgfile);
if ($opt_d) {
# User requested no display,
# Be friendly, give a hint that everything is OK
printf STDERR "image file is: %s\n", $imgfile;
return;
}
# Postscript files are generated in portrait: need 90-degrees
# rotation
my $rotate = ($imgfile =~ /ps$/) ? '-rotate 90' : '';
$opt_i = '' unless (defined $opt_i);
system("display $opt_i $rotate $imgfile");
}
# -- main
get_args();
my @losses = average_loss_arrays();
do_plot(\@losses, $opt_o);
display($opt_o);
Jump to Line
Something went wrong with that request. Please try again.