In [2]:
!pip install -q tensorflow==2.10.0 wandb python-dotenv tensorboard_plugin_profile tensorflow_io==0.27.0

In [3]:
from pathlib import Path
import os

try:
    import wandb
except:
    if os.environ['COLAB_RELEASE_TAG']:
        print("Found Colab Environment")
        from google.colab import drive
        drive.mount('/content/drive')
        from google.colab import auth
        auth.authenticate_user()

        %pip install -q tensorflow==2.10.0 wandb python-dotenv tensorboard_plugin_profile tensorflow_io==0.27.0
        exit()
    elif Path().cwd().name == 'Mushroom-Classifier':
        print("Found Other Environment")
        %pip install -q tensorflow==2.10.0 wandb python-dotenv tensorboard_plugin_profile tensorflow_io==0.27.0
        exit()
    else:
        print('Please run this notebook from the root of the repository')
        exit()

%cd /content/drive/MyDrive/Mushroom-Classifier

[Errno 2] No such file or directory: '/content/drive/MyDrive/Mushroom-Classifier'
/home/broug/Desktop/Mushroom-Classifier/training


In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [42]:
import math, re, os, pickle
import tensorflow as tf
from datetime import datetime
import wandb
from wandb.keras import WandbCallback, WandbModelCheckpoint
import numpy as np
from matplotlib import pyplot as plt
# from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
from src.models.swintransformer import SwinTransformer
# from src.optimizers import lion
# from prefect import task, flow

print(f"Tensorflow version {tf.__version__}")
AUTO = tf.data.experimental.AUTOTUNE
np.set_printoptions(threshold=15, linewidth=80)

from config import GCFG, CFG

CFG2 = GCFG()

Tensorflow version 2.10.0


In [43]:
save_time = datetime.now().strftime('%m%d-%H%M')
log_dir = f"{CFG2.GCS_REPO}/logs/{CFG2.MODEL}/{save_time}"

# wandb.tensorboard.patch(root_logdir=log_dir + "/tf")
# wandb.init(project="Mushroom-Classifier", tags=[f'{CFG2.MODEL}', "Adam - Cosine", str(CFG2.IMAGE_SIZE[0])])

In [44]:
# Detect hardware
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:  # If TPU not found
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()
CFG2.REPLICAS = strategy.num_replicas_in_sync
print("Number of accelerators: ", strategy.num_replicas_in_sync)

Number of accelerators:  1


## Visualization Utils

In [45]:
def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return class_dict[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(class_dict[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                class_dict[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]

    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows

    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))

    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else class_dict[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)

    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(class_dict)))
    ax.set_xticklabels(class_dict, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(class_dict)))
    ax.set_yticklabels(class_dict, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if titlestring != "":
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()

def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title(f'model {title}')
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])
    wandb.log({"chart": plt})
    path = CFG.ROOT / "images" / CFG.MODEL
    path.mkdir(exist_ok=True)
    plt.savefig(path / f'{title}-{save_time}.png')

In [46]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)  # image format uint8 [0,255]
    image = tf.reshape(image, [*CFG.IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image


def read_labeled_tfrecord(example):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'dataset': tf.io.FixedLenFeature([], tf.int64),
        'longitude': tf.io.FixedLenFeature([], tf.float32),
        'latitude': tf.io.FixedLenFeature([], tf.float32),
        'norm_date': tf.io.FixedLenFeature([], tf.float32),
        'class_priors': tf.io.FixedLenFeature([], tf.float32),
        'class_id': tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    image = decode_image(example['image'])
    label = tf.cast(example['class_id'], tf.int32)
    return image, label


def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.cache()
    dataset = dataset.shuffle(CFG.BATCH_SIZE * 10)
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord, num_parallel_calls=AUTO) # if labeled else read_unlabeled_tfrecord
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def data_augment(image, label):
    # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
    # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
    # of the TPU while the TPU itself is computing gradients.
    # image = tf.image.random_flip_left_right(image)
    #image = tf.image.random_saturation(image, 0, 2)
    return image, label

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
     # the training dataset must repeat for several epochs
    dataset = dataset.batch(CFG.BATCH_SIZE)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(CFG.BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [47]:
GCS_PATH_SELECT = {
    192: f'{CFG2.GCS_REPO}/tfrecords-jpeg-192x192',
    224: f'{CFG2.GCS_REPO}/tfrecords-jpeg-224x224v2',
    384: f'{CFG2.GCS_REPO}/tfrecords-jpeg-384x384',
    512: f'{CFG2.GCS_REPO}/tfrecords-jpeg-512x512',
}
GCS_PATH = GCS_PATH_SELECT[CFG2.IMAGE_SIZE[0]]
TRAINING_FILENAMES = tf.io.gfile.glob(f'{GCS_PATH}/train*.tfrec')
VALIDATION_FILENAMES = tf.io.gfile.glob(f'{GCS_PATH}/val*.tfrec')

class_dict = pickle.load(open('src/class_dict.pkl', 'rb'))

CFG2.NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
CFG2.NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)

CFG = CFG(REPLICAS=CFG2.REPLICAS, NUM_TRAINING_IMAGES=CFG2.NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES=CFG2.NUM_VALIDATION_IMAGES)

In [9]:
# data dump
print("Training data shapes:")
for image, label in get_training_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Training data label examples:", label.numpy())
print("Validation data shapes:")
for image, label in get_validation_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Validation data label examples:", label.numpy())

Training data shapes:


2023-10-24 03:22:59.032971: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-24 03:22:59.033006: W tensorflow/stream_executor/cuda/cuda_driver.cc:263] failed call to cuInit: UNKNOWN ERROR (303)
2023-10-24 03:22:59.033034: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (instance-3): /proc/driver/nvidia/version does not exist
2023-10-24 03:22:59.034187: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


