#### Packages

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim  # Make sure this is imported
from torch.distributions import Categorical
import numpy as np
import random
import os
import pickle
import matplotlib.pyplot as plt
import pandas as pd

#### DataLoader

In [18]:
class GenDataLoader:

    def __init__(self, batch_size):
        self.batch_size = batch_size
        self.token_stream = []
        self.num_batch = 0
        self.pointer = 0
        self.sequence_batch = []

    def create_batches(self, data_file):
        self.token_stream = []
        with open(data_file, 'r') as f:
            for line in f:
                line = line.strip()
                line = line.split()
                parse_line = [int(x) for x in line]
                if len(parse_line) == 20:  # Fixed sequence length of 20
                    self.token_stream.append(parse_line)

        self.num_batch = int(len(self.token_stream) / self.batch_size)
        self.token_stream = self.token_stream[:self.num_batch * self.batch_size]
        self.sequence_batch = np.split(np.array(self.token_stream), self.num_batch, 0)
        self.pointer = 0

    def next_batch(self):
        ret = self.sequence_batch[self.pointer]
        self.pointer = (self.pointer + 1) % self.num_batch
        return ret

    def reset_pointer(self):
        self.pointer = 0


class DisDataloader:

    def __init__(self, batch_size):
        self.batch_size = batch_size
        self.sentences = np.array([])
        self.labels = np.array([])
        self.num_batch = 0
        self.pointer = 0
        self.sentences_batches = []
        self.labels_batches = []

    def load_train_data(self, positive_file, negative_file):
        # Load data
        positive_examples = []
        negative_examples = []

        with open(positive_file) as fin:
            for line in fin:
                line = line.strip()
                line = line.split()
                parse_line = [int(x) for x in line]
                positive_examples.append(parse_line)

        with open(negative_file) as fin:
            for line in fin:
                line = line.strip()
                line = line.split()
                parse_line = [int(x) for x in line]
                if len(parse_line) == 20:
                    negative_examples.append(parse_line)

        self.sentences = np.array(positive_examples + negative_examples)

        # Generate labels
        positive_labels = [[0, 1] for _ in positive_examples]
        negative_labels = [[1, 0] for _ in negative_examples]
        self.labels = np.concatenate([positive_labels, negative_labels], 0)

        # Shuffle the data
        shuffle_indices = np.random.permutation(np.arange(len(self.labels)))
        self.sentences = self.sentences[shuffle_indices]
        self.labels = self.labels[shuffle_indices]

        # Split batches
        self.num_batch = int(len(self.labels) / self.batch_size)
        self.sentences = self.sentences[:self.num_batch * self.batch_size]
        self.labels = self.labels[:self.num_batch * self.batch_size]
        self.sentences_batches = np.split(self.sentences, self.num_batch, 0)
        self.labels_batches = np.split(self.labels, self.num_batch, 0)

        self.pointer = 0

    def next_batch(self):
        ret = self.sentences_batches[self.pointer], self.labels_batches[self.pointer]
        self.pointer = (self.pointer + 1) % self.num_batch
        return ret

    def reset_pointer(self):
        self.pointer = 0



#### Generator

