In [1]:
from models import *
from utils import *
import tensorflow as tf

In [2]:
class Transformer(Model):
    def __init__(self, input_dims=6, output_dims=6, n_layers=6, n_heads=8):
        super(Transformer, self).__init__()
        self.encoder = encoder(n_layers= n_layers, n_heads= n_heads, input_dims=input_dims)
        self.decoder = decoder(n_layers= n_layers, n_heads= n_heads, input_dims=input_dims)

    
    def call(self, x, y):
        x = self.encoder(x)
        x = self.decoder(y, x)
        return x
    
    def compile(self, optimizer, loss):
        super(Transformer, self).compile()
        self.optimizer = optimizer
        self.loss = loss
    
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, y)
            loss = self.loss(y, y_pred)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        return {'loss': loss}

In [3]:
y = tf.ones(shape=(10,300, 300), dtype=tf.float32)
x = tf.ones(shape=(10,300), dtype=tf.float32)

2024-11-14 22:17:31.698468: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-11-14 22:17:31.698493: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 8.00 GB
2024-11-14 22:17:31.698499: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 2.67 GB
2024-11-14 22:17:31.698544: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-11-14 22:17:31.698599: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [4]:
t = Transformer(input_dims= 300)
z = t(x, y)
t.summary()

Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (encoder)           multiple                  19055616  
                                                                 
 decoder (decoder)           multiple                  25824812  
                                                                 
Total params: 44880428 (171.21 MB)
Trainable params: 44880428 (171.21 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [9]:
z.get_shape

<bound method _EagerTensorBase.get_shape of <tf.Tensor: shape=(10, 300, 300), dtype=float32, numpy=
array([[[-1.2139132 , -0.85386574, -0.38919994, ...,  0.12991093,
          0.3868761 , -0.7477759 ],
        [-1.214941  , -0.8516314 , -0.3897114 , ...,  0.12777205,
          0.38981205, -0.74723107],
        [-1.2165864 , -0.84910727, -0.3894591 , ...,  0.12704293,
          0.39193863, -0.7473798 ],
        ...,
        [-1.2256557 , -0.8601747 , -0.38061914, ...,  0.14272445,
          0.39942598, -0.7473026 ],
        [-1.2259754 , -0.8606748 , -0.38110748, ...,  0.14259157,
          0.3990687 , -0.7473596 ],
        [-1.2266117 , -0.8613493 , -0.38169628, ...,  0.14250286,
          0.39835012, -0.7472714 ]],

       [[-1.2139132 , -0.85386574, -0.38919994, ...,  0.12991093,
          0.3868761 , -0.7477759 ],
        [-1.214941  , -0.8516314 , -0.3897114 , ...,  0.12777205,
          0.38981205, -0.74723107],
        [-1.2165864 , -0.84910727, -0.3894591 , ...,  0.12704293,
   