In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, GRU, Dense, Bidirectional
from tensorflow.keras.models import Model

# Hyperparameters
embedding_dim = 300  # Embedding dimensions
hidden_units = 512  # Hidden units for GRU
vocab_size = 10000  # Size of the vocabulary
max_sentence_length = 50  # Maximum length of a sentence

# Encoder
class SkipThoughtEncoder(Model):
    def __init__(self, vocab_size, embedding_dim, hidden_units):
        super(SkipThoughtEncoder, self).__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.bidirectional_gru = Bidirectional(GRU(hidden_units, return_sequences=False))

    def call(self, input_sentence):
        embedded_sentence = self.embedding(input_sentence)
        sentence_vector = self.bidirectional_gru(embedded_sentence)
        return sentence_vector

# Decoder
class SkipThoughtDecoder(Model):
    def __init__(self, vocab_size, embedding_dim, hidden_units):
        super(SkipThoughtDecoder, self).__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.gru = GRU(hidden_units, return_sequences=True, return_state=True)
        self.dense = Dense(vocab_size)

    def call(self, input_sentence, initial_state):
        embedded_sentence = self.embedding(input_sentence)
        output, state = self.gru(embedded_sentence, initial_state=initial_state)
        logits = self.dense(output)
        return logits, state

# Build the Skip-Thought model
class SkipThoughtModel(Model):
    def __init__(self, vocab_size, embedding_dim, hidden_units):
        super(SkipThoughtModel, self).__init__()
        self.encoder = SkipThoughtEncoder(vocab_size, embedding_dim, hidden_units)
        self.decoder_prev = SkipThoughtDecoder(vocab_size, embedding_dim, hidden_units)
        self.decoder_next = SkipThoughtDecoder(vocab_size, embedding_dim, hidden_units)

    def call(self, input_sentence, target_prev_sentence, target_next_sentence):
        sentence_vector = self.encoder(input_sentence)
        logits_prev, _ = self.decoder_prev(target_prev_sentence, sentence_vector)
        logits_next, _ = self.decoder_next(target_next_sentence, sentence_vector)
        return logits_prev, logits_next

model = SkipThoughtModel(vocab_size=vocab_size, embedding_dim=embedding_dim, hidden_units=hidden_units)

# Input tensors
input_sentence = tf.random.uniform((32, max_sentence_length), minval=0, maxval=vocab_size, dtype=tf.int32)  # batch_size=32
target_prev_sentence = tf.random.uniform((32, max_sentence_length), minval=0, maxval=vocab_size, dtype=tf.int32)
target_next_sentence = tf.random.uniform((32, max_sentence_length), minval=0, maxval=vocab_size, dtype=tf.int32)

# Forward pass
logits_prev, logits_next = model(input_sentence, target_prev_sentence, target_next_sentence)

# Print shapes
print("Logits for previous sentence prediction:", logits_prev.shape)
print("Logits for next sentence prediction:", logits_next.shape)