In [49]:
import csv
from docx import Document
from textstat import textstat
from torch import nn
from transformers import XLNetTokenizerFast, XLNetModel, BertTokenizer, BertForMaskedLM
import numpy as np
import spacy
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from nltk.tokenize import sent_tokenize
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.classification import MultilabelPrecision, MultilabelRecall, MultilabelAccuracy, MultilabelF1Score, MultilabelHammingDistance, Accuracy, Precision, Recall, F1Score, BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score
import sklearn
from keybert import KeyBERT
import time
from tqdm import tqdm

In [130]:
nlp = spacy.load("en_core_web_lg")
xlnet_tokenizer = XLNetTokenizerFast.from_pretrained("xlnet-base-cased")
xlnet_model = XLNetModel.from_pretrained("xlnet-base-cased", output_attentions=True)
xlnet_model.eval()
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
bert_model.eval()
kw_model = KeyBERT()

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetModel: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This 

In [67]:
# a = torch.rand((1,2,3))
# a.view

# from torch.nn.modules.container import ModuleList
# import copy
#
# class WordSentenceSplitTransformerEncoder(nn.TransformerEncoder):
#
#     def __init__(self, encoder_layers, num_layers, norm=None, enable_nested_tensor=False):
#         super(nn.TransformerEncoder, self).__init__()
#         self.layers = ModuleList(encoder_layers)
#         self.num_layers = num_layers
#         self.norm = norm
#         self.enable_nested_tensor = enable_nested_tensor

# def hook(module, fea_in, fea_out):
#     print(fea_in[0].shape)
#     print(module.get_parameter())
#     print('hooking!')
#
# model = ClassificationModel(24)
# tgt_name = 'transformer_encoder.layers.1.self_attn'
# for name, module in model.named_modules():
#     if name == tgt_name:
#         module.register_forward_hook(hook=hook)

<function Tensor.view>

In [157]:
class ClassificationModel(torch.nn.Module):

    def __init__(self, transformer_in_feature_size=768, transformer_out_feature_size=64, difficulty_feature_size=3, topic_feature_size=1, eye_feature_size=12, head_num=10, layer_count=2):
        super(ClassificationModel, self).__init__()

        feature_size = transformer_out_feature_size + difficulty_feature_size + topic_feature_size + eye_feature_size
        self.transformer_linear = nn.Linear(in_features=transformer_in_feature_size, out_features=transformer_out_feature_size)
        self.normalize = nn.LayerNorm
        self.transformer_encoder_ahead_layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=feature_size, nhead=head_num, batch_first=True),
            num_layers=layer_count - 1,
        )
        self.transformer_encoder_word_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=head_num, batch_first=True)
        self.transformer_encoder_sent_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=head_num, batch_first=True)

        # self.mask = torch.tensor([])

        self.word_normal_linear = nn.Linear(in_features=feature_size, out_features=1)
        self.sent_normal_linear = nn.Linear(in_features=feature_size, out_features=1)
        self.word_normal_activation = nn.Sigmoid()
        self.sent_normal_activation = nn.Sigmoid()

        self.word_multilabel_linear = nn.Linear(in_features=feature_size, out_features=1)
        self.sent_multilabel_linear = nn.Linear(in_features=feature_size, out_features=2)
        self.word_multilabel_activation = nn.Sigmoid()
        self.sent_multilabel_activation = nn.Sigmoid()

    def forward(self, transformer_features, difficulty_features, topic_features, eye_features, sent_ranges):
        current_device = next(self.parameters()).device
        compressed_transformer_features = self.transformer_linear(transformer_features)
        word_features = torch.concat((compressed_transformer_features, eye_features, difficulty_features, topic_features), dim=-1)
        word_features = self.transformer_encoder_ahead_layers(word_features)

        word_count = transformer_features.shape[-2]

        word_mask = torch.concat([torch.zeros((8, word_count, word_count)), torch.ones((2, word_count, word_count))]).to(current_device)
        sent_mask = torch.concat([torch.ones((2, word_count, word_count)), torch.zeros((8, word_count, word_count))]).to(current_device)
        if len(transformer_features.shape) > 2:
            batch_size = transformer_features.shape[0]
            word_mask = word_mask.repeat((batch_size, 1, 1))
            sent_mask = sent_mask.repeat((batch_size, 1, 1))
            word_level_encoded_features = self.transformer_encoder_word_layer(word_features, src_mask=word_mask)
            sent_level_encoded_features = self.transformer_encoder_sent_layer(word_features, src_mask=sent_mask)
            sent_level_encoded_features = torch.concat(
                [torch.mean(sent_level_encoded_features[ : , start_id : end_id + 1, : ], dim=-2, keepdim=True) for [start_id, end_id] in sent_ranges],
                dim=-2,
            )
            # print(sent_level_encoded_features.shape)
        else:
            word_level_encoded_features = self.transformer_encoder_word_layer(word_features, src_mask=word_mask)
            sent_level_encoded_features = self.transformer_encoder_sent_layer(word_features, src_mask=sent_mask)
            sent_level_encoded_features = torch.concat(
                [torch.mean(sent_level_encoded_features[start_id : end_id + 1, : ], dim=-2, keepdim=True) for [start_id, end_id] in sent_ranges],
                dim=-2,
            )

        word_normal_features = self.word_normal_linear(word_level_encoded_features)
        sent_normal_features = self.sent_normal_linear(sent_level_encoded_features)
        word_normal_classifications = self.word_normal_activation(word_normal_features)
        sent_normal_classifications = self.sent_normal_activation(sent_normal_features)

        word_multilabel_features = self.word_multilabel_linear(word_level_encoded_features)
        sent_multilabel_features = self.sent_multilabel_linear(sent_level_encoded_features)
        word_multilabel_classifications = self.word_multilabel_activation(word_multilabel_features)
        sent_multilabel_classifications = self.sent_multilabel_activation(sent_multilabel_features)

        return word_normal_classifications, sent_normal_classifications, word_multilabel_classifications, sent_multilabel_classifications

