In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["TF_FORCE_UNIFIED_MEMORY"]="1"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="16.0"
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
from einops.layers.tensorflow import Rearrange

In [None]:
class SwinTransformerBlock(tf.keras.layers.Layer):
    def __init__(self, hidden_dim, num_heads, mlp_ratio=4, drop_rate=0.0, **kwargs):
        super(SwinTransformerBlock, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.drop_rate = drop_rate

        self.norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
        self.attention = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=hidden_dim)
        self.drop1 = tf.keras.layers.Dropout(drop_rate)

        self.norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
        self.mlp = tf.keras.Sequential([
            tf.keras.layers.Dense(mlp_ratio * hidden_dim, activation=tf.keras.activations.gelu),
            tf.keras.layers.Dense(hidden_dim),
        ])
        self.drop2 = tf.keras.layers.Dropout(drop_rate)

    def call(self, x, training=None):
        residual = x
        x = self.norm1(x, training=training)
        x = self.attention(x, x, x)
        x = self.drop1(x, training=training)
        x += residual

        residual = x
        x = self.norm2(x, training=training)
        x = self.mlp(x)
        x = self.drop2(x, training=training)
        x += residual

        return x


class SwinTransformer(tf.keras.Model):
    def __init__(self, input_shape, window_size, patch_size, hidden_dim, num_heads, num_layers, channels, num_classes):
        super(SwinTransformer, self).__init__()

        self.window_size = window_size
        self.patch_size = patch_size
        self.channels = channels

        self.patch_proj = tf.keras.Sequential([
            Rearrange('b (h p1) (w p2) c -> b h w (p1 p2 c)', p1=patch_size, p2=patch_size),
            tf.keras.layers.Dense(hidden_dim),
        ])

        self.blocks = [
            SwinTransformerBlock(hidden_dim, num_heads, mlp_ratio=4, drop_rate=0.0)
            for _ in range(num_layers)
        ]

        self.final_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5)

        self.proj = tf.keras.layers.Dense(channels)

    def call(self, x, training=None):
        x = self.patch_proj(x)

        for block in self.blocks:
            x = block(x, training=training)

        x = self.final_norm(x, training=training)
        x = self.proj(x)
        x = tf.keras.activations.sigmoid(x)

        return x

In [None]:
# Constants
BATCH_SIZE = 1
EPOCHS = 100
LEARNING_RATE = 1e-4
DIV2K_SCALE = 2

# Load the DIV2K dataset
def load_div2k(scale=2):
    train, valid = tfds.load("div2k/bicubic_x2", split=["train", "validation"], as_supervised=True)

    def preprocessing(lr, hr):
        lr = tf.cast(lr, tf.float32) / 255.0
        hr = tf.cast(hr, tf.float32) / 255.0
        return lr, hr

    train = train.map(preprocessing, num_parallel_calls=tf.data.AUTOTUNE).cache().shuffle(100).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    valid = valid.map(preprocessing, num_parallel_calls=tf.data.AUTOTUNE).cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    return train, valid

train_dataset, valid_dataset = load_div2k(DIV2K_SCALE)

# Build the Swin Transformer model
def build_swin_transformer_model(scale=2):
    input_shape = (None, None, 3)

    lr_input = tf.keras.layers.Input(shape=input_shape)
    lr_upscaled = tf.keras.layers.UpSampling2D(size=(scale, scale), interpolation='bilinear')(lr_input)
    
    swin_transformer = SwinTransformer(
        input_shape=input_shape,
        window_size=8,
        patch_size=4,
        hidden_dim=128,
        num_heads=4,
        num_layers=2,
        channels=3,
        num_classes=0
    )

    hr_output = swin_transformer(lr_upscaled)

    model = tf.keras.Model(inputs=lr_input, outputs=hr_output)
    return model

model = build_swin_transformer_model(DIV2K_SCALE)

# Define loss function and optimizer
loss = tf.keras.losses.MeanAbsoluteError()
optimizer = tfa.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=1e-4)

# Compile and train the model
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_dataset, epochs=EPOCHS, validation_data=valid_dataset)