(8, 224, 224, 3) (8,)
(8, 224, 224, 3) (8,)
(8, 224, 224, 3) (8,)


2023-10-24 03:22:59.704887: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Training data label examples: [135 296  41 463  63 184  18 378]
Validation data shapes:
(8, 224, 224, 3) (8,)
(8, 224, 224, 3) (8,)
(8, 224, 224, 3) (8,)


2023-10-24 03:23:00.476477: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Validation data label examples: [277 218 239   7 226 373 415 434]


In [11]:
# Peek at training data
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)

In [None]:
# run this cell again for next set of images
# display_batch_of_images(next(train_batch))

you can select from these models:
- swin_tiny_224
- swin_small_224
- swin_base_224
- swin_base_384
- swin_large_224
- swin_large_384

In [48]:
def make_callbacks(CFG):
    # options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
    options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")

    callbacks = [
        # tf.keras.callbacks.EarlyStopping(
        #     monitor="val_loss",
        #     patience=CFG.ES_PATIENCE,
        #     verbose=1,
        #     restore_best_weights=True,
        # ),
        # tf.keras.callbacks.TensorBoard(log_dir=log_dir + "/tf", profile_batch=(50, 250)),
        tf.keras.callbacks.CSVLogger(
            filename=f'{CFG.GCS_REPO}/logs/{save_time}-csv_log.csv',
            separator=",",
            append=False,
        ),
        # wandb.keras.WandbMetricsLogger(log_freq='batch'),
        # wandb.keras.WandbModelCheckpoint(
        #     str(CFG.ROOT / 'models' / CFG.MODEL / f"{save_time}.h5"),
        #     monitor='val_loss', verbose=1, save_best_only=True,
        #     save_weights_only=True, options=options,
        # )
    ]
    return callbacks

In [38]:
%load_ext tensorboard

In [49]:
with strategy.scope():
    img_adjust_layer = tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*CFG.IMAGE_SIZE, 3])
    pretrained_model = SwinTransformer(CFG.MODEL, num_classes=len(class_dict), include_top=False, pretrained=False, use_tpu=True)
    pretrained_model = tf.keras.Sequential([
        img_adjust_layer,
        pretrained_model,
        tf.keras.layers.Dense(len(class_dict), activation='softmax')
    ])

    top3_acc = tf.keras.metrics.SparseTopKCategoricalAccuracy(
        k=3, name='sparse_top_3_categorical_accuracy'
    )
    lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=CFG.LR_START,
        decay_steps=CFG.DECAY_STEPS
    )
    def get_lr_metric(optimizer):
        def lr(y_true, y_pred):
            return optimizer._decayed_lr(tf.float32) # I use ._decayed_lr method instead of .lr
        return lr

    optimizer = tf.keras.optimizers.Adam(lr_decayed_fn)
    lr_metric = get_lr_metric(optimizer)

In [9]:
pretrained_model.load_weights(CFG.ROOT / 'base_models' / CFG.MODEL / 'base_model.h5')

In [50]:
pretrained_model.compile(
    optimizer= optimizer,  # lion.Lion(learning_rate=lr_decayed_fn),
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy', lr_metric, top3_acc],
)
pretrained_model.summary()

