In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
# Define constants
EMBEDDING_DIM = 16
HIDDEN_DIM = 16
LATENT_DIM = 16 # Dimension of the latent space
SEQ_LEN = 16 # Max length of the sequence

# Gumbel softmax temperature
TAU = 1.0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.random.manual_seed(1024)

<torch._C.Generator at 0x113efd8b0>

In [3]:
# Pass embeded into decoder instead of using the original x
class TransformerEncoder(nn.Module):
    def __init__(self, d_model=EMBEDDING_DIM, nhead=4, num_layers=2):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, d_model)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead), num_layers
        )
        self.fc_logits = nn.Linear(d_model, LATENT_DIM)

    def forward(self, x):
        embedded = self.embedding(x).permute(1, 0, 2)  # Transformer expects seq_len, batch, features
        transformed = self.transformer_encoder(embedded)
        # Use the final state to predict logits for latent space
        logits = self.fc_logits(transformed[-1])
        return logits, embedded


class TransformerDecoder(nn.Module):
    def __init__(self, d_model=EMBEDDING_DIM, nhead=4, num_layers=2):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, d_model)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead), num_layers
        )
        self.fc_out = nn.Linear(d_model, VOCAB_SIZE)
        self.fc_z = nn.Linear(LATENT_DIM, d_model)  # Convert z to feature size for transformer

    def forward(self, embedded, z):
        # embedded = self.embedding(x).permute(1, 0, 2) # Transformer expects [seq_len, batch, features], permute函数用于改变张量的维度顺序
        z_adjusted = self.fc_z(z).unsqueeze(0)
        output = self.transformer_decoder(embedded, z_adjusted)
        return self.fc_out(output.permute(1, 0, 2))


class TransformerCVAE(nn.Module):
    def __init__(self):
        super(TransformerCVAE, self).__init__()
        self.encoder = TransformerEncoder()
        self.decoder = TransformerDecoder()

    def reparameterize(self, logits):
        return F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1)

    def forward(self, x):
        logits, emb = self.encoder(x)
        z = self.reparameterize(logits)
        return self.decoder(emb, z), logits

