In [3]:
import tensorflow as tf
import numpy as np
from swin_transformer_tensorflow.model.swin_transformer import SwinTransformer
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStack

 The versions of TensorFlow you are currently using is 2.12.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [4]:
def make_env(env_name="ALE/Pong-v5", seed=42):
    env = gym.make(env_name, render_mode="rgb_array", full_action_space=False, frameskip=1)
    env = AtariPreprocessing(env)
    # env = RecordEpisodeStatistics(env)
    env = FrameStack(env, 4)
    env.observation_space.seed(seed)
    env.action_space.seed(seed)

    return env

env = make_env()
state, _ = env.reset()
state = np.array(state)
state = np.expand_dims(state, 0)
state

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


array([[[[ 52,  52,  52, ...,  87,  87,  87],
         [ 87,  87,  87, ...,  87,  87,  87],
         [ 87,  87,  87, ...,  87,  87,  87],
         ...,
         [236, 236, 236, ..., 236, 236, 236],
         [236, 236, 236, ..., 236, 236, 236],
         [236, 236, 236, ..., 236, 236, 236]],

        [[ 52,  52,  52, ...,  87,  87,  87],
         [ 87,  87,  87, ...,  87,  87,  87],
         [ 87,  87,  87, ...,  87,  87,  87],
         ...,
         [236, 236, 236, ..., 236, 236, 236],
         [236, 236, 236, ..., 236, 236, 236],
         [236, 236, 236, ..., 236, 236, 236]],

        [[ 52,  52,  52, ...,  87,  87,  87],
         [ 87,  87,  87, ...,  87,  87,  87],
         [ 87,  87,  87, ...,  87,  87,  87],
         ...,
         [236, 236, 236, ..., 236, 236, 236],
         [236, 236, 236, ..., 236, 236, 236],
         [236, 236, 236, ..., 236, 236, 236]],

        [[ 52,  52,  52, ...,  87,  87,  87],
         [ 87,  87,  87, ...,  87,  87,  87],
         [ 87,  87,  87, ...,  8

In [6]:
class SwinTransformerAtariBlock(tf.keras.Model):
    def __init__(self, num_actions):
        super(SwinTransformerAtariBlock, self).__init__()

        # Preprocessing phase
        self.rescaling = tf.keras.layers.Rescaling(scale=1.0 / 255)
        self.swin = SwinTransformer(img_size=84, patch_size=3, in_chans=4, embed_dim=96, depths=[2, 3, 2], num_heads=[3, 3, 6], window_size=7, mlp_ratio=4.0, drop_path_rate=0.1)
        self.action_outputs = tf.keras.layers.Dense(num_actions, name="action", activation="linear")

    def call(self, inputs, training=None, **kwargs):
        rescaled = self.rescaling(inputs)
        swin_outputs = self.swin(rescaled)
        logits = self.action_outputs(swin_outputs)

        return logits

    def build(self, input_shape):
        self.swin(tf.zeros(input_shape), training=False)

#tf.random.set_seed(42)
model = SwinTransformerAtariBlock(6)
output = model(state, training=False)
#output
model.summary(expand_nested=True)

Model: "swin_transformer_atari_block_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rescaling_2 (Rescaling)     multiple                  0         
                                                                 
 swin_transformer_1 (SwinTra  multiple                 5872219   
 nsformer)                                                       
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| patch_embed_1 (PatchEmbed)  multiple                3744      |
|                                                               |
| dropout_22 (Dropout)      multiple                  0         |
|                                                               |
| basic_layers_seq (Sequentia  multiple               5482707   |
| l)                                                            |
||¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯||
|| basic_layer_3 (BasicLayer)  multi

configuration = SwinConfig(
     image_size=84,
     patch_size=3,
     num_channels=4,
     embed_dim=96,
     depths=[2, 3, 2],
     num_heads=[3, 3, 6],
     window_size=7,
     mlp_ratio=4.0,
     drop_path_rate=0.1,
)