In [19]:
class Generator(nn.Module):

    def __init__(self, num_emb, batch_size, emb_dim, hidden_dim, 
                 sequence_length, start_token, learning_rate=0.01, reward_gamma=0.95):
        super(Generator, self).__init__()
        
        self.num_emb = num_emb  # Vocabulary size
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = start_token
        self.learning_rate = learning_rate
        self.reward_gamma = reward_gamma
        self.temperature = 1.0
        
        # Initialize embeddings
        self.embeddings = nn.Embedding(num_emb, emb_dim)
        
        # LSTM cell
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        
        # Output layer: maps hidden state to vocabulary distribution
        self.output_layer = nn.Linear(hidden_dim, num_emb)
        
        # Optimizer
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def init_hidden(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size
            
        h = torch.zeros(1, batch_size, self.hidden_dim).to(next(self.parameters()).device)
        c = torch.zeros(1, batch_size, self.hidden_dim).to(next(self.parameters()).device)
        return (h, c)
    
    def forward(self, x, hidden=None):
        """
        Forward pass for the generator.
        
        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            hidden: Initial hidden state tuple (h, c)
            
        Returns:
            logits: Output logits [batch_size, seq_len, vocab_size]
            hidden: Final hidden state
        """
        if hidden is None:
            hidden = self.init_hidden()
            
        # Get embeddings for input tokens
        emb = self.embeddings(x)  # [batch_size, seq_len, emb_dim]
        
        # Process through LSTM
        output, hidden = self.lstm(emb, hidden)  # [batch_size, seq_len, hidden_dim]
        
        # Map to vocabulary space
        logits = self.output_layer(output)  # [batch_size, seq_len, num_emb]
        
        return logits, hidden
    
    def sample(self, num_samples=None, hidden=None):
        """
        Sample a batch of sequences from the generator.
        
        Args:
            num_samples: Number of samples to generate (defaults to batch_size)
            hidden: Initial hidden state
            
        Returns:
            generated_samples: Tensor of token indices [batch_size, seq_len]
        """
        if num_samples is None:
            num_samples = self.batch_size
            
        if hidden is None:
            hidden = self.init_hidden(num_samples)
            
        with torch.no_grad():
            device = next(self.parameters()).device
            # Start with start token for each sequence
            x = torch.full((num_samples,), self.start_token, dtype=torch.long).to(device)
            
            # Store generated tokens
            generated_samples = torch.zeros(num_samples, self.sequence_length, dtype=torch.long).to(device)
            
            # Generate tokens one at a time
            for i in range(self.sequence_length):
                # Get embeddings for current token
                emb = self.embeddings(x).unsqueeze(1)  # [batch_size, 1, emb_dim]
                
                # Process through LSTM
                output, hidden = self.lstm(emb, hidden)  # [batch_size, 1, hidden_dim]
                
                # Get logits
                logits = self.output_layer(output.squeeze(1))  # [batch_size, num_emb]
                
                # Sample from the distribution
                probs = F.softmax(logits / self.temperature, dim=-1)
                x = torch.multinomial(probs, 1).squeeze()
                
                # Store the generated token
                generated_samples[:, i] = x
                
            return generated_samples
    
    def pretrain_step(self, x):
        """
        Perform one step of maximum likelihood pretraining.
        
        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            
        Returns:
            loss: The cross-entropy loss
        """
        self.train()
        self.optimizer.zero_grad()
        
        # Prepare inputs and targets
        # Input: all tokens except the last one
        inputs = x[:, :-1]
        # Target: all tokens except the first one (which is <start>)
        targets = x[:, 1:]
        
        # Forward pass
        logits, _ = self.forward(inputs)
        
        # Compute loss
        loss = F.cross_entropy(logits.reshape(-1, self.num_emb), targets.reshape(-1))
        
        # Backward pass and optimize
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def adversarial_loss(self, x, rewards):
        """
        Compute the policy gradient loss for adversarial training.
        
        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            rewards: Rewards for each token [batch_size, seq_len]
            
        Returns:
            loss: The policy gradient loss
        """
        self.train()
        self.optimizer.zero_grad()
        
        # Prepare inputs and targets
        # Input: all tokens except the last one
        inputs = x[:, :-1]
        # Target: all tokens except the first one (which is <start>)
        targets = x[:, 1:]
        # Rewards: align with targets
        rewards = rewards[:, 1:]
        
        # Forward pass
        logits, _ = self.forward(inputs)
        
        # Compute log probabilities
        log_probs = F.log_softmax(logits, dim=-1)
        
        # One-hot encode targets
        target_one_hot = F.one_hot(targets, num_classes=self.num_emb).float()
        
        # Compute token-level rewards
        token_rewards = torch.sum(log_probs * target_one_hot, dim=-1) * rewards
        
        # Policy gradient loss (negative expected reward)
        loss = -torch.mean(token_rewards)
        
        # Backward pass and optimize
        loss.backward()
        self.optimizer.step()
        
        return loss.item()




#### Discriminator

In [28]:
class Highway(nn.Module):

    def __init__(self, size, num_layers=1, bias=-2.0):
        super(Highway, self).__init__()
        self.num_layers = num_layers
        self.bias = bias
        
        self.highways = nn.ModuleList([
            nn.ModuleDict({
                'transform': nn.Linear(size, size),
                'gate': nn.Linear(size, size)
            })
            for _ in range(num_layers)
        ])
    
    def forward(self, x):
        for i in range(self.num_layers):
            transform = F.relu(self.highways[i]['transform'](x))
            gate = torch.sigmoid(self.highways[i]['gate'](x) + self.bias)
            x = gate * transform + (1.0 - gate) * x
        return x


class Discriminator(nn.Module):

    def __init__(self, sequence_length, num_classes, vocab_size, 
                 embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0, dropout_keep_prob=0.75):
        super(Discriminator, self).__init__()
        
        self.sequence_length = sequence_length
        self.num_classes = num_classes
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.filter_sizes = filter_sizes
        self.num_filters = num_filters
        self.l2_reg_lambda = l2_reg_lambda
        self.dropout_keep_prob = dropout_keep_prob
        
        # Initialize embeddings
        self.embeddings = nn.Embedding(vocab_size, embedding_size)
        
        # Create multiple convolutional layers with different filter sizes
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filter, [filter_size, embedding_size], padding=(0, 0))
            for filter_size, num_filter in zip(filter_sizes, num_filters)
        ])
        
        # Highway layer
        self.highway = Highway(sum(num_filters), num_layers=1, bias=0)
        
        # Dropout layer
        self.dropout = nn.Dropout(1.0 - dropout_keep_prob)
        
        # Final output layer
        self.fc = nn.Linear(sum(num_filters), num_classes)
        
        # Optimizer
        self.optimizer = optim.Adam(self.parameters(), lr=1e-4, weight_decay=l2_reg_lambda)
    
    def forward(self, x):
        """
        Forward pass for the discriminator.
        
        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            
        Returns:
            logits: Output logits [batch_size, num_classes]
            probs: Output probabilities [batch_size, num_classes]
        """
        # Get embeddings for input tokens
        emb = self.embeddings(x)  # [batch_size, seq_len, emb_dim]
        
        # Add channel dimension for CNN
        emb = emb.unsqueeze(1)  # [batch_size, 1, seq_len, emb_dim]
        
        # Apply convolutions and max-pooling
        conv_outputs = []
        for i, conv in enumerate(self.convs):
            # Convolution
            h = F.relu(conv(emb))  # [batch_size, num_filters[i], seq_len-filter_sizes[i]+1, 1]
            
            # Max-pooling over time
            pooled = F.max_pool2d(h, (h.size(2), 1))  # [batch_size, num_filters[i], 1, 1]
            
            # Flatten
            pooled = pooled.squeeze(3).squeeze(2)  # [batch_size, num_filters[i]]
            
            conv_outputs.append(pooled)
        
        # Concatenate all conv outputs
        h_pool_flat = torch.cat(conv_outputs, dim=1)  # [batch_size, sum(num_filters)]
        
        # Apply highway network
        h_highway = self.highway(h_pool_flat)
        
        # Apply dropout
        h_drop = self.dropout(h_highway)
        
        # Final fully connected layer
        logits = self.fc(h_drop)
        probs = F.softmax(logits, dim=1)
        
        return logits, probs
    
    def train_step(self, x, y):
        """
        Perform one training step.
        
        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            y: Target tensor of labels [batch_size, num_classes]
            
        Returns:
            loss: The cross-entropy loss
            probs: The output probabilities
        """
        self.train()
        self.optimizer.zero_grad()
        
        # Forward pass
        logits, probs = self.forward(x)
        
        # Compute loss
        loss = F.cross_entropy(logits, torch.argmax(y, dim=1))
        
        # Add L2 regularization
        if self.l2_reg_lambda > 0:
            l2_loss = 0
            for param in self.parameters():
                l2_loss += torch.norm(param, 2)
            loss += self.l2_reg_lambda * l2_loss
        
        # Backward pass and optimize
        loss.backward()
        self.optimizer.step()
        
        return loss.item(), probs.detach()  # Return detached probability tensor



