In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import wandb
import tensorflow.keras.backend as K
from dotenv import load_dotenv
import math
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
from dataclasses import dataclass
import pickle
from os import environ

load_dotenv()

AUTO = tf.data.experimental.AUTOTUNE
class_dict = pickle.load(open("../training/src/class_dict.pkl", "rb"))

2023-10-29 04:24:13.986887: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-29 04:24:13.992270: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-29 04:24:14.059361: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-29 04:24:14.059404: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-29 04:24:14.059449: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

In [2]:
@dataclass
class CFG:
    BATCH_SIZE: int = 8
    IMAGE_SIZE: tuple = (224, 224)
    AUGMENT: bool = False

In [3]:
def decode_image(image_data, CFG):
    image = tf.image.decode_jpeg(image_data, channels=3)  # image format uint8 [0,255]
    image = tf.reshape(image, [*CFG.IMAGE_SIZE, 3]) # explicit size needed for TPU
    image = tf.cast(image, tf.float32)
    return image


def read_labeled_tfrecord(CFG, example):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'dataset': tf.io.FixedLenFeature([], tf.int64),
        'longitude': tf.io.FixedLenFeature([], tf.float32),
        'latitude': tf.io.FixedLenFeature([], tf.float32),
        'norm_date': tf.io.FixedLenFeature([], tf.float32),
        'class_priors': tf.io.FixedLenFeature([], tf.float32),
        'class_id': tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    image = decode_image(example['image'], CFG)
    label = tf.cast(example['class_id'], tf.int32)
    return image, label


def load_dataset(filenames, CFG, labeled=True, ordered=False, shuffle=True):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # uses data as soon as it streams in, rather than in its original order

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    # dataset = dataset.cache()
    if shuffle:
        dataset = dataset.shuffle(CFG.BATCH_SIZE * 10)
    dataset = dataset.with_options(ignore_order)
    if labeled:
         dataset = dataset.map(lambda x: read_labeled_tfrecord(CFG, x), num_parallel_calls=AUTO) # if labeled else read_unlabeled_tfrecord
    else:
        dataset = dataset.map(lambda x: read_unlabeled_tfrecord(CFG, x), num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def data_augment(img, label, CFG):
    # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
    # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
    # of the TPU while the TPU itself is computing gradients.
    img = transform(img, CFG)
    img = tf.image.random_flip_left_right(img)
    # img = tf.image.random_hue(img, 0.01)
    img = tf.image.random_saturation(img, 0.7, 1.3)
    img = tf.image.random_contrast(img, 0.8, 1.2)
    img = tf.image.random_brightness(img, 0.1)
    return img, label

def get_training_dataset(filenames, CFG):
    dataset = load_dataset(filenames, CFG, labeled=True)
    if CFG.AUGMENT:
        dataset = dataset.map(lambda x, y: data_augment(x, y, CFG), num_parallel_calls=AUTO)
    # the training dataset must repeat for several epochs
    dataset = dataset.batch(CFG.BATCH_SIZE)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(filenames, CFG, ordered=False):
    dataset = load_dataset(filenames, CFG, labeled=True, ordered=ordered, shuffle=False)
    dataset = dataset.batch(CFG.BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(filenames, CFG, ordered=False):
    dataset = load_dataset(filenames, CFG, labeled=False, ordered=ordered, shuffle=False)
    dataset = dataset.batch(CFG.BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def read_unlabeled_tfrecord(example):
    tfrec_format = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'dataset': tf.io.FixedLenFeature([], tf.int64),
        'longitude': tf.io.FixedLenFeature([], tf.float32),
        'latitude': tf.io.FixedLenFeature([], tf.float32),
        'norm_date': tf.io.FixedLenFeature([], tf.float32),
        'class_priors': tf.io.FixedLenFeature([], tf.float32),
    }
    example = tf.io.parse_single_example(example, tfrec_format)
    return example['image']


def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transform matrix which transforms indices

    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.0
    shear = math.pi * shear / 180.0

    def get_3x3_mat(lst):
        return tf.reshape(tf.concat([lst], axis=0), [3, 3])

    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1], dtype='float32')
    zero = tf.constant([0], dtype='float32')

    rotation_matrix = get_3x3_mat([c1, s1, zero, -s1, c1, zero, zero, zero, one])
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)

    shear_matrix = get_3x3_mat([one, s2, zero, zero, c2, zero, zero, zero, one])
    # ZOOM MATRIX
    zoom_matrix = get_3x3_mat([one / height_zoom, zero, zero, zero, one / width_zoom, zero, zero, zero, one])
    # SHIFT MATRIX
    shift_matrix = get_3x3_mat([one, zero, height_shift, zero, one, width_shift, zero, zero, one])

    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))


