In [1]:
#-----------------------------Custom Mode------------------------------------#Final---version-----------------------------
import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import nltk
from nltk.tokenize import word_tokenize
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import matplotlib.pyplot as plt
import torch.nn.functional as F
nltk.download('punkt')
import random


# 1. Dataset Preparation
class PoemDataset(Dataset):
    def __init__(self, poems, vocab, seq_length):
        self.poems = poems
        self.vocab = vocab
        self.seq_length = seq_length
        self.word_to_idx = {word: idx for idx, word in enumerate(vocab)}
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}

    def __len__(self):
        return len(self.poems)

    def __getitem__(self, idx):
        words = self.poems[idx].split()
        words = words[:self.seq_length]  # Limit the length of the poem to seq_length
        input_ids = [self.word_to_idx.get(word, self.word_to_idx["<unk>"]) for word in words[:-1]]
        target_ids = [self.word_to_idx.get(word, self.word_to_idx["<unk>"]) for word in words[1:]]

        # Padding the sequences to ensure they have the same length
        input_ids = input_ids + [self.word_to_idx["<pad>"]] * (self.seq_length - len(input_ids))
        target_ids = target_ids + [self.word_to_idx["<pad>"]] * (self.seq_length - len(target_ids))

        return torch.tensor(input_ids), torch.tensor(target_ids)

# Custom collate_fn for padding sequences dynamically in each batch
def collate_fn(batch):
    input_ids, target_ids = zip(*batch)
    input_ids = torch.stack(input_ids, dim=0)
    target_ids = torch.stack(target_ids, dim=0)
    return input_ids, target_ids

# Load and preprocess data
with open(r"F:\LeakGan_thesis\poem.txt\poem.txt", "r", encoding="utf-8") as file:
    poems = file.readlines()

poems = [' '.join(word_tokenize(line.lower())) for line in poems if line.strip()]
train_data, val_data = train_test_split(poems, test_size=0.1, random_state=42)

# Build vocabulary
all_words = set(word for line in poems for word in line.split())
vocab = ["<pad>", "<unk>", "<start>"] + sorted(all_words)  # Add "<start>" token
vocab_size = len(vocab)
seq_length = 50

train_dataset = PoemDataset(train_data, vocab, seq_length)
val_dataset = PoemDataset(val_data, vocab, seq_length)

# Use the custom collate_fn in the DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

# 2. Generator (Transformer-based architecture)
class Generator(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, word_to_idx):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = nn.Parameter(torch.randn(1, seq_length, embed_size))  # Positional Encoding
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=8, dropout=0.3, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=4)
        self.fc = nn.Linear(embed_size, vocab_size)
        self.word_to_idx = word_to_idx

    def forward(self, x):
        embedded = self.embedding(x) + self.positional_encoding
        transformer_output = self.transformer_encoder(embedded)
        logits = self.fc(transformer_output)
        return logits

    def sample(self, start_token, max_length, num_lines=1, temperature=1.0, top_k=50):
        self.eval()
        poem = [[start_token] for _ in range(num_lines)]
        for _ in range(max_length-1):
            inputs = torch.tensor([line[-1] for line in poem], dtype=torch.long).unsqueeze(1)
            logits = self(inputs)
            logits = logits[:, -1, :] / temperature  # Apply temperature scaling
            probabilities = torch.softmax(logits, dim=-1)

            # Top-k sampling
            topk_values, topk_indices = torch.topk(probabilities, top_k)
            next_tokens = torch.multinomial(topk_values, 1).squeeze(1)
            next_tokens = topk_indices.gather(1, next_tokens.unsqueeze(1)).squeeze(1)

            for i, token in enumerate(next_tokens):
                poem[i].append(token.item())
                if token == self.word_to_idx["<pad>"]:
                    poem[i] = poem[i][:len(poem[i]) - 1]

        poem_tensor = torch.tensor([line[:max_length] + [self.word_to_idx["<pad>"]] * (max_length - len(line)) for line in poem], dtype=torch.long)
      # Ensure it’s a tensor
        return poem_tensor