#### Rollout

In [29]:
class Rollout(nn.Module):
    """
    Rollout module for Monte Carlo search in SeqGAN.
    This creates a copy of the generator and performs rollouts to estimate rewards.
    """
    def __init__(self, generator, update_rate):
        super(Rollout, self).__init__()
        
        self.generator = generator
        self.update_rate = update_rate
        
        # Copy parameters from the generator
        self.own_generator = Generator(
            generator.num_emb,
            generator.batch_size,
            generator.emb_dim,
            generator.hidden_dim,
            generator.sequence_length,
            generator.start_token,
            generator.learning_rate,
            generator.reward_gamma
        )
        
        # Copy all parameters from original generator
        self.update_params()
    
    def update_params(self):
        """
        Update the parameters of the rollout generator using the original generator.
        In the original TF implementation, this was done with an update rate.
        In PyTorch, we simply copy all parameters.
        """
        for target_param, source_param in zip(self.own_generator.parameters(), self.generator.parameters()):
            target_param.data.copy_(
                self.update_rate * target_param.data + (1.0 - self.update_rate) * source_param.data
            )
    
    def get_reward(self, x, rollout_num, discriminator):
        """
        Get reward for each token in the sequences using rollouts with the discriminator.
        
        Args:
            x: Input sequences [batch_size, seq_len]
            rollout_num: Number of rollouts to perform
            discriminator: The discriminator model
            
        Returns:
            rewards: Reward for each token in each sequence [batch_size, seq_len]
        """
        batch_size = x.size(0)
        seq_len = self.generator.sequence_length
        rewards = torch.zeros(batch_size, seq_len).to(x.device)
        
        with torch.no_grad():
            discriminator.eval()
            self.own_generator.eval()
            
            # Evaluate final reward (complete sequences) using discriminator
            logits, probs = discriminator(x)
            rewards[:, -1] = probs[:, 1]  # Probability of being real for completed sequence
            
            # For each position in the sequence
            for given_num in range(1, seq_len):
                # Store intermediate rollout rewards
                position_rewards = torch.zeros(batch_size).to(x.device)
                
                # Perform multiple rollouts
                for _ in range(rollout_num):
                    # Continue sequence generation from this position
                    rollout_samples = self._rollout_from_position(x, given_num)
                    
                    # Get reward from discriminator for the rolled-out sequences
                    logits, probs = discriminator(rollout_samples)
                    position_rewards += probs[:, 1]  # Accumulate probability of being real
                
                # Average rewards across rollouts
                rewards[:, given_num-1] = position_rewards / rollout_num
            
            # Apply reward discount
            if self.generator.reward_gamma < 1.0:
                for i in range(seq_len - 1):
                    rewards[:, i] = rewards[:, i] * self.generator.reward_gamma
            
        return rewards
    
    def _rollout_from_position(self, x, given_num):
        """
        Generate rollout sequences starting from a given position.
        
        Args:
            x: Input sequences [batch_size, seq_len]
            given_num: Position up to which to use original sequence
            
        Returns:
            rollout_samples: Rolled-out sequences [batch_size, seq_len]
        """
        batch_size = x.size(0)
        
        # Create a copy of the input sequence
        rollout_samples = x.clone()
        
        # Initialize hidden state by running the prefix through the LSTM
        hidden = self.own_generator.init_hidden(batch_size)
        
        # First run the prefix through LSTM to get the hidden state at given_num
        prefix = x[:, :given_num]
        prefix_emb = self.own_generator.embeddings(prefix)
        _, hidden = self.own_generator.lstm(prefix_emb, hidden)
        
        # Then generate tokens from given_num to the end
        current = x[:, given_num-1]
        
        for i in range(given_num, self.generator.sequence_length):
            # Get embeddings for current token
            emb = self.own_generator.embeddings(current).unsqueeze(1)  # [batch_size, 1, emb_dim]
            
            # Process through LSTM
            output, hidden = self.own_generator.lstm(emb, hidden)  # [batch_size, 1, hidden_dim]
            
            # Get logits
            logits = self.own_generator.output_layer(output.squeeze(1))  # [batch_size, num_emb]
            
            # Sample from the distribution
            probs = F.softmax(logits / self.own_generator.temperature, dim=-1)
            next_token = torch.multinomial(probs, 1).squeeze()
            
            # Store the generated token
            rollout_samples[:, i] = next_token
            current = next_token
        
        return rollout_samples


