In [7]:
from models import *
from utils import *
import tensorflow as tf

In [3]:
class Transformer(Model):
    def __init__(self, input_dims=6, output_dims=6, n_layers=6, n_heads=8):
        super(Transformer, self).__init__()
        self.encoder = encoder(n_layers= n_layers, n_heads= n_heads, input_dims=input_dims)
        self.decoder = decoder(n_layers= n_layers, n_heads= n_heads, input_dims=input_dims)

    
    def call(self, x, y):
        x = self.encoder(x)
        x = self.decoder(y, x)
        return x
    
    def compile(self, optimizer, loss):
        super(Transformer, self).compile()
        self.optimizer = optimizer
        self.loss = loss
    
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, y)
            loss = self.loss(y, y_pred)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        return {'loss': loss}

In [4]:
y = tf.ones(shape=(10,300, 300), dtype=tf.float32)
x = tf.ones(shape=(10,300), dtype=tf.float32)

2024-09-02 13:19:30.956894: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-09-02 13:19:30.957711: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 8.00 GB
2024-09-02 13:19:30.957745: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 2.67 GB
2024-09-02 13:19:30.959000: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-09-02 13:19:30.960487: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [5]:
t = Transformer(input_dims= 300)
z = t(x, y)
t.summary()

Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (encoder)           multiple                  19055616  
                                                                 
 decoder (decoder)           multiple                  25824812  
                                                                 
Total params: 44880428 (171.21 MB)
Trainable params: 44880428 (171.21 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [8]:
z

<tf.Tensor: shape=(10, 300, 300), dtype=float32, numpy=
array([[[-0.29845726,  2.29079   , -0.12942715, ...,  0.45482853,
          0.33290613,  0.04219987],
        [-0.29891664,  2.2901628 , -0.1308194 , ...,  0.45329648,
          0.33261505,  0.04334272],
        [-0.2989568 ,  2.289533  , -0.13213918, ...,  0.4520475 ,
          0.3329478 ,  0.04350449],
        ...,
        [-0.30312443,  2.2996764 , -0.13052866, ...,  0.46019652,
          0.33033192,  0.03498219],
        [-0.30304295,  2.2995281 , -0.12877221, ...,  0.45966458,
          0.33046624,  0.03400264],
        [-0.30250347,  2.299864  , -0.12739386, ...,  0.4588153 ,
          0.3310257 ,  0.03437804]],

       [[-0.29845726,  2.29079   , -0.12942715, ...,  0.45482853,
          0.33290613,  0.04219987],
        [-0.29891664,  2.2901628 , -0.1308194 , ...,  0.45329648,
          0.33261505,  0.04334272],
        [-0.2989568 ,  2.289533  , -0.13213918, ...,  0.4520475 ,
          0.3329478 ,  0.04350449],
        ...