In [1]:
from allennlp.data.tokenizers import Token, Tokenizer, PretrainedTransformerTokenizer

import nltk
import numpy as np
from os import listdir
from os.path import join as pathjoin
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
import tqdm

from minGPT.mingpt.model import GPT, GPTConfig
from minGPT.mingpt.trainer import Trainer, TrainerConfig
# make deterministic
from minGPT.mingpt.utils import sample, set_seed
set_seed(42)
np.random.seed(42)

In [2]:
import math
from torch.utils.data import Dataset


def detokenize(tokens):
    return ' '.join([str(x) for x in tokens[1:-1]]).replace(' ##', '')

class BPEDataset(Dataset):
    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
    
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        dix = [self.stoi[s] for s in chunk]
    
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

In [None]:
def generate_topic_dataset(train_text_file, state_dict_file, n_layer=8, n_head=8, n_embd=512,
                           texts_count=10, text_len=500):
    text_sentences = nltk.tokenize.sent_tokenize(open(train_text_file, 'r').read())
    tokens = np.concatenate([tokenizer.tokenize(sent)[1:-1] for sent in text_sentences])
    tokens = [str(token) for token in tokens]
    train_dataset = BPEDataset(tokens, block_size) 
    tokens_set = set(train_dataset.stoi.keys())
    print("dataset is loaded")
    
    mconf = GPTConfig(
        train_dataset.vocab_size, train_dataset.block_size,
        n_layer=n_layer, n_head=n_head, n_embd=n_embd
    )
    model = GPT(mconf)
    model.load_state_dict(torch.load(state_dict_file))
    print("model is loaded")
    
    tconf = TrainerConfig(num_workers=1)
    trainer = Trainer(model, train_dataset, None, tconf)
    
    for topic, topic_keywords in topics.items():
        first_word_candidates = list(set(topic_keywords) & tokens_set)
        
        for text_id in range(texts_count):
            context = [np.random.choice(first_word_candidates)]
            x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
            y = sample(model, x, text_len, temperature=1.0, sample=True, top_k=10)[0]
            completion = ' '.join([train_dataset.itos[int(i)] for i in y]).replace(' ##', '')
            yield completion, topic

In [None]:
def read_classifier():
    pass

def read_generator(generator_dir):
    pass

In [None]:
def get_generator_prob_loss():
    pass

def get_generator_full_loss():
    pass