# 3. Discriminator (Transformer-based architecture)
class Discriminator(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = nn.Parameter(torch.randn(1, seq_length, embed_size))  # Positional Encoding
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=8, dropout=0.3, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
        self.fc = nn.Linear(embed_size, 1)

    def forward(self, x):
        embedded = self.embedding(x) + self.positional_encoding
        transformer_output = self.transformer_encoder(embedded)
        logits = self.fc(transformer_output[:, -1, :])
        return torch.sigmoid(logits)

# Initialize models
generator = Generator(vocab_size, embed_size=256, hidden_size=512, word_to_idx=train_dataset.word_to_idx)
discriminator = Discriminator(vocab_size, embed_size=256, hidden_size=512)

g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.98))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.9, 0.98))  # Lowered d_optimizer learning rate

# Loss functions
criterion_gen = nn.CrossEntropyLoss()
criterion_disc = nn.BCELoss()

# 4. Pretrain Generator
print("Pretraining Generator...")
num_epochs = 3# Increased epochs for more training
for epoch in range(num_epochs):
    generator.train()
    total_loss = 0
    for inputs, targets in tqdm(train_loader):
        logits = generator(inputs)
        loss = criterion_gen(logits.view(-1, vocab_size), targets.view(-1))

        g_optimizer.zero_grad()
        loss.backward()
        g_optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}")