In [4]:
def load_and_preprocess_wikitext(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

    # Use regular expressions to split the text into sentences
    sentences = re.split(r'(?<=[.!?])\s+', text)
    sentences = [sentence.strip() for sentence in sentences]
    
    return sentences

train_file_path = "wikitext-2/wiki.train.tokens"
test_file_path = "wikitext-2/wiki.test.tokens"
val_file_path = "wikitext-2/wiki.valid.tokens"

wikitext_sentences_train = load_and_preprocess_wikitext(train_file_path)
wikitext_sentences_test = load_and_preprocess_wikitext(test_file_path)
wikitext_sentences_val = load_and_preprocess_wikitext(val_file_path)

# Print the first few sentences to check
print("\nSample of train sentences:")
print(wikitext_sentences_train[:5])
print("\nSample of test sentences:")
print(wikitext_sentences_test[:5])
print("\nSample of val sentences:")
print(wikitext_sentences_val[:5])


Sample of train sentences:
['= Valkyria Chronicles III = \n \n Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit .', 'Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable .', 'Released in January 2011 in Japan , it is the third game in the Valkyria series .', '<unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " .', 'The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II .']

Sample of test sentences:
['= Robert <unk> = \n \n Robert <unk> is an English film , television and theatre actor .',

In [5]:
from torch.utils.data import Dataset, DataLoader
import torch

# Hyperparameters
BATCH_SIZE = 32
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"

# Tokenize the data
tokens = [word for sentence in wikitext_sentences_train for word in sentence.split()]

# Build vocabulary
vocab = [PAD_TOKEN, UNK_TOKEN] + list(set(tokens))
word_index = {word: index for index, word in enumerate(vocab)}
# 添加新的tokens
SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'
word_index[SOS_TOKEN] = len(word_index)
word_index[EOS_TOKEN] = len(word_index)
vocab = {v: k for k, v in word_index.items()}
# Convert tokens to integers
def tokenize_and_encode(text):
    return [word_index.get(word, word_index[UNK_TOKEN]) for word in text.split()]

encoded_data_train = [tokenize_and_encode(sentence) for sentence in wikitext_sentences_train]

# Create a PyTorch Dataset
class WikiDataset(Dataset):
    def __init__(self, data, sequence_length):
        self.data = data
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if len(sample) < self.sequence_length:
            sample.extend([word_index[PAD_TOKEN]] * (self.sequence_length - len(sample)))
        else:
            sample = sample[:self.sequence_length]
        return torch.tensor(sample)

# dataset = WikiDataset(encoded_data_train, SEQUENCE_LENGTH)
# dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Split the data into train and validation sets
dataset = WikiDataset(encoded_data_train, SEQ_LEN)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Display a sample batch
next(iter(train_dataloader))

VOCAB_SIZE = len(vocab)
print(f'Vocabulary size: {VOCAB_SIZE}')


Vocabulary size: 33281


In [6]:
import torch.nn.functional as F
import matplotlib.pyplot as plt
import gradio as gr

class MultiMultiSignalingGame:
    def __init__(self, senders: list, receivers: list, optimizer, criterion):
        self.senders = senders
        self.receivers = receivers
        self.optimizer = optimizer
        self.criterion = criterion

    def play_round(self, states):
        all_decoded_outputs = []
        all_logits = []
        
        for i, sender in enumerate(self.senders):
            # Sender encodes the state
            logits, emb = sender(states[i])
            all_logits.append(logits)
            z = F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1)
            
            # Each receiver decodes the signal from the sender
            for receiver in self.receivers:
                decoded_output = receiver(emb, z)
                all_decoded_outputs.append(decoded_output)
      
        # Calculate loss
        loss = self.compute_loss(states, all_decoded_outputs, all_logits, beta=1.0)
        
        # Update model parameters
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Capture the input and output sentences
        _, input_sentence_ids = torch.max(states[0], dim=1)
        input_sentence_ids = input_sentence_ids.cpu().numpy()
        input_sentence = ' '.join([vocab[idx] for idx in input_sentence_ids])

        _, output_sentence_ids = torch.max(all_decoded_outputs[0][0], dim=1)
        output_sentence_ids = output_sentence_ids.cpu().numpy()
        output_sentence = ' '.join([vocab[idx] for idx in output_sentence_ids])

        return loss.item(), input_sentence, output_sentence

    def compute_loss(self, original_states, decoded_states, logits, beta):
        recon_loss = sum([self.criterion(decoded_state.view(-1, VOCAB_SIZE), original_state.view(-1))
                          for original_state, decoded_state in zip(original_states * len(self.receivers), decoded_states)])
        
        # Calculate KLD loss
        kld_losses = []
        for logit in logits:
            mean, logvar = torch.chunk(logit, 2, dim=-1)
            kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
            kld_losses.append(kld_loss)

        return recon_loss + beta * sum(kld_losses)

def train_signal_game(NUM_SENDERS=3, NUM_RECEIVERS=3, num_rounds=10000):
    senders = [TransformerEncoder().to(device) for _ in range(NUM_SENDERS)]
    receivers = [TransformerDecoder().to(device) for _ in range(NUM_RECEIVERS)]
    params = [list(sender.parameters()) for sender in senders]
    params.extend([list(receiver.parameters()) for receiver in receivers])
    optimizer = torch.optim.Adam([param for sublist in params for param in sublist], lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()
    game = MultiMultiSignalingGame(senders, receivers, optimizer, criterion)

    losses = []
    conversations = []
    for round in range(num_rounds):
        states = [torch.randint(VOCAB_SIZE, (BATCH_SIZE, 16)).to(device) for _ in range(NUM_SENDERS)]
        loss, input_sentence, output_sentence = game.play_round(states)
        losses.append(loss)
        conversations.append(f"Round {round+1} - Input: {input_sentence} | Output: {output_sentence}")

    conversation_str = "\n".join(conversations)
    
    plt.figure(figsize=(10,6))
    plt.plot(losses, label='losses')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curve')
    plt.savefig('loss_curve.png')
    plt.close()

    return 'loss_curve.png', conversation_str

iface = gr.Interface(
    fn=train_signal_game,
    inputs=[
        gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label="NUM_SENDERS"),
        gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label="NUM_RECEIVERS"),
        gr.inputs.Slider(minimum=1000, maximum=20000, step=1000, default=10000, label="num_rounds"),
    ],
    outputs=[
        gr.outputs.Image(type="filepath", label="Loss Curve"),
        gr.outputs.Textbox(label="Conversations")
    ],
    live=True
)

iface.launch()


  gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label="NUM_SENDERS"),
  gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label="NUM_SENDERS"),
  gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label="NUM_RECEIVERS"),
  gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label="NUM_RECEIVERS"),
  gr.inputs.Slider(minimum=1000, maximum=20000, step=1000, default=10000, label="num_rounds"),
  gr.inputs.Slider(minimum=1000, maximum=20000, step=1000, default=10000, label="num_rounds"),
  gr.outputs.Image(type="filepath", label="Loss Curve"),
  gr.outputs.Textbox(label="Conversations")


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




Error while flagging: field larger than field limit (131072)
Error while flagging: field larger than field limit (131072)


In [8]:
import streamlit as st
import torch.nn.functional as F
import torch.nn.functional as F
import matplotlib.pyplot as plt
import gradio as gr

