In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the Generator
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        embedded = self.embedding(input)
        output, _ = self.rnn(embedded)
        output = self.fc(output)
        return output

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, input):
        embedded = self.embedding(input)
        output, _ = self.rnn(embedded)
        output = self.fc(output)
        return output






In [None]:
# Hyperparameters
vocab_size = 10000  # Adjust based on your dataset
embedding_size = 128
hidden_size = 256
seq_length = 20  # Adjust based on your prompt length
lr = 0.001
batch_size = 64

In [None]:
# Instantiate Generator and Discriminator
generator = Generator(vocab_size, embedding_size, vocab_size)
discriminator = Discriminator(vocab_size, embedding_size)

# Loss and Optimizer
criterion = nn.BCEWithLogitsLoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

# Training Loop
num_epochs = 1000  # Adjust based on your dataset and convergence
for epoch in range(num_epochs):
    # Training the Discriminator
    real_data = torch.randint(0, vocab_size, (batch_size, seq_length))
    generated_recepies = generator(torch.randint(0, vocab_size, (batch_size, seq_length)))

    real_labels = torch.ones((batch_size, 1))
    fake_labels = torch.zeros((batch_size, 1))

    disc_optimizer.zero_grad()

    real_output = discriminator(real_data)
    real_loss = criterion(real_output, real_labels)

    fake_output = discriminator(generated_recepies.detach())
    fake_loss = criterion(fake_output, fake_labels)

    disc_loss = real_loss + fake_loss
    disc_loss.backward()
    disc_optimizer.step()

    # Training the Generator
    gen_optimizer.zero_grad()
    fake_output = discriminator(generated_recepies)
    gen_loss = criterion(fake_output, real_labels)
    gen_loss.backward()
    gen_optimizer.step()

    # Print losses
    if epoch % 100 == 0:
        print(f'Epoch [{epoch}/{num_epochs}], Disc Loss: {disc_loss.item()}, Gen Loss: {gen_loss.item()}')