Skip to content

Commit

Permalink
Merge pull request #5 from SimulatedANeal/stats/refactor
Browse files Browse the repository at this point in the history
Stats/refactor
  • Loading branch information
SimulatedANeal committed May 24, 2018
2 parents cd5d174 + a9d0142 commit 71f9753
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 97 deletions.
22 changes: 15 additions & 7 deletions carpedm/data/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def code2hex(code):
"""Returns hex integer for a unicode string.
The argument code could either be an ascii representation,
e.g. U+3055, or a unicode character.
(e.g. U+3055, <UNK>) or a unicode character.
Args:
code (str): Code to convert.
Expand All @@ -29,15 +29,14 @@ def code2hex(code):
_ = code.encode('ascii')
except UnicodeEncodeError:
code = char2code(code)

if 'U+' in code:
code = code.lstrip('U+')
# Code is either 'U+XXXX(X)' code or unknown format.
code = code.lstrip('U+') if 'U+' in code else code

try:
result = int(code, 16)
except ValueError:
# Not a number, so probably just a raw ascii character.
result = int(char2code(code).lstrip('U+'), 16)
result = code

return result

Expand All @@ -46,7 +45,7 @@ def code2char(code):
"""Returns the unicode string for the character."""
try:
char = chr(code2hex(code))
except ValueError:
except (ValueError, TypeError):
char = code
return char

Expand All @@ -60,7 +59,11 @@ def char2code(unicode):
Raises:
TypeError: string is length two.
"""
return "U+{0:04x}".format(ord(unicode))
try:
code = "U+{0:04x}".format(ord(unicode))
except TypeError:
code = unicode
return code


