#### Packages

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#### Generator

In [None]:
class Generator(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim, sequence_length, start_token, device='cuda'):

        super(Generator, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = start_token
        self.device = device
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # LSTM cell
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        
        # Output layer (maps hidden state to token logits)
        self.output_layer = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, hidden=None):
        """
        Forward pass for training with teacher forcing
        
        Args:
            x: Input tensor of shape [batch_size, seq_length]
            hidden: Initial hidden state (optional)
            
        Returns:
            logits: Output logits of shape [batch_size, seq_length, vocab_size]
        """
        batch_size = x.size(0)
        
        # Initialize hidden state if not provided
        if hidden is None:
            hidden = self.init_hidden(batch_size)
            
        # Embed input tokens
        embedded = self.embedding(x)  # [batch_size, seq_length, embedding_dim]
        
        # Pass through LSTM
        lstm_out, hidden = self.lstm(embedded, hidden)  # [batch_size, seq_length, hidden_dim]
        
        # Project to vocabulary space
        logits = self.output_layer(lstm_out)  # [batch_size, seq_length, vocab_size]
        
        return logits, hidden
    
    def init_hidden(self, batch_size):
        """Initialize hidden state"""
        h0 = torch.zeros(1, batch_size, self.hidden_dim).to(self.device)
        c0 = torch.zeros(1, batch_size, self.hidden_dim).to(self.device)
        return (h0, c0)
    
    def sample(self, batch_size, seq_length=None):
        """
        Sample sequences from the generator
        
        Args:
            batch_size: Number of sequences to generate
            seq_length: Length of sequences (defaults to self.sequence_length)
            
        Returns:
            samples: Generated sequences of shape [batch_size, seq_length]
        """
        if seq_length is None:
            seq_length = self.sequence_length
            
        # Start with start tokens
        samples = torch.full((batch_size, 1), self.start_token, 
                            dtype=torch.long, device=self.device)
        
        # Initial hidden state
        hidden = self.init_hidden(batch_size)
        
        # Generate one token at a time
        for i in range(seq_length - 1):
            # Get the last token
            input_tokens = samples[:, -1].unsqueeze(1)
            
            # Embed it
            embedded = self.embedding(input_tokens)
            
            # Get next hidden state and output
            lstm_out, hidden = self.lstm(embedded, hidden)
            logits = self.output_layer(lstm_out.squeeze(1))
            
            # Sample next token
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            
            # Append to samples
            samples = torch.cat([samples, next_token], dim=1)
            
        return samples
    
    def pretrain_step(self, x, optimizer):
        """
        Supervised pre-training step using maximum likelihood
        
        Args:
            x: Input sequences [batch_size, seq_length]
            optimizer: PyTorch optimizer
            
        Returns:
            loss: Training loss for this batch
        """
        # Shift input and target
        inp = x[:, :-1]
        target = x[:, 1:].contiguous().view(-1)
        
        # Forward pass
        logits, _ = self.forward(inp)
        logits = logits.contiguous().view(-1, self.vocab_size)
        
        # Calculate loss
        loss = F.cross_entropy(logits, target)
        
        # Update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()
    
    def reinforce_step(self, x, rewards, optimizer):
        """
        Policy gradient update step
        
        Args:
            x: Input sequences [batch_size, seq_length]
            rewards: Rewards from rollout [batch_size, seq_length]
            optimizer: PyTorch optimizer
            
        Returns:
            loss: Policy gradient loss
        """
        # Shift input and calculate log probs
        inp = x[:, :-1]
        
        # Forward pass
        logits, _ = self.forward(inp)
        
        # Calculate log probabilities
        log_probs = F.log_softmax(logits, dim=-1)
        
        # Get the log probability of each chosen token
        target = x[:, 1:].contiguous()
        batch_size, seq_length = target.size()
        
        # Create one-hot encoding of targets
        one_hot = torch.zeros(batch_size, seq_length, self.vocab_size).to(self.device)
        one_hot.scatter_(2, target.unsqueeze(2), 1)
        
        # Calculate selected log probabilities and multiply by rewards
        selected_log_probs = torch.sum(log_probs * one_hot, dim=-1)
        rewards = rewards[:, 1:].contiguous()  # Align with targets
        
        # Policy gradient loss
        loss = -torch.sum(selected_log_probs * rewards) / batch_size
        
        # Update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()

#### Discriminator

In [None]:
# models/discriminator.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class Highway(nn.Module):
    """Highway Network (cf. http://arxiv.org/abs/1505.00387)."""
    def __init__(self, input_size, num_layers=1, bias=-2.0):
        super(Highway, self).__init__()
        self.input_size = input_size
        self.num_layers = num_layers
        self.bias = bias
        
        self.highway_layers = nn.ModuleList([
            nn.ModuleDict({
                'transform': nn.Linear(input_size, input_size),
                'gate': nn.Linear(input_size, input_size)
            })
            for _ in range(num_layers)
        ])
        
        # Initialize gate bias to negative values to start with more carry behavior
        for i in range(num_layers):
            nn.init.constant_(self.highway_layers[i]['gate'].bias, self.bias)
    
    def forward(self, x):
        for i in range(self.num_layers):
            transform = F.relu(self.highway_layers[i]['transform'](x))
            gate = torch.sigmoid(self.highway_layers[i]['gate'](x))
            x = gate * transform + (1 - gate) * x
            
        return x

class Discriminator(nn.Module):
    """CNN-based discriminator for sequence classification."""
    
    def __init__(self, sequence_length, vocab_size, embedding_dim, filter_sizes, num_filters, dropout=0.5):
        super(Discriminator, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Convolutional layers with different filter sizes
        self.convs = nn.ModuleList([
            nn.Conv2d(1, n_filter, (filter_size, embedding_dim))
            for filter_size, n_filter in zip(filter_sizes, num_filters)
        ])
        
        # Highway network
        self.total_filters = sum(num_filters)
        self.highway = Highway(self.total_filters, num_layers=1)
        
        # Output layer
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(self.total_filters, 2)  # Binary classification
        
    def forward(self, x):
        """
        Args:
            x: Input sequences [batch_size, seq_length]
            
        Returns:
            logits: Output logits of shape [batch_size, 2]
        """
        batch_size = x.size(0)
        
        # Embedding Layer
        x = self.embedding(x)  # [batch_size, seq_length, embedding_dim]
        x = x.unsqueeze(1)  # [batch_size, 1, seq_length, embedding_dim]
        
        # Convolutional Layers
        conv_outputs = []
        for conv in self.convs:
            h = F.relu(conv(x))  # [batch_size, n_filter, seq_length - filter_size + 1, 1]
            h = F.max_pool2d(h, (h.size(2), 1))  # [batch_size, n_filter, 1, 1]
            h = h.squeeze(-1).squeeze(-1)  # [batch_size, n_filter]
            conv_outputs.append(h)
            
        # Concatenate
        concat_h = torch.cat(conv_outputs, dim=1)  # [batch_size, total_filters]
        
        # Highway layer
        highway_out = self.highway(concat_h)
        
        # Dropout and fully-connected
        dropped = self.dropout(highway_out)
        logits = self.fc(dropped)
        
        return logits