In [None]:
import torch
import torch.nn as nn
import fasttext

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, embedding_matrix, nhead, num_encoder_layers, num_decoder_layers):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        
        self.transformer = nn.Transformer(
            d_model=embedding_matrix.size(1), 
            nhead=nhead, 
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers
        )
        
        self.fc = nn.Linear(embedding_matrix.size(1), embedding_matrix.size(0))

    def forward(self, src, tgt):
        src_embedding = self.embedding(src)
        tgt_embedding = self.embedding(tgt)
        
        src_embedding = src_embedding.permute(1, 0, 2)  # Permute to (seq_length, batch_size, embedding_dim)
        tgt_embedding = tgt_embedding.permute(1, 0, 2)
        
        memory = self.transformer.encoder(src_embedding)
        output = self.transformer.decoder(tgt_embedding, memory)
        
        output = output.permute(1, 0, 2)  # Permute back to (batch_size, seq_length, embedding_dim)
        output = self.fc(output)
        
        return output

# Load pre-trained FastText Hindi embeddings
embedding_dim = 300  # Replace with the appropriate embedding dimension
embedding_matrix = torch.load("path_to_fasttext_hindi_embeddings.pt")

# Example usage
nhead = 4
num_encoder_layers = 3
num_decoder_layers = 3

model = TransformerModel(embedding_matrix, nhead, num_encoder_layers, num_decoder_layers)

# Generate random input tensors for demonstration
batch_size = 16
src_seq_length = 20
tgt_seq_length = 15

src_input = torch.randint(embedding_matrix.size(0), (src_seq_length, batch_size))
tgt_input = torch.randint(embedding_matrix.size(0), (tgt_seq_length, batch_size))

output = model(src_input, tgt_input)
print("Output shape:", output.shape)