# just for test
BATCH_NUM = 32
WORD_COUNT = 538
TRANSFORMER_IN_FEAT_SIZE = 768
TRANSFORMER_OUT_FEAT_SIZE = 64
EYE_FEAT_SIZE = 12
DIFFICULTY_FEAT_SIZE = 3
TOPIC_FEAT_SIZE = 1

test_model = ClassificationModel(TRANSFORMER_IN_FEAT_SIZE, TRANSFORMER_OUT_FEAT_SIZE, DIFFICULTY_FEAT_SIZE, TOPIC_FEAT_SIZE, EYE_FEAT_SIZE)
transformer_feats = torch.rand(BATCH_NUM, WORD_COUNT, TRANSFORMER_IN_FEAT_SIZE)
eye_feats = torch.rand(BATCH_NUM, WORD_COUNT, EYE_FEAT_SIZE)
difficulty_feats = torch.rand(BATCH_NUM, WORD_COUNT, DIFFICULTY_FEAT_SIZE)
topic_feats = torch.rand(BATCH_NUM, WORD_COUNT, TOPIC_FEAT_SIZE)
sent_rs = [(0, 99), (100, 199), (200, 299), (300, 399), (400, 499), (500, 537)]
word_normal_cls, sent_normal_cls, word_multilabel_cls, sent_multilabel_cls = test_model(transformer_feats, difficulty_feats, topic_feats, eye_feats, sent_rs)
word_normal_cls.shape, sent_normal_cls.shape, word_multilabel_cls.shape, sent_multilabel_cls.shape

(torch.Size([32, 538, 1]),
 torch.Size([32, 6, 1]),
 torch.Size([32, 538, 1]),
 torch.Size([32, 6, 2]))

In [None]:
def get_docx_text(docx_path):
    document = Document(docx_path)
    read_text = ''
    for pa in document.paragraphs:
        read_text += pa.text + ' '
    return read_text


