we will reimplement the customer support ticket
management model using a Model subclass.

In [None]:
from tensorflow import keras
from tensorflow.keras import layers, Model

In [None]:
vocabulary_size = 10000
num_tags = 100
num_departments = 4

In [None]:
import numpy as np
num_samples = 1280
 # Dummy input data
title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))
 # Dummy output data
priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(0, 2, size=(num_samples, num_departments))

In [None]:
class CustomerTicketModel(keras.Model):
 def __init__(self, num_departments):
  super().__init__() # Don’t forget to call the super() constructor!
  self.concat_layer = layers.Concatenate()
  self.mixing_layer = layers.Dense(64, activation="relu")
  self.priority_scorer = layers.Dense(1, activation="sigmoid")
  self.department_classifier = layers.Dense(num_departments, activation="softmax")
 def call(self, inputs):
  title = inputs["title"]
  text_body = inputs["text_body"]
  tags = inputs["tags"]
  features = self.concat_layer([title, text_body, tags])
  features = self.mixing_layer(features)
  priority = self.priority_scorer(features)
  department = self.department_classifier(features)
  return priority, department

Once you’ve defined the model, you can instantiate it. Note that it will only create its
weights the first time you call it on some data, much like Layer subclasses

In [None]:
model = CustomerTicketModel(num_departments=4)
priority, department = model({"title": title_data, "text_body": text_body_data, "tags": tags_data})

You can compile and train a Model subclass just like a Sequential or Functional
model

The structure of what you pass as the loss and
metrics arguments must match exactly what gets
returned by call()—here, a list of two elements

The structure of the input data must match
exactly what is expected by the call() method—
here, a dict with keys title, text_body, and tags.

In [None]:
model.compile(optimizer="rmsprop",
 loss=["mean_squared_error", "categorical_crossentropy"],
 metrics=[["mean_absolute_error"], ["accuracy"]])

The structure of the target
data must match exactly what is
returned by the call() method—
here, a list of two elements.

In [None]:
model.fit({"title": title_data,
 "text_body": text_body_data,
 "tags": tags_data},
 [priority_data, department_data],
 epochs=1)



<keras.src.callbacks.History at 0x7836cd7e7ee0>