class MultiMultiSignalingGame:
    def __init__(self, senders: list, receivers: list, optimizer, criterion):
        self.senders = senders
        self.receivers = receivers
        self.optimizer = optimizer
        self.criterion = criterion

    def play_round(self, states):
        all_decoded_outputs = []
        all_logits = []
        
        for i, sender in enumerate(self.senders):
            # Sender encodes the state
            logits, emb = sender(states[i])
            all_logits.append(logits)
            z = F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1)
            
            # Each receiver decodes the signal from the sender
            for receiver in self.receivers:
                decoded_output = receiver(emb, z)
                all_decoded_outputs.append(decoded_output)
      
        # Calculate loss
        loss = self.compute_loss(states, all_decoded_outputs, all_logits, beta=1.0)
        
        # Update model parameters
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Capture the input and output sentences
        _, input_sentence_ids = torch.max(states[0], dim=1)
        input_sentence_ids = input_sentence_ids.cpu().numpy()
        input_sentence = ' '.join([vocab[idx] for idx in input_sentence_ids])

        _, output_sentence_ids = torch.max(all_decoded_outputs[0][0], dim=1)
        output_sentence_ids = output_sentence_ids.cpu().numpy()
        output_sentence = ' '.join([vocab[idx] for idx in output_sentence_ids])

        return loss.item(), input_sentence, output_sentence

    def compute_loss(self, original_states, decoded_states, logits, beta):
        recon_loss = sum([self.criterion(decoded_state.view(-1, VOCAB_SIZE), original_state.view(-1))
                          for original_state, decoded_state in zip(original_states * len(self.receivers), decoded_states)])
        
        # Calculate KLD loss
        kld_losses = []
        for logit in logits:
            mean, logvar = torch.chunk(logit, 2, dim=-1)
            kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
            kld_losses.append(kld_loss)

        return recon_loss + beta * sum(kld_losses)

def train_signal_game(NUM_SENDERS=3, NUM_RECEIVERS=3, num_rounds=10000):
    senders = [TransformerEncoder().to(device) for _ in range(NUM_SENDERS)]
    receivers = [TransformerDecoder().to(device) for _ in range(NUM_RECEIVERS)]
    params = [list(sender.parameters()) for sender in senders]
    params.extend([list(receiver.parameters()) for receiver in receivers])
    optimizer = torch.optim.Adam([param for sublist in params for param in sublist], lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()
    game = MultiMultiSignalingGame(senders, receivers, optimizer, criterion)

    losses = []
    conversations = []
    for round in range(num_rounds):
        states = [torch.randint(VOCAB_SIZE, (BATCH_SIZE, 16)).to(device) for _ in range(NUM_SENDERS)]
        loss, input_sentence, output_sentence = game.play_round(states)
        losses.append(loss)
        conversations.append(f"Round {round+1} - Input: {input_sentence} | Output: {output_sentence}")

    conversation_str = "\n".join(conversations)
    
    plt.figure(figsize=(10,6))
    plt.plot(losses, label='losses')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curve')
    plt.savefig('loss_curve.png')
    plt.close()

    return 'loss_curve.png', conversation_str

def main():
    st.title('Transformer-based CatVAE and Signal Game')

    # 添加 sliders 来选择参数
    NUM_SENDERS = st.sidebar.slider('NUM_SENDERS', 1, 10, 3)
    NUM_RECEIVERS = st.sidebar.slider('NUM_RECEIVERS', 1, 10, 3)
    num_rounds = st.sidebar.slider('num_rounds', 1000, 20000, 10000)

    # 为了记录所有的参数组合尝试
    try:
        attempts = st.session_state.attempts
    except AttributeError:
        st.session_state.attempts = []
        attempts = st.session_state.attempts

    # 显示过往所有尝试的参数组合
    st.sidebar.text("Previous Attempts:")
    for attempt in attempts:
        st.sidebar.text(attempt)

    # 添加一个按钮来开始游戏
    if st.button('Start'):
        # 运行游戏并获取结果
        losses, conversations = play_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds)  # 假设 play_game 是实际执行信号游戏的函数

        # 显示结果
        st.line_chart(losses, use_container_width=True)
        st.text_area("Conversations", "\n".join(conversations), height=200)

        # 保存这次尝试的参数组合
        attempts.append(f"Senders: {NUM_SENDERS}, Receivers: {NUM_RECEIVERS}, Rounds: {num_rounds}")

if __name__ == '__main__':
    main()
!streamlit run /Users/YUAN/opt/anaconda3/envs/myCVAE/lib/python3.11/site-packages/ipykernel_launcher.py

AttributeError: st.session_state has no attribute "attempts". Did you forget to initialize it? More info: https://docs.streamlit.io/library/advanced-features/session-state#initialization