In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# Set up

Put this notebook in the folder `grammar_compositionality/`

In [None]:
!pip install -U wandb
!pip install transformers

In [None]:
!pip install pyyaml==5.4.1

In [None]:
!wandb login

# Data setting

In [None]:
import numpy as np
import pickle

num_words = 100
num_topics = 10
num_words_per_topic = num_words // num_topics
topic_model = {}

topic_term_matrix = np.zeros((num_topics, num_words))
for i in range(num_topics):
    words_of_topic_i = range(num_words_per_topic * i, num_words_per_topic * (i+1))
    for word in words_of_topic_i:
        topic_term_matrix[i][word] = num_topics / num_words
        topic_model[word] = i

print('topic_term_matrix')
print(topic_term_matrix)

topic_term_matrix
[[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
  0.1 0.1 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0

In [None]:
num_train_sentences = 10000
num_dev_sentences = 10000
num_test_sentences = 10000
sentence_len_min = 100 # 10 100
sentence_len_max = 150 # 40 150
alpha = [0.1] * num_topics

## Only run if need to re-generate data

In [None]:
with open(f"./trained_models/topic{num_topics}_word{num_words}.pkl", 'wb') as f:
    pickle.dump(topic_model, f)

In [None]:
import random
from scipy.stats import dirichlet


def write_lda_data(fn, num_sentences):
    with open(fn, 'wt') as f:
        for i in range(num_sentences):
            sentence_len = random.randint(sentence_len_min, sentence_len_max)
            topic_distr = dirichlet.rvs(alpha, size=None)
            for _ in range(sentence_len):
                topic = np.random.choice(range(num_topics), p=topic_distr)
                word = np.random.choice(range(num_words), p=topic_term_matrix[topic])
                f.write(f"{word} ")
            f.write(f"END\n")

write_lda_data(f"data/topic{num_topics}_word{num_words}_long.train", num_train_sentences)
write_lda_data(f"data/topic{num_topics}_word{num_words}_long.dev", num_dev_sentences)
write_lda_data(f"data/topic{num_topics}_word{num_words}_long.test", num_test_sentences)




# Generate config

In [None]:
def gen_vocab_str(num_words):
    vocab = dict(zip(['PAD', 'MASK', 'START', 'END'] + list(range(num_words)), range(num_words+4)))

    vocab_str = 'vocab:'
    for token in vocab:
        vocab_str += f"\n{' ' * 26}'{token}': {vocab[token]}"
    return vocab_str

In [None]:
print(gen_vocab_str(num_words))

vocab:
                          'PAD': 0
                          'MASK': 1
                          'START': 2
                          'END': 3
                          '0': 4
                          '1': 5
                          '2': 6
                          '3': 7
                          '4': 8
                          '5': 9
                          '6': 10
                          '7': 11
                          '8': 12
                          '9': 13
                          '10': 14
                          '11': 15
                          '12': 16
                          '13': 17
                          '14': 18
                          '15': 19
                          '16': 20
                          '17': 21
                          '18': 22
                          '19': 23
                          '20': 24
                          '21': 25
                          '22': 26
                          '23': 27
                          

In [None]:
for hiddenlayers in [1]:
  for num_heads in [1]:
    for optimizer in ['Adam']:  # ['Adam', 'SGD']
        for lr in [0.01]: 
            config_text = f"""
            corpus:
                train_corpus_loc: data/topic10_word100_long.train
                dev_corpus_loc: data/topic10_word100_long.dev
                test_corpus_loc:  data/topic10_word100_long.test
            language:
                name: lda
                num_topics: {num_topics}
                num_words: {num_words}
                {gen_vocab_str(num_words)}
                dev_sample_count:  {num_dev_sentences}
                test_sample_count: {num_test_sentences}
                train_sample_count: {num_train_sentences}
            lm:
                embedding_dim: {(num_words + 4)}
                hidden_dim: {(num_words + 4)}
                lm_type: BertForMaskedLMCustom
                residual: False  # TODO whether the self attention has residual connections
                attn_output_fc: False  # TODO whether the self attention output has a fully connected layer
                bert_intermediate: False  # TODO whether the BertLayer has a BertIntermediate (FC) sub-layer
                bert_output: False  # TODO whether the BertLayer has a BertOutput (FC with residual) sub-layer
                bert_head_transform: False  # whether the BertLMPredictionHead has a transform (FC) sub-layer
                layer_norm: False  # whether the model has LayerNorm
                num_layers: {hiddenlayers}
                save_path: lm.params
                num_heads: {num_heads}
                embedding_type: none
                token_embedding_type: one_hot  # trained or one_hot
                freeze_uniform_attention: True  # TODO freeze W^K and W^Q to 0 
                freeze_id_value_matrix: False  # TODO freeze W^V to I
                freeze_block_value_matrix: False  # TODO
                freeze_decoder_to_I: True
                no_softmax: False  # remove the final softmax layer and change the loss to MSELoss
            reporting:
                reporting_loc: ./trained_models/lda_bert_simplified_one_hot/  # TODO
                reporting_methods:
                - constraints
                plot_attention_dir: ./plot_attention/lda_bert_simplified_one_hot/  # TODO
                inspect_results_dir: ./inspect_results/lda_bert_simplified_one_hot/  # TODO
                num_sentences_to_plot: 5
                random: False  # TODO 
                log_all_steps_until: 0  # log all the first several steps to wandb
            training:
                batch_size: 40
                dropout: 0.0
                optimizer: {optimizer}  # Adam or SGD
                learning_rate: {lr}
                weight_decay: 0.0
                max_epochs: 20  # LIKELY TOO LOW, JUST A DEMO
                seed: 0
                objective: default  # default or contrastive or multi
                mask_prob: 0.15  # Should almost always be 0.0 for GPT
                mask_correct_prob: 0.1  # the proportion of "masked" tokens that show the correct token
                mask_random_prob: 0.1  # the proportion of "masked" tokens that show a random token
                zero_init_attn: False  # init W^K, W^Q, W^V to near 0
                zero_init_emb_dec: False  # init embedding and decoder to near 0
                zero_init_noise: 0.0  # noise for `near 0`
            experiment:
                repeat: 1  # number of times to re-train the model
            """
            with open(f"config/bert_lda_hiddenlayers{hiddenlayers}_heads{num_heads}_lr{lr}_one_hot.yaml", 'wt') as f:
                f.write(config_text)

# Train

In [None]:
!cat scripts/train_lda_bert.sh

#!/bin/bash

for hiddenlayers in 1
do
  for num_heads in 1
  do
    for lr in 0.01
    do
      python3 src/run_lm.py "config/bert_lda_hiddenlayers"$hiddenlayers"_heads"$num_heads"_lr"$lr"_one_hot.yaml"
    done
  done
done


In [None]:
!sh scripts/train_lda_bert.sh

  args = yaml.load(open(config_file))
Getting dataset from data/topic10_word100_long.train
Getting dataset from data/topic10_word100_long.dev
Getting dataset from data/topic10_word100_long.test
Construct the language model with args {'corpus': {'train_corpus_loc': 'data/topic10_word100_long.train', 'dev_corpus_loc': 'data/topic10_word100_long.dev', 'test_corpus_loc': 'data/topic10_word100_long.test'}, 'language': {'name': 'lda', 'num_topics': 10, 'num_words': 100, 'vocab': {'PAD': 0, 'MASK': 1, 'START': 2, 'END': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, '10': 14, '11': 15, '12': 16, '13': 17, '14': 18, '15': 19, '16': 20, '17': 21, '18': 22, '19': 23, '20': 24, '21': 25, '22': 26, '23': 27, '24': 28, '25': 29, '26': 30, '27': 31, '28': 32, '29': 33, '30': 34, '31': 35, '32': 36, '33': 37, '34': 38, '35': 39, '36': 40, '37': 41, '38': 42, '39': 43, '40': 44, '41': 45, '42': 46, '43': 47, '44': 48, '45': 49, '46': 50, '47': 51, '48': 52, '49'

# Plot attention or compute statistics

In [None]:
!cat scripts/inspect_result_lda_bert.sh

#!/bin/bash

for hiddenlayers in 1
do
  for num_heads in 1
  do
    for lr in 0.01
    do
      python3 src/inspect_result.py "config/bert_lda_hiddenlayers"$hiddenlayers"_heads"$num_heads"_lr"$lr"_one_hot.yaml"
    done
  done
done


In [None]:
!cat src/inspect_result.py

from argparse import ArgumentParser
from transformers import BertModel, GPT2Model
from seaborn import heatmap
import gensim
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import random
import torch

from run_lm import create_args, init_lm
import dataset
from topic_model import load_topic_model, get_top_term_topics
import utils


def cut_sentences_at_length(batch, max_sentence_len):
    assert type(batch) is tuple
    new_batch0 = batch[0][:, :max_sentence_len].clone().detach()  # observations
    new_batch1 = batch[1][:, :max_sentence_len].clone().detach()  # labels
    if batch[2] is not None:
        new_batch2 = batch[2][:, :max_sentence_len].clone().detach()  # attention_mask_all
    else:
        new_batch2 = None
    new_batch3 = [max_sentence_len]  # sentence length
    return new_batch0, new_batch1, new_batch2, new_batch3


def prepare_dev_data(dataset, num_sentences_to_plot, max_sentence_len=None):
    """
    Note: one batch contains `arg.training.

In [None]:
!sh scripts/inspect_result_lda_bert.sh

  args = yaml.load(open(config_file))
Getting dataset from data/topic10_word100_long.train
[computing labels]: 100% 10000/10000 [00:00<00:00, 20244.14it/s]
Getting dataset from data/topic10_word100_long.dev
[computing labels]: 100% 10000/10000 [00:00<00:00, 20693.30it/s]
Getting dataset from data/topic10_word100_long.test
[computing labels]: 100% 10000/10000 [00:00<00:00, 13213.52it/s]
config BertConfig {
  "attention_probs_dropout_prob": 0.0,
  "attn_output_fc": false,
  "bert_head_transform": false,
  "bert_intermediate": false,
  "bert_output": false,
  "classifier_dropout": null,
  "freeze_block_value_matrix": false,
  "freeze_decoder_to_I": true,
  "freeze_id_value_matrix": false,
  "freeze_uniform_attention": true,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 104,
  "initializer_range": 0.02,
  "intermediate_size": 104,
  "layer_norm": false,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 6000,
  "model_type": "bert",
  "num_attention_heads": 1,