class ReadingArticle:

    all_difficulty_values = [[], [], []]

    word_fam_map = {}
    with open('/home/wtpan/memx4edu-code/mrc2.dct', 'r') as fp:
        i = 0
        for line in fp:
            line = line.strip()

            word, phon, dphon, stress = line[51:].split('|')

            w = {
                'wid': i,
                'nlet': int(line[0:2]),
                'nphon': int(line[2:4]),
                'nsyl': int(line[4]),
                'kf_freq': int(line[5:10]),
                'kf_ncats': int(line[10:12]),
                'kf_nsamp': int(line[12:15]),
                'tl_freq': int(line[15:21]),
                'brown_freq': int(line[21:25]),
                'fam': int(line[25:28]),
                'conc': int(line[28:31]),
                'imag': int(line[31:34]),
                'meanc': int(line[34:37]),
                'meanp': int(line[37:40]),
                'aoa': int(line[40:43]),
                'tq2': line[43],
                'wtype': line[44],
                'pdwtype': line[45],
                'alphasyl': line[46],
                'status': line[47],
                'var': line[48],
                'cap': line[49],
                'irreg': line[50],
                'word': word,
                'phon': phon,
                'dphon': dphon,
                'stress': stress
            }
            if word not in word_fam_map:
                word_fam_map[word] = w['fam']
            word_fam_map[word] = max(word_fam_map[word], w['fam'])
            i += 1

    def __init__(self, article_id, article_text):
        self.id = article_id
        self.text = article_text
        self._spacy_doc = nlp(article_text)
        self._word_list = [token.text for token in self._spacy_doc]
        self._transformer_features = self._generate_word_embedding()
        self._difficulty_features = self._generate_word_difficulty()
        self._topic_features = self._generate_topic_rate(top_n=5)
        self._sentence_word_mapping = self._generate_sentence_mapping()

    def _generate_word_embedding(self):
        inputs = xlnet_tokenizer(self.text, return_tensors='pt')
        word_token_mapping = self.generate_token_mapping(self._word_list, inputs.tokens())
        outputs = xlnet_model(**inputs)
        token_embeddings = outputs[0].squeeze()
        word_embeddings = []
        for start_id, end_id in word_token_mapping:
            if start_id <= end_id:
                word_embeddings.append(torch.mean(token_embeddings[start_id : end_id + 1, :], 0, dtype=torch.float32))
            else:
                word_embeddings.append(torch.zeros((token_embeddings.shape[1],), dtype=torch.float32))
        return word_embeddings

    def _generate_sentence_mapping(self):
        sentences = sent_tokenize(self.text)
        sentence_word_mapping = self.generate_token_mapping(sentences, self._word_list)
        return sentence_word_mapping

    @staticmethod
    def generate_token_mapping(string_list, token_list):
        string_pos = 0
        string_idx = 0
        token_string_idx_list = []
        max_cross_count = 3
        for token_idx, token in enumerate(token_list):
            original_token = token.replace('▁', '')
            flag = False
            while string_idx < len(string_list) and string_list[string_idx][string_pos : string_pos + len(original_token)] != original_token:
                string_pos += 1
                if string_pos >= len(string_list[string_idx]):
                    cross_count = 1
                    prefix = string_list[string_idx]
                    pre_length = len(string_list[string_idx])
                    while cross_count <= max_cross_count and string_idx + cross_count < len(string_list):
                        prefix += string_list[string_idx + cross_count]
                        new_string_pos = 0
                        while new_string_pos + len(original_token) <= len(prefix) and new_string_pos < len(string_list[string_idx]):
                            if prefix[new_string_pos : new_string_pos + len(original_token)] == original_token and new_string_pos + len(original_token) > len(string_list[string_idx]):
                                string_pos = new_string_pos + len(original_token) - pre_length
                                flag = True
                                break
                            new_string_pos += 1
                        if flag:
                            break
                        pre_length += len(string_list[string_idx + cross_count])
                        cross_count += 1
                    if flag:
                        for delta_idx in range(cross_count + 1):
                            token_string_idx_list.append((token_idx, string_idx + delta_idx))
                        string_idx += cross_count
                        if string_idx < len(string_list) and string_pos == len(string_list[string_idx]):
                            string_pos = 0
                            string_idx += 1
                        break
                    else:
                        string_pos = 0
                        string_idx += 1
            if flag:
                continue
            if string_idx < len(string_list) and string_pos == len(string_list[string_idx]):
                string_pos = 0
                string_idx += 1
            if string_idx >= len(string_list):
                continue
            token_string_idx_list.append((token_idx, string_idx))
            string_pos += len(original_token)

        # for token_idx, string_idx in token_string_idx_list:
        #     print(inputs.tokens()[token_idx], string_list[string_idx])

        string_token_mapping = [(float('inf'), 0)] * len(string_list)
        for token_idx, string_idx in token_string_idx_list:
            string_token_mapping[string_idx] = (min(string_token_mapping[string_idx][0], token_idx), max(string_token_mapping[string_idx][1], token_idx))

        return string_token_mapping

    @staticmethod
    def get_word_familiar_rate(word_text):
        capital_word = word_text.upper()
        return ReadingArticle.word_fam_map.get(capital_word, 0)

    def _generate_word_difficulty(self):
        word_difficulties = []
        for token in self._spacy_doc:
            if token.is_alpha and not token.is_stop:
                fam = self.get_word_familiar_rate(token.text)
                if fam == 0:
                    fam = self.get_word_familiar_rate(token.lemma_)
                syllable = textstat.syllable_count(token.text)
                length = len(token.text)
                score_tensor = torch.tensor([
                    # float(textstat.syllable_count(token.text) > 2),
                    # float(len(token.text) > 7),
                    # float(fam < 482),
                    float(syllable),
                    float(length),
                    float(fam)
                ])
                ReadingArticle.all_difficulty_values[0].append(syllable)
                ReadingArticle.all_difficulty_values[1].append(length)
                ReadingArticle.all_difficulty_values[2].append(fam)
            else:
                score_tensor = torch.zeros((3,))
            word_difficulties.append(score_tensor)
        return word_difficulties

    def _generate_topic_rate(self, top_n):
        keywords = kw_model.extract_keywords(self.text, keyphrase_ngram_range=(1,1), stop_words=None, top_n=top_n)
        keywords = {k: v for k, v in keywords}
        # print(keywords)
        topic_rate = []
        for token in self._spacy_doc:
            topic_rate.append(torch.tensor([keywords.get(token.text, 0.)]))
        return topic_rate

    def get_word_filter_id_set(self, only_alpha=True, filter_digit=True, filter_punctuation=True, filter_stop_words=False):
        word_filter_id_set = set()
        for word in self._spacy_doc:
            if only_alpha and not word.is_alpha:
                word_filter_id_set.add(word.i)
            if filter_digit and word.is_digit:
                word_filter_id_set.add(word.i)
            if filter_punctuation and word.is_punct:
                word_filter_id_set.add(word.i)
            if filter_stop_words and word.is_stop:
                word_filter_id_set.add(word.i)
        return word_filter_id_set

    def get_word_list(self):
        return self._word_list

    def get_transformer_features(self):
        return self._transformer_features

    def get_difficulty_features(self):
        values = []
        mean_values = [np.mean(values) for values in ReadingArticle.all_difficulty_values]
        std_values = [np.std(values) for values in ReadingArticle.all_difficulty_values]
        for feature in self._difficulty_features:
            values.append(torch.tensor([(feature[column] - mean_values[column]) / std_values[column] for column in range(3)], dtype=torch.float32))
        return values
        # return self._difficulty_features

    def get_topic_features(self):
        return self._topic_features

    def get_sentence_word_mapping(self):
        return self._sentence_word_mapping


