In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, Flatten, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def create_imbalanced_cifar10(split, imbalance_ratio=0.1):
    (ds_train, ds_test), ds_info = tfds.load(
        "cifar10",
        split=["train", "test"],
        as_supervised=True,
        with_info=True,
    )

    if split == "train":
        ds = ds_train
    elif split == "test":
        ds = ds_test

    imbalance_class = 1
    def filter_class(image, label):
        return tf.math.equal(label, imbalance_class)

    def filter_other_classes(image, label):
        return tf.math.not_equal(label, imbalance_class)

    ds_imbalanced_class = ds.filter(filter_class).take(int(5000 * imbalance_ratio))
    ds_other_classes = ds.filter(filter_other_classes)

    ds_imbalanced = ds_other_classes.concatenate(ds_imbalanced_class)

    return ds_imbalanced

ds_train_imbalanced = create_imbalanced_cifar10("train")
ds_test_imbalanced = create_imbalanced_cifar10("test")

def preprocess(image, label):
    image = tf.cast(image, tf.float32)
    image = image / 255.0
    return image, label

ds_train_imbalanced = ds_train_imbalanced.map(preprocess)
ds_test_imbalanced = ds_test_imbalanced.map(preprocess)

# Convert datasets to NumPy arrays
X_train, y_train = zip(*tfds.as_numpy(ds_train_imbalanced))
X_train, y_train = np.array(X_train), np.array(y_train)

X_test, y_test = zip(*tfds.as_numpy(ds_test_imbalanced))
X_test, y_test = np.array(X_test), np.array(y_test)

# Create and compile the ResNet model
input_shape = X_train.shape[1:]
num_classes = 10

inputs = Input(shape=input_shape)
base_model = ResNet50(weights=None, include_top=False, input_tensor=inputs)
x = Flatten()(base_model.output)
x = Dense(num_classes, activation="softmax")(x)
model = Model(inputs=inputs, outputs=x)

model.compile(optimizer=Adam(), loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# Train the model
model.fit(X_train, y_train, batch_size=64, epochs=10, validation_data=(X_test, y_test))