#### Target LSTM

In [30]:
class TargetLSTM(nn.Module):
    """
    Oracle LSTM model used for synthetic data experiments.
    This model should be pre-trained and fixed during SeqGAN training.
    """
    def __init__(self, num_emb, batch_size, emb_dim, hidden_dim, sequence_length, start_token):
        super(TargetLSTM, self).__init__()
        
        self.num_emb = num_emb
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = start_token
        self.temperature = 1.0
        
        # Initialize embeddings
        self.embeddings = nn.Embedding(num_emb, emb_dim)
        
        # LSTM cell
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        
        # Output layer: maps hidden state to vocabulary distribution
        self.output_layer = nn.Linear(hidden_dim, num_emb)
    
    def init_hidden(self, batch_size=None):
        """Initialize hidden state and cell state for LSTM"""
        if batch_size is None:
            batch_size = self.batch_size
            
        h = torch.zeros(1, batch_size, self.hidden_dim).to(next(self.parameters()).device)
        c = torch.zeros(1, batch_size, self.hidden_dim).to(next(self.parameters()).device)
        return (h, c)
    
    def forward(self, x, hidden=None):
        """
        Forward pass for the target LSTM.
        
        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            hidden: Initial hidden state tuple (h, c)
            
        Returns:
            logits: Output logits [batch_size, seq_len, vocab_size]
            hidden: Final hidden state
        """
        if hidden is None:
            hidden = self.init_hidden(x.size(0))
            
        # Get embeddings for input tokens
        emb = self.embeddings(x)  # [batch_size, seq_len, emb_dim]
        
        # Process through LSTM
        output, hidden = self.lstm(emb, hidden)  # [batch_size, seq_len, hidden_dim]
        
        # Map to vocabulary space
        logits = self.output_layer(output)  # [batch_size, seq_len, num_emb]
        
        return logits, hidden
    
    def sample(self, num_samples=None, hidden=None):
        """
        Sample a batch of sequences from the target LSTM.
        
        Args:
            num_samples: Number of samples to generate (defaults to batch_size)
            hidden: Initial hidden state
            
        Returns:
            generated_samples: Tensor of token indices [batch_size, seq_len]
        """
        if num_samples is None:
            num_samples = self.batch_size
            
        if hidden is None:
            hidden = self.init_hidden(num_samples)
            
        with torch.no_grad():
            device = next(self.parameters()).device
            # Start with start token for each sequence
            x = torch.full((num_samples,), self.start_token, dtype=torch.long).to(device)
            
            # Store generated tokens
            generated_samples = torch.zeros(num_samples, self.sequence_length, dtype=torch.long).to(device)
            
            # Generate tokens one at a time
            for i in range(self.sequence_length):
                # Get embeddings for current token
                emb = self.embeddings(x).unsqueeze(1)  # [batch_size, 1, emb_dim]
                
                # Process through LSTM
                output, hidden = self.lstm(emb, hidden)  # [batch_size, 1, hidden_dim]
                
                # Get logits
                logits = self.output_layer(output.squeeze(1))  # [batch_size, num_emb]
                
                # Sample from the distribution
                probs = F.softmax(logits / self.temperature, dim=-1)
                x = torch.multinomial(probs, 1).squeeze()
                
                # Store the generated token
                generated_samples[:, i] = x
                
            return generated_samples
    
    def pretrain_loss(self, x):
        """
        Calculate the negative log-likelihood loss for the target LSTM.
        
        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            
        Returns:
            loss: The cross-entropy loss
        """
        # Prepare inputs and targets
        # Input: all tokens except the last one
        inputs = x[:, :-1]
        # Target: all tokens except the first one (which is <start>)
        targets = x[:, 1:]
        
        # Forward pass
        logits, _ = self.forward(inputs)
        
        # Compute loss
        loss = F.cross_entropy(logits.reshape(-1, self.num_emb), targets.reshape(-1))
        
        return loss.item()

    def init_model(self, params=None):
        """
        Initialize the model with specific parameters to create a fixed oracle.
        
        Args:
            params: Dictionary of pre-defined weights or list of tensors
        """
        if params is None:
            return
            
        if isinstance(params, dict):
            # If params is a dictionary with the expected structure
            if 'emb' in params:
                self.embeddings.weight.data = params['emb']
                
            if 'lstm' in params:
                if isinstance(params['lstm'], dict):
                    # Modern format with named parameters
                    for layer, param_dict in params['lstm'].items():
                        if isinstance(param_dict, dict):
                            for param_name, param_tensor in param_dict.items():
                                # Set the appropriate LSTM parameters
                                getattr(self.lstm, param_name + '_' + layer).data = param_tensor
                        else:
                            # Direct tensor assignment
                            setattr(self.lstm, layer, param_dict)
                else:
                    # Old format or different structure
                    print("Unrecognized LSTM parameter format.")
                    
            if 'out_w' in params:
                self.output_layer.weight.data = params['out_w']
                
            if 'out_b' in params:
                self.output_layer.bias.data = params['out_b']
        elif isinstance(params, list):
            # If params is a list of tensors (original TF implementation format)
            try:
                # Embeddings
                self.embeddings.weight.data = params[0]
                
                # LSTM parameters - map to PyTorch's format
                # This is a simplified mapping and may need adjustments
                ih_weights = torch.cat([params[1], params[4], params[7], params[10]], dim=0)
                hh_weights = torch.cat([params[2], params[5], params[8], params[11]], dim=0)
                ih_bias = torch.cat([params[3], params[6], params[9], params[12]], dim=0)
                
                self.lstm.weight_ih_l0.data = ih_weights
                self.lstm.weight_hh_l0.data = hh_weights
                self.lstm.bias_ih_l0.data = ih_bias
                self.lstm.bias_hh_l0.data = torch.zeros_like(ih_bias)
                
                # Output layer
                self.output_layer.weight.data = params[13]
                self.output_layer.bias.data = params[14]
                
                print("Initialized target LSTM from tensor list format.")
            except (IndexError, ValueError) as e:
                print(f"Error initializing from tensor list: {e}")
        else:
            print(f"Unsupported parameter format: {type(params)}")


    

