<a href="https://colab.research.google.com/github/Swarna1804/GenerativeAI/blob/main/VEA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
class TextVAE(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, latent_dim, pad_idx):
        super(TextVAE, self).__init__()
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        # Encoder
        self.encoder_rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.mu_layer = nn.Linear(hidden_dim * 2, latent_dim)
        self.logvar_layer = nn.Linear(hidden_dim * 2, latent_dim)
        # Decoder
        self.decoder_initial = nn.Linear(latent_dim, hidden_dim)
        self.decoder_rnn = nn.GRU(embedding_dim + hidden_dim, hidden_dim, batch_first=True)
        self.output_layer = nn.Linear(hidden_dim, vocab_size)
        # Save dimensions
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
    def encode(self, x, lengths):
        # Embed input
        embedded = self.embedding(x)
        # Pack sequence for RNN
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        # Encode
        _, hidden = self.encoder_rnn(packed)
        hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
        # Get latent parameters
        mu = self.mu_layer(hidden)
        logvar = self.logvar_layer(hidden)
        return mu, logvar
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    def decode(self, z, target, teacher_forcing_ratio=0.5):
        batch_size = z.size(0)
        max_length = target.size(1)
        # Initialize decoder hidden state
        hidden = torch.tanh(self.decoder_initial(z)).unsqueeze(0)
        # Initialize first input as start token
        decoder_input = target[:, 0].unsqueeze(1)
        # Store outputs
        outputs = torch.zeros(batch_size, max_length, self.output_layer.out_features).to(z.device)
        for t in range(max_length):
            # Embed input
            embedded = self.embedding(decoder_input)
            # Concatenate with latent vector
            z_repeat = z.unsqueeze(1).repeat(1, embedded.size(1), 1)
            decoder_input_combined = torch.cat((embedded, z_repeat), dim=2)
            # Decode
            output, hidden = self.decoder_rnn(decoder_input_combined, hidden)
            prediction = self.output_layer(output)
            # Store prediction
            outputs[:, t:t+1] = prediction
            # Teacher forcing
            use_teacher_forcing = torch.rand(1).item() < teacher_forcing_ratio
            decoder_input = target[:, t:t+1] if use_teacher_forcing else prediction.argmax(2)
        return outputs
    def forward(self, x, lengths, target, teacher_forcing_ratio=0.5):
        # Encode
        mu, logvar = self.encode(x, lengths)
        # Sample latent vector
        z = self.reparameterize(mu, logvar)
        # Decode
        outputs = self.decode(z, target, teacher_forcing_ratio)
        return outputs, mu, logvar
def vae_loss(predictions, target, mu, logvar, pad_idx):
    # Reconstruction loss
    rec_loss = F.cross_entropy(
        predictions.view(-1, predictions.size(-1)),
        target.view(-1),
        ignore_index=pad_idx,
        reduction='sum'
    )
    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return rec_loss + kl_loss
# Example usage
def train_vae(model, train_loader, optimizer, pad_idx, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, lengths, target) in enumerate(train_loader):
            optimizer.zero_grad()
            # Forward pass
            predictions, mu, logvar = model(data, lengths, target)
            # Calculate loss
            loss = vae_loss(predictions, target, mu, logvar, pad_idx)
            # Backward pass
            loss.backward()
            # Update parameters
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}, Average loss: {avg_loss:.4f}')
def generate_text(model, start_token, max_length, temperature=1.0):
    model.eval()
    with torch.no_grad():
        # Sample from latent space
        z = torch.randn(1, model.latent_dim)
        # Initialize sequence with start token
        current_token = torch.tensor([[start_token]])
        generated = [start_token]
        # Initialize hidden state
        hidden = torch.tanh(model.decoder_initial(z)).unsqueeze(0)
        for _ in range(max_length):
            # Embed input
            embedded = model.embedding(current_token)
            # Concatenate with latent vector
            z_repeat = z.unsqueeze(1)
            decoder_input = torch.cat((embedded, z_repeat), dim=2)
            # Decode
            output, hidden = model.decoder_rnn(decoder_input, hidden)
            prediction = model.output_layer(output)
            # Apply temperature
            prediction = prediction.squeeze() / temperature
            probabilities = F.softmax(prediction, dim=-1)
            # Sample from distribution
            current_token = torch.multinomial(probabilities, 1).unsqueeze(0)
            generated.append(current_token.item())
        return generated

