In [1]:
import os  # use OS dependent functions like reading or writing files
os.environ['NUMEXPR_MAX_THREADS'] = '10' # default is 8, Mac Book M1 Pro has 10 can support 10, which may slow down the computer 

import torch
import torch.nn as nn  

import numpy as np
import collections # compute source_vocab
from tqdm import tqdm 
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 


device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Config:
    E = 128
    V = 128
    sigma = 0.1   
    num_steps = 40 
    checkpoint_path = "model_checkpoint.pth"
    corpus_train_vocab = "corpus_train_vocab.txt"
    corpus_train_split = "corpus_train_split.txt"
    input_file_path = "input.txt"
    output_file_path = "output.txt"
    batch_size = 32
CONFIG = Config()



class Vocab:  
    # __init__ sorts the tokens by frequency in descending order, and assigns an index to each token
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None): 
        if tokens is None:
            tokens = []
        if reserved_tokens is None:
            reserved_tokens = []
        counter = tokens
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1],  
                                   reverse=True)            # A list of tokens sorted by frequency
        # Unknown tokens have an index of 0
        self.idx_to_token = ['<unk>'] + reserved_tokens  # List, idx corresponds to the position of token
        self.token_to_idx = {token: idx
                             for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
        self.stoi = self.token_to_idx
        self.itos = self.idx_to_token
        self.tokens = self.idx_to_token[1:]

    def __len__(self): # returns the length of the vocabulary
        return len(self.idx_to_token)

    def __getitem__(self, tokens): # returns the index of a given token
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices): # returns the token corresponding to a given index
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]
    
    def to_indices(self, tokens): # returns the index corresponding to a given token
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.token_to_idx.get(token, self.unk) for token in tokens]

    @property
    def unk(self):  # Unknown tokens have an index of 0
        return 0

    @property
    def token_freqs(self): # returns the list of token frequencies that was sorted in descending order during initialization of the Vocab class.
        return self._token_freqs

## load here to allow dataset module to use the source_vocab file
checkpoint = torch.load(CONFIG.checkpoint_path, map_location=torch.device('cpu'))
source_vocab = checkpoint['source_vocab']
train_corpus = open(CONFIG.corpus_train_split, 'r').readlines()  # the special "\n" has not been processed


class MyDataset(torch.utils.data.Dataset):  # define Pytorch data set
    def __init__(self, vocab=source_vocab, corpus=train_corpus):
        self.vocab = vocab
        self.corpus = corpus
        
    def __len__(self):
        return len(self.corpus)
    
    def __getitem__(self, index):
        sentence = self.corpus[index].strip().split() + ['<eos>']
        return self.vocab[sentence], len(sentence)
    

def collate_fn(batch_data): 
    batch_data.sort(key=lambda xi: len(xi[0]), reverse=True)
    data_length = [xi[1] for xi in batch_data]
    data = [torch.tensor(xi[0]) for xi in batch_data]
    padded_data = nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=1)
    return padded_data, torch.tensor(data_length)



