In [1]:
%env HSA_OVERRIDE_GFX_VERSION=10.3.0
%env TF_FORCE_GPU_ALLOW_GROWTH=true
%env ROCM_PATH=/opt/rocm
#https://github.com/chaitanya100100/VAE-for-Image-Generation https://github.com/podgorskiy/VAE
import tensorflow as tf
tf.test.is_built_with_rocm()
# document export HSA_OVERRIDE_GFX_VERSION=10.3.0
physical_devices = tf.config.list_physical_devices('GPU')
try:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  pass

env: HSA_OVERRIDE_GFX_VERSION=10.3.0
env: TF_FORCE_GPU_ALLOW_GROWTH=true
env: ROCM_PATH=/opt/rocm


2023-03-18 18:46:55.617455: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-18 18:46:55.781604: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-03-18 18:46:57.278238: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-03-18 18:46:57.308033: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] suc

In [2]:
import numpy as np
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_probability as tfp
import tensorflow as tf


In [3]:
class VectorQuantizer(layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings

        # The `beta` parameter is best kept between [0.25, 2] as per the paper.
        self.beta = beta

        # Initialize the embeddings which we will quantize.
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, x):
        # Calculate the input shape of the inputs and
        # then flatten the inputs keeping `embedding_dim` intact.
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # Quantization.
        encoding_indices = self.get_code_indices(flattened)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)

        # Reshape the quantized values back to the original input shape
        quantized = tf.reshape(quantized, input_shape)

        # Calculate vector quantization loss and add that to the layer. You can learn more
        # about adding losses to different layers here:
        # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
        # the original paper to get a handle on the formulation of the loss function.
        commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(self.beta * commitment_loss + codebook_loss)

        # Straight-through estimator.
        quantized = x + tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate L2-normalized distance between the inputs and the codes.
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        distances = (
            tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
            + tf.reduce_sum(self.embeddings ** 2, axis=0)
            - 2 * similarity
        )

        # Derive the indices for minimum distances.
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices


In [4]:
def get_encoder(latent_dim=1536):
    encoder_inputs = keras.Input(shape=(64, 64, 1))
    x = layers.Conv2D(1536, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
    x = layers.Conv2D(3072, 3, activation="relu", strides=2, padding="same")(x)
    encoder_outputs = layers.Conv2D(latent_dim, 1, padding="same")(x)
    return keras.Model(encoder_inputs, encoder_outputs, name="encoder")

In [5]:
def get_decoder(latent_dim=1536):
    latent_inputs = keras.Input(shape=get_encoder(latent_dim).output.shape[1:])
    x = layers.Conv2DTranspose(3072, 3, activation="relu", strides=2, padding="same")(latent_inputs)
    x = layers.Conv2DTranspose(1536, 3, activation="relu", strides=2, padding="same")(x)
    decoder_outputs = layers.Conv2DTranspose(1, 3, padding="same")(x)
    return keras.Model(latent_inputs, decoder_outputs, name="decoder")

In [6]:
def get_vqvae(latent_dim=1536, num_embeddings=2048):
    vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
    encoder = get_encoder(latent_dim)
    decoder = get_decoder(latent_dim)
    inputs = keras.Input(shape=(64, 64, 1))
    encoder_outputs = encoder(inputs)
    quantized_latents = vq_layer(encoder_outputs)
    reconstructions = decoder(quantized_latents)
    return keras.Model(inputs, reconstructions, name="vq_vae")


get_vqvae().summary()


2023-03-18 18:46:57.900208: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-18 18:46:57.901256: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-03-18 18:46:57.901410: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-03-18 18:46:57.901503: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), bu

Model: "vq_vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 64, 64, 1)]       0         
                                                                 
 encoder (Functional)        (None, 16, 16, 2048)      83912704  
                                                                 
 vector_quantizer (VectorQua  (None, 16, 16, 2048)     4194304   
 ntizer)                                                         
                                                                 
 decoder (Functional)        (None, 64, 64, 1)         151019521 
                                                                 
Total params: 239,126,529
Trainable params: 239,126,529
Non-trainable params: 0
_________________________________________________________________


In [7]:
class VQVAETrainer(keras.models.Model):
    def __init__(self, train_variance, latent_dim=1536, num_embeddings=2048, **kwargs):
        super().__init__(**kwargs)
        self.train_variance = train_variance
        self.latent_dim = latent_dim
        self.num_embeddings = num_embeddings

        self.vqvae = get_vqvae(self.latent_dim, self.num_embeddings)

        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.vq_loss_tracker,
        ]

    def train_step(self, x):
        with tf.GradientTape() as tape:
            # Outputs from the VQ-VAE.
            reconstructions = self.vqvae(x)

            # Calculate the losses.
            reconstruction_loss = (
                tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
            )
            total_loss = reconstruction_loss + sum(self.vqvae.losses)

        # Backpropagation.
        grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))

        # Loss tracking.
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.vq_loss_tracker.update_state(sum(self.vqvae.losses))

        # Log results.
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "vqvae_loss": self.vq_loss_tracker.result(),
        }


