In [4]:
from datetime import date
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import pandas as pd
import time

from commons import Commons


class QualityCheck(Commons):
    def __init__(self, project_dir, job_id, num_processes, marker):
        super().__init__(project_dir, job_id, num_processes)

        self.marker = marker

    def check_quality(self):
        records = pd.read_csv(glob.glob(f"{self.dir_merged_recs}/{self.job_id}_{self.marker}.tsv")[0],  header=0, sep='\t')
        self.check_dataset(records)

    def check_dataset(self, records):
        if 'qualified' not in records.columns:
            records['qualified'] = None
            records['dataset'] = None
            
        species_total = len(records['species_name'].unique())
            
        for i, species in enumerate(sorted(records['species_name'].unique())):
            print(f"Species {i} of {species_total}")
            
            test_rec_cnt = len(records[(records['species_name'] == species) & (records['qualified'].notnull())])
            if test_rec_cnt >= 2: continue
            
            for index, row in records[records['species_name'] == species].iterrows():
                if test_rec_cnt == 2 or pd.isnull(row['image_url']) or not pd.isnull(row['qualified']): continue

                qc = self.check_image(row)
                if qc:
                    records.loc[records['record_id'] == row['record_id'], ['dataset', 'qualified']] = 'val', date.today()
                    test_rec_cnt += 1
                else:
                    records.loc[records['record_id'] == row['record_id'], ['image_url', 'image_path', 'downloaded', 'duplicate']] = None, None, False, False

            records.to_csv(f"{self.dir_merged_recs}/{self.job_id}_{self.marker}.tsv", header=True,
                           index=False, sep='\t')

    def check_image(self, row):
        image_path = row['image_path']
        print(f"Species: {row['species_name']}")
        print(f"Image path: {image_path}")
        print(f"Image url: {row['image_url']}")
        print(f"Record id: {row['record_id']}")

        plt.imshow(mpimg.imread(image_path))
        plt.show()
        time.sleep(1.5)

        qc = input('Quality image? Type nothing if you do not want to use this picture.\n')
        if qc == 'y':
            return True
        else:
            print(f"Made a mistake? Here's the record id: {row['record_id']}")
            return False
    
    def add_train(self):
        records = pd.read_csv(f"{self.dir_merged_recs}/{self.job_id}_{self.marker}.tsv",  header=0, sep='\t')
        records = records.loc[records['image_url'].notnull(), :].copy()
        records.loc[records['dataset'].isnull(), 'dataset'] = 'train'
        records.to_csv(f"{self.dir_merged_recs}/{self.job_id}_{self.marker}.tsv", sep='\t', header=True, index=False)

In [None]:
NUM_PROCESSES = 4
JOB_ID = ''
PROJECT_DIR = ''
MARKER = ''

qc = QualityCheck(PROJECT_DIR, JOB_ID, NUM_PROCESSES, MARKER)

qc.check_quality()

In [6]:
qc.add_train()