# BigQuery ML Semi-supervised Self-training Classification with mnist Dataset

## Imports and project variables

In [3]:
import os
import shutil
from google.cloud import bigquery
import numpy as np
import pandas as pd
import tensorflow as tf
print(tf.__version__)

1.14.0


In [2]:
# Allow you to easily have Python variables in SQL query.
from IPython.core.magic import register_cell_magic
from IPython import get_ipython


@register_cell_magic("with_globals")
def with_globals(line, cell):
    contents = cell.format(**globals())
    if "print" in line:
        print(contents)
    get_ipython().run_cell(contents)

In [4]:
# change these to try this notebook out
# PROJECT = "cloud-training-demos"
# BUCKET = "cloud-training-demos-ml"
PROJECT = "qwiklabs-gcp-8312a1428d9eb5e2"
BUCKET = "qwiklabs-gcp-8312a1428d9eb5e2-bucket"
REGION = "us-central1"

In [5]:
os.environ["PROJECT"] = PROJECT
os.environ["BUCKET"] = BUCKET
os.environ["REGION"] = REGION

## Create data

In [6]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [7]:
print("x_train.shape = {}".format(x_train.shape))
print("y_train.shape = {}".format(y_train.shape))
print("x_test.shape = {}".format(x_test.shape))
print("y_test.shape = {}".format(y_test.shape))

x_train.shape = (60000, 28, 28)
y_train.shape = (60000,)
x_test.shape = (10000, 28, 28)
y_test.shape = (10000,)


In [8]:
x_train_flat = x_train.reshape(
  x_train.shape[0], x_train.shape[1] * x_train.shape[2])
x_train_flat.shape

(60000, 784)

In [9]:
x_test_flat = x_test.reshape(
  x_test.shape[0], x_test.shape[1] * x_test.shape[2])
x_test_flat.shape

(10000, 784)

In [10]:
train = np.concatenate([x_train_flat, np.expand_dims(y_train, -1),
                        np.random.rand(x_train_flat.shape[0], 1)],
                       axis = 1)
train.shape

(60000, 786)

In [11]:
test = np.concatenate([x_test_flat,
                       np.expand_dims(y_test, -1)],
                      axis = 1)
test.shape

(10000, 785)

In [12]:
train_df = pd.DataFrame(
  train,
  columns=["v_" + str(i)
           for i in range(x_train_flat.shape[1])] + ["label", "rand"])
train_df.head()

Unnamed: 0,v_0,v_1,v_2,v_3,v_4,v_5,v_6,v_7,v_8,v_9,...,v_776,v_777,v_778,v_779,v_780,v_781,v_782,v_783,label,rand
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,0.287787
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.284469
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.916785
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.378841
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.0,0.363079


In [13]:
test_df = pd.DataFrame(
  test,
  columns=["v_" + str(i)
           for i in range(x_test_flat.shape[1])] + ["label"])
test_df.head()

Unnamed: 0,v_0,v_1,v_2,v_3,v_4,v_5,v_6,v_7,v_8,v_9,...,v_775,v_776,v_777,v_778,v_779,v_780,v_781,v_782,v_783,label
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0


In [14]:
train_df.describe()

Unnamed: 0,v_0,v_1,v_2,v_3,v_4,v_5,v_6,v_7,v_8,v_9,...,v_776,v_777,v_778,v_779,v_780,v_781,v_782,v_783,label,rand
count,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,...,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0,60000.0
mean,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.000179,7.6e-05,5.9e-05,8e-06,0.0,0.0,0.0,0.0,4.453933,0.499276
std,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.011137,0.006615,0.006582,0.001359,0.0,0.0,0.0,0.0,2.88927,0.288386
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.2e-05
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.249056
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.498355
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0,0.74947
max,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.992157,0.992157,0.996078,0.243137,0.0,0.0,0.0,0.0,9.0,0.999963


In [15]:
test_df.describe()

Unnamed: 0,v_0,v_1,v_2,v_3,v_4,v_5,v_6,v_7,v_8,v_9,...,v_775,v_776,v_777,v_778,v_779,v_780,v_781,v_782,v_783,label
count,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,...,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0
mean,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.000642,0.000206,2e-06,0.0,0.0,0.0,0.0,0.0,0.0,4.4434
std,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.022494,0.00949,0.000235,0.0,0.0,0.0,0.0,0.0,0.0,2.895865
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0
max,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.992157,0.611765,0.023529,0.0,0.0,0.0,0.0,0.0,0.0,9.0


