# Image classification on CINIC-10

Dataset source: [https://github.com/BayesWatch/cinic-10](https://github.com/BayesWatch/cinic-10)

License: [MIT](https://github.com/BayesWatch/cinic-10/blob/master/LICENSE)

### Call to action
If you found some bugs or have a new feature idea, don't hesitate to [open a new issue on Github](https://github.com/Toloka/toloka-kit/issues/new/choose).
Like our library and examples? Star [our repo on Github](https://github.com/Toloka/toloka-kit)

## Install dependencies and import

In [None]:
%%capture
!pip install toloka-kit==0.1.26
!pip install crowd-kit==1.0.0
!pip install ipyplot # display images

In [59]:
import datetime
import time
import pandas as pd
import numpy as np
import ipyplot
from sklearn.metrics import balanced_accuracy_score
import os
import logging
import sys
import getpass

import toloka.client as toloka
import toloka.client.project.template_builder as tb

from crowdkit.aggregation import DawidSkene
%matplotlib inline

In [60]:
logging.basicConfig(
    format='[%(levelname)s] %(name)s: %(message)s',
    level=logging.INFO,
    stream=sys.stdout,
)

# Load the dataset

In [61]:
N_ROWS = 1000

In [62]:
def sample_stratified(df, label_column, n_rows):
    """Function to sample n_rows from a dataframe while presenving class distribution"""
    return df.groupby(label_column, group_keys=False) \
            .apply(lambda x: x.sample(int(np.rint(n_rows*len(x)/len(df))))) \
            .sample(frac=1).reset_index(drop=True)

base_url = 'https://tlk.s3.yandex.net/ext_dataset/CINIC-10'
df = pd.read_csv(os.path.join(base_url, 'test.csv'))
df['img_url'] = df.img_path.apply(lambda p: os.path.join(base_url, p))

df = sample_stratified(df, 'label', n_rows=N_ROWS)
df.head()

Unnamed: 0,img_path,label,img_url
0,test/ship/n03545470_11883.png,ship,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...
1,test/dog/n02111889_9460.png,dog,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...
2,test/frog/n01645776_9028.png,frog,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...
3,test/frog/n01639765_21028.png,frog,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...
4,test/dog/cifar10-train-45915.png,dog,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...


In [63]:
ipyplot.plot_class_representations(images=df.img_url, labels=df.label, img_width=70)

# Setup the project

In [64]:
toloka_client = toloka.TolokaClient(getpass.getpass('Enter your OAuth token: '), 'PRODUCTION') # Or switch to 'SANDBOX'

## Create project

In [65]:
project = toloka.Project(
    public_name='Small images classification',
    public_description='Classify small images into 10 categories',
    private_comment='OOTB: CINIC-10'
)

In [66]:
input_specification = {'image': toloka.project.UrlSpec()}
output_specification = {'result': toloka.project.StringSpec()}

In [67]:
CINIC_LABELS = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
len(CINIC_LABELS)

10

## Annotator interface

In [68]:
image_viewer = tb.ImageViewV1(tb.InputData('image'), 
                              ratio=[1, 1],
                              popup=False,
                             )

label_buttons = [tb.GroupFieldOption(l, l.capitalize()) for l in CINIC_LABELS]
radio_group_field = tb.ButtonRadioGroupFieldV1(
    tb.OutputData('result'),
    label_buttons,
    validation=tb.RequiredConditionV1(),
)

task_width_plugin = tb.TolokaPluginV1(
    'scroll',
    task_width=300,
)

hot_keys_plugin = tb.HotkeysPluginV1(
    key_1=tb.SetActionV1(tb.OutputData('result'), 'airplane'),
    key_2=tb.SetActionV1(tb.OutputData('result'), 'automobile'),
    key_3=tb.SetActionV1(tb.OutputData('result'), 'bird'),
    key_4=tb.SetActionV1(tb.OutputData('result'), 'cat'),
    key_5=tb.SetActionV1(tb.OutputData('result'), 'deer'),
    key_6=tb.SetActionV1(tb.OutputData('result'), 'dog'),
    key_7=tb.SetActionV1(tb.OutputData('result'), 'frog'),
    key_8=tb.SetActionV1(tb.OutputData('result'), 'horse'),
    key_9=tb.SetActionV1(tb.OutputData('result'), 'ship'),
    key_0=tb.SetActionV1(tb.OutputData('result'), 'truck'),
)

project_interface = toloka.project.TemplateBuilderViewSpec(
    config=tb.TemplateBuilder(
        view=tb.ListViewV1([image_viewer, radio_group_field]),
        plugins=[task_width_plugin, hot_keys_plugin],
    )
)

project.task_spec = toloka.project.task_spec.TaskSpec(
    input_spec=input_specification,
    output_spec=output_specification,
    view_spec=project_interface,
)

In [69]:
project.public_instructions = """
In this task, you will see images from 10 different classes.<br/>
Your task is to classify these images.<br/>

<b>Some images are blurry and hard to label</b>. That's the nature of the task, so just assign whatever label seems most appropriate.

How to complete the task:
<ul>
<li>Look at the picture.</li>
<li>Click on the image to resize it. You can rotate the image if it's in the wrong orientation.</li>
<li>Chose one of the possible answers. If the picture is unavailable or you have any other technical difficulty, please write us about it.</li>
<li>If you think that you can not classify the image correctly, choose the most appropriate label in your opinion.</li>
<li>You can use keyboard shortcuts (numbers from 1 to 0) to pick labels.</li>
</ul>
""".strip()

In [70]:
project = toloka_client.create_project(project)

[INFO] toloka.client: A new project with ID "60772" has been created. Link to open in web interface: https://toloka.dev/requester/project/60772


## Create training tasks

In [71]:
training_pool = toloka.training.Training(project_id=project.id,
    private_name='Training pool',  
    training_tasks_in_task_suite_count=10, 
    task_suites_required_to_pass=1, 
    may_contain_adult_content=False,
    inherited_instructions=True,
    assignment_max_duration_seconds=60*5,
    retry_training_after_days=1,
    mix_tasks_in_creation_order=True,
    shuffle_tasks_in_task_suite=True,
)

In [72]:
training_pool = toloka_client.create_training(training_pool)

[INFO] toloka.client: A new training with ID "28225833" has been created. Link to open in web interface: https://toloka.dev/requester/project/60772/training/28225833


In [None]:
label_examples = {label: df[df.label == label].head(1).img_url.item() for label in CINIC_LABELS}
tasks = [
    toloka.Task(input_values={'image': url}, 
    known_solutions=[toloka.task.BaseTask.KnownSolution(output_values={'result': label})],   
    message_on_unknown_solution=f'Incorrect label! The actual label is: {label}',
    infinite_overlap=True,
    pool_id=training_pool.id)
    for label, url in label_examples.items()
]
toloka_client.create_tasks(tasks, allow_defaults=True)

## Create task Pool

In [74]:
pool = toloka.Pool(
    project_id=project.id,
    private_name='Pool',  
    may_contain_adult_content=False,
    reward_per_assignment=0.01,  
    assignment_max_duration_seconds=60*5,  
    will_expire=datetime.datetime.utcnow() + datetime.timedelta(days=365),
)
pool.defaults = toloka.pool.Pool.Defaults(
    default_overlap_for_new_tasks=5,
    default_overlap_for_new_task_suites=0,
)
pool.set_mixer_config(
    real_tasks_count=10,
)
pool.filter = toloka.filter.Languages.in_('EN')

In [75]:
pool.quality_control.training_requirement = toloka.quality_control.QualityControl.TrainingRequirement(
    training_pool_id=training_pool.id, 
    training_passing_skill_value=30,
) 

pool.quality_control.add_action(
    collector=toloka.collectors.MajorityVote(
        answer_threshold=4,
        history_size=5,
    ),
    conditions=[
        toloka.conditions.TotalAnswersCount >= 5,
        toloka.conditions.IncorrectAnswersRate > 30,
    ],
    action=toloka.actions.RestrictionV2(
        scope='PROJECT',
        duration=1,
        duration_unit='DAYS',
        private_comment='Wrong on over 30% cases',
    ),    
)

pool.quality_control.add_action(
    collector=toloka.collectors.AssignmentSubmitTime(history_size=5, fast_submit_threshold_seconds=15),
    conditions=[
        toloka.conditions.TotalSubmittedCount >= 5,
        toloka.conditions.FastSubmittedCount >= 3],
    action=toloka.actions.RestrictionV2(
        scope='PROJECT',
        duration=1,
        duration_unit='DAYS',
        private_comment='Answering too fast',
    ),    
)

pool.quality_control.add_action(
    collector=toloka.collectors.SkippedInRowAssignments(),
    conditions=[toloka.conditions.SkippedInRowCount >= 3],
    action=toloka.actions.RestrictionV2(
        scope=toloka.user_restriction.UserRestriction.PROJECT,
        duration=1,
        duration_unit='DAYS',
        private_comment='Lazy performer',
    )
)


In [76]:
pool = toloka_client.create_pool(pool)

[INFO] toloka.client: A new pool with ID "28225839" has been created. Link to open in web interface: https://toloka.dev/requester/project/60772/pool/28225839


## Create tasks from dataset

In [None]:
tasks = [
    toloka.Task(input_values={'image': url}, pool_id=pool.id)
    for url in df.img_url
]
toloka_client.create_tasks(tasks, allow_defaults=True)

# Start annotation

In [78]:
training_pool = toloka_client.open_pool(training_pool.id)
pool = toloka_client.open_pool(pool.id)

In [None]:
pool_id = pool.id

def wait_pool_for_close(pool_id, minutes_to_wait=0.5):
    sleep_time = 60 * minutes_to_wait
    pool = toloka_client.get_pool(pool_id)
    while not pool.is_closed():
        op = toloka_client.get_analytics([toloka.analytics_request.CompletionPercentagePoolAnalytics(subject_id=pool.id)])
        op = toloka_client.wait_operation(op)
        percentage = op.details['value'][0]['result']['value']
        print(
            f'   {datetime.datetime.now().strftime("%H:%M:%S")}\t'
            f'Pool {pool.id} - {percentage}%'
        )
        time.sleep(sleep_time)
        pool = toloka_client.get_pool(pool.id)
    print('Pool was closed.')

wait_pool_for_close(pool_id)

In [80]:
training_pool = toloka_client.close_pool(training_pool.id)

# Extract results

In [81]:
answers_df = toloka_client.get_assignments_df(pool_id)
answers_df = answers_df.rename(columns={
    'INPUT:image': 'task',
    'OUTPUT:result': 'label',
    'ASSIGNMENT:worker_id': 'worker',
})



# Aggregate results

In [82]:
aggregated_answers = DawidSkene(n_iter=100).fit_predict(answers_df)

In [83]:
aggregated_answers = aggregated_answers.reset_index()
aggregated_answers.columns = ['img_url', 'pred_label']
aggregated_answers = aggregated_answers.merge(df, on='img_url')
aggregated_answers.head()

Unnamed: 0,img_url,pred_label,img_path,label
0,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...,airplane,test/airplane/n02686121_2744.png,airplane
1,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...,deer,test/deer/cifar10-test-1720.png,deer
2,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...,deer,test/deer/n02419796_6293.png,deer
3,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...,cat,test/cat/n02125081_7666.png,cat
4,https://tlk.s3.yandex.net/ext_dataset/CINIC-10...,deer,test/deer/cifar10-test-4957.png,deer


# View results

In [84]:
sample = aggregated_answers.sample(10)
captions = [f'True: {row.label}\nPred: {row.pred_label}' for row in sample.itertuples()]

ipyplot.plot_images(
    images=sample.img_url.values,
    labels=captions,
    max_images=10,
    img_width=100,
)

# View mistakes

In [85]:
wrong_answers = aggregated_answers[aggregated_answers.pred_label != aggregated_answers.label]
sample = wrong_answers.sample(12)
captions = [f'True: {row.label}\nPred: {row.pred_label}' for row in sample.itertuples()]

ipyplot.plot_images(
    images=sample.img_url.values,
    labels=captions,
    max_images=10,
    img_width=100,
)

# Obtain accuracy

In [86]:
accuracy = balanced_accuracy_score(aggregated_answers.label, aggregated_answers.pred_label)
print(f'Accuracy: {accuracy:.2f}')
print(f'Error: {1-accuracy:.2f}')

Accuracy: 0.88
Error: 0.12
