In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow_datasets as tfds

import wandb
from wandb.keras import WandbMetricsLogger

In [None]:
wandb.login()

In [None]:
import pickle
import numpy as np
from tensorflow.keras.utils import to_categorical

def load_cifar10_batch(file_path):
    with open(file_path, 'rb') as file:
        batch = pickle.load(file, encoding='bytes')
    return batch

def load_cifar10_data(folder_path):
    train_data = []
    train_labels = []

    for i in range(1, 6):
        batch_file = f"{folder_path}/data_batch_{i}"
        batch = load_cifar10_batch(batch_file)
        train_data.append(batch[b'data'])
        train_labels.extend(batch[b'labels'])

    test_batch_file = f"{folder_path}/test_batch"
    test_batch = load_cifar10_batch(test_batch_file)
    test_data = test_batch[b'data']
    test_labels = test_batch[b'labels']

    train_data = np.vstack(train_data)
    train_labels = np.array(train_labels)
    test_labels = np.array(test_labels)

    return train_data, train_labels, test_data, test_labels

def preprocess_data(train_data, train_labels, test_data, test_labels):
    train_data = train_data.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    test_data = test_data.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)

    train_labels_onehot = to_categorical(train_labels)
    test_labels_onehot = to_categorical(test_labels)

    return train_data, train_labels_onehot, test_data, test_labels_onehot

cifar10_folder = 'cifar-10-batches-py'

train_data, train_labels, test_data, test_labels = load_cifar10_data(cifar10_folder)

x_train, y_train, x_test, y_test = preprocess_data(
    train_data, train_labels, test_data, test_labels
)

print("Train Data Shape:", x_train.shape)
print("Train Labels Shape:", y_train.shape)
print("Test Data Shape:", x_test.shape)
print("Test Labels Shape:", y_test.shape)

In [None]:
import tensorflow as tf

def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(4096, activation='relu'),
        tf.keras.layers.Dense(4096, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

In [None]:
configs = dict(
    num_classes = 10,
    batch_size = 32,
    earlystopping_patience = 1,
    learning_rate = 0.00001,
    epochs = 3
)

In [None]:
tf.keras.backend.clear_session()
model = create_model()
model.summary()

In [None]:
from tensorflow.keras.optimizers import Adam

model.compile(
    optimizer = Adam(learning_rate=configs["learning_rate"]),
    loss = "categorical_crossentropy",
    metrics = ["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top@3_accuracy')]
)

In [None]:
from tensorflow.keras.callbacks import EarlyStopping

# Initialize a W&B run
run = wandb.init(
    project="CIFAR-10_Classification",
    name="lr_tune_0.002123",
    config=configs
)

# Define the EarlyStopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=configs["earlystopping_patience"],
    restore_best_weights=True
)

# Train your model
history = model.fit(
    x_train, y_train,
    epochs=configs["epochs"],
    batch_size=configs["batch_size"],
    validation_split=0.2,
    callbacks=[
        WandbMetricsLogger(log_freq='epoch'),
        early_stopping
    ]
)

test_accuracy = model.evaluate(x_test, y_test)[1]
wandb.log({"Test Accuracy": test_accuracy})

# Close the W&B run
run.finish()