In [None]:
#TODO copy and paste your assignment code here:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
def _parse_function(proto):
    keys_to_features = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([1], tf.int64)
    }
    parsed_features = tf.io.parse_single_example(proto, keys_to_features)
    image = tf.image.decode_jpeg(parsed_features["image"])
    label = parsed_features["label"][0]
    return image, label

raw_dataset = tf.data.TFRecordDataset("shared/filtered_emnist.tfrecord")

dataset = raw_dataset.map(_parse_function)


def preprocess(image, label):
    image = tf.image.resize(image, [32, 32])  
    image = tf.image.grayscale_to_rgb(image)  
    image = tf.keras.applications.resnet50.preprocess_input(image)  
    return image, label

def create_dataset(tfrecord_file):
    dataset = tf.data.TFRecordDataset(tfrecord_file)
    dataset = dataset.map(_parse_function)
    dataset = dataset.map(preprocess)
    return dataset

full_dataset = create_dataset("shared/filtered_emnist.tfrecord")
full_dataset = full_dataset.cache()  
dataset_size = sum(1 for _ in full_dataset)
print("Dataset size:", dataset_size)

full_dataset = create_dataset("shared/filtered_emnist.tfrecord")

train_size = (int(0.8 * dataset_size))  
val_size = (dataset_size - train_size)  
train_dataset = full_dataset.take(train_size)
val_dataset = full_dataset.skip(train_size).take(val_size)


batch_size = 128 
train_dataset = train_dataset.cache().batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.cache().batch(batch_size).prefetch(tf.data.AUTOTUNE)

for images, labels in train_dataset.take(1):  
    print(f"Shape of images: {images.shape}")
    print(f"Shape of labels: {labels.shape}")

def build_model():
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
    for layer in base_model.layers[:150]:  
        layer.trainable = False
    model = models.Sequential([
          base_model,
          layers.GlobalAveragePooling2D(),
          layers.Dense(256, activation='relu'),
          layers.Dense(36, activation='softmax')  
      ])

    model.compile(
          optimizer=tf.keras.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy']
    )
    return model

In [None]:
model = build_model()
model.summary()

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=3),
        tf.keras.callbacks.ModelCheckpoint('best_model.h5')
    ]
)
def evaluate_model(dataset):
    y_true = []
    y_pred = []
    for images, labels in dataset:
        predictions = model.predict(images)
        y_true.extend(labels.numpy())
        y_pred.extend(np.argmax(predictions, axis=1))
    return np.array(y_true), np.array(y_pred)

y_true, y_pred = evaluate_model(val_dataset)

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=[str(i) for i in range(36)], yticklabels=[str(i) for i in range(36)])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

def plot_predictions(dataset, num_samples=10):
    plt.figure(figsize=(15, 10))
    for i, (images, labels) in enumerate(dataset.take(1)):
        for j in range(min(num_samples, len(images))):
            plt.subplot(2, 5, j + 1)
            img = images[j]
            label = labels[j]
            pred = model.predict(tf.expand_dims(img, 0)).argmax()
            plt.imshow(img)
            plt.title(f"True: {label.numpy()}\nPred: {pred}")
            plt.axis('off')
    plt.tight_layout()
    plt.show()

plot_predictions(val_dataset)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
print("f1 score", f1)
print("precision", precision)
print("recall", recall)