In [None]:
import torch
import torch.nn as nn
import math
import torch.optim as optim


class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
    
    def forward(self, x):
        return self.embedding(x)


class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_seq_len):
        super(PositionalEncoding, self).__init__()
        self.positional_encoding = torch.zeros(max_seq_len, embed_dim)
        
        for pos in range(max_seq_len):
            for i in range(0, embed_dim, 2):
                self.positional_encoding[pos, i] = math.sin(pos / (10000 ** ((2 * i)/embed_dim)))
                self.positional_encoding[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/embed_dim)))
        
        self.positional_encoding = self.positional_encoding.unsqueeze(0)
    
    def forward(self, x):
        return x + self.positional_encoding[:, :x.size(1), :]


class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        attn_output, _ = self.attention(x, x, x, attn_mask=mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        ff_output = self.feed_forward(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])
    
    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x


class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.cross_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, tgt_mask=None, memory_mask=None):
        self_attn_output, _ = self.self_attention(x, x, x, attn_mask=tgt_mask)
        x = x + self.dropout(self_attn_output)
        x = self.norm1(x)
        cross_attn_output, _ = self.cross_attention(x, enc_output, enc_output, attn_mask=memory_mask)
        x = x + self.dropout(cross_attn_output)
        x = self.norm2(x)
        ff_output = self.feed_forward(x)
        x = x + self.dropout(ff_output)
        x = self.norm3(x)
        return x


class TransformerDecoder(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])
    
    def forward(self, x, enc_output, tgt_mask=None, memory_mask=None):
        for layer in self.layers:
            x = layer(x, enc_output, tgt_mask, memory_mask)
        return x


class OutputLayer(nn.Module):
    def __init__(self, embed_dim, vocab_size):
        super(OutputLayer, self).__init__()
        self.linear = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, x):
        return self.linear(x)


class GenerativeQAModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_encoder_layers, num_decoder_layers, max_seq_len, dropout=0.1):
        super(GenerativeQAModel, self).__init__()
        self.embedding = EmbeddingLayer(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim, max_seq_len)
        self.encoder = TransformerEncoder(embed_dim, num_heads, ff_dim, num_encoder_layers, dropout)
        self.decoder = TransformerDecoder(embed_dim, num_heads, ff_dim, num_decoder_layers, dropout)
        self.output_layer = OutputLayer(embed_dim, vocab_size)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        src_embedded = self.embedding(src)
        tgt_embedded = self.embedding(tgt)
        
        src_pos_encoded = self.positional_encoding(src_embedded)
        tgt_pos_encoded = self.positional_encoding(tgt_embedded)
        
        enc_output = self.encoder(src_pos_encoded, src_mask)
        dec_output = self.decoder(tgt_pos_encoded, enc_output, tgt_mask, memory_mask)
        
        output = self.output_layer(dec_output)
        return output


# Example setup
model = GenerativeQAModel(vocab_size=30522, embed_dim=512, num_heads=8, ff_dim=2048, num_encoder_layers=6, num_decoder_layers=6, max_seq_len=512)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)
num_epochs = 10
train_dataloader = {}

# Training loop (simplified)
for epoch in range(num_epochs):
    model.train()
    for src, tgt in train_dataloader:
        optimizer.zero_grad()
        outputs = model(src, tgt[:, :-1])
        loss = criterion(outputs.view(-1, outputs.size(-1)), tgt[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()



# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.nn import TransformerDecoder, TransformerDecoderLayer

# class PositionalEncoding(nn.Module):
#     def __init__(self, d_model, max_len=5000):
#         super(PositionalEncoding, self).__init__()
#         pe = torch.zeros(max_len, d_model)
#         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
#         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         pe = pe.unsqueeze(0).transpose(0, 1)
#         self.register_buffer('pe', pe)

#     def forward(self, x):
#         return x + self.pe[:x.size(0), :]

# class GenerativeQAModel(nn.Module):
#     def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_len=1000):
#         super(GenerativeQAModel, self).__init__()
#         self.embedding = nn.Embedding(vocab_size, d_model)
#         self.pos_encoder = PositionalEncoding(d_model, max_len)
        
#         decoder_layer = TransformerDecoderLayer(d_model=d_model, nhead=nhead)
#         self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_layers)
        
#         self.fc_out = nn.Linear(d_model, vocab_size)
#         self.d_model = d_model
#         self.init_weights()

#     def init_weights(self):
#         initrange = 0.1
#         self.embedding.weight.data.uniform_(-initrange, initrange)
#         self.fc_out.bias.data.zero_()
#         self.fc_out.weight.data.uniform_(-initrange, initrange)

#     def forward(self, input_tokens, memory):
#         embedded = self.embedding(input_tokens) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
#         embedded = self.pos_encoder(embedded)
#         output = self.transformer_decoder(embedded, memory)
#         logits = self.fc_out(output)
#         return logits

# # Example of using the model:
# vocab_size = 30522  # Example vocab size (like GPT-2)
# model = GenerativeQAModel(vocab_size)

# # Dummy input for testing
# input_tokens = torch.randint(0, vocab_size, (50, 10))  # (sequence_length, batch_size)
# memory = torch.randn(50, 10, 512)  # Memory from encoder

# # Forward pass
# output = model(input_tokens, memory)
# print(output.shape)  # Should return (sequence_length, batch_size, vocab_size)