class ReadingExperiment:

    all_values = {
        'word_understand': [],
        'reading_times': [],
        'number_of_fixations': [],
        'second_pass_dwell_time_of_sentence': [],
        'total_dwell_time_of_sentence': [],
        'reading_times_of_sentence': [],
        'saccade_times_of_sentence': [],
        'forward_times_of_sentence': [],
        'backward_times_of_sentence': [],
        # 'saccade_times_of_para': [],
        # 'forward_saccade_times_of_para': [],
        # 'backward_saccade_times_of_para': [],
    }

    mean_values = {}
    std_values = {}

    def __init__(self, experiment_id, user, article_id, timestamp):
        self.id = experiment_id
        self.user = user
        self.article_id = article_id
        self.timestamp = timestamp
        self.reading_records = []
        self.default_record = {
            'word': '',
            'word_understand': 0.,
            'word_watching': 0.,
            'sentence_understand': 0.,
            'mind_wandering': 0.,
            'reading_times': 0.,
            'number_of_fixations': 0.,
            'second_pass_dwell_time_of_sentence': 0.,
            'total_dwell_time_of_sentence': 0.,
            'reading_times_of_sentence': 0.,
            'saccade_times_of_sentence': 0.,
            'forward_times_of_sentence': 0.,
            'backward_times_of_sentence': 0.,
            # 'saccade_times_of_para': 0.,
            # 'forward_saccade_times_of_para': 0.,
            # 'backward_saccade_times_of_para': 0.,
        }
        self.article_record_word_id_map = {}
        self.record_article_word_id_map = {}

    def add_reading_record(self, **kwargs):
        reading_record = self.default_record.copy()
        reading_record.update(**kwargs)
        for key in ReadingExperiment.all_values:
            ReadingExperiment.all_values[key].append(reading_record[key])
        self.reading_records.append(reading_record)

    def get_word_list(self):
        return [record['word'] for record in self.reading_records]

    def get_normalized_features(self, columns):
        values = []
        if not ReadingExperiment.mean_values or not ReadingExperiment.std_values:
            # print('haha first time')
            ReadingExperiment.mean_values = {key: np.mean(values) for key, values in ReadingExperiment.all_values.items()}
            ReadingExperiment.std_values = {key: np.std(values) for key, values in ReadingExperiment.all_values.items()}
        for record in self.reading_records:
            values.append(torch.tensor([(record[column] - ReadingExperiment.mean_values[column]) / ReadingExperiment.std_values[column] if ReadingExperiment.std_values[column] != 0. else 0. for column in columns], dtype=torch.float32))
        return values

    def get_values(self, columns, reverse=False):
        values = []
        if reverse:
            for record in self.reading_records:
                values.append([1 - record[column] for column in columns])
        else:
            for record in self.reading_records:
                values.append([record[column] for column in columns])
        return values