In [8]:
#TODO document jupyter
import pickle
import configparser
import sqlite3 as sl
import pandas as pd
import numpy as np
from PIL import Image

configParser = configparser.RawConfigParser()   
configFilePath = r'configuration.txt'
configParser.read(configFilePath)
datasetPathDatabase =  configParser.get('COMMON', 'datasetPathDatabase') + '/dataset.db'

con = sl.connect(datasetPathDatabase)
data = con.execute("SELECT V.ID, V.VIDEO_PATH, V.AGE, V.ETHNICITY, V.GENDER, A.SPEAKER_EMB, A.LANG, F.FACE_PATH  FROM VIDEO V INNER JOIN AUDIO A ON V.ID = A.VIDEO_ID INNER JOIN FACE F ON V.ID = F.VIDEO_ID")
dataGotten = data.fetchall()
pd.set_option('display.max_columns', None)
df = pd.DataFrame(dataGotten,columns = ['ID','VIDEO_PATH','AGE','ETHNICITY','GENDER','SPEAKER_EMB','LANG','FACE_PATH'])
df['SPEAKER_EMB'] = df['SPEAKER_EMB'].apply(lambda x:pickle.loads(x))


In [9]:
df['SPEAKER_EMB'] = df['SPEAKER_EMB'].apply(lambda x:x.squeeze() )

In [10]:


def getImage(face_path):
    im = Image.open(face_path)

    im2 = im.convert('L')
    im3 = im2.resize((64,64))
    im3 = np.array(im2)
    #im4 = np.rollaxis(im3,2)
    return im3

df['IMAGE'] = df['FACE_PATH'].apply(lambda x:getImage(x) )


In [11]:
df.head(1)['IMAGE'][0].shape

(64, 64)

In [12]:
df.head(1)['IMAGE'][0]

