In [1]:
import numpy as np
from scripts.applications.prediction_models.models.syntax_infused_model  import *
from scripts.applications.prediction_models.models.sequential_model import *
from scripts.applications.prediction_models.models.neural_cky import *
from scripts.language_processing.language_builder.neural_builder.models.tn_pcfg import TNPCFG
from scripts.language_processing.language_builder.neural_builder.models.utils import rebuild_T_from_head_left_right
from scripts.applications.prediction_models.models.word_embedding import *

def generate_random_eeg_data(cnt_channels = 19, cnt_time_samples = 1000): # unit: mV
    return 0.1 * np.random.random_sample((cnt_channels, cnt_time_samples)) * 90

def generate_random_eeg_microstates(cnt_channels = 19, cnt_time_samples = 1000): # unit: mV
    return np.random.randint(0, cnt_words, (cnt_time_samples))

def generate_random_segmented_eeg_data(cnt_channels = 19, cnt_time_samples = 1000): # unit: mV
    dummy_eeg_data = generate_dummy_eeg_data(cnt_channels, cnt_time_samples)
    split_points = [0]
    p = np.random()
    for i in range(1, cnt_time_samples):
        if np.random() < p:
            split_points.append(i)
    split_points.append(cnt_time_samples)
    return [dummy_eeg_data[split_points[i - 1], split_points[i]] for i in range(1, len(split_points))]

def generate_random_corpus(word_count, cnt_article, article_max_length):
    return [np.random.randint(0, word_count, size = (np.random.randint(0, article_max_length))) for _ in range(cnt_article)]
    

In [2]:
main_configuration = {
'cnt_channels': 19,
'cnt_words': 10,
'word_emb_size': 200,
'syntax_infused_model_args': 
    {'NT': 20,
              'T': 65,
              's_dim': 50,
              'r_dim': 75,
              'word_emb_size': 200,
            'cnt_words': 10,
            'summary_parameters': True,
    }
}
cnt_channels = main_configuration['cnt_channels']
cnt_words = main_configuration['cnt_words']
word_emb_size = main_configuration['word_emb_size']

In [3]:
LNN_electrode_value_based_prediction_model = LNNElectrodeValueBasedPredictionModel(ncp_input_size = cnt_channels, hidden_size=100, output_size=1, sequence_length=1)
combine_model = SimpleConcatCombing()

In [4]:
embeding_model = WordEmbeddingModel(vocab_size = cnt_words, embedding_dim = word_emb_size, context_size = 2)
word_embeddings = embeding_model.embeddings.weight

In [5]:
args = main_configuration['syntax_infused_model_args']
args['word_embeddings'] = word_embeddings

In [6]:
corpus = generate_random_corpus(cnt_words, 1000, 2000)
sentence_for_inference = corpus[0]

In [7]:
tn_pcfg = TNPCFG(args=args)
inference = tn_pcfg.forward(torch.Tensor((np.array(sentence_for_inference)).reshape((1, len(corpus[0])))))
# convert the inference result of 'unary' array to a more formal form.
# the original unary is a 2-dimension array, in which i-th row is possibility of each terminate symbol directly deduct to the word at time point t.
# in the origin ouput, unary[i] = unary[j] if word_sequence[i] = word_sequence[j]
# now we put each unique word's feature into a 2-dimension matrix.
inference_unary = np.zeros((args['T'], args['cnt_words']))
original_unary = inference['unary'].detach().numpy()[0]
sequence_length = original_unary.shape[0]
for i in range(sequence_length):
    inference_unary[:, sentence_for_inference[i]] = original_unary[i]

def rebuild_T_from_head_left_right(head, left, right, NT, T):
    r_dim = head.shape[1]
    sum_NT_T = NT + T
    T = np.zeros((NT * sum_NT_T * sum_NT_T))
    for r in range(r_dim):
        T += np.kron(np.kron(head[:, r].detach().numpy(), left[:, r].detach().numpy()), right[:, r].detach().numpy())
    return T.reshape((NT, sum_NT_T, sum_NT_T))

T = rebuild_T_from_head_left_right(inference['head'][0], inference['left'][0], inference['right'][0], args['NT'], args['T'])

build td pcfg
device = cpu, NT = 20, T = 65, V = 10, s_dim = 50, r = 75, word_emb_size = 200
begin forward>>>>>>>>>>>>>, input shape = torch.Size([1, 789])
b, n = 1, 789
torch.Size([65, 10])
torch.Size([1, 789, 65]) torch.Size([1, 20]) torch.Size([1, 20, 75]) torch.Size([1, 85, 75]) torch.Size([1, 85, 75])


In [8]:
args['grammar_starts'] = inference['root'].detach().numpy()[0]
args['grammar_preterminates'] = inference_unary
args['grammar_double_nonterminates'] = T
args['beam_search_strategy'] = select_tops
args['preterminate_feature_generation_model'] = NN_CYK_FeatureCombingModel_Preterminate()
args['merge_model'] = NN_CYK_FeatureCombingModel_NonPreterminate()
nn_cyk_model = NN_CYK_Model(args)

In [9]:
simple_full_connection_prediction_model = FCPrediction(input_size = 228)

In [10]:
syntax_infused_model = \
    SyntaxInfusedModel(sequential_model = LNN_electrode_value_based_prediction_model\
                       ,syntax_model = nn_cyk_model\
                       , combining_model = combine_model\
                       , prediction_model = simple_full_connection_prediction_model)

Fix base sequential prediction model...


In [11]:
from scripts.applications.prediction_models.models.train.train import ModelTrainer

trainer = ModelTrainer({
    'model': syntax_infused_model,
    'optimizer': optim.Adam(syntax_infused_model.parameters(), lr = 0.001),
    'train_arg':{
        'shuffle_trainning_set': True,
        'train_only': True,
        'batch_size_train': 1,
        'is_unsupervisor_learning': False,
        'clip_gradient': False,
        'clip': 0,
        'is_print_loss_per_batch': False,
        'max_epoches': 10
    },
    'save_model_automatically': False
})

In [20]:
random_eeg_data = generate_random_eeg_data(cnt_channels = cnt_channels).T
random_labels = np.random.randint(0, 2, random_eeg_data.shape[1])
random_words = np.random.randint(0, cnt_words, size=random_eeg_data.shape[0])

In [21]:
trainer.train([{'eeg_data':random_eeg_data, 'words':random_words}], [torch.Tensor(random_labels)])

TypeError: not a sequence

In [40]:
sd = torch.utils.data.DataLoader(torch.utils.data.StackDataset([{'eeg_data':random_eeg_data, 'words':random_words}, {'eeg_data':random_eeg_data, 'words':random_words}]), batch_size  = 2)

In [48]:
for s in sd:
    print(len(s[0]['words']))

2