class CharacterSet(object):
Expand Down Expand Up @@ -228,6 +231,7 @@ def __init__(self, reserved, vocab):
"""
self._vocab = {}
self._reserved = reserved

for ix, char in enumerate(vocab):
self._vocab[char] = ix
Expand Down Expand Up @@ -278,3 +282,7 @@ def id_to_char(self, char_id):
def get_num_classes(self):
"""Returns number of classes, includes <UNK>."""
return len(self._vocab)

def get_num_reserved(self):
"""Returns number of reserved IDs."""
return len(self._reserved)
123 changes: 33 additions & 90 deletions carpedm/data/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,11 @@
import warnings
from collections import Counter

import numpy as np

from carpedm.data.download import get_books_list
from carpedm.data.providers import TFDataSet
from carpedm.data.io import CSVParser, DataWriter
from carpedm.data.lang import JapaneseUnicodes, Vocabulary, code2char
from carpedm.data.providers import TFDataSet
from carpedm.data.stats import majority, ratio, ClassCounts


DEFAULT_SEED = 123456
Expand All @@ -97,30 +96,30 @@ def _get_split(split):
"""Returns (ratio, heldout)."""
valid_books = get_books_list('pmjtc')
if isinstance(split, float):
ratio = split
assert 0 <= ratio <= 1, "Invalid split {}.".format(split)
frac = split
assert 0 <= frac <= 1, "Invalid split {}.".format(split)
heldout = None
elif isinstance(split, str):
if split in valid_books:
ratio = None
frac = None
heldout = [split]
else:
try:
ratio = float(split)
frac = float(split)
heldout = None
except ValueError:
ratio = None
frac = None
heldout = split.split(',')
for bib in heldout:
assert bib in valid_books, "Invalid ID %s" % bib
else:
assert 0 <= ratio <= 1, "Invalid split {}.".format(split)
assert 0 <= frac <= 1, "Invalid split {}.".format(split)
else:
raise ValueError(
"Invalid split {}. Must be float or string.".format(split)
)

return ratio, heldout
return frac, heldout


def num_examples_per_epoch(data_dir, subset):
Expand Down Expand Up @@ -399,107 +398,51 @@ def view_images(self, subset, shape=None):
def data_stats(self,
which_sets=('train', 'dev', 'test'),
which_stats=('majority', 'frequency', 'unknowns'),
save_figures=False,
id_start=None,
id_stop=None):
save_dir=None, include=(None, None)):
"""Print or show data statistics.
Args:
which_sets (tuple): Data subsets to see statistics for.
which_stats (tuple): Statistics to view. Default gives all
options.
save_figures (bool): Save figures when possible.
id_start (int): lowest character ID to include in
visualizations.
id_stop (int): highest character ID to include in
visualizations.
save_dir (str): If not None, save figures/files to this
directory.
include (tuple): Include class IDs from this range.
"""

def all_labels_flat(sub):
meta = self._image_meta[sub]
return [u for img in meta for u in img.char_labels]
def alf(metadata):
"""Return all labels flat"""
return [u for img in metadata for u in img.char_labels]

if 'frequency' in which_stats:

try:
import matplotlib.pyplot as plt
except ImportError:
warnings.warn("The view_images method is not available."
"Please install matplotlib if you wish to use it.")
return
else:
from carpedm.data.viewer import font

rr = len(self.reserved_tokens)
start = rr if id_start is None else id_start
stop = self.vocab.get_num_classes() if id_stop is None else id_stop
nn = stop - start + 1
colors = ['green', 'blue', 'red']
fig, ax = plt.subplots()
rects = []
bar_width = 1.0 / (len(which_sets) + 1)
# center groups of bars
ind = np.arange(nn)
centers = ind + bar_width * len(which_sets) / 2.
max_count = 0
counts = ClassCounts()
for i in range(len(which_sets)):
chars = [self.vocab.char_to_id(u)
for u in all_labels_flat(which_sets[i])]
chars = [cid for cid in chars if start <= cid <= stop]
counts = Counter(chars)
x = ind + i * bar_width
y = [0] * nn
for cid, count in counts.items():
if count > max_count:
max_count = count
y[cid - start] = count
rects.append(ax.bar(x, y, bar_width, color=colors[i]))
char_list = [self.vocab.id_to_char(i)
for i in range(start, stop + 1)]
if start < rr:
char_list = char_list[:rr - start] + list(
map(code2char, char_list[rr - start:])
)
else:
char_list = map(code2char, char_list)

ax.set_ylabel("Counts")
ax.set_ylim([0, max_count])
ax.set_title("Characters {0}-{1} Relative Frequency".format(
start, stop))
ax.set_xlabel("Characters")
ax.set_xticks(centers)
ax.set_xticklabels(char_list, fontproperties=font(10))
ax.legend([r[0] for r in rects], which_sets)
ax.autoscale()
if save_figures:
plt.savefig("{0}-{1}_frequency.svg".format(start, stop))
plt.show()
name = which_sets[i]
data = self._image_meta[name]
if len(data) > 0:
counts.add_dataset(data=alf(data), label=name)
counts.plot_counts(vocab=self.vocab, include=include,
save_dir=save_dir)

if 'majority' in which_stats:
for primary in which_sets:
chars = all_labels_flat(primary)
total = len(chars)
counts = Counter(chars)
majority, count = counts.most_common(1)[0]
major, count, rr = majority(alf(self._image_meta[primary]))
print("Majority class from {0}: {1} ({2}), {3:.2f}%".format(
primary,
self.vocab.char_to_id(majority),
code2char(majority),
float(count) / total * 100.))
secondary = list(which_sets)
secondary.remove(primary)
self.vocab.char_to_id(major),
code2char(major),
rr * 100.))
secondary = list([s for s in which_sets if not s == primary])
for subset in secondary:
chars = all_labels_flat(subset)
total = len(chars)
count = chars.count(majority)
print("\t% of {0}: {1:.2f}".format(
subset, count / float(total) * 100.))
subset,
ratio(major, alf(self._image_meta[subset])) * 100.)
)

if 'unknowns' in which_stats:
for subset in which_sets:
chars = all_labels_flat(subset)
chars = set(chars)
chars = set(alf(self._image_meta[subset]))
unknowns = [c for c in chars if
self.vocab.id_to_char(self.vocab.char_to_id(c))
== "<UNK>"]
Expand Down

0 comments on commit 71f9753

Please sign in to comment.