array([[ 43,  40,  38, ..., 133, 121, 108],
       [ 39,  38,  35, ..., 131, 118, 104],
       [ 37,  34,  29, ..., 133, 111,  90],
       ...,
       [ 19,  16,  13, ..., 170, 162, 158],
       [ 23,  18,  12, ..., 164, 162, 162],
       [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8)

In [13]:
im = Image.fromarray(df.head(1)['IMAGE'][0])

In [14]:
x_train=df.sample(frac=0.8,random_state=200)
x_test=df.drop(x_train.index)

In [15]:
x_test

Unnamed: 0,ID,VIDEO_PATH,AGE,ETHNICITY,GENDER,SPEAKER_EMB,LANG,FACE_PATH,IMAGE
2,91,/home/gamal/Datasets/Dataset1/Video/id00039/fp...,31.0,asian,Woman,"[-9.479224, 1.5550817, -2.5814886, -9.568478, ...",Indonesian,/home/gamal/Datasets/Dataset1/Faces/id00039/fp...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
3,211,/home/gamal/Datasets/Dataset1/Video/id00126/8E...,26.0,asian,Man,"[-0.9427292, -25.27765, -12.866363, -24.726225...",Indonesian,/home/gamal/Datasets/Dataset1/Faces/id00126/8E...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
4,211,/home/gamal/Datasets/Dataset1/Video/id00126/8E...,26.0,asian,Man,"[-0.9427292, -25.27765, -12.866363, -24.726225...",Indonesian,/home/gamal/Datasets/Dataset1/Faces/id00126/8E...,"[[115, 116, 122, 122, 116, 118, 111, 111, 113,..."
5,211,/home/gamal/Datasets/Dataset1/Video/id00126/8E...,26.0,asian,Man,"[-0.9427292, -25.27765, -12.866363, -24.726225...",Indonesian,/home/gamal/Datasets/Dataset1/Faces/id00126/8E...,"[[137, 124, 103, 103, 113, 121, 113, 103, 106,..."
17,181,/home/gamal/Datasets/Dataset1/Video/id00126/8E...,30.0,asian,Man,"[-4.496821, -24.527327, 0.6800653, -7.232779, ...",Indonesian,/home/gamal/Datasets/Dataset1/Faces/id00126/8E...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
...,...,...,...,...,...,...,...,...,...
55256,17939,/home/gamal/Datasets/Dataset1/Video/id00111/ZX...,39.0,white,Woman,"[1.0937226, 17.07921, -37.014835, 35.95936, -2...",Polish,/home/gamal/Datasets/Dataset1/Faces/id00111/ZX...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
55257,17969,/home/gamal/Datasets/Dataset1/Video/id00111/sl...,27.0,white,Woman,"[-6.559404, -2.633592, -24.764732, -10.985298,...",English,/home/gamal/Datasets/Dataset1/Faces/id00111/sl...,"[[181, 235, 245, 234, 222, 243, 234, 219, 243,..."
55261,17939,/home/gamal/Datasets/Dataset1/Video/id00111/ZX...,39.0,white,Woman,"[3.5408385, 15.353499, -26.113169, 33.54378, -...",Polish,/home/gamal/Datasets/Dataset1/Faces/id00111/ZX...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
55265,17909,/home/gamal/Datasets/Dataset1/Video/id00111/Ks...,31.0,white,Woman,"[6.6956983, -8.293536, -8.618606, 13.906551, -...",English,/home/gamal/Datasets/Dataset1/Faces/id00111/Ks...,"[[33, 33, 46, 44, 34, 18, 26, 37, 20, 19, 38, ..."


In [16]:
x_train.head(1)

Unnamed: 0,ID,VIDEO_PATH,AGE,ETHNICITY,GENDER,SPEAKER_EMB,LANG,FACE_PATH,IMAGE
24881,7889,/home/gamal/Datasets/Dataset1/Video/id00029/TG...,37.0,latino hispanic,Man,"[-12.147309, 0.7924035, -3.731266, -13.940857,...",English,/home/gamal/Datasets/Dataset1/Faces/id00029/TG...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [17]:
x_train = x_train['IMAGE']


In [18]:
x_test = x_test['IMAGE']

In [19]:
x_train = x_train.values
x_test = x_test.values

In [20]:
x_train

array([array([[ 0,  0,  0, ...,  0,  0,  0],
              [17, 18, 23, ..., 15, 15, 15],
              [17, 18, 23, ..., 15, 15, 15],
              ...,
              [56, 52, 51, ..., 41, 30, 46],
              [58, 54, 52, ..., 37, 35, 53],
              [ 0,  0,  0, ...,  0,  0,  0]], dtype=uint8),
       array([[0, 0, 0, ..., 0, 0, 0],
              [0, 0, 0, ..., 0, 0, 0],
              [0, 0, 0, ..., 0, 0, 0],
              ...,
              [0, 0, 0, ..., 0, 0, 0],
              [0, 0, 0, ..., 0, 0, 0],
              [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
       array([[  0,   0,   0, ...,   0,   0,   0],
              [ 47,  51,  37, ...,  22,  38,  41],
              [ 48,  49,  30, ...,   9,  14,  23],
              ...,
              [ 37,  41,  41, ..., 107, 109, 110],
              [  0,   0,   0, ...,   0,   0,   0],
              [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8), ...,
       array([[165, 100, 103, ...,  34,  31,  30],
              [188, 136, 125, ..

In [21]:
x_tr = x_train
x_te = x_test

In [22]:
x_train = []
x_test = []

In [23]:
for e in x_tr:
    e = np.expand_dims(e, -1)
    x_train.append(e)

In [24]:
for e in x_te:
    e = np.expand_dims(e, -1)
    x_test.append(e)

In [25]:
x_train = np.array(x_train)
x_test = np.array(x_test)

In [26]:
x_train.shape

(44218, 64, 64, 1)

In [27]:


x_train_scaled = (x_train / 255.0) - 0.5
x_test_scaled = (x_test / 255.0) - 0.5




In [28]:
x_train[0]

array([[[ 0],
        [ 0],
        [ 0],
        ...,
        [ 0],
        [ 0],
        [ 0]],

       [[17],
        [18],
        [23],
        ...,
        [15],
        [15],
        [15]],

       [[17],
        [18],
        [23],
        ...,
        [15],
        [15],
        [15]],

       ...,

       [[56],
        [52],
        [51],
        ...,
        [41],
        [30],
        [46]],

       [[58],
        [54],
        [52],
        ...,
        [37],
        [35],
        [53]],

       [[ 0],
        [ 0],
        [ 0],
        ...,
        [ 0],
        [ 0],
        [ 0]]], dtype=uint8)

In [29]:
data_variance = np.var(x_train / 255.0)

In [30]:
data_variance

0.059315418150385386

In [31]:
vqvae_trainer = VQVAETrainer(data_variance, latent_dim=1536, num_embeddings=2048)
vqvae_trainer.compile(optimizer=keras.optimizers.Adam())
vqvae_trainer.fit(x_train_scaled, epochs=10, batch_size=10)


Epoch 1/2


2023-03-18 18:47:18.536060: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8401
2023-03-18 18:47:19.687978: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-03-18 18:47:19.700893: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x5627e5d49dd0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-03-18 18:47:19.700910: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): NVIDIA GeForce RTX 3060 Ti, Compute Capability 8.6
2023-03-18 18:47:19.705236: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-03-18 18:47:19.768806: I tensorflow/compiler/jit/xla_compilation_cache.cc:477] Compiled cluster using XLA!  This line is logged at most once for the lifetime 

ResourceExhaustedError: Graph execution error:

Detected at node 'StatefulPartitionedCall_9' defined at (most recent call last):
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 725, in start
      self.io_loop.start()
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
      await self.process_one()
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 502, in process_one
      await dispatch(*args)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
      await result
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_3078922/2750333534.py", line 3, in <module>
      vqvae_trainer.fit(x_train_scaled, epochs=2, batch_size=3)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/engine/training.py", line 1650, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/engine/training.py", line 1249, in train_function
      return step_function(self, iterator)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/engine/training.py", line 1233, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/engine/training.py", line 1222, in run_step
      outputs = model.train_step(data)
    File "/tmp/ipykernel_3078922/1119447023.py", line 37, in train_step
      self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1140, in apply_gradients
      return super().apply_gradients(grads_and_vars, name=name)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 634, in apply_gradients
      iteration = self._internal_apply_gradients(grads_and_vars)
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1166, in _internal_apply_gradients
      return tf.__internal__.distribute.interim.maybe_merge_call(
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1216, in _distributed_apply_gradients_fn
      distribution.extended.update(
    File "/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1211, in apply_grad_to_update_var
      return self._update_step_xla(grad, var, id(self._var_key(var)))
Node: 'StatefulPartitionedCall_9'
Out of memory while trying to allocate 301989888 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.12GiB
              constant allocation:         0B
        maybe_live_out allocation:  864.00MiB
     preallocated temp allocation:       260B
  preallocated temp fragmentation:       124B (47.69%)
                 total allocation:    1.12GiB
Peak buffers:
	Buffer 1:
		Size: 288.00MiB
		Operator: op_name="XLA_Args"
		Entry Parameter Subshape: f32[3,3,2048,4096]
		==========================

	Buffer 2:
		Size: 288.00MiB
		Operator: op_name="XLA_Args"
		Entry Parameter Subshape: f32[3,3,2048,4096]
		==========================

	Buffer 3:
		Size: 288.00MiB
		Operator: op_name="XLA_Args"
		Entry Parameter Subshape: f32[3,3,2048,4096]
		==========================

	Buffer 4:
		Size: 288.00MiB
		Operator: op_name="XLA_Args"
		Entry Parameter Subshape: f32[3,3,2048,4096]
		==========================

	Buffer 5:
		Size: 24B
		Operator: op_type="AssignSubVariableOp" op_name="AssignSubVariableOp" source_file="/home/gamal/anaconda3/envs/ds2f_m_i/lib/python3.8/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200
		XLA Label: fusion
		Shape: (f32[3,3,2048,4096], f32[3,3,2048,4096], f32[3,3,2048,4096])
		==========================

	Buffer 6:
		Size: 16B
		Operator: op_type="Sqrt" op_name="Sqrt"
		XLA Label: fusion
		Shape: (f32[], f32[])
		==========================

	Buffer 7:
		Size: 8B
		Operator: op_name="XLA_Args"
		Entry Parameter Subshape: s64[]
		==========================

	Buffer 8:
		Size: 4B
		Operator: op_type="Sqrt" op_name="Sqrt"
		XLA Label: fusion
		Shape: f32[]
		==========================

	Buffer 9:
		Size: 4B
		Operator: op_type="Sqrt" op_name="Sqrt"
		XLA Label: fusion
		Shape: f32[]
		==========================

	Buffer 10:
		Size: 4B
		Operator: op_name="XLA_Args"
		Entry Parameter Subshape: f32[]
		==========================


	 [[{{node StatefulPartitionedCall_9}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_2603]

In [None]:
def show_subplot(original, reconstructed):
    plt.subplot(1, 2, 1)
    plt.imshow(original.squeeze() + 0.5)
    plt.title("Original")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(reconstructed.squeeze() + 0.5)
    plt.title("Reconstructed")
    plt.axis("off")
    
    plt.show()


trained_vqvae_model = vqvae_trainer.vqvae
idx = np.random.choice(len(x_test_scaled), 10)
test_images = x_test_scaled[idx]
reconstructions_test = trained_vqvae_model.predict(test_images)

for test_image, reconstructed_image in zip(test_images, reconstructions_test):
    show_subplot(test_image, reconstructed_image)