class ReadingDataset(Dataset):

    def __init__(self, reading_articles_path, reading_experiments_path, eye_feature_names, eye_feature_size, word_label_names, sent_label_names, word_watching_label_name, article_word_mismatch_thr, record_word_mismatch_thr):
        self.eye_feature_names = eye_feature_names
        self.eye_feature_size = eye_feature_size
        if eye_feature_size < len(eye_feature_names):
            raise Exception
        self.word_label_names = word_label_names
        self.sent_label_names = sent_label_names
        self.word_watching_label_name = word_watching_label_name
        self.article_word_mismatch_thr = article_word_mismatch_thr
        self.record_word_mismatch_thr = record_word_mismatch_thr

        self.article_record_word_ids_mapping = {}

        self.article_id_map = {}
        with open(reading_articles_path) as fp:
            cr = csv.reader(fp)
            for row in cr:
                article_id = int(row[0])
                article_text = row[1]
                article = ReadingArticle(article_id, article_text)
                self.article_id_map[article_id] = article

        self.experiment_id_timestamp_map = {}
        df = pd.read_csv(reading_experiments_path)
        all_records = df.to_dict('records')
        for record in all_records:
            experiment_id = record['experiment_id']
            user = record['user']
            article_id = record['article_id']
            timestamp = record['time']
            # word_watching = record['word_watching']
            # timestamp = -1
            if experiment_id not in self.experiment_id_timestamp_map:
                self.experiment_id_timestamp_map[experiment_id] = {}
            if timestamp not in self.experiment_id_timestamp_map[experiment_id]:
                self.experiment_id_timestamp_map[experiment_id][timestamp] = ReadingExperiment(experiment_id, user, article_id, timestamp)
            self.experiment_id_timestamp_map[experiment_id][timestamp].add_reading_record(**record)
        # print(self.experiment_id_timestamp_map)
        self.experiments = [v for vs in self.experiment_id_timestamp_map.values() for v in vs.values()]
        # print(self.experiments)

        ti = time.time()
        self.experiment_data = []
        self.experiment_lengths = []
        for experiment in tqdm(self.experiments):
            eye_features = experiment.get_normalized_features(self.eye_feature_names)
            word_multilabel_single_labels = experiment.get_values(self.word_label_names, reverse=True)
            sent_multilabel_single_labels = experiment.get_values(self.sent_label_names, reverse=True)
            word_watching_situations = experiment.get_values([self.word_watching_label_name])
            article = self.article_id_map[experiment.article_id]
            transformer_features = article.get_transformer_features()
            difficulty_features = article.get_difficulty_features()
            topic_features = article.get_topic_features()
            word_filter_id_set = article.get_word_filter_id_set()
            sentence_word_mapping = article.get_sentence_word_mapping()
            article_word_list = article.get_word_list()
            article_record_word_ids = self._match_article_record_word_list(article, experiment)

            filtered_transformer_features = []
            filtered_difficulty_features = []
            filtered_topic_features = []
            filtered_eye_features = []

            filtered_word_watching_situations = []
            filtered_sent_watching_situations = [False]

            filtered_word_normal_labels = []
            filtered_sent_normal_labels = [True]

            filtered_word_multilabel_labels = []
            filtered_sent_multilabel_labels = [[False, False]]

            filtered_sent_word_mapping = [(0, 0)]

            filtered_word_list = []

            current_sentence_id = 0
            current_word_id = 0
            for article_word_id, record_word_id in article_record_word_ids:
                if article_word_id not in word_filter_id_set:
                    filtered_transformer_features.append(transformer_features[article_word_id])
                    filtered_difficulty_features.append(difficulty_features[article_word_id])
                    filtered_topic_features.append(topic_features[article_word_id])
                    filtered_eye_features.append(torch.concat([eye_features[record_word_id], torch.zeros(self.eye_feature_size - len(self.eye_feature_names))]))

                    filtered_word_watching_situations.append(word_watching_situations[record_word_id][0])

                    filtered_word_normal_labels.append(sum(word_multilabel_single_labels[record_word_id]) == 0)

                    filtered_word_multilabel_labels.append(word_multilabel_single_labels[record_word_id])

                    filtered_word_list.append(article_word_list[article_word_id])

                    if article_word_id > sentence_word_mapping[current_sentence_id][1]:
                        while article_word_id > sentence_word_mapping[current_sentence_id][1]:
                            current_sentence_id += 1
                        filtered_sent_word_mapping.append((current_word_id, current_word_id))
                        filtered_sent_watching_situations.append(word_watching_situations[record_word_id][0])
                        filtered_sent_normal_labels.append(sum(sent_multilabel_single_labels[record_word_id]) == 0)
                        filtered_sent_multilabel_labels.append(torch.tensor(sent_multilabel_single_labels[record_word_id]))
                    else:
                        filtered_sent_word_mapping[-1] = (filtered_sent_word_mapping[-1][0], current_word_id)
                        filtered_sent_watching_situations[-1] |= word_watching_situations[record_word_id][0]
                        filtered_sent_normal_labels[-1] &= sum(sent_multilabel_single_labels[record_word_id]) == 0
                        filtered_sent_multilabel_labels[-1] = tuple(x | y for x, y in zip(filtered_sent_multilabel_labels[-1], sent_multilabel_single_labels[record_word_id]))
                    current_word_id += 1

            # filtered_word_watching_situations = [torch.tensor(word_watching_situation, dtype=torch.float) for word_watching_situation in filtered_word_watching_situations]
            # filtered_sent_watching_situations = [torch.tensor(sent_watching_situation, dtype=torch.float) for sent_watching_situation in filtered_sent_watching_situations]
            filtered_word_normal_labels = [torch.tensor([1.] if word_normal_label else [0.]) for word_normal_label in filtered_word_normal_labels]
            filtered_sent_normal_labels = [torch.tensor([1.] if sent_normal_label else [0.]) for sent_normal_label in filtered_sent_normal_labels]
            filtered_word_multilabel_labels = [torch.tensor(word_multilabel_label, dtype=torch.float) for word_multilabel_label in filtered_word_multilabel_labels]
            filtered_sent_multilabel_labels = [torch.tensor(sent_multilabel_label, dtype=torch.float) for sent_multilabel_label in filtered_sent_multilabel_labels]
            # print(article_record_word_ids)
            # print(filtered_transformer_features)
            self.experiment_data.append((
                torch.stack(filtered_transformer_features).detach(),
                torch.stack(filtered_difficulty_features),
                torch.stack(filtered_topic_features),
                torch.stack(filtered_eye_features),

                torch.stack(filtered_word_normal_labels),
                torch.stack(filtered_sent_normal_labels),
                torch.stack(filtered_word_multilabel_labels),
                torch.stack(filtered_sent_multilabel_labels),

                torch.tensor(filtered_sent_word_mapping),

                # torch.tensor(filtered_word_list),

                torch.tensor(filtered_word_watching_situations),
                torch.tensor(filtered_sent_watching_situations),
            ))

        print(time.time() - ti)
            # self.experiment_lengths.append(len(filtered_transformer_features))

    def __len__(self):
        # return sum(self.experiment_lengths)
        return len(self.experiments)

    def __getitem__(self, idx):
        # current_exp_idx = 0
        # while idx - self.experiment_lengths[current_exp_idx] >= 0:
        #     idx -= self.experiment_lengths[current_exp_idx]
        #     current_exp_idx += 1
        # return self.experiment_data[idx], torch.tensor([idx])
        return self.experiment_data[idx]

    def _match_article_record_word_list(self, article, experiment):
        if (article.id, experiment.id) in self.article_record_word_ids_mapping:
            return self.article_record_word_ids_mapping[(article.id, experiment.id)]
        article_word_list = article.get_word_list()
        record_word_list = experiment.get_word_list()

        article_record_word_ids = []

        article_word_id, record_word_id = 0, 0
        while article_word_id < len(article_word_list) and record_word_id < len(record_word_list):
            if article_word_list[article_word_id] == record_word_list[record_word_id]:
                article_record_word_ids.append((article_word_id, record_word_id))
                # print(article_word_id, record_word_id, article_word_list[article_word_id])
                article_word_id += 1
                record_word_id += 1
            else:
                delta_id = 1
                flag = False
                while delta_id < self.article_word_mismatch_thr and article_word_id + delta_id < len(article_word_list):
                    cur_article_word_id = article_word_id + delta_id
                    if article_word_list[cur_article_word_id] == record_word_list[record_word_id]:
                        article_record_word_ids.append((cur_article_word_id, record_word_id))
                        # print(cur_article_word_id, record_word_id, article_word_list[cur_article_word_id])
                        article_word_id = cur_article_word_id + 1
                        record_word_id += 1
                        flag = True
                        break
                    delta_id += 1
                if flag:
                    continue

                delta_id = 1
                flag = False
                while delta_id < self.record_word_mismatch_thr and record_word_id + delta_id < len(record_word_list):
                    cur_record_word_id = record_word_id + delta_id
                    if article_word_list[article_word_id] == record_word_list[cur_record_word_id]:
                        article_record_word_ids.append((article_word_id, cur_record_word_id))
                        # print(article_word_id, cur_record_word_id, article_word_list[article_word_id])
                        article_word_id += 1
                        record_word_id = cur_record_word_id + 1
                        flag = True
                        break
                    delta_id += 1
                if flag:
                    continue

                article_word_id += 1
                record_word_id += 1

        self.article_record_word_ids_mapping[(article.id, experiment.id)] = article_record_word_ids
        return article_record_word_ids


