In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from sklearn.preprocessing import LabelEncoder

In [None]:


# Define the VAE Model
class VAE(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc_decode = nn.Linear(latent_dim, hidden_dim)
        self.rnn_decode = nn.LSTM(embedding_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def encode(self, x):
        embedded = self.embedding(x)
        _, (hidden, _) = self.rnn(embedded)
        mu = self.fc_mu(hidden.squeeze(0))
        logvar = self.fc_logvar(hidden.squeeze(0))
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        hidden = self.fc_decode(z).unsqueeze(0)
        output, _ = self.rnn_decode(hidden)
        output = self.fc_out(output.squeeze(0))
        return output

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        output = self.decode(z)
        return output, mu, logvar


# VAE Loss function
def vae_loss(recon_x, x, mu, logvar):
    # Reconstruction loss (cross-entropy)
    BCE = nn.CrossEntropyLoss()(recon_x.view(-1, recon_x.size(-1)), x.view(-1))
    
    MSE = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + MSE


# Dataset class for text corpus
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len=512):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')
        return encoding['input_ids'].squeeze(0)  # Return tokenized input


# Initialize model, tokenizer, and data
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
vocab_size = tokenizer.vocab_size
embedding_dim = 256
hidden_dim = 512
latent_dim = 128
vae = VAE(vocab_size, embedding_dim, hidden_dim, latent_dim)

# Example corpus
corpus = [
    "Artificial intelligence is transforming industries.",
    "Machine learning is a subset of AI that focuses on building systems that can learn from data.",
    "Neural networks are a class of machine learning models inspired by the brain's structure."
]

# Prepare dataset and dataloader
dataset = TextDataset(corpus, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Training loop
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

epochs = 10
for epoch in range(epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        output, mu, logvar = vae(batch)
        loss = vae_loss(output, batch, mu, logvar)
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")


# Function to generate prompt for LLM (GPT-2)
def generate_prompt(latent_vector):
    # Decode latent vector to generate a prompt
    decoded_prompt = vae.decode(latent_vector)
    decoded_prompt = torch.argmax(decoded_prompt, dim=-1)
    prompt_text = tokenizer.decode(decoded_prompt.squeeze().cpu().numpy())
    return prompt_text


# Generate a prompt and pass it to an LLM (GPT-2)
latent_vector = torch.randn(1, latent_dim)  # Sample a latent vector
prompt = generate_prompt(latent_vector)

# Load GPT-2 model and tokenizer
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')

# Encode the prompt and generate continuation
input_ids = tokenizer.encode(prompt, return_tensors='pt')
output = gpt2.generate(input_ids, max_length=50, num_return_sequences=1)

# Decode and display the result
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Generated Prompt: {prompt}")
print(f"GPT-2 Response: {generated_text}")
