Skip to content

Commit

Permalink
Move translate chunks step to where it is needed
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalior committed Aug 31, 2018
1 parent 7827455 commit 5c68c08
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 3 additions & 1 deletion action_recognition/classifiers/classification_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging

from ..analysis import ChunkVisualiser
from ..transforms import TranslateChunks


class ClassificationVisualiser:
Expand Down Expand Up @@ -43,7 +44,7 @@ def plot_confusion_matrix(self, labels, test_labels, class_names, title):
plt.savefig(title + '.png', bbox_inches='tight')
plt.show(block=False)

def visualise_incorrect_classifications(self, pred_labels, test_labels, chunks, frames, translated_chunks, videos):
def visualise_incorrect_classifications(self, pred_labels, test_labels, chunks, frames, videos):
"""Displays videos of the incorrect classifications.
Can help with identifying features that can be added to the
Expand All @@ -63,6 +64,7 @@ def visualise_incorrect_classifications(self, pred_labels, test_labels, chunks,
The paths to the corresponding videos of the chunks.
"""
translated_chunks = TranslateChunks().transform(chunks)
visualiser = ChunkVisualiser(chunks, frames, translated_chunks)
unique_labels = set(pred_labels)
for pred_label in unique_labels:
Expand Down
3 changes: 1 addition & 2 deletions train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def main(args):
def train_classifier(train, test, title, classifier, visualise_incorrect_classifications):
train_chunks, _, train_labels, _ = train
test_chunks, test_frames, test_labels, test_videos = test
test_translated_chunks = transforms.TranslateChunks().transform(test_chunks)

logging.info("Fitting classifier.")
classifier.fit(train_chunks, train_labels)
Expand All @@ -63,7 +62,7 @@ def train_classifier(train, test, title, classifier, visualise_incorrect_classif

if visualise_incorrect_classifications:
visualiser.visualise_incorrect_classifications(
pred_labels, test_labels, test_chunks, test_frames, test_translated_chunks, test_videos)
pred_labels, test_labels, test_chunks, test_frames, test_videos)

file_name = "{}.pkl".format(title)
logging.info("Saving model to {}.".format(file_name))
Expand Down

0 comments on commit 5c68c08

Please sign in to comment.