def transform(image, CFG):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = CFG.IMAGE_SIZE[0]
    XDIM = DIM % 2  # fix for size 331   

    rot = CFG.ROT_ * tf.random.normal([1], dtype='float32')
    shr = CFG.SHR_ * tf.random.normal([1], dtype='float32')
    h_zoom = 1.0 + tf.random.normal([1], dtype='float32') / CFG.HZOOM_
    w_zoom = 1.0 + tf.random.normal([1], dtype='float32') / CFG.WZOOM_
    h_shift = CFG.HSHIFT_ * tf.random.normal([1], dtype='float32')
    w_shift = CFG.WSHIFT_ * tf.random.normal([1], dtype='float32')

    # GET TRANSFORMATION MATRIX
    m = get_mat(rot, shr, h_zoom, w_zoom, h_shift, w_shift)

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat(tf.range(DIM // 2, -DIM // 2, -1), DIM)
    y = tf.tile(tf.range(-DIM // 2, DIM // 2), [DIM])
    z = tf.ones([DIM * DIM], dtype='int32')
    idx = tf.stack([x, y, z])

    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m, tf.cast(idx, dtype='float32'))
    idx2 = K.cast(idx2, dtype='int32')
    idx2 = K.clip(idx2, -DIM // 2 + XDIM + 1, DIM // 2)

    # FIND ORIGIN PIXEL VALUES
    idx3 = tf.stack([DIM // 2 - idx2[0,], DIM // 2 - 1 + idx2[1,]])
    d = tf.gather_nd(image, tf.transpose(idx3))

    return tf.reshape(d, [DIM, DIM, 3])

In [4]:
wandb.init(project="Mushroom-Classifier", job_type="testing", )
path = wandb.use_artifact("g-broughton/model-registry/Mushroom-Classifier:latest", type="model").download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mg-broughton[0m. Use [1m`wandb login --relogin`[0m to force relogin
git-nbdiffdriver diff: 1: git-nbdiffdriver: not found
fatal: external diff died, stopping at training/train.ipynb


[34m[1mwandb[0m: Downloading large artifact Mushroom-Classifier:latest, 2262.80MB. 5 files... 
[34m[1mwandb[0m:   5 of 5 files downloaded.  
Done. 0:0:3.8


In [15]:
model = tf.keras.models.load_model(str(path))

  function = cls._parse_function_from_config(


In [6]:
GCS_PATH_SELECT = {
    192: f"{environ['GCS_REPO']}/tfrecords-jpeg-192x192",
    224: f"{environ['GCS_REPO']}/tfrecords-jpeg-224x224v2",
    384: f"{environ['GCS_REPO']}/tfrecords-jpeg-384x384",
    512: f"{environ['GCS_REPO']}/tfrecords-jpeg-512x512",
}
GCS_PATH = GCS_PATH_SELECT[CFG.IMAGE_SIZE[0]]

VALIDATION_FILENAMES = tf.io.gfile.glob(f"{GCS_PATH}/val*.tfrec")

# NUM_TRAINING_IMAGES = tr_fn.count_data_items(TRAINING_FILENAMES)
# NUM_VALIDATION_IMAGES = tr_fn.count_data_items(VALIDATION_FILENAMES)

In [7]:
GCS_PATH

'gs://mush-img-repo/tfrecords-jpeg-224x224v2'

In [36]:
preds_l = []
labels_l = []
for images, labels in get_validation_dataset(VALIDATION_FILENAMES, CFG).take(1):
    preds = model.predict(images)
    preds_l.append(preds)
    labels_l.append(labels)



In [37]:
preds

array([[9.0538904e-12, 1.3545815e-12, 1.7053411e-11, ..., 3.4159637e-11,
        7.8716643e-12, 5.5715502e-12],
       [3.5986499e-08, 9.1118423e-08, 5.2418478e-09, ..., 7.2495165e-09,
        1.2011550e-09, 3.2498737e-09],
       [5.4941909e-11, 1.5688406e-10, 5.3636410e-12, ..., 2.7036511e-09,
        2.9166478e-10, 4.7350941e-08],
       ...,
       [1.6051135e-10, 4.8369749e-08, 2.9984335e-09, ..., 2.6298627e-08,
        1.2206925e-08, 7.5581198e-08],
       [7.4830867e-05, 1.3993742e-08, 4.5303295e-06, ..., 1.1116118e-06,
        3.9574690e-08, 3.5087339e-07],
       [1.0749997e-06, 5.3991571e-09, 3.8924462e-08, ..., 2.1432578e-09,
        9.2097252e-10, 5.6187344e-09]], dtype=float32)

In [41]:
np.max(preds_l, axis=-1)

array([[0.99999547, 0.9998259 , 0.99997485, 0.9859897 , 0.48543844,
        0.99956673, 0.9739571 , 0.652332  ]], dtype=float32)