class PositionalEncoding(nn.Module):
    """position encoding"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # Create a long enough P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

class MyEmbedding(nn.Module):
    def __init__(self, vocab=source_vocab):
        super(MyEmbedding, self).__init__()
        self.embedding = nn.Embedding(len(vocab), CONFIG.E, padding_idx=vocab['<pad>'])
        
    def forward(self, X):
        return self.embedding(X)

# with batch normalization
class Encoder(nn.Module):
    def __init__(self, vocab=source_vocab):
        super(Encoder, self).__init__()
        self.position_encoding = PositionalEncoding(CONFIG.V, dropout=0.1)
        self.norm1 = nn.BatchNorm1d(CONFIG.V)
        self.transformer_encoder1 = nn.TransformerEncoderLayer(d_model=CONFIG.V, nhead=8, dim_feedforward=512,  # other's parameter
                                                              batch_first=True)
        self.norm2 = nn.BatchNorm1d(CONFIG.V)
        self.transformer_encoder2 = nn.TransformerEncoderLayer(d_model=CONFIG.V, nhead=8, dim_feedforward=512,  # other's parameter
                                                              batch_first=True)
        self.norm3 = nn.BatchNorm1d(CONFIG.V)
        self.transformer_encoder3 = nn.TransformerEncoderLayer(d_model=CONFIG.V, nhead=8, dim_feedforward=512,  # other's parameter
                                                              batch_first=True)
        self.linear1 = nn.Linear(CONFIG.V, 2 * CONFIG.V)
        self.norm4 = nn.BatchNorm1d(2 * CONFIG.V)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(2 * CONFIG.V, 16)
        self.norm5 = nn.BatchNorm1d(16)
        self.relu2 = nn.ReLU()
        
    def forward(self, X, valid_lens):
        mask = (torch.arange((X.shape[1]), device=device).unsqueeze(0) >= valid_lens.unsqueeze(1)).to(device)
        X2 = self.position_encoding(X)
        X2_norm = self.norm1(X2.transpose(1, 2)).transpose(1, 2)
        X3 = self.transformer_encoder1(X2_norm, src_key_padding_mask=mask)
        X3_norm = self.norm2(X3.transpose(1, 2)).transpose(1, 2)
        X4 = self.transformer_encoder2(X3_norm, src_key_padding_mask=mask)
        X4_norm = self.norm3(X4.transpose(1, 2)).transpose(1, 2)
        X5 = self.transformer_encoder3(X4_norm, src_key_padding_mask=mask)
        X6 = self.linear1(X5)
        X6_norm = self.norm4(X6.transpose(1, 2)).transpose(1, 2)
        X7 = self.relu1(X6_norm)
        X8 = self.linear2(X7)
        X8_norm = self.norm5(X8.transpose(1, 2)).transpose(1, 2)
        X9 = self.relu2(X8_norm)
        return X9
        
def Channel(X):  # AWGN
    return X + torch.normal(0, CONFIG.sigma, size=X.shape).to(device)


class Decoder(nn.Module):
    def __init__(self, vocab=source_vocab):
        super(Decoder, self).__init__()
        # reshape
        self.position_encoding = PositionalEncoding(CONFIG.V, dropout=0.1)
        self.linear1 = nn.Linear(16, 2 * CONFIG.V)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(2 * CONFIG.V, CONFIG.V)
        self.relu2 = nn.ReLU()
        self.transformer_decoder1 = nn.TransformerDecoderLayer(d_model=CONFIG.V, nhead=8, dim_feedforward=512,  # other's parameter
                                                              batch_first=True)
        self.transformer_decoder2 = nn.TransformerDecoderLayer(d_model=CONFIG.V, nhead=8, dim_feedforward=512,  # other's parameter
                                                              batch_first=True)
        self.transformer_decoder3 = nn.TransformerDecoderLayer(d_model=CONFIG.V, nhead=8, dim_feedforward=512,  # other's parameter
                                                              batch_first=True)
        self.linear3 = nn.Linear(CONFIG.V, len(source_vocab))
        
    def forward(self, emb_decoder_input, channel_output, origin_len, tgt_mask=None, mode='train'):
        memory_mask = (torch.arange((channel_output.shape[1]), dtype=torch.float32,
                            device=device)[None, :] >= origin_len[:, None]).to(device)
        channel_output = self.linear1(channel_output)
        channel_output = self.relu1(channel_output)
        channel_output = self.linear2(channel_output)
        memory = self.relu2(channel_output)
        emb_decoder_input = self.position_encoding(emb_decoder_input)
        X6 = self.transformer_decoder1(emb_decoder_input, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_mask, tgt_key_padding_mask=memory_mask if mode == 'train' else None)
        X7 = self.transformer_decoder2(X6, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_mask, tgt_key_padding_mask=memory_mask if mode == 'train' else None)
        X8 = self.transformer_decoder3(X7, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_mask, tgt_key_padding_mask=memory_mask if mode == 'train' else None)
        X9 = self.linear3(X8)
        return X9
                


class VAE(nn.Module):
    def __init__(self, in_dim=16, out_dim=64, latent_dim=8):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim),
            nn.ReLU()
        )

        self.mu = nn.Linear(out_dim, latent_dim)
        self.logvar = nn.Linear(out_dim, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, in_dim)
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        x_hat = self.decoder(z)
        return x_hat

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, x.shape[-1]))
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

class DeepST(nn.Module):
    def __init__(self):
        super(DeepST, self).__init__()
        self.encoder = Encoder()
        self.channel = Channel
        self.decoder = Decoder()
        self.vae = VAE(in_dim=16, out_dim=64) # change out_dim 16 to 64 for example
    

    def forward(self, emb_encoder_input, valid_lens, emb_decoder_input=None, embedding=None, phase=1):
        encode_result = self.encoder(emb_encoder_input, valid_lens)

        # VAE encoding and decoding
        vae_encoded = self.vae.encode(encode_result)
        vae_sampled = self.vae.reparameterize(*vae_encoded)
        vae_decoded = self.vae.decode(vae_sampled)

        channel_outputs = self.channel(encode_result)

        
        if phase == None:
            mask = (torch.triu(torch.ones(emb_decoder_input.shape[1], emb_decoder_input.shape[1])) == 1).transpose(0, 1)
            mask = (mask.masked_fill(mask == 0, True).masked_fill(mask == 1, False)).to(device)
            return encode_result, self.channel(encode_result), \
                   self.decoder(torch.cat([embedding(torch.full([emb_decoder_input.shape[0], 1], source_vocab['<bos>'], dtype=torch.long, device=device)), 
                                        emb_decoder_input[:, :-1, :]], dim=1).to(device), 
                                channel_outputs,
                                valid_lens, 
                                mask)
        else:
            return encode_result, self.channel(encode_result)


class MI(nn.Module):
    def __init__(self): #6 layers
        super(MI, self).__init__()
        self.linear1 = nn.Linear(16 * 2, 8 * CONFIG.V)
        self.relu1 = nn.ReLU()
        #self.dropout1 = nn.Dropout(p=0.5)
        self.linear2 = nn.Linear(8 * CONFIG.V, 4 * CONFIG.V)
        self.relu2 = nn.ReLU()
        #self.dropout2 = nn.Dropout(p=0.5)
        self.linear3 = nn.Linear(4 * CONFIG.V, 2 * CONFIG.V)
        self.relu3 = nn.ReLU()
        #self.dropout3 = nn.Dropout(p=0.5)
        self.linear4 = nn.Linear(2 * CONFIG.V, 2 * CONFIG.V)
        self.relu4 = nn.ReLU()
        #self.dropout4 = nn.Dropout(p=0.5)
        self.linear5 = nn.Linear(2 * CONFIG.V, 2 * CONFIG.V)
        self.relu5 = nn.ReLU()
        #self.dropout5 = nn.Dropout(p=0.5)
        self.linear6 = nn.Linear(2 * CONFIG.V, 1)
        self.relu6 = nn.ReLU()
        
    def network(self, X, Y):
        x = self.relu1(self.linear1(torch.cat([X, Y], dim=1)))
        #x = self.dropout1(x)
        x = self.relu2(self.linear2(x))
        #x = self.dropout2(x)
        x = self.relu3(self.linear3(x))
        #x = self.dropout3(x)
        x = self.relu4(self.linear4(x))
        #x = self.dropout4(x)
        x = self.relu5(self.linear5(x))
        #x = self.dropout5(x)
        x = self.relu6(self.linear6(x))
        return x
        
    def forward(self, X, Y, valid_lens): 
        
        mask = (torch.arange((X.shape[1]), dtype=torch.long,
                            device=X.device)[None, :] >= valid_lens[:, None]).reshape(-1)
        # Reshape X and Y first, then take them out
        X = X.reshape(-1, 16)
        Y = Y.reshape(-1, 16)
        
        X = X[mask == False]
        Y = Y[mask == False]
        
        # sample
        sample_size = X.shape[0]
        idx = list(range(sample_size))
        random.shuffle(idx)
        idx = torch.tensor(idx).to(device)
        X = X[idx]
        Y = Y[idx]
        idx_shuffle = list(range(sample_size))
        random.shuffle(idx_shuffle)
        idx_shuffle = torch.tensor(idx_shuffle).to(device)
        shuffle_Y = Y[idx_shuffle]
        
        output_joint = self.network(X, Y)
        output_marginal = self.network(X, shuffle_Y)
        
        return output_joint, output_marginal

# Define the driving app
def translate(input_file_path, output_file_path, checkpoint_path, batch_size=32, num_steps=40):
    # Load the checkpoint
    #checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    #source_vocab = checkpoint['source_vocab']
    embedding_state_dict = checkpoint['embedding_state_dict']
    model_state_dict = checkpoint['model_state_dict']
    
    # Initialize the embedding and model
    embedding = nn.Embedding.from_pretrained(torch.zeros((len(source_vocab), embedding_state_dict['embedding.weight'].shape[1])), freeze=False)
    embedding_state_dict['weight'] = embedding_state_dict.pop('embedding.weight')
    embedding.load_state_dict(embedding_state_dict)
    embedding_state_dict['embedding.weight'] = embedding_state_dict.pop('weight')
    model = DeepST().to(device)
    model.load_state_dict(model_state_dict)

    # Set the model to evaluation mode
    model.eval()

    # Set up the data loader
    input_corpus = open(CONFIG.input_file_path, 'r').readlines()
    data_loader = torch.utils.data.DataLoader(dataset=MyDataset(corpus=input_corpus), batch_size=CONFIG.batch_size, shuffle=False, collate_fn=collate_fn)

    # Set the model to evaluation mode
    model.eval()

    num_steps = CONFIG.num_steps
    bleus = []

    with torch.no_grad():
        val_corpus = open(CONFIG.input_file_path, 'r').readlines()
        val_data_loader = torch.utils.data.DataLoader(dataset=MyDataset(corpus=val_corpus), batch_size=CONFIG.batch_size, shuffle=True, collate_fn=collate_fn)

        for index, data in enumerate(tqdm(val_data_loader), 0):
            inputs, valid_lens = data
            inputs, valid_lens = inputs.to(device), valid_lens.to(device)
            emb_inputs = embedding(inputs)
            _, channel_outputs = model(emb_inputs, valid_lens)
            # decoder's first element is <bos>
            outputs = torch.cat([torch.full([inputs.shape[0], 1], source_vocab['<bos>'], dtype=torch.long, device=device),
                                 torch.full([inputs.shape[0], num_steps - 1], source_vocab['<pad>'], dtype=torch.long, device=device)],
                                dim=1).to(device)
            # continue_idxtest which sentence can continuously generate
            continue_idx = torch.arange(inputs.shape[0], device=device)
            num_step = 0
            while not len(continue_idx) == 0 and num_step < num_steps - 1:
                emb_outputs = embedding(outputs[continue_idx, :num_step + 1])
                pred_words = model.decoder(emb_outputs, channel_outputs[continue_idx], valid_lens[continue_idx], mode='validate').argmax(dim=2)[:, -1:]
                outputs[continue_idx, num_step + 1] = pred_words.squeeze(1)
                continue_idx = continue_idx[(pred_words != source_vocab['<eos>']).squeeze(1)]
                num_step += 1

            # Postprocess the output sentence and print input/output pairs
            for i in range(inputs.shape[0]):
                input_sentence = source_vocab.to_tokens(list(inputs[i].cpu().numpy()))
                input_sentence = [t for t in input_sentence if t not in ['<pad>', '<bos>', '<eos>']]
                input_sentence = ' '.join(input_sentence)

                output_sentence = source_vocab.to_tokens(list(outputs[i, 1:].cpu().numpy()))
                output_sentence = [t for t in output_sentence if t not in ['<pad>', '<eos>']]
                output_sentence = ' '.join(output_sentence)
                print(f"Input: {input_sentence}")
                print(f"Output: {output_sentence}")

                # Compute the BLEU score
                bleu = sentence_bleu([input_sentence.split()], output_sentence.split(), smoothing_function=SmoothingFunction().method1)
                bleus.append(bleu)
                print(f"BLEU score: {bleu}\n")

            if index == 2:
                break

        print(f'BLEU score mean: {sum(bleus) / len(bleus)}')

        
        
# Call the function


translate(CONFIG.input_file_path, CONFIG.output_file_path, CONFIG.checkpoint_path, CONFIG.batch_size)


100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  7.53it/s]

Input: As far as the question of disposal costs is concerned , the fact is that very different systems are currently being applied here
Output: As far as the question of disposal costs is concerned , the fact is as very different systems are being adopted here
BLEU score: 0.7108332982740264

Input: If you stop talking to them and don't keep them as a FB <unk> what's to worry about ?
Output: If you stop talking to them and don't keep them as a FB someday what's to worry about ?
BLEU score: 0.8492326635760689

Input: I'm lucky I'm in a field where I can be curious for a living !
Output: I'm lucky I'm in a field where I can be curious for a living !
BLEU score: 1.0

Input: My target demographic is specific but it is very <unk> women who like women
Output: My target demographic is specific but it is very emphatically women who like women
BLEU score: 0.7825422900366437

Input: It can be hard to trust that people <unk> but many really do
Output: It can be hard to trust that people brag but m