EYE_FEATURE_NAMES = [
    'reading_times',
    'number_of_fixations',
    'second_pass_dwell_time_of_sentence',
    'total_dwell_time_of_sentence',
    'reading_times_of_sentence',
    'saccade_times_of_sentence',
    'forward_times_of_sentence',
    'backward_times_of_sentence',
    # 'saccade_times_of_para',
    # 'forward_saccade_times_of_para',
    # 'backward_saccade_times_of_para',
]
WORD_LABEL_NAMES = [
    'word_understand',
]
SENT_LABEL_NAMES = [
    'sentence_understand',
    'mind_wandering',
]
WORD_WATCHING_LABEL_NAME = 'word_watching'
dataset = ReadingDataset(
    reading_articles_path='/home/wtpan/memx4edu-code/training_data/article.csv',
    # reading_experiments_path='/home/wtpan/memx4edu-code/training_data/2022-10-28_4test.csv',
    reading_experiments_path='/home/wtpan/memx4edu-code/training_data/2022-10-28.csv',
    eye_feature_names=EYE_FEATURE_NAMES,
    eye_feature_size=12,
    word_label_names=WORD_LABEL_NAMES,
    sent_label_names=SENT_LABEL_NAMES,
    word_watching_label_name=WORD_WATCHING_LABEL_NAME,
    article_word_mismatch_thr=10,
    record_word_mismatch_thr=3,
)
len(dataset)