# 5. Pretrain Discriminator
print("Pretraining Discriminator...")
for epoch in range(num_epochs):
    discriminator.train()
    total_loss = 0
    y_true, y_pred = [], []
    for inputs, _ in tqdm(train_loader):
        # Real labels
        real_labels = torch.ones(inputs.size(0), 1)

        # Generate fake data using the generator
        fake_data = generator.sample(start_token=train_dataset.word_to_idx["<start>"], max_length=seq_length, num_lines=inputs.size(0))
        fake_data = fake_data.clone().detach()  # Shape: [batch_size, seq_length]

        # Ensure fake data matches the sequence length (seq_length = 50)
        if fake_data.size(1) < seq_length:
            # Pad if needed
            padding_size = (seq_length - fake_data.size(1))
            fake_data = torch.nn.functional.pad(fake_data, (0, padding_size), value=train_dataset.word_to_idx["<pad>"])
        elif fake_data.size(1) > seq_length:
            # Truncate if needed
            fake_data = fake_data[:, :seq_length]

        # Get discriminator's prediction for real and fake data
        real_preds = discriminator(inputs)
        fake_preds = discriminator(fake_data)

        # Define fake labels
        fake_labels = torch.zeros(inputs.size(0), 1)

        # Compute the loss
        real_loss = criterion_disc(real_preds, real_labels)
        fake_loss = criterion_disc(fake_preds, fake_labels)

        loss = real_loss + fake_loss

        # Backpropagate and update discriminator
        d_optimizer.zero_grad()
        loss.backward()
        d_optimizer.step()

        total_loss += loss.item()

        # Collect results for metrics calculation
        y_true.extend([1] * real_labels.size(0) + [0] * fake_labels.size(0))
        y_pred.extend(torch.cat([real_preds, fake_preds]).round().tolist())

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}")
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    print(f"Discriminator Metrics -> Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

# 6. Train the GAN (Adversarial Training)
g_losses = []
d_losses = []

# For BLEU evaluation
start_token = train_dataset.word_to_idx["<start>"]
id_to_word = train_dataset.idx_to_word

def decode_poem(poem_tensor):
    return [' '.join([id_to_word[idx.item()] for idx in line if idx.item() != train_dataset.word_to_idx["<pad>"]]) for line in poem_tensor]

gan_epochs = 3
print("Adversarial Training...")
for epoch in range(gan_epochs):
    g_loss_epoch = 0
    d_loss_epoch = 0

    for inputs, _ in tqdm(train_loader):
        # Train Discriminator
        discriminator.train()

        real_labels = torch.ones(inputs.size(0), 1)
        fake_data = generator.sample(start_token=start_token, max_length=seq_length, num_lines=inputs.size(0))
        fake_data = fake_data.clone().detach()
        if fake_data.size(1) < seq_length:
            padding_size = (seq_length - fake_data.size(1))
            fake_data = torch.nn.functional.pad(fake_data, (0, padding_size), value=train_dataset.word_to_idx["<pad>"])
        elif fake_data.size(1) > seq_length:
            fake_data = fake_data[:, :seq_length]

        real_preds = discriminator(inputs)
        fake_preds = discriminator(fake_data)

        fake_labels = torch.zeros(inputs.size(0), 1)

        real_loss = criterion_disc(real_preds, real_labels)
        fake_loss = criterion_disc(fake_preds, fake_labels)
        d_loss = real_loss + fake_loss

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        d_loss_epoch += d_loss.item()

        # Train Generator
        generator.train()

        fake_data = generator.sample(start_token=start_token, max_length=seq_length, num_lines=inputs.size(0))
        
        fake_data = fake_data.clone().detach()

        if fake_data.size(1) < seq_length:
            padding_size = (seq_length - fake_data.size(1))
            fake_data = torch.nn.functional.pad(fake_data, (0, padding_size), value=train_dataset.word_to_idx["<pad>"])
        elif fake_data.size(1) > seq_length:
            fake_data = fake_data[:, :seq_length]

        gen_preds = discriminator(fake_data)
        g_loss = criterion_disc(gen_preds, real_labels)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        g_loss_epoch += g_loss.item()

    g_losses.append(g_loss_epoch / len(train_loader))
    d_losses.append(d_loss_epoch / len(train_loader))

    print(f"Epoch {epoch + 1}, Generator Loss: {g_losses[-1]:.4f}, Discriminator Loss: {d_losses[-1]:.4f}")


# Function to remove unwanted tokens and adjust the length
def generate_poem_lines(generator, start_token, train_dataset, num_lines=10, min_words=5, max_words=10):
    poem_lines = generator.sample(start_token=start_token, max_length=seq_length, num_lines=num_lines)

    poem_words = []
    for line in poem_lines:
        words = []
        used_words = set()  # Track used words in the line to avoid repetition
        for idx in line:
            word = train_dataset.idx_to_word[idx.item()]  # Convert tensor index to item if needed

            # Skip unwanted tokens
            if word not in ["<start>", "<unk>", "<pad>"] and word not in used_words:
                words.append(word)
                used_words.add(word)

        # Adjust the line to meet the random word count range
        word_count = len(words)
        random_word_count = random.randint(min_words, max_words)  # Choose a random word count between 5 and 10

        # If fewer than the random word count, extend the line with random words (avoiding repetition)
        while len(words) < random_word_count:
            additional_word = random.choice([word for word in train_dataset.idx_to_word.values() if word not in used_words and word not in ["<start>", "<unk>", "<pad>"]])
            words.append(additional_word)
            used_words.add(additional_word)

        # Truncate the line if it exceeds the random word count
        if len(words) > random_word_count:
            words = words[:random_word_count]

        poem_words.append(" ".join(words))

    return poem_words

# Final BLEU score calculation with smoothing function
if len(val_data) > 0:
    references = [line.split() for line in val_data[:10]]  # Using first 10 samples as references

    # Generate the final poem
    final_generated_poem = generate_poem_lines(generator, start_token=train_dataset.word_to_idx["<start>"], train_dataset=train_dataset, num_lines=random.randint(7, 10))
    print(f"Final Generated Poem: {'\n'.join(final_generated_poem)}")

    bleu_scores = []
    smoothing_function = SmoothingFunction().method4  # Choose a smoothing function

    for ref, cand in zip(references, final_generated_poem):
        try:
            score = sentence_bleu([ref], cand.split(), smoothing_function=smoothing_function)
            bleu_scores.append(score)
        except Exception as e:
            print(f"Error calculating BLEU score: {e}")
            bleu_scores.append(0)  # Add a default value on error
    
    avg_bleu_score = np.mean(bleu_scores)
    print(f"Average BLEU Score: {avg_bleu_score:.4f}")




# 7. Save the Models
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")

# 8. Plot Losses
plt.figure(figsize=(10, 5))
plt.plot(range(gan_epochs), g_losses, label='Generator Loss')
plt.plot(range(gan_epochs), d_losses, label='Discriminator Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Generator and Discriminator Losses during Adversarial Training')
plt.show()



[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Muhtasim\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Pretraining Generator...


  0%|          | 1/1477 [00:17<7:19:45, 17.88s/it]


KeyboardInterrupt: 