# Import 必要的库：

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
import re

# 数据处理：

In [3]:
def load_and_preprocess_wikitext(file_path):
    # Load the data
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read().splitlines()
    # Preprocess: Remove lines that are too short, and tokenize
    sentences = [line for line in text if len(line.split()) > 2]
    return sentences[:1000]  # Taking only 1000 sentences for simplicity

def tokenize(sentence):
    """Tokenizes a sentence."""
    return [word for word in re.split(r'\W+', sentence) if word]

def sentence_to_tensor(sentence):
    """Converts a sentence to its tensor representation."""
    tokens = tokenize(sentence)
    tokens = ['<SOS>'] + tokens + ['<EOS>']
    if len(tokens) > MAX_LENGTH:
        tokens = tokens[:MAX_LENGTH]
    while len(tokens) < MAX_LENGTH:
        tokens.append('<PAD>')
    tensor = torch.tensor([word2index.get(token, word2index["<UNK>"]) for token in tokens], dtype=torch.long)
    return tensor

wikitext_train_path = "wikitext-2/wiki.train.tokens"  # Modify the path as per your directory structure
wikitext_sentences = load_and_preprocess_wikitext(wikitext_train_path)
all_words = [word for sentence in wikitext_sentences for word in tokenize(sentence)]
vocab = Counter(all_words)
vocab


Counter({'the': 4964,
         'of': 2571,
         'unk': 2369,
         'and': 2240,
         'in': 1686,
         'to': 1645,
         'a': 1506,
         'was': 869,
         'The': 812,
         'with': 678,
         'for': 644,
         'as': 627,
         's': 622,
         'that': 617,
         'is': 552,
         'on': 552,
         'by': 489,
         'were': 392,
         'at': 355,
         'from': 346,
         'his': 340,
         'are': 302,
         'an': 264,
         'her': 255,
         'which': 252,
         'In': 244,
         'he': 221,
         'be': 219,
         'had': 212,
         'it': 204,
         'their': 186,
         'gods': 175,
         'or': 174,
         'has': 169,
         'also': 167,
         'not': 164,
         'one': 156,
         'who': 153,
         'but': 148,
         'two': 148,
         'its': 137,
         'she': 137,
         'have': 133,
         'this': 130,
         'other': 126,
         'Fey': 125,
         'first': 124,
        

# 模型参数和词汇表：

In [3]:
EMBEDDING_DIM = 256
NUM_HEADS = 4
NUM_LAYERS = 2
LATENT_DIM = 50
MAX_LENGTH = 10  # or whatever maximum length you decide on

special_tokens = ['<PAD>', '<SOS>', '<EOS>', '<UNK>']
all_tokens = special_tokens + [word for word, _ in vocab.most_common()]
word2index = {word: index for index, word in enumerate(all_tokens)}
index2word = {index: word for word, index in word2index.items()}
VOCAB_SIZE = len(word2index)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 模型定义

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(EMBEDDING_DIM, NUM_HEADS),
            num_layers=NUM_LAYERS
        )
        # self.fc_mu = nn.Linear(MAX_LENGTH * EMBEDDING_DIM, LATENT_DIM)
        # self.fc_var = nn.Linear(MAX_LENGTH * EMBEDDING_DIM, LATENT_DIM)
        self.fc_mu = nn.Linear((MAX_LENGTH-1)*EMBEDDING_DIM, LATENT_DIM)
        self.fc_var = nn.Linear((MAX_LENGTH-1)*EMBEDDING_DIM, LATENT_DIM)




    def forward(self, src):
        embedded = self.embedding(src)
        encoded = self.transformer(embedded)
        mu = self.fc_mu(encoded.permute(1, 0, 2).reshape(src.size(0), -1))
        logvar = self.fc_var(encoded.permute(1, 0, 2).reshape(src.size(0), -1))
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.sequence_length = MAX_LENGTH
        self.latent_dim = LATENT_DIM
        self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.transformer = nn.Transformer(
            d_model=EMBEDDING_DIM,
            nhead=NUM_HEADS,
            num_encoder_layers=NUM_LAYERS,
            num_decoder_layers=NUM_LAYERS
        )
        self.fc_out = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)
        # self.fc_latent_to_embedding = nn.Linear(self.sequence_length * self.latent_dim, MAX_LENGTH * EMBEDDING_DIM)
        self.fc_latent_to_embedding = nn.Linear(LATENT_DIM, MAX_LENGTH * EMBEDDING_DIM)


    def forward(self, tgt, z):
        embedded = self.embedding(tgt)
        # Flatten z and then reshape it to the desired shape
        z_flattened = z.view(z.size(0), -1)  # Shape: [batch_size, 9*50]
        embedded_latent = self.fc_latent_to_embedding(z_flattened)  # Shape: [batch_size, MAX_LENGTH * EMBEDDING_DIM]
        embedded_latent = embedded_latent.view(z.size(0), MAX_LENGTH, EMBEDDING_DIM)  # Reshape to [batch_size, MAX_LENGTH, EMBEDDING_DIM]
        print(embedded.shape)
        print(embedded_latent.shape)
        decoded = self.transformer(embedded, embedded_latent)
        out = self.fc_out(decoded)
        return out