#### SeqGAN


In [31]:
# Helper functions for training and evaluation

def generate_samples(model, batch_size, generated_num, output_file, device):
    """
    Generate samples using the model and save to file.
    
    Args:
        model: Generator or TargetLSTM model
        batch_size: Batch size for generation
        generated_num: Total number of samples to generate
        output_file: File to save the generated samples
        device: Device to run the model on
    """
    model.eval()
    generated_samples = []
    
    with torch.no_grad():
        # Generate batches of samples
        for _ in range(int(generated_num / batch_size)):
            samples = model.sample(batch_size)
            # Convert to numpy and append to list
            generated_samples.extend(samples.cpu().numpy().tolist())
    
    # Write samples to file
    with open(output_file, 'w') as fout:
        for sample in generated_samples:
            buffer = ' '.join([str(x) for x in sample]) + '\n'
            fout.write(buffer)


def target_loss(target_lstm, data_loader, device):
    """
    Calculate the negative log-likelihood of data according to target_lstm.
    
    Args:
        target_lstm: The oracle LSTM model
        data_loader: Data loader containing batches to evaluate
        device: Device to run the model on
        
    Returns:
        avg_loss: Average negative log-likelihood
    """
    target_lstm.eval()
    total_loss = 0.0
    data_loader.reset_pointer()
    
    with torch.no_grad():
        for _ in range(data_loader.num_batch):
            batch = data_loader.next_batch()
            batch_tensor = torch.LongTensor(batch).to(device)
            loss = target_lstm.pretrain_loss(batch_tensor)
            total_loss += loss
    
    avg_loss = total_loss / data_loader.num_batch
    return avg_loss


def pre_train_epoch(model, data_loader, device):
    """
    Pre-train the generator for one epoch using maximum likelihood.
    
    Args:
        model: The generator model
        data_loader: Data loader containing real data batches
        device: Device to run the model on
        
    Returns:
        avg_loss: Average loss for the epoch
    """
    model.train()
    total_loss = 0.0
    data_loader.reset_pointer()
    
    for _ in range(data_loader.num_batch):
        batch = data_loader.next_batch()
        batch_tensor = torch.LongTensor(batch).to(device)
        loss = model.pretrain_step(batch_tensor)
        total_loss += loss
    
    avg_loss = total_loss / data_loader.num_batch
    return avg_loss


def evaluate_discriminator(discriminator, data_loader, device, dropout_keep_prob=0.75):
    """
    Evaluate the discriminator on a dataset without updating weights.
    
    Args:
        discriminator: The discriminator model
        data_loader: Data loader containing evaluation data
        device: Device to run the model on
        dropout_keep_prob: Dropout keep probability
        
    Returns:
        metrics: Dictionary of evaluation metrics
    """
    discriminator.eval()
    data_loader.reset_pointer()
    
    d_loss_sum = 0.0
    real_probs_sum = 0.0
    fake_probs_sum = 0.0
    real_count = 0
    fake_count = 0
    batch_count = 0
    
    with torch.no_grad():
        for _ in range(data_loader.num_batch):
            x_batch, y_batch = data_loader.next_batch()
            x_tensor = torch.LongTensor(x_batch).to(device)
            y_tensor = torch.FloatTensor(y_batch).to(device)
            
            # Forward pass
            logits, probs = discriminator(x_tensor)
            loss = F.cross_entropy(logits, torch.argmax(y_tensor, dim=1))
            
            # Process probabilities
            real_probs = probs[:, 1].cpu().numpy()  # Already detached in torch.no_grad()
            real_indices = np.where(y_batch[:, 1] == 1)[0]  # Indices of real samples
            fake_indices = np.where(y_batch[:, 0] == 1)[0]  # Indices of fake samples
            
            if len(real_indices) > 0:
                real_probs_sum += np.sum(real_probs[real_indices])
                real_count += len(real_indices)
            
            if len(fake_indices) > 0:
                fake_probs_sum += np.sum(real_probs[fake_indices])
                fake_count += len(fake_indices)
            
            d_loss_sum += loss.item()
            batch_count += 1
    
    # Calculate overall metrics
    avg_d_loss = d_loss_sum / batch_count if batch_count > 0 else 0
    avg_real_prob = real_probs_sum / real_count if real_count > 0 else 0
    avg_fake_prob = fake_probs_sum / fake_count if fake_count > 0 else 0
    
    real_accuracy = avg_real_prob  # Average probability assigned to real samples
    fake_accuracy = 1 - avg_fake_prob  # 1 - probability assigned to fake samples
    overall_accuracy = (real_accuracy + fake_accuracy) / 2
    
    return {
        'avg_d_loss': avg_d_loss,
        'avg_real_prob': avg_real_prob,
        'avg_fake_prob': avg_fake_prob,
        'real_accuracy': real_accuracy,
        'fake_accuracy': fake_accuracy,
        'overall_accuracy': overall_accuracy
    }