In [149]:
transformer_features, difficulty_features, topic_features, eye_features, word_normal_labels, sent_normal_labels, word_multilabel_labels, sent_multilabel_labels, sent_word_mapping, word_watching_situations, sent_watching_situations = dataset[0]
sent_word_mapping

tensor([[  0,  70],
        [ 71, 111],
        [112, 144],
        [145, 163],
        [164, 179],
        [180, 187],
        [188, 207],
        [208, 218],
        [219, 253],
        [254, 269],
        [270, 288],
        [289, 316],
        [317, 336],
        [337, 393],
        [394, 423]])

In [185]:
train_count = int(0.8 * len(dataset))
val_count = len(dataset) - train_count
train_dataset, val_dataset = random_split(dataset, (train_count, val_count))
# TODO: batch_size
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1)
len(train_dataset)

43

In [144]:
for a in train_dataloader:
    transformer_features, difficulty_features, topic_features, eye_features, word_normal_labels, sent_normal_labels, word_multilabel_labels, sent_multilabel_labels, sent_word_mapping, word_list, word_watching_situations, sent_watching_situations = a
    print(word_watching_situations)
    break

[tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([1]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0])

In [125]:
a = torch.tensor([[[1,0], [1,0], [1,0], [1,0], [1,0]]], dtype=torch.float)
b = torch.tensor([[[1,0], [1,0], [1,0], [1,0], [1,0]]], dtype=torch.float)
w = torch.tensor([[[1,1], [1,1], [1,1], [1,1], [1,1]]], dtype=torch.float)
print(a.shape)
nn.BCELoss(weight=w)(a,b)

torch.Size([1, 5, 2])


tensor(0.)

In [186]:
writer = SummaryWriter('/home/wtpan/memx4edu-code/training_summary')

# valid_accuracy = MultilabelAccuracy(num_labels=3, average=None)
# valid_precision = MultilabelPrecision(num_labels=3, average=None)
# valid_recall = MultilabelRecall(num_labels=3, average=None)
# valid_f1_score = MultilabelF1Score(num_labels=3, average=None)
# valid_hamming_distance = MultilabelHammingDistance(num_labels=3, average=None)
validation_metrics = {
    'word_normal': {
        'accuracy': Accuracy(),
        'precision': Precision(),
        'recall': Recall(),
        'f1_score': F1Score(),
    },
    'sent_normal': {
        'accuracy': Accuracy(),
        'precision': Precision(),
        'recall': Recall(),
        'f1_score': F1Score(),
    },
}

# model = torch.load('/home/wtpan/memx4dog-code/cvr/saved_model/new_complete_model_0415_100000_0.652.pkl')
model = ClassificationModel().cuda()