Model: "sequential_29"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lambda_4 (Lambda)           (None, 384, 384, 3)       0         
                                                                 
 swin_base_384 (SwinTransfor  (None, 1024)             89781624  
 merModel)                                                       
                                                                 
 dense_4 (Dense)             (None, 467)               478675    
                                                                 
Total params: 90,260,299
Trainable params: 87,357,259
Non-trainable params: 2,903,040
_________________________________________________________________


In [12]:
# service_addr = tpu.get_master().replace(':8470', ':8466')
# print(service_addr)
# %tensorboard --logdir={log_dir + "/tf"}

In [13]:
wandb.config = CFG

In [14]:
history = model.fit(
    get_training_dataset(),
    steps_per_epoch=CFG.STEPS_PER_EPOCH,
    epochs=CFG.EPOCHS,
    validation_data=get_validation_dataset(),
    validation_steps=CFG.VALIDATION_STEPS,
    callbacks=make_callbacks(CFG)
)



Epoch 1/20
  6/189 [..............................] - ETA: 1:05 - loss: 6.2224 - sparse_categorical_accuracy: 0.0078 - lr: 3.9999e-04 - sparse_top_3_categorical_accuracy: 0.0156



Epoch 1: val_loss improved from inf to 6.09191, saving model to /content/drive/MyDrive/Mushroom-Classifier/models/swin_large_224/1015-0924.h5
Epoch 2/20
Epoch 2: val_loss improved from 6.09191 to 2.80786, saving model to /content/drive/MyDrive/Mushroom-Classifier/models/swin_large_224/1015-0924.h5
Epoch 3/20
Epoch 3: val_loss improved from 2.80786 to 1.76564, saving model to /content/drive/MyDrive/Mushroom-Classifier/models/swin_large_224/1015-0924.h5
Epoch 4/20
Epoch 4: val_loss improved from 1.76564 to 1.21996, saving model to /content/drive/MyDrive/Mushroom-Classifier/models/swin_large_224/1015-0924.h5
Epoch 5/20
Epoch 5: val_loss improved from 1.21996 to 1.16929, saving model to /content/drive/MyDrive/Mushroom-Classifier/models/swin_large_224/1015-0924.h5
Epoch 6/20
Epoch 6: val_loss improved from 1.16929 to 1.06367, saving model to /content/drive/MyDrive/Mushroom-Classifier/models/swin_large_224/1015-0924.h5
Epoch 7/20
Epoch 7: val_loss did not improve from 1.06367
Epoch 8/20
Epoc

In [51]:
pretrained_model.save(CFG.ROOT / '../models' / CFG.MODEL / f"{save_time}.h5", save_format='h5')

In [None]:
art = wandb.Artifact(
    'model',
    type='model')
art.add_file(str(CFG.ROOT / 'models' / CFG.MODEL / f"{time}.h5"))
wandb.log_artifact(art)

In [15]:
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
display_training_curves(history.history['sparse_categorical_accuracy'], history.history['val_sparse_categorical_accuracy'], 'accuracy', 212)

  ax = plt.subplot(subplot)

Looks like the annotation(s) you are trying 
to draw lies/lay outside the given figure size.

Therefore, the resulting Plotly figure may not be 
large enough to view the full text. To adjust 
the size of the figure, use the 'width' and 
'height' keys in the Layout object. Alternatively,
use the Margin object to adjust the figure's margins.


I found a path object that I don't think is part of a bar chart. Ignoring.



<Figure size 640x480 with 0 Axes>

In [16]:
wandb.finish()

VBox(children=(Label(value='4498.752 MB of 4498.752 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.…

0,1
batch/batch_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
batch/learning_rate,██▇▆▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/loss,███▆▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▁▁▁▁▁
batch/lr,██▇▇▆▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/sparse_categorical_accuracy,▁▁▁▂▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████▇▇▇▇▇▇▇▇▇▇▇████
batch/sparse_top_3_categorical_accuracy,▁▁▁▃▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████▇▇▇▇▇▇▇▇██████
epoch/epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
epoch/learning_rate,█▆▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▅▃▂▂▂▂▂▂▁▁▁▁▂▂▂▂▁▁▁
epoch/lr,█▇▅▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch/batch_step,3779.0
batch/learning_rate,0.0
batch/loss,0.3781
batch/lr,0.0
batch/sparse_categorical_accuracy,0.90427
batch/sparse_top_3_categorical_accuracy,0.97718
epoch/epoch,19.0
epoch/learning_rate,0.0
epoch/loss,0.3781
epoch/lr,0.0