In [32]:
def main():
    # Configuration parameters
    EMB_DIM = 32  # Embedding dimension
    HIDDEN_DIM = 32  # Hidden state dimension of LSTM cell
    SEQ_LENGTH = 20  # Sequence length
    START_TOKEN = 0
    PRE_EPOCH_NUM = 20  # Supervised pre-training epochs
    SEED = 88
    BATCH_SIZE = 64
    
    # Discriminator parameters
    dis_embedding_dim = 64
    dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
    dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
    dis_dropout_keep_prob = 0.75
    dis_l2_reg_lambda = 0.2
    
    # Training parameters
    TOTAL_BATCH = 200
    generated_num = 10000
    
    # File paths
    os.makedirs('save', exist_ok=True)
    positive_file = 'save/real_data.txt'
    negative_file = 'save/generator_sample.txt'
    eval_file = 'save/eval_file.txt'
    
    # Set random seeds
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    
    # Check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize data loaders
    gen_data_loader = GenDataLoader(BATCH_SIZE)
    likelihood_data_loader = GenDataLoader(BATCH_SIZE)  # For testing
    dis_data_loader = DisDataloader(BATCH_SIZE)
    
    # Vocabulary size
    vocab_size = 5000
    
    # Initialize models
    generator = Generator(
        num_emb=vocab_size,
        batch_size=BATCH_SIZE,
        emb_dim=EMB_DIM,
        hidden_dim=HIDDEN_DIM,
        sequence_length=SEQ_LENGTH,
        start_token=START_TOKEN
    ).to(device)
    
    # Create or load the target LSTM (oracle)
    try:
        print("Attempting to load target parameters...")
        with open('save/target_params.pkl', 'rb') as f:
            try:
                target_params = pickle.load(f)
                print(f"Loaded parameters of type: {type(target_params)}")
                
                # Initialize a fresh target LSTM
                target_lstm = TargetLSTM(
                    num_emb=vocab_size,
                    batch_size=BATCH_SIZE,
                    emb_dim=EMB_DIM,
                    hidden_dim=HIDDEN_DIM,
                    sequence_length=SEQ_LENGTH,
                    start_token=START_TOKEN
                ).to(device)
                
                # Try to initialize with loaded parameters
                try:
                    target_lstm.init_model(target_params)
                    print("Initialized target LSTM with loaded parameters.")
                except Exception as e:
                    print(f"Failed to initialize with loaded parameters: {e}")
                    print("Creating a fresh target LSTM...")
                    target_lstm = TargetLSTM(
                        num_emb=vocab_size,
                        batch_size=BATCH_SIZE,
                        emb_dim=EMB_DIM,
                        hidden_dim=HIDDEN_DIM,
                        sequence_length=SEQ_LENGTH,
                        start_token=START_TOKEN
                    ).to(device)
                    
                    # Save new parameters
                    new_target_params = {
                        'emb': target_lstm.embeddings.weight.data,
                        'lstm': {
                            'weight_ih_l0': target_lstm.lstm.weight_ih_l0.data,
                            'weight_hh_l0': target_lstm.lstm.weight_hh_l0.data,
                            'bias_ih_l0': target_lstm.lstm.bias_ih_l0.data,
                            'bias_hh_l0': target_lstm.lstm.bias_hh_l0.data
                        },
                        'out_w': target_lstm.output_layer.weight.data,
                        'out_b': target_lstm.output_layer.bias.data
                    }
                    
                    with open('save/target_params.pkl', 'wb') as f:
                        pickle.dump(new_target_params, f)
                    print("Created and saved new target parameters.")
                    
            except (UnicodeDecodeError, pickle.UnpicklingError) as e:
                print(f"Error unpickling parameters: {e}")
                print("Creating a fresh target LSTM...")
                target_lstm = TargetLSTM(
                    num_emb=vocab_size,
                    batch_size=BATCH_SIZE,
                    emb_dim=EMB_DIM,
                    hidden_dim=HIDDEN_DIM,
                    sequence_length=SEQ_LENGTH,
                    start_token=START_TOKEN
                ).to(device)
                
                # Save new parameters
                new_target_params = {
                    'emb': target_lstm.embeddings.weight.data,
                    'lstm': {
                        'weight_ih_l0': target_lstm.lstm.weight_ih_l0.data,
                        'weight_hh_l0': target_lstm.lstm.weight_hh_l0.data,
                        'bias_ih_l0': target_lstm.lstm.bias_ih_l0.data,
                        'bias_hh_l0': target_lstm.lstm.bias_hh_l0.data
                    },
                    'out_w': target_lstm.output_layer.weight.data,
                    'out_b': target_lstm.output_layer.bias.data
                }
                
                with open('save/target_params.pkl', 'wb') as f:
                    pickle.dump(new_target_params, f)
                print("Created and saved new target parameters.")
    except FileNotFoundError:
        print("Target parameters file not found. Creating a new target LSTM...")
        target_lstm = TargetLSTM(
            num_emb=vocab_size,
            batch_size=BATCH_SIZE,
            emb_dim=EMB_DIM,
            hidden_dim=HIDDEN_DIM,
            sequence_length=SEQ_LENGTH,
            start_token=START_TOKEN
        ).to(device)
        
        # Save the parameters for future use
        new_target_params = {
            'emb': target_lstm.embeddings.weight.data,
            'lstm': {
                'weight_ih_l0': target_lstm.lstm.weight_ih_l0.data,
                'weight_hh_l0': target_lstm.lstm.weight_hh_l0.data,
                'bias_ih_l0': target_lstm.lstm.bias_ih_l0.data,
                'bias_hh_l0': target_lstm.lstm.bias_hh_l0.data
            },
            'out_w': target_lstm.output_layer.weight.data,
            'out_b': target_lstm.output_layer.bias.data
        }
        
        os.makedirs('save', exist_ok=True)
        with open('save/target_params.pkl', 'wb') as f:
            pickle.dump(new_target_params, f)
        print("Created and saved new target parameters.")

        # Initialize discriminator
        try:
            discriminator = Discriminator(
                sequence_length=SEQ_LENGTH,
                num_classes=2,
                vocab_size=vocab_size,
                embedding_size=dis_embedding_dim,
                filter_sizes=dis_filter_sizes,
                num_filters=dis_num_filters,
                l2_reg_lambda=dis_l2_reg_lambda,
                dropout_keep_prob=dis_dropout_keep_prob
            ).to(device)
            print("Initialized discriminator successfully.")
        except Exception as e:
            print(f"Error initializing discriminator: {e}")
            raise  # Re-raise the exception to see the full error
        
        # Open log files
        log = open('save/experiment-log.txt', 'w')
        metrics_log = open('save/training-metrics.txt', 'w')
        metrics_log.write('phase\tepoch\tg_loss\tpre_d_loss\tpre_real_prob\tpre_fake_prob\tpre_accuracy\tpost_d_loss\tpost_real_prob\tpost_fake_prob\tpost_accuracy\n')
        
        # First, use the oracle model to generate real data
        generate_samples(target_lstm, BATCH_SIZE, generated_num, positive_file, device)
        gen_data_loader.create_batches(positive_file)
        
        # Pre-train the generator using MLE
        print('Start pre-training generator...')
        log.write('pre-training generator...\n')
        for epoch in range(PRE_EPOCH_NUM):
            loss = pre_train_epoch(generator, gen_data_loader, device)
            
            if epoch % 5 == 0 or epoch == PRE_EPOCH_NUM - 1:
                generate_samples(generator, BATCH_SIZE, generated_num, eval_file, device)
                likelihood_data_loader.create_batches(eval_file)
                test_loss = target_loss(target_lstm, likelihood_data_loader, device)
                print(f'Pre-train epoch {epoch}, loss: {loss:.4f}, test_loss: {test_loss:.4f}')
                buffer = f'epoch:\t{epoch}\tnll:\t{test_loss}\n'
                log.write(buffer)
                log.flush()
    
    # Pre-train the discriminator
    print('Start pre-training discriminator...')
    for d_pre_epoch in range(50):
        generate_samples(generator, BATCH_SIZE, generated_num, negative_file, device)
        dis_data_loader.load_train_data(positive_file, negative_file)
        
        for inner_epoch in range(3):
            dis_data_loader.reset_pointer()
            
            # Metrics tracking
            d_loss_sum = 0.0
            real_probs_sum = 0.0
            fake_probs_sum = 0.0
            real_count = 0
            fake_count = 0
            batch_count = 0
            
            for _ in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                x_tensor = torch.LongTensor(x_batch).to(device)
                y_tensor = torch.FloatTensor(y_batch).to(device)
                
                # Train discriminator
                discriminator.train()
                loss, probs = discriminator.train_step(x_tensor, y_tensor)
                
                # Extract probabilities of being real
                real_probs = probs[:, 1].detach().cpu().numpy()
                real_indices = np.where(y_batch[:, 1] == 1)[0]
                fake_indices = np.where(y_batch[:, 0] == 1)[0]
                
                if len(real_indices) > 0:
                    real_probs_sum += np.sum(real_probs[real_indices])
                    real_count += len(real_indices)
                
                if len(fake_indices) > 0:
                    fake_probs_sum += np.sum(real_probs[fake_indices])
                    fake_count += len(fake_indices)
                
                d_loss_sum += loss
                batch_count += 1
            
            # Calculate metrics
            if batch_count > 0:
                avg_d_loss = d_loss_sum / batch_count
                avg_real_prob = real_probs_sum / real_count if real_count > 0 else 0
                avg_fake_prob = fake_probs_sum / fake_count if fake_count > 0 else 0
                real_accuracy = avg_real_prob
                fake_accuracy = 1 - avg_fake_prob
                overall_accuracy = (real_accuracy + fake_accuracy) / 2
                
                # Log metrics
                metrics_log.write(f'pre-train-d-epoch\t{d_pre_epoch}\t{inner_epoch}\t-\t{avg_d_loss:.4f}\t{avg_real_prob:.4f}\t{avg_fake_prob:.4f}\t{overall_accuracy:.4f}\n')
                metrics_log.flush()
                
                # Print progress less frequently
                if d_pre_epoch % 10 == 0 and inner_epoch == 2:
                    print(f'Pre-train D epoch {d_pre_epoch}, inner epoch {inner_epoch}, loss: {avg_d_loss:.4f}')
                    print(f'  Avg prob for real samples: {avg_real_prob:.4f}, fake samples: {avg_fake_prob:.4f}')
                    print(f'  Real accuracy: {real_accuracy:.4f}, Fake accuracy: {fake_accuracy:.4f}, Overall: {overall_accuracy:.4f}')
    
    # Create rollout module
    rollout = Rollout(generator, 0.8).to(device)
    
    print('#' * 80)
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        g_loss_sum = 0.0
        
        # Generate samples and get rewards
        samples = generator.sample(BATCH_SIZE)
        rewards = rollout.get_reward(samples, 16, discriminator)
        
        # Adversarial training step
        generator.train()
        g_loss = generator.adversarial_loss(samples, rewards)
        g_loss_sum += g_loss
        
        avg_g_loss = g_loss_sum  # Since we only did one iteration
        
        # MEASURE POINT 1: Evaluate discriminator on new generator output before updating discriminator
        generate_samples(generator, BATCH_SIZE, generated_num, negative_file, device)
        dis_eval_loader = DisDataloader(BATCH_SIZE)
        dis_eval_loader.load_train_data(positive_file, negative_file)
        
        pre_update_metrics = evaluate_discriminator(discriminator, dis_eval_loader, device, dis_dropout_keep_prob)
        
        # Test generator against the oracle model
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(generator, BATCH_SIZE, generated_num, eval_file, device)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(target_lstm, likelihood_data_loader, device)
            buffer = f'epoch:\t{total_batch}\tnll:\t{test_loss}\n'
            print(f'Adversarial training epoch {total_batch}, test_loss: {test_loss:.4f}')
            log.write(buffer)
            log.flush()
        
        # Update rollout parameters
        rollout.update_params()
        
        # Train the discriminator
        d_loss_sum = 0.0
        
        for _ in range(5):  # Train discriminator for 5 rounds
            generate_samples(generator, BATCH_SIZE, generated_num, negative_file, device)
            dis_data_loader.load_train_data(positive_file, negative_file)
            
            for inner_epoch in range(3):  # Train 3 epochs each round
                dis_data_loader.reset_pointer()
                
                for _ in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    x_tensor = torch.LongTensor(x_batch).to(device)
                    y_tensor = torch.FloatTensor(y_batch).to(device)
                    
                    # Train discriminator
                    discriminator.train()
                    loss, _ = discriminator.train_step(x_tensor, y_tensor)
                    d_loss_sum += loss
        
        # MEASURE POINT 2: Evaluate discriminator after it has been updated
        dis_eval_loader.reset_pointer()
        post_update_metrics = evaluate_discriminator(discriminator, dis_eval_loader, device, dis_dropout_keep_prob)
        
        # Log metrics
        metrics_log.write(f'adv-train-epoch\t{total_batch}\t{avg_g_loss:.4f}\t{pre_update_metrics["avg_d_loss"]:.4f}\t'
                          f'{pre_update_metrics["avg_real_prob"]:.4f}\t{pre_update_metrics["avg_fake_prob"]:.4f}\t'
                          f'{pre_update_metrics["overall_accuracy"]:.4f}\t{post_update_metrics["avg_d_loss"]:.4f}\t'
                          f'{post_update_metrics["avg_real_prob"]:.4f}\t{post_update_metrics["avg_fake_prob"]:.4f}\t'
                          f'{post_update_metrics["overall_accuracy"]:.4f}\n')
        metrics_log.flush()
    
    # Close log files
    log.close()
    metrics_log.close()
    
    print("Training finished!")
    
    # Save the trained models
    torch.save(generator.state_dict(), 'save/generator.pt')
    torch.save(discriminator.state_dict(), 'save/discriminator.pt')
    
    # Visualize training progress
    plot_training_metrics('save/training-metrics.txt')