In [None]:
import torch
from torch.utils.data import Dataset
from collections import Counter
import numpy as np
# Sample data
sample_texts = [
    "hello world",
    "machine learning is amazing",
    "deep learning with pytorch",
    "natural language processing",
    "artificial intelligence research"
]
class Vocabulary:
    def __init__(self, texts, min_freq=1):
        self.stoi = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
        self.itos = ['<pad>', '<sos>', '<eos>', '<unk>']
        word_counts = Counter()
        for text in texts:
            word_counts.update(text.split())
        for word, count in word_counts.items():
            if count >= min_freq and word not in self.stoi:
                self.stoi[word] = len(self.itos)
                self.itos.append(word)
    def __len__(self):
        return len(self.itos)
class TextDataset(Dataset):
    def __init__(self, texts, vocab):
        self.texts = texts
        self.vocab = vocab
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        words = text.split()
        indices = [self.vocab.stoi['<sos>']]
        indices.extend([self.vocab.stoi.get(word, self.vocab.stoi['<unk>']) for word in words])
        indices.append(self.vocab.stoi['<eos>'])
        return torch.tensor(indices)
def collate_fn(batch):
    # Get lengths before sorting
    lengths = torch.tensor([len(x) for x in batch])
    # Sort by length
    lengths, sort_idx = lengths.sort(descending=True)
    batch = [batch[i] for i in sort_idx]
    # Pad sequences
    padded = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)
    return padded, lengths, padded
class TextVAE(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, latent_dim, pad_idx):
        super(TextVAE, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.encoder_rnn = torch.nn.GRU(
            embedding_dim,
            hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        self.mu_layer = torch.nn.Linear(hidden_dim * 2, latent_dim)
        self.logvar_layer = torch.nn.Linear(hidden_dim * 2, latent_dim)
        self.decoder_initial = torch.nn.Linear(latent_dim, hidden_dim)
        self.decoder_rnn = torch.nn.GRU(
            embedding_dim + latent_dim,
            hidden_dim,
            batch_first=True
        )
        self.output_layer = torch.nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
    def forward(self, x, lengths, target, teacher_forcing_ratio=0.5):
        mu, logvar = self.encode(x, lengths)
        # Sample from latent space
        z = self.reparameterize(mu, logvar)
        # Decode
        outputs = self.decode(z, target, teacher_forcing_ratio)
        return outputs, mu, logvar
    def encode(self, x, lengths):
        # Embed input
        embedded = self.embedding(x)
        # Pack sequence
        packed = torch.nn.utils.rnn.pack_padded_sequence(
            embedded, lengths.cpu(), batch_first=True, enforce_sorted=True
        )
        # Encode
        _, hidden = self.encoder_rnn(packed)
        hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
        # Get latent parameters
        mu = self.mu_layer(hidden)
        logvar = self.logvar_layer(hidden)
        return mu, logvar
    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu
    def decode(self, z, target, teacher_forcing_ratio=0.5):
        batch_size = z.size(0)
        max_length = target.size(1)
        # Initialize decoder hidden state
        hidden = torch.tanh(self.decoder_initial(z)).unsqueeze(0)
        # Initialize first input
        decoder_input = target[:, 0].unsqueeze(1)
        outputs = torch.zeros(
            batch_size,
            max_length,
            self.output_layer.out_features
        ).to(z.device)
        for t in range(max_length):
            embedded = self.embedding(decoder_input)
            z_repeat = z.unsqueeze(1).expand(-1, embedded.size(1), -1)
            decoder_input_combined = torch.cat((embedded, z_repeat), dim=2)
            output, hidden = self.decoder_rnn(decoder_input_combined, hidden)
            prediction = self.output_layer(output)
            outputs[:, t:t+1] = prediction
            use_teacher_forcing = torch.rand(1).item() < teacher_forcing_ratio
            decoder_input = target[:, t:t+1] if use_teacher_forcing else prediction.argmax(2)
        return outputs
def vae_loss(predictions, target, mu, logvar, pad_idx):
    """Calculate VAE loss: reconstruction loss + KL divergence"""
    rec_loss = torch.nn.functional.cross_entropy(
        predictions.view(-1, predictions.size(-1)),
        target.view(-1),
        ignore_index=pad_idx,
        reduction='sum'
    )
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return rec_loss + kl_loss
def demo_vae():
    # Create vocabulary and dataset
    vocab = Vocabulary(sample_texts)
    dataset = TextDataset(sample_texts, vocab)
    # Create dataloader
    batch_size = 2
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=False
    )
    # Initialize model
    model = TextVAE(
        vocab_size=len(vocab),
        embedding_dim=128,
        hidden_dim=256,
        latent_dim=32,
        pad_idx=vocab.stoi['<pad>']
    )
    # Training
    optimizer = torch.optim.Adam(model.parameters())
    model.train()
    print("Training the model...")
    for epoch in range(50):
        total_loss = 0
        for batch_idx, (data, lengths, target) in enumerate(train_loader):
            optimizer.zero_grad()
            # Forward pass
            predictions, mu, logvar = model(data, lengths, target)
            # Calculate loss
            loss = vae_loss(predictions, target, mu, logvar, vocab.stoi['<pad>'])
            # Backward pass
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader.dataset)
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}, Average loss: {avg_loss:.4f}')
if __name__ == "__main__":
    demo_vae()