In [16]:
train_df.to_csv("mnist_train.csv", index=False)
test_df.to_csv("mnist_test.csv", index=False)

In [17]:
!head -2 mnist_train.csv

v_0,v_1,v_2,v_3,v_4,v_5,v_6,v_7,v_8,v_9,v_10,v_11,v_12,v_13,v_14,v_15,v_16,v_17,v_18,v_19,v_20,v_21,v_22,v_23,v_24,v_25,v_26,v_27,v_28,v_29,v_30,v_31,v_32,v_33,v_34,v_35,v_36,v_37,v_38,v_39,v_40,v_41,v_42,v_43,v_44,v_45,v_46,v_47,v_48,v_49,v_50,v_51,v_52,v_53,v_54,v_55,v_56,v_57,v_58,v_59,v_60,v_61,v_62,v_63,v_64,v_65,v_66,v_67,v_68,v_69,v_70,v_71,v_72,v_73,v_74,v_75,v_76,v_77,v_78,v_79,v_80,v_81,v_82,v_83,v_84,v_85,v_86,v_87,v_88,v_89,v_90,v_91,v_92,v_93,v_94,v_95,v_96,v_97,v_98,v_99,v_100,v_101,v_102,v_103,v_104,v_105,v_106,v_107,v_108,v_109,v_110,v_111,v_112,v_113,v_114,v_115,v_116,v_117,v_118,v_119,v_120,v_121,v_122,v_123,v_124,v_125,v_126,v_127,v_128,v_129,v_130,v_131,v_132,v_133,v_134,v_135,v_136,v_137,v_138,v_139,v_140,v_141,v_142,v_143,v_144,v_145,v_146,v_147,v_148,v_149,v_150,v_151,v_152,v_153,v_154,v_155,v_156,v_157,v_158,v_159,v_160,v_161,v_162,v_163,v_164,v_165,v_166,v_167,v_168,v_169,v_170,v_171,v_172,v_173,v_174,v_175,v_176,v_177,v_178,v_179,v_180,v_181,v_182,v_183,v_184,

In [None]:
%%bash
gcloud storage cp mnist*.csv gs://${BUCKET}

## Write data to BigQuery

In [19]:
client = bigquery.Client()
dataset_id = "semi"
dataset_ref = client.dataset(dataset_id)
feature_schema = [bigquery.SchemaField(
  name="v_{}".format(i),
  field_type="FLOAT64",
  mode="NULLABLE",
  description="Feature {}".format(i))
                  for i in range(x_train_flat.shape[-1])]
label_schema = [bigquery.SchemaField(
  name="label",
  field_type="FLOAT64",
  mode="NULLABLE",
  description="Label")]
rand_schema = [bigquery.SchemaField(
  name="rand",
  field_type="FLOAT64",
  mode="NULLABLE",
  description="Random number")]
job_config = bigquery.LoadJobConfig()
job_config.schema = feature_schema + label_schema + rand_schema
job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
job_config.skip_leading_rows = 1
# The source format defaults to CSV, so the line below is optional.
job_config.source_format = bigquery.SourceFormat.CSV

In [20]:
def load_csv_data_to_bigquery(client, dataset_ref, job_config, name):
  uri = "gs://{bucket}/{name}.csv".format(bucket=BUCKET, name=name)

  load_job = client.load_table_from_uri(
      uri, dataset_ref.table(name), job_config=job_config
  )  # API request
  print("Starting job {}".format(load_job.job_id))

  load_job.result()  # Waits for table load to complete.
  print("Job finished.")

  destination_table = client.get_table(dataset_ref.table(name))
  print("Loaded {} rows.".format(destination_table.num_rows))

  return None

### Train set

In [21]:
job_config.schema = feature_schema + label_schema + rand_schema
load_csv_data_to_bigquery(client, dataset_ref, job_config, "mnist_train")

Starting job d0624e4a-8806-4f6b-8758-2305f7bd447c
Job finished.
Loaded 60000 rows.