class TransformerCVAE(nn.Module):
    def __init__(self):
        super(TransformerCVAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def forward(self, src, tgt, tau=1.0):
        mu, logvar = self.encoder(src)
        z = self.reparameterize(mu, logvar)
        gumbel_softmax_sample = F.gumbel_softmax(z, tau=tau, hard=False)
        out = self.decoder(tgt, gumbel_softmax_sample)
        return out, mu, logvar


# Instantiate the model to check if it's constructed correctly
model = TransformerCVAE().to(device)
model


TransformerCVAE(
  (encoder): Encoder(
    (embedding): Embedding(9694, 256)
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (fc_mu): Linear(in_features=2304, out_features=50, bias=True)
    (fc_var): Linear(in_features=2304, out_features=50, bias=True)
  )
  (decoder): Decoder(
    (embedding): Embe

# 训练 & 生成

In [5]:
def train(model, data, epochs=10, lr=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        total_loss = 0
        for sentence in data:
            optimizer.zero_grad()
            
            # Prepare the input and target tensors
            sentence_in = sentence[:-1].unsqueeze(0)
            sentence_out = sentence[1:].unsqueeze(0)
            
            outputs, mu, logvar = model(sentence_in, sentence_in)
            loss = criterion(outputs.view(-1, VOCAB_SIZE), sentence_out.view(-1))
            
            # Add KL divergence
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss += KLD
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(data)}")

# Now, let's define a function to generate sentences from the trained model
def generate_sentence(model, max_length=MAX_LENGTH):
    model.eval()
    with torch.no_grad():
        # Start with the <SOS> token
        sentence = [word2index["<SOS>"]]
        for _ in range(max_length - 1):
            input_tensor = torch.tensor(sentence).unsqueeze(0)
            outputs, _, _ = model(input_tensor, input_tensor)
            next_word_idx = torch.argmax(outputs[0, -1]).item()
            if next_word_idx == word2index["<EOS>"]:
                break
            sentence.append(next_word_idx)
        return ' '.join([index2word[idx] for idx in sentence])

# Convert the wikitext sentences to tensor format for training
# tensor_data_train = [sentence_to_tensor(sentence) for sentence in wikitext_sentences]
tensor_data_train = [sentence_to_tensor(sentence).to(device) for sentence in wikitext_sentences]
# Check the lengths to ensure they are all 10
lengths_updated = [len(tensor) for tensor in tensor_data_train]
all(length == 10 for length in lengths_updated)

# Train the model
train(model, tensor_data_train, epochs=5)

# Generate a sentence
generated_sentence = generate_sentence(model)
generated_sentence


RuntimeError: shape '[1]' is invalid for input of size 50