def plot_training_metrics(metrics_file):
    """
    Plot training metrics from the log file.
    
    Args:
        metrics_file: Path to the metrics log file
    """
    # Read metrics
    df = pd.read_csv(metrics_file, sep='\t')
    
    # Plot discriminator metrics during pre-training
    pre_train_df = df[df['phase'] == 'pre-train-d-epoch']
    
    plt.figure(figsize=(12, 8))
    
    plt.subplot(2, 2, 1)
    plt.plot(pre_train_df['epoch'], pre_train_df['pre_d_loss'])
    plt.title('Pre-training Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(2, 2, 2)
    plt.plot(pre_train_df['epoch'], pre_train_df['pre_real_prob'], label='Real')
    plt.plot(pre_train_df['epoch'], pre_train_df['pre_fake_prob'], label='Fake')
    plt.title('Pre-training Discriminator Probabilities')
    plt.xlabel('Epoch')
    plt.ylabel('Probability')
    plt.legend()
    
    # Plot generator and discriminator metrics during adversarial training
    adv_train_df = df[df['phase'] == 'adv-train-epoch']
    
    plt.subplot(2, 2, 3)
    plt.plot(adv_train_df['epoch'], adv_train_df['g_loss'])
    plt.title('Adversarial Training Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(2, 2, 4)
    plt.plot(adv_train_df['epoch'], adv_train_df['pre_accuracy'], label='Pre-update')
    plt.plot(adv_train_df['epoch'], adv_train_df['post_accuracy'], label='Post-update')
    plt.title('Discriminator Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('save/training_metrics.png')
    plt.close()

if __name__ == '__main__':
    main()

Using device: cpu
Attempting to load target parameters...
Loaded parameters of type: <class 'dict'>
Failed to initialize with loaded parameters: cannot assign 'torch.FloatTensor' as parameter 'weight_ih_l0' (torch.nn.Parameter or None expected)
Creating a fresh target LSTM...
Created and saved new target parameters.
Start pre-training discriminator...


UnboundLocalError: cannot access local variable 'discriminator' where it is not associated with a value