### Test set

In [22]:
job_config.schema = feature_schema + label_schema
load_csv_data_to_bigquery(client, dataset_ref, job_config, "mnist_test")

Starting job 3e514686-96b1-4919-bff9-4c67a1c6d323
Job finished.
Loaded 10000 rows.


## Create semi-supervised simulated splits

In [23]:
PERCENT_LABELED = 10.0

In [24]:
def create_semi_supervised_simulated_splits_in_bigquery(dataset_id, sql, name):
  job_config = bigquery.QueryJobConfig()
  # Set the destination table
  table_ref = client.dataset(dataset_id).table(name)
  job_config.destination = table_ref
  job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
  # Start the query, passing in the extra configuration.
  query_job = client.query(
      sql,
      # Location must match that of the dataset(s) referenced in the query
      # and of the destination table.
      location="US",
      job_config=job_config)  # API request - starts the query

  query_job.result()  # Waits for the query to finish
  print('Query results loaded to table {}'.format(table_ref.path))

  return None

### Labeled

In [25]:
def create_labeled_train_set(project, dataset_id, percent_labeled):
  mnist_train_labeled_sql = """
  SELECT
    * EXCEPT(rand)
  FROM
    `{project}.{dataset}.{table}`
  WHERE rand < {percent}
  """.format(
    project=project,
    dataset=dataset_id,
    table="mnist_train",
    percent=percent_labeled / 100.0)

  create_semi_supervised_simulated_splits_in_bigquery(
    dataset_id, mnist_train_labeled_sql, "mnist_train_labeled")

  return None

In [26]:
create_labeled_train_set(PROJECT, dataset_id, PERCENT_LABELED)

Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_labeled


### Unlabeled

In [27]:
def create_unlabeled_train_set(project, dataset_id, percent_labeled):
  mnist_train_unlabeled_sql = """
  SELECT
    * EXCEPT(rand)
  FROM
    `{project}.{dataset}.{table}`
  WHERE rand >= {percent}
  """.format(
    project=project,
    dataset=dataset_id,
    table="mnist_train",
    percent=percent_labeled / 100.0)

  create_semi_supervised_simulated_splits_in_bigquery(
    dataset_id, mnist_train_unlabeled_sql, "mnist_train_unlabeled")

  return None

In [28]:
create_unlabeled_train_set(PROJECT, dataset_id, PERCENT_LABELED)

Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_unlabeled


## BQML

### Train model on labeled train set

In [29]:
def bqml_train_model_on_labeled_dataset():
  query_job = client.query("""
  CREATE OR REPLACE MODEL
    `bqml_ssl.self_training`
  OPTIONS
    ( model_type="logistic_reg",
      auto_class_weights=true,
      input_label_cols = ["label"]) AS
  SELECT
    *
  FROM
    `semi.mnist_train_labeled`
  """)

  try:
    query_job.result()
  finally:
    print("Training complete.")

  return None

In [30]:
bqml_train_model_on_labeled_dataset()

Training complete.


### Look at training info

In [31]:
def bqml_training_info():
  query_job = client.query("""
  SELECT
      *
  FROM
      ML.TRAINING_INFO(MODEL `bqml_ssl.self_training`)
  """)

  results = query_job.result()  # Waits for job to complete.

  return results

In [32]:
pd.DataFrame([{key: value for key, value in row.items()} for row in bqml_training_info()])

Unnamed: 0,duration_ms,eval_loss,iteration,learning_rate,loss,training_run
0,56049,0.034381,9,1.6,0.025464,0
1,48802,0.034698,8,0.8,0.02648,0
2,45422,0.035927,7,3.2,0.028079,0
3,50042,0.037224,6,1.6,0.030945,0
4,56957,0.039116,5,0.8,0.033941,0
5,48610,0.040741,4,0.4,0.036025,0
6,50531,0.047081,3,1.6,0.04192,0
7,55793,0.051897,2,0.8,0.04943,0
8,55009,0.069936,1,0.4,0.068228,0
9,49424,0.11537,0,0.2,0.114392,0


### Evaluate on test set

In [33]:
def bqml_evaluate_on_test_dataset():
  query_job = client.query("""
  SELECT
    *
  FROM
    ML.EVALUATE(MODEL `bqml_ssl.self_training`,
    (SELECT * FROM `semi.mnist_test`))
  """)

  results = query_job.result()  # Waits for job to complete.

  return results