print('Start training...')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train_len = len(train_dataset)
ALPHA, BETA = 1., 1.
WORD_WINDOW_SIZE = 5
SENT_WINDOW_SIZE = 1
for epoch in range(50000):
    model.train()
    losses = []
    mask_nums = []
    for step, (
            transformer_feats, difficulty_feats, topic_feats,
            eye_feats,
            word_normal_labels, sent_normal_labels,
            word_multilabel_labels, sent_multilabel_labels,
            sent_mapping,
            word_watching_situations, sent_watching_situations,
    ) in enumerate(train_dataloader):
        word_normal_cls, sent_normal_cls, _, sent_multilabel_cls = model(transformer_feats.cuda(), difficulty_feats.cuda(), topic_feats.cuda(), eye_feats.cuda(), sent_mapping[0])

        word_normal_weight = torch.tensor([[[0.01] if l else [1.] for l in word_normal_labels[0]]]).cuda()
        sent_normal_weight = torch.tensor([[[0.01] if l else [1.] for l in sent_normal_labels[0]]]).cuda()
        # word_multilabel_weight = torch.tensor([[0.] if l else [1.] for l in word_normal_labels[0]]).cuda()
        sent_multilabel_weight = torch.tensor([[[0.01, 0.01] if l else [1., 1.] for l in sent_normal_labels[0]]]).cuda()
        word_normal_position_weight = torch.tensor([[[1.] if word_watching_situation else [0.] for word_watching_situation in word_watching_situations[0]]]).cuda()
        sent_normal_position_weight = torch.tensor([[[1.] if sent_watching_situation else [0.] for sent_watching_situation in sent_watching_situations[0]]]).cuda()
        # word_multilabel_position_weight = torch.tensor([[1. if word_watching_situation else 0. for word_watching_situation in word_watching_situations[0]]]).cuda()
        sent_multilabel_position_weight = torch.tensor([[[1., 1.] if sent_watching_situation else [0., 0.] for sent_watching_situation in sent_watching_situations[0]]]).cuda()

        word_normal_loss = nn.BCELoss(weight=word_normal_weight * word_normal_position_weight)(word_normal_cls, word_normal_labels.cuda())
        sent_normal_loss = nn.BCELoss(weight=sent_normal_weight * sent_normal_position_weight)(sent_normal_cls, sent_normal_labels.cuda())
        # word_multilabel_loss = nn.BCELoss()(word_multilabel_cls.squeeze(), word_multilabel_labels.squeeze().cuda())
        sent_multilabel_loss = nn.BCELoss(weight=sent_multilabel_weight * sent_multilabel_position_weight)(sent_multilabel_cls, sent_multilabel_labels.cuda())

        loss = word_normal_loss + ALPHA * sent_normal_loss + BETA * sent_multilabel_loss

        losses.append(loss.detach().cpu().numpy())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss_num = np.sum(losses) / train_len
    writer.add_scalar('loss', loss_num, epoch)
    print(epoch, loss_num)

    if (epoch + 1) % 1 == 0:
        model.eval()
        pred_list = []
        gt_list = []
        for (
                transformer_feats, difficulty_feats, topic_feats,
                eye_feats,
                word_normal_labels, sent_normal_labels,
                word_multilabel_labels, sent_multilabel_labels,
                sent_mapping,
                word_watching_situations, sent_watching_situations,
            ) in val_dataloader:
            word_normal_cls, sent_normal_cls, _, sent_multilabel_cls = model(transformer_feats.cuda(), difficulty_feats.cuda(), topic_feats.cuda(), eye_feats.cuda(), sent_mapping[0])

            current_position_predictions = {
                'word_normal': word_normal_cls[ : , torch.nonzero(word_watching_situations[0]).squeeze(), : ].detach().cpu(),
                'sent_normal': sent_normal_cls[ : , torch.nonzero(sent_watching_situations[0]).squeeze(), : ].detach().cpu(),
            }
            current_position_labels = {
                'word_normal': word_normal_labels[ : , torch.nonzero(word_watching_situations[0]).squeeze(), : ].int(),
                'sent_normal': sent_normal_labels[ : , torch.nonzero(sent_watching_situations[0]).squeeze(), : ].int(),
            }

            for x in validation_metrics:
                for y in validation_metrics[x]:
                    validation_metrics[x][y](current_position_predictions[x], current_position_labels[x])

            # pred_list.append(cur_pred[0].numpy())
            # gt_list.append(cur_label[0].numpy())

        validation_metrics_value = {'word_normal': {}, 'sent_normal': {}}
        for x in validation_metrics:
            for y in validation_metrics[x]:
                validation_metrics_value[x][y] = validation_metrics[x][y].compute().item()

        writer.add_scalars('word_normal', validation_metrics_value['word_normal'], epoch)
        writer.add_scalars('sent_normal', validation_metrics_value['sent_normal'], epoch)

        torch.save(model, '/home/wtpan/memx4edu-code/saved_model/model_1018_%06d.pkl' % epoch)

        for x in validation_metrics:
            for y in validation_metrics[x]:
                validation_metrics[x][y].reset()

Start training...
0 0.000761026733143385
1 0.0007286608912223994
2 0.0006973510391490404
3 0.0006709064595228017
4 0.0006443746821131817
5 0.0006221193213795507
6 0.0006008523321428964


KeyboardInterrupt: 

In [172]:
a = torch.tensor([[1, 0, 1, 0]])
b = torch.nonzero(a)
a[b]

IndexError: index 2 is out of bounds for dimension 0 with size 1

In [190]:
a = np.array([0.3,0.501,0.7])
np.around(a)

array([0., 1., 1.])

In [116]:
dataset.experiments[3].id

477

In [257]:
from torchmetrics.classification import MultilabelPrecision
target = torch.tensor([[0, 1, 0], [1, 0, 1]])
preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
metric = MultilabelPrecision(num_labels=3, average='macro')
metric(preds, target)

tensor(0.5000)

In [15]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
src = torch.rand(1000, 512)
out = encoder_layer(src)
out.shape

torch.Size([1000, 512])