<a href="https://colab.research.google.com/github/Dustin4444/.github/blob/main/guides/ipynb/custom_train_step_in_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Customizing what happens in `fit()` with PyTorch

**Author:** [fchollet](https://twitter.com/fchollet)<br>
**Date created:** 2023/06/27<br>
**Last modified:** 2024/08/01<br>
**Description:** Overriding the training step of the Model class with PyTorch.

In [30]:
# @title Example form fields
# @markdown Forms support many types of fields.

no_type_checking = ''  # @param
string_type = 'example'  # @param {type: "string"}
slider_value = 142  # @param {type: "slider", min: 100, max: 200}
number = 102  # @param {type: "number"}
date = '2010-11-05'  # @param {type: "date"}
pick_me = "monday"  # @param ['monday', 'tuesday', 'wednesday', 'thursday']
select_or_input = "apples" # @param ["apples", "bananas", "oranges"] {allow-input: true}
# @markdown ---


## Introduction

When you're doing supervised learning, you can use `fit()` and everything works
smoothly.

When you need to take control of every little detail, you can write your own training
loop entirely from scratch.

But what if you need a custom training algorithm, but you still want to benefit from
the convenient features of `fit()`, such as callbacks, built-in distribution support,
or step fusing?

A core principle of Keras is **progressive disclosure of complexity**. You should
always be able to get into lower-level workflows in a gradual way. You shouldn't fall
off a cliff if the high-level functionality doesn't exactly match your use case. You
should be able to gain more control over the small details while retaining a
commensurate amount of high-level convenience.

When you need to customize what `fit()` does, you should **override the training step
function of the `Model` class**. This is the function that is called by `fit()` for
every batch of data. You will then be able to call `fit()` as usual -- and it will be
running your own learning algorithm.

Note that this pattern does not prevent you from building models with the Functional
API. You can do this whether you're building `Sequential` models, Functional API
models, or subclassed models.

Let's see how that works.

## Setup

In [31]:
import os

# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras
from keras import layers
import numpy as np

## A first simple example

Let's start from a simple example:

- We create a new class that subclasses `keras.Model`.
- We just override the method `train_step(self, data)`.
- We return a dictionary mapping metric names (including the loss) to their current
value.

The input argument `data` is what gets passed to fit as training data:

- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple
`(x, y)`
- If you pass a `torch.utils.data.DataLoader` or a `tf.data.Dataset`,
by calling `fit(dataset, ...)`, then `data` will be what gets yielded
by `dataset` at each batch.

In the body of the `train_step()` method, we implement a regular training update,
similar to what you are already familiar with. Importantly, **we compute the loss via
`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to
`compile()`.

Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,
to update the state of the metrics that were passed in `compile()`,
and we query results from `self.metrics` at the end to retrieve their current value.

In [32]:

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.compute_loss(y=y, y_pred=y_pred)

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


Let's try this out:

In [33]:
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)

Epoch 1/3
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - mae: 0.3817 - loss: 0.2217
Epoch 2/3
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - mae: 0.3614 - loss: 0.2050
Epoch 3/3
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - mae: 0.3455 - loss: 0.1859


<keras.src.callbacks.history.History at 0x7b9938ff3d40>

## Going lower-level

Naturally, you could just skip passing a loss function in `compile()`, and instead do
everything *manually* in `train_step`. Likewise for metrics.

Here's a lower-level example, that only uses `compile()` to configure the optimizer:

- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).
- We implement a custom `train_step()` that updates the state of these metrics
(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,
to be displayed by the progress bar and to be pass to any callback.
- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise
calling `result()` would return an average since the start of training, whereas we usually work
with per-epoch averages. Thankfully, the framework can do that for us: just list any metric
you want to reset in the `metrics` property of the model. The model will call `reset_states()`
on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to
`evaluate()`.

In [34]:

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.loss_fn = keras.losses.MeanSquaredError()

    def train_step(self, data):
        x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.loss_fn(y, y_pred)

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(y, y_pred)
        return {
            "loss": self.loss_tracker.result(),
            "mae": self.mae_metric.result(),
        }

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        return [self.loss_tracker, self.mae_metric]


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# We don't pass a loss or metrics here.
model.compile(optimizer="adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)


Epoch 1/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 2.0532 - mae: 1.3397
Epoch 2/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.9426 - mae: 0.8554
Epoch 3/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.4511 - mae: 0.5592
Epoch 4/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.2886 - mae: 0.4388
Epoch 5/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.2397 - mae: 0.3992


<keras.src.callbacks.history.History at 0x7b990d319190>

## Supporting `sample_weight` & `class_weight`

You may have noticed that our first basic example didn't make any mention of sample
weighting. If you want to support the `fit()` arguments `sample_weight` and
`class_weight`, you'd simply do the following:

- Unpack `sample_weight` from the `data` argument
- Pass it to `compute_loss` & `update_state` (of course, you could also just apply
it manually if you don't rely on `compile()` for losses & metrics)
- That's it.

In [35]:

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.compute_loss(
            y=y,
            y_pred=y_pred,
            sample_weight=sample_weight,
        )

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)

Epoch 1/3
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - mae: 1.8127 - loss: 1.8478
Epoch 2/3
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - mae: 1.3254 - loss: 1.0657
Epoch 3/3
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - mae: 0.8994 - loss: 0.5520


<keras.src.callbacks.history.History at 0x7b990d31a300>

## Providing your own evaluation step

What if you want to do the same for calls to `model.evaluate()`? Then you would
override `test_step` in exactly the same way. Here's what it looks like:

In [36]:

class CustomModel(keras.Model):
    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_pred = self(x, training=False)
        # Updates the metrics tracking the loss
        loss = self.compute_loss(y=y, y_pred=y_pred)
        # Update the metrics.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)

[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - mae: 0.4940 - loss: 0.3865


[tensor(0.3701, device='cuda:0'), {'mae': tensor(0.4903, device='cuda:0')}]

## Wrapping up: an end-to-end GAN example

Let's walk through an end-to-end example that leverages everything you just learned.

Let's consider:

- A generator network meant to generate 28x28x1 images.
- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and
"real").
- One optimizer for each.
- A loss function to train the discriminator.

In [37]:
# Create the discriminator
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

# Create the generator
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

Here's a feature-complete GAN class, overriding `compile()` to use its own signature,
and implementing the entire GAN algorithm in 17 lines in `train_step`:

In [38]:

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.seed_generator = keras.random.SeedGenerator(1337)
        self.built = True

    @property
    def metrics(self):
        return [self.d_loss_tracker, self.g_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if isinstance(real_images, tuple) or isinstance(real_images, list):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = real_images.shape[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        real_images = torch.tensor(real_images, device=device)
        combined_images = torch.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = torch.concat(
            [
                torch.ones((batch_size, 1), device=device),
                torch.zeros((batch_size, 1), device=device),
            ],
            axis=0,
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * keras.random.uniform(labels.shape, seed=self.seed_generator)

        # Train the discriminator
        self.zero_grad()
        predictions = self.discriminator(combined_images)
        d_loss = self.loss_fn(labels, predictions)
        d_loss.backward()
        grads = [v.value.grad for v in self.discriminator.trainable_weights]
        with torch.no_grad():
            self.d_optimizer.apply(grads, self.discriminator.trainable_weights)

        # Sample random points in the latent space
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Assemble labels that say "all real images"
        misleading_labels = torch.zeros((batch_size, 1), device=device)

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        self.zero_grad()
        predictions = self.discriminator(self.generator(random_latent_vectors))
        g_loss = self.loss_fn(misleading_labels, predictions)
        grads = g_loss.backward()
        grads = [v.value.grad for v in self.generator.trainable_weights]
        with torch.no_grad():
            self.g_optimizer.apply(grads, self.generator.trainable_weights)

        # Update metrics and return their value.
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
        }


Let's test-drive it:

In [39]:
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))

# Create a TensorDataset
dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(all_digits), torch.from_numpy(all_digits)
)
# Create a DataLoader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

gan.fit(dataloader, epochs=1)

[1m   2/1094[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:32[0m 85ms/step - d_loss: 0.6996 - g_loss: 0.7177

  real_images = torch.tensor(real_images, device=device)


[1m   5/1094[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:18[0m 72ms/step - d_loss: 0.6899 - g_loss: 0.7198

  real_images = torch.tensor(real_images, device=device)


[1m  28/1094[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m58s[0m 55ms/step - d_loss: 0.6379 - g_loss: 0.7308

  real_images = torch.tensor(real_images, device=device)


[1m  52/1094[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m55s[0m 53ms/step - d_loss: 0.5972 - g_loss: 0.7273

  real_images = torch.tensor(real_images, device=device)


[1m  76/1094[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m53s[0m 53ms/step - d_loss: 0.5806 - g_loss: 0.7177

  real_images = torch.tensor(real_images, device=device)


[1m 100/1094[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m51s[0m 52ms/step - d_loss: 0.5584 - g_loss: 0.7611

  real_images = torch.tensor(real_images, device=device)


[1m 124/1094[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m50s[0m 53ms/step - d_loss: 0.5425 - g_loss: 0.8016

  real_images = torch.tensor(real_images, device=device)


[1m 146/1094[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m49s[0m 53ms/step - d_loss: 0.5325 - g_loss: 0.8321

  real_images = torch.tensor(real_images, device=device)


[1m 170/1094[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m48s[0m 53ms/step - d_loss: 0.5236 - g_loss: 0.8586

  real_images = torch.tensor(real_images, device=device)


[1m 194/1094[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m47s[0m 53ms/step - d_loss: 0.5155 - g_loss: 0.8795

  real_images = torch.tensor(real_images, device=device)


[1m 218/1094[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m45s[0m 52ms/step - d_loss: 0.5070 - g_loss: 0.8997

  real_images = torch.tensor(real_images, device=device)


[1m 243/1094[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m44s[0m 52ms/step - d_loss: 0.4971 - g_loss: 0.9250

  real_images = torch.tensor(real_images, device=device)


[1m 266/1094[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m43s[0m 52ms/step - d_loss: 0.4871 - g_loss: 0.9526

  real_images = torch.tensor(real_images, device=device)


[1m 290/1094[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m41s[0m 52ms/step - d_loss: 0.4763 - g_loss: 0.9867

  real_images = torch.tensor(real_images, device=device)


[1m 315/1094[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m40s[0m 52ms/step - d_loss: 0.4649 - g_loss: 1.0277

  real_images = torch.tensor(real_images, device=device)


[1m 338/1094[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m39s[0m 52ms/step - d_loss: 0.4551 - g_loss: 1.0653

  real_images = torch.tensor(real_images, device=device)


[1m 362/1094[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m38s[0m 52ms/step - d_loss: 0.4478 - g_loss: 1.0980

  real_images = torch.tensor(real_images, device=device)


[1m 384/1094[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m37s[0m 52ms/step - d_loss: 0.4427 - g_loss: 1.1231

  real_images = torch.tensor(real_images, device=device)


[1m 408/1094[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m35s[0m 52ms/step - d_loss: 0.4380 - g_loss: 1.1469

  real_images = torch.tensor(real_images, device=device)


[1m 431/1094[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m34s[0m 52ms/step - d_loss: 0.4341 - g_loss: 1.1670

  real_images = torch.tensor(real_images, device=device)


[1m 454/1094[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m33s[0m 52ms/step - d_loss: 0.4308 - g_loss: 1.1844

  real_images = torch.tensor(real_images, device=device)


[1m 478/1094[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m32s[0m 52ms/step - d_loss: 0.4279 - g_loss: 1.2002

  real_images = torch.tensor(real_images, device=device)


[1m 502/1094[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m30s[0m 52ms/step - d_loss: 0.4257 - g_loss: 1.2141

  real_images = torch.tensor(real_images, device=device)


[1m 525/1094[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m29s[0m 52ms/step - d_loss: 0.4239 - g_loss: 1.2264

  real_images = torch.tensor(real_images, device=device)


[1m 549/1094[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m28s[0m 52ms/step - d_loss: 0.4223 - g_loss: 1.2380

  real_images = torch.tensor(real_images, device=device)


[1m 573/1094[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m27s[0m 52ms/step - d_loss: 0.4213 - g_loss: 1.2482

  real_images = torch.tensor(real_images, device=device)


[1m 596/1094[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m26s[0m 52ms/step - d_loss: 0.4206 - g_loss: 1.2570

  real_images = torch.tensor(real_images, device=device)


[1m 619/1094[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m24s[0m 52ms/step - d_loss: 0.4199 - g_loss: 1.2654

  real_images = torch.tensor(real_images, device=device)


[1m 643/1094[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m23s[0m 52ms/step - d_loss: 0.4194 - g_loss: 1.2735

  real_images = torch.tensor(real_images, device=device)


[1m 667/1094[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m22s[0m 52ms/step - d_loss: 0.4191 - g_loss: 1.2807

  real_images = torch.tensor(real_images, device=device)


[1m 690/1094[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m21s[0m 52ms/step - d_loss: 0.4189 - g_loss: 1.2872

  real_images = torch.tensor(real_images, device=device)


[1m 714/1094[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m19s[0m 52ms/step - d_loss: 0.4189 - g_loss: 1.2937

  real_images = torch.tensor(real_images, device=device)


[1m 738/1094[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m18s[0m 52ms/step - d_loss: 0.4190 - g_loss: 1.2997

  real_images = torch.tensor(real_images, device=device)


[1m 762/1094[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m17s[0m 52ms/step - d_loss: 0.4194 - g_loss: 1.3049

  real_images = torch.tensor(real_images, device=device)


[1m 785/1094[0m [32m━━━━━━━━━━━━━━[0m[37m━━━━━━[0m [1m16s[0m 52ms/step - d_loss: 0.4198 - g_loss: 1.3095

  real_images = torch.tensor(real_images, device=device)


[1m 809/1094[0m [32m━━━━━━━━━━━━━━[0m[37m━━━━━━[0m [1m14s[0m 52ms/step - d_loss: 0.4203 - g_loss: 1.3139

  real_images = torch.tensor(real_images, device=device)


[1m 832/1094[0m [32m━━━━━━━━━━━━━━━[0m[37m━━━━━[0m [1m13s[0m 52ms/step - d_loss: 0.4209 - g_loss: 1.3178

  real_images = torch.tensor(real_images, device=device)


[1m 855/1094[0m [32m━━━━━━━━━━━━━━━[0m[37m━━━━━[0m [1m12s[0m 52ms/step - d_loss: 0.4216 - g_loss: 1.3214

  real_images = torch.tensor(real_images, device=device)


[1m 879/1094[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m11s[0m 52ms/step - d_loss: 0.4222 - g_loss: 1.3249

  real_images = torch.tensor(real_images, device=device)


[1m 903/1094[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m9s[0m 52ms/step - d_loss: 0.4230 - g_loss: 1.3279 

  real_images = torch.tensor(real_images, device=device)


[1m 927/1094[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m8s[0m 52ms/step - d_loss: 0.4239 - g_loss: 1.3307

  real_images = torch.tensor(real_images, device=device)


[1m 950/1094[0m [32m━━━━━━━━━━━━━━━━━[0m[37m━━━[0m [1m7s[0m 52ms/step - d_loss: 0.4245 - g_loss: 1.3337

  real_images = torch.tensor(real_images, device=device)


[1m 973/1094[0m [32m━━━━━━━━━━━━━━━━━[0m[37m━━━[0m [1m6s[0m 52ms/step - d_loss: 0.4250 - g_loss: 1.3367

  real_images = torch.tensor(real_images, device=device)


[1m 997/1094[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m5s[0m 52ms/step - d_loss: 0.4256 - g_loss: 1.3397

  real_images = torch.tensor(real_images, device=device)


[1m1022/1094[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m3s[0m 52ms/step - d_loss: 0.4263 - g_loss: 1.3427

  real_images = torch.tensor(real_images, device=device)


[1m1047/1094[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m2s[0m 52ms/step - d_loss: 0.4270 - g_loss: 1.3453

  real_images = torch.tensor(real_images, device=device)


[1m1070/1094[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m1s[0m 52ms/step - d_loss: 0.4277 - g_loss: 1.3475

  real_images = torch.tensor(real_images, device=device)


[1m1094/1094[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m57s[0m 52ms/step - d_loss: 0.4283 - g_loss: 1.3497


<keras.src.callbacks.history.History at 0x7b9938fb8c20>

The ideas behind deep learning are simple, so why should their implementation be painful?