In [34]:
pd.DataFrame([{key: value for key, value in row.items()}
              for row in bqml_evaluate_on_test_dataset()])

Unnamed: 0,accuracy,f1_score,log_loss,precision,recall,roc_auc
0,0.8993,0.897872,1.840486,0.898469,0.898071,0.960578


### Predict on unlabeled train set

In [35]:
def bqml_predict_unlabeled_dataset():
  query_job = client.query("""
  SELECT
      * EXCEPT(predicted_label_probs, label)
  FROM
      ML.PREDICT(MODEL `bqml_ssl.self_training`,
                 (SELECT * FROM `semi.mnist_train_unlabeled` LIMIT 10)),
    UNNEST(predicted_label_probs) AS unnested_predicted_label_probs
  """)

  results = query_job.result()  # Waits for job to complete.

  return results

In [36]:
pd.DataFrame([{key: value for key, value in row.items()}
              for row in bqml_predict_unlabeled_dataset()])

Unnamed: 0,predicted_label,prob,v_0,v_1,v_10,v_100,v_101,v_102,v_103,v_104,...,v_90,v_91,v_92,v_93,v_94,v_95,v_96,v_97,v_98,v_99
0,0.0,0.149127,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.133352,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.133298,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.127621,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.107521,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.0,0.061369,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
96,0.0,0.059155,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
97,0.0,0.058982,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
98,0.0,0.058966,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Check confidence

In [37]:
percent_over_random = 80.0
number_of_classes = 10
confidence_percent = (1.0 + percent_over_random / 100.0) / number_of_classes

In [38]:
features_list = ["v_{}".format(i) for i in range(x_train_flat.shape[-1])]
features = ",\n  ".join(features_list)

In [39]:
confidence_query = """
WITH
  CTE_gen_ids AS (
  SELECT
    ROW_NUMBER() OVER () AS row_id,
    *
  FROM
    ML.PREDICT(MODEL `bqml_ssl.self_training`,
      (
      SELECT
        *
      FROM
        `semi.mnist_train_unlabeled`))),
  CTE_max_probs AS (
  SELECT
    row_id,
    MAX(unnested_predicted_label_probs.prob) AS max_prob
  FROM
    CTE_gen_ids,
    UNNEST(predicted_label_probs) AS unnested_predicted_label_probs
  GROUP BY
    row_id),
  CTE_filtered_max_probs AS (
  SELECT
    *
  FROM
    CTE_max_probs
  WHERE
    max_prob {inequality} {confidence_percent})
SELECT
  {features}{label}
FROM
  CTE_filtered_max_probs AS A
INNER JOIN
  CTE_gen_ids AS B
ON
  A.row_id = B.row_id
"""

In [40]:
high_confidence_features_label_query = confidence_query.format(
  inequality=">=",
  confidence_percent=confidence_percent,
  features=features,
  label=", predicted_label AS label")

In [41]:
high_confidence_features_query = confidence_query.format(
  inequality=">=",
  confidence_percent=confidence_percent,
  features=features,
  label="")

In [42]:
low_confidence_features_query = confidence_query.format(
  inequality="<",
  confidence_percent=confidence_percent,
  features=features,
  label="")

In [43]:
%%with_globals
%%bigquery --project $PROJECT
{high_confidence_features_label_query}

Unnamed: 0,v_0,v_1,v_2,v_3,v_4,v_5,v_6,v_7,v_8,v_9,...,v_775,v_776,v_777,v_778,v_779,v_780,v_781,v_782,v_783,label
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
457,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0
458,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0
459,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.0
460,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0


### Check initial table counts

In [44]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_labeled`

Unnamed: 0,row_count
0,5963


In [45]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_unlabeled`

Unnamed: 0,row_count
0,54037


In [46]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count
FROM ({high_confidence_features_query})

Unnamed: 0,row_count
0,462


In [47]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count
FROM ({low_confidence_features_query})

Unnamed: 0,row_count
0,53575


## Adjust tables based on confidence of predictions

### Add high confidence examples to labeled dataset with predicted labels