Training the model...
Epoch 10, Average loss: 7.0274
Epoch 20, Average loss: 7.4995
Epoch 30, Average loss: 4.7260
Epoch 40, Average loss: 3.7060
Epoch 50, Average loss: 3.2084


In [None]:
def generate_text(model, vocab, max_length=10, temperature=0.8):
    model.eval()  # Set to evaluation mode
    with torch.no_grad():
        # Sample from latent space
        z = torch.randn(1, model.latent_dim)

        # Initialize with start token
        current_token = torch.tensor([[vocab.stoi['<sos>']]])

        # Initialize hidden state
        hidden = torch.tanh(model.decoder_initial(z)).unsqueeze(0)

        generated_tokens = []

        for _ in range(max_length):
            # Embed current token
            embedded = model.embedding(current_token)

            # Prepare decoder input
            z_repeat = z.unsqueeze(1)
            decoder_input = torch.cat((embedded, z_repeat), dim=2)

            # Generate next token
            output, hidden = model.decoder_rnn(decoder_input, hidden)
            prediction = model.output_layer(output)

            # Apply temperature
            prediction = prediction.squeeze() / temperature
            probs = torch.nn.functional.softmax(prediction, dim=-1)

            # Sample from the distribution
            next_token = torch.multinomial(probs, 1)

            # Break if end token is generated
            if next_token.item() == vocab.stoi['<eos>']:
                break

            generated_tokens.append(next_token.item())
            current_token = next_token.unsqueeze(0)

        # Convert tokens to words
        generated_words = [vocab.itos[idx] for idx in generated_tokens
                         if idx not in [vocab.stoi['<pad>'], vocab.stoi['<sos>'],
                                      vocab.stoi['<eos>'], vocab.stoi['<unk>']]]

        return ' '.join(generated_words)

def demo_vae():
    # Create vocabulary and dataset
    vocab = Vocabulary(sample_texts)
    dataset = TextDataset(sample_texts, vocab)

    # Create dataloader
    batch_size = 2
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=False
    )

    # Initialize model
    model = TextVAE(
        vocab_size=len(vocab),
        embedding_dim=128,
        hidden_dim=256,
        latent_dim=32,
        pad_idx=vocab.stoi['<pad>']
    )

    # Training
    optimizer = torch.optim.Adam(model.parameters())
    model.train()

    print("Training the model...")
    for epoch in range(50):
        total_loss = 0
        for batch_idx, (data, lengths, target) in enumerate(train_loader):
            optimizer.zero_grad()

            # Forward pass
            predictions, mu, logvar = model(data, lengths, target)

            # Calculate loss
            loss = vae_loss(predictions, target, mu, logvar, vocab.stoi['<pad>'])

            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader.dataset)
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}, Average loss: {avg_loss:.4f}')

    print("\nGenerating sample texts:")
    for i in range(5):
        generated = generate_text(model, vocab, max_length=15, temperature=0.8)
        print(f"Sample {i+1}: {generated}")

if __name__ == "__main__":
    demo_vae()

Training the model...
Epoch 10, Average loss: 8.0573
Epoch 20, Average loss: 5.7624
Epoch 30, Average loss: 7.9546
Epoch 40, Average loss: 4.7905
Epoch 50, Average loss: 5.4193

Generating sample texts:
Sample 1: machine learning is amazing
Sample 2: hello world
Sample 3: deep learning with pytorch
Sample 4: natural language processing
Sample 5: natural world
