#### Packages

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class OTHER(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, sequence_length, start_token, batch_size, device='cpu'):
        
        super(OTHER, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = start_token
        self.batch_size = batch_size
        self.device = device
        
        # Define layers
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.output_layer = nn.Linear(hidden_dim, vocab_size)
        
        # Initialize on device
        self.to(device)
        
    def forward(self, x, hidden=None):

        emb = self.embeddings(x)                    # [batch_size, sequence_length, embedding_dim]
        lstm_out, hidden = self.lstm(emb, hidden)   # lstm_out: [batch_size, sequence_length, hidden_dim]
        logits = self.output_layer(lstm_out)        # [batch_size, sequence_length, vocab_size]
        
        return logits, hidden
    
    def generate(self, num_samples):

        with torch.no_grad():
            
            # Start token for all sequences
            x = torch.full((num_samples, 1), self.start_token, dtype=torch.long, device=self.device)
            hidden = None  # Let PyTorch initialize the hidden state

            generated_sequences = torch.zeros(num_samples, self.sequence_length, dtype=torch.long, device=self.device)

            for i in range(self.sequence_length):
                # Forward pass
                emb = self.embeddings(x[:, -1:])  # Only use the last token
                lstm_out, hidden = self.lstm(emb, hidden)
                logits = self.output_layer(lstm_out)
                
                # Sample from distribution
                probs = F.softmax(logits.squeeze(1), dim=-1)
                next_token = torch.multinomial(probs, 1)
                
                # Add to sequence
                generated_sequences[:, i] = next_token.squeeze()
                
                # Update input for next step (only need the current token, not the entire history)
                x = next_token
            
            return generated_sequences
            
    def calculate_nll(self, generated_sequences):

        with torch.no_grad():
            # Use all tokens except the last one as input
            inputs = generated_sequences[:, :-1]
            
            # Use all tokens except the first one as targets
            targets = generated_sequences[:, 1:]
            
            # Forward pass
            logits, _ = self.forward(inputs)
            
            # Calculate negative log-likelihood
            nll = F.cross_entropy(logits.reshape(-1, self.vocab_size), targets.reshape(-1), reduction='mean')
            
            return nll.item()
            
    def save_params(self, path):
        torch.save(self.state_dict(), path)
        
    def load_params(self, path):
        self.load_state_dict(torch.load(path))
        
    def save_samples(self, samples, file_path):
        with open(file_path, 'w') as f:
            for sample in samples.cpu().numpy():
                f.write(' '.join([str(int(x)) for x in sample]) + '\n')
    

In [3]:
class TargetLSTM(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, sequence_length, start_token, batch_size, device='cpu'):

        np.random.seed(66)
        torch.manual_seed(66)
        
        super(TargetLSTM, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = start_token
        self.batch_size = batch_size
        self.device = device
        
        # Define layers
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.output_layer = nn.Linear(hidden_dim, vocab_size)

        for param in self.parameters():
            nn.init.normal_(param, mean=0.0, std=0.1)
        
        # Initialize on device
        self.to(device)
        
    def forward(self, x, hidden=None):

        emb = self.embeddings(x)                    # [batch_size, sequence_length, embedding_dim]
        lstm_out, hidden = self.lstm(emb, hidden)   # lstm_out: [batch_size, sequence_length, hidden_dim]
        logits = self.output_layer(lstm_out)        # [batch_size, sequence_length, vocab_size]
        
        return logits, hidden
    
    def generate(self, num_samples):

        with torch.no_grad():
            
            # Start token for all sequences
            x = torch.full((num_samples, 1), self.start_token, dtype=torch.long, device=self.device)
            hidden = None  # Let PyTorch initialize the hidden state

            generated_sequences = torch.zeros(num_samples, self.sequence_length, dtype=torch.long, device=self.device)

            for i in range(self.sequence_length):
                # Forward pass
                emb = self.embeddings(x[:, -1:])  # Only use the last token
                lstm_out, hidden = self.lstm(emb, hidden)
                logits = self.output_layer(lstm_out)
                
                # Sample from distribution
                probs = F.softmax(logits.squeeze(1), dim=-1)
                next_token = torch.multinomial(probs, 1)
                
                # Add to sequence
                generated_sequences[:, i] = next_token.squeeze()
                
                # Update input for next step (only need the current token, not the entire history)
                x = next_token
            
            return generated_sequences
            
    def calculate_nll(self, generated_sequences):

        with torch.no_grad():
            # Use all tokens except the last one as input
            inputs = generated_sequences[:, :-1]
            
            # Use all tokens except the first one as targets
            targets = generated_sequences[:, 1:]
            
            # Forward pass
            logits, _ = self.forward(inputs)
            
            # Calculate negative log-likelihood
            nll = F.cross_entropy(logits.reshape(-1, self.vocab_size), targets.reshape(-1), reduction='mean')
            
            return nll.item()
            
    def save_params(self, path):
        torch.save(self.state_dict(), path)
        
    def load_params(self, path):
        self.load_state_dict(torch.load(path))
        
    def save_samples(self, samples, file_path):
        with open(file_path, 'w') as f:
            for sample in samples.cpu().numpy():
                f.write(' '.join([str(int(x)) for x in sample]) + '\n')
    

In [4]:


# Hyperparameters (matching the original implementation)
vocab_size = 5000
embedding_dim = 32
hidden_dim = 32
sequence_length = 20
start_token = 0
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

oracle = TargetLSTM(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
    sequence_length=sequence_length,
    start_token=start_token,
    batch_size=batch_size,
    device=device
)

In [5]:
# Hyperparameters (matching the original implementation)
vocab_size = 5000
embedding_dim = 32
hidden_dim = 32
sequence_length = 20
start_token = 0
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

OTHER = TargetLSTM(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
    sequence_length=sequence_length,
    start_token=start_token,
    batch_size=batch_size,
    device=device
)

In [6]:
# Generate some sequences using the Oracle
num_samples = 1000
oracle_sequences = oracle.generate(num_samples)
print(f"Generated {num_samples} sequences of length {sequence_length}")
print(f"Sequences shape: {oracle_sequences.shape}")

Generated 1000 sequences of length 20
Sequences shape: torch.Size([1000, 20])


In [7]:
oracle_nll = oracle.calculate_nll(oracle_sequences)
print(f"\nNLL of oracle-generated sequences: {oracle_nll}")


NLL of oracle-generated sequences: 8.512005805969238


In [8]:
random_sequences = OTHER.generate(num_samples)
random_nll = oracle.calculate_nll(random_sequences)
print(f"NLL of random sequences: {random_nll}")
print(f"Difference: {random_nll - oracle_nll}")

NLL of random sequences: 8.511055946350098
Difference: -0.000949859619140625


In [34]:
random_sequences = torch.randint(1, vocab_size, (num_samples, sequence_length), device=device)
random_nll = oracle.calculate_nll(random_sequences)
print(f"NLL of random sequences: {random_nll}")
print(f"Difference: {random_nll - oracle_nll}")

NLL of random sequences: 8.522820472717285
Difference: 0.010814666748046875


In [35]:
repeated_sequences = torch.ones(num_samples, sequence_length, device=device, dtype=torch.long)
for i in range(num_samples):
    repeated_sequences[i] = i % 100 + 1  # Use tokens 1-100, repeating
repeated_nll = oracle.calculate_nll(repeated_sequences)
print(f"NLL of repeated token sequences: {repeated_nll}")

NLL of repeated token sequences: 8.521069526672363