In [48]:
def add_high_confidence_examples_to_labeled(
  dataset_id, high_confidence_features_label_query):
  job_config = bigquery.QueryJobConfig()
  # Set the destination table
  table_ref = client.dataset(dataset_id).table("mnist_train_labeled")
  job_config.destination = table_ref
  job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
  # Start the query, passing in the extra configuration.
  query_job = client.query(
      high_confidence_features_label_query,
      # Location must match that of the dataset(s) referenced in the query
      # and of the destination table.
      location="US",
      job_config=job_config)  # API request - starts the query

  query_job.result()  # Waits for the query to finish
  print('Query results loaded to table {}'.format(table_ref.path))

  return None

In [49]:
add_high_confidence_examples_to_labeled(
  dataset_id, high_confidence_features_label_query)

Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_labeled


### Check updated table counts

In [50]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_labeled`

Unnamed: 0,row_count
0,6425


In [51]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_unlabeled`

Unnamed: 0,row_count
0,54037


### Remove high confidence examples from unlabeled dataset

In [52]:
def remove_high_confidence_examples_from_unlabeled(
  dataset_id, low_confidence_features_query):
  job_config = bigquery.QueryJobConfig()
  # Set the destination table
  table_ref = client.dataset(dataset_id).table("mnist_train_unlabeled")
  job_config.destination = table_ref
  job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
  # Start the query, passing in the extra configuration.
  query_job = client.query(
      low_confidence_features_query,
      # Location must match that of the dataset(s) referenced in the query
      # and of the destination table.
      location="US",
      job_config=job_config)  # API request - starts the query

  query_job.result()  # Waits for the query to finish
  print('Query results loaded to table {}'.format(table_ref.path))

  return None

In [53]:
remove_high_confidence_examples_from_unlabeled(
  dataset_id, low_confidence_features_query)

Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_unlabeled


### Check updated table counts

In [54]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_labeled`

Unnamed: 0,row_count
0,6425


In [55]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_unlabeled`

Unnamed: 0,row_count
0,53575


# Semi-supervised Self-training Loop

## Reset labeled and unlabeled datasets

In [56]:
create_labeled_train_set(PROJECT, dataset_id, PERCENT_LABELED)

Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_labeled


In [57]:
create_unlabeled_train_set(PROJECT, dataset_id, PERCENT_LABELED)

Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_unlabeled


## Loop until no improvement

In [58]:
old_accuracy = 0.0
max_iterations = 5
iteration = 0
while iteration < max_iterations:
  print("Iteration = {}".format(iteration))

  # Train model on labeled dataset
  print("Starting training.")
  bqml_train_model_on_labeled_dataset()

  # Evaluate model on test set
  print("Starting evaluation.")
  eval_metrics = pd.DataFrame([{key: value for key, value in row.items()}
                               for row in bqml_evaluate_on_test_dataset()])
  print("eval_metrics = {}".format(eval_metrics))

  # Extract accuracy from eval metrics
  accuracy = eval_metrics["accuracy"][0]

  accuracy_improvement = accuracy - old_accuracy
  old_accuracy = accuracy

  if accuracy_improvement > 0.01:
    # Add high confidence examples to labeled from unlabeled
    print("Adding high confidence examples to labeled.")
    add_high_confidence_examples_to_labeled(
      dataset_id, high_confidence_features_label_query)

    # Remove high confidence examples from unlabeled
    print("Removing high confidence examples from unlabeled.")
    remove_high_confidence_examples_from_unlabeled(
      dataset_id, low_confidence_features_query)
    
    iteration += 1
  else:
    print("Not enough improvement, breaking loop!")
    break

Iteration = 0
Starting training.
Training complete.
Starting evaluation.
eval_metrics =    accuracy  f1_score  log_loss  precision    recall   roc_auc
0    0.8993  0.897872  1.840486   0.898469  0.898071  0.960578
Adding high confidence examples to labeled.
Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_labeled
Removing high confidence examples from unlabeled.
Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_unlabeled
Iteration = 1
Starting training.
Training complete.
Starting evaluation.
eval_metrics =    accuracy  f1_score  log_loss  precision    recall   roc_auc
0    0.9005  0.898888  1.839274   0.899643  0.899122  0.959887
Not enough improvement, breaking loop!
