In [None]:
from comet_ml import Experiment
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler, BatchSampler, SequentialSampler
from torch.utils.data import Dataset, DataLoader
import numpy as np
from bs4 import BeautifulSoup
import re
import pandas as pd
import math
import os
from collections import Counter
from functools import reduce
import glob
from matplotlib import pyplot as plt
from nltk import RegexpTokenizer
from nltk.corpus import stopwords as stopwords_list
import pickle
import sys
import random
from nltk.tag.stanford import CoreNLPPOSTagger
# from nltk.tag.stanford import StanfordTagger as CoreNLPPOSTagger
import bs4

In [None]:
# Check whether gpu is available
torch.cuda.is_available()
torch.cuda.current_device()


In [None]:
sequence_length = 100


In [None]:
class ArticleTextData():
    authors = {}
    texts = {}
    tokenizer = RegexpTokenizer(r'\w+')
    stop = stopwords_list.words("english")

    def __init__(self, filedir=None, filedir_submitted=None, max_classes=None, min_class_freq=None, vocabulary_file=None, 
                 vocabulary=None, authors=None, texts = None, keys_mapping_file=None,
                 stopwords_index=None, stopwords_dim=None,
                references=None, texts_submitted=None,
                classes=None, nr_labels=None, references_index=None, 
                 tail_authors=False, middle_authors=False,
                references_dim=None, class_weights=None, just_abstract=False, 
                 just_contexts=False):
        self.max_classes = max_classes
        self.min_class_freq = min_class_freq
        self.just_abstract = just_abstract
        self.just_contexts = just_contexts
        if authors:
            self.vocabulary = vocabulary
            self.vocabulary_size = len(self.vocabulary) + 1
            self.authors = authors
            self.texts = texts
            self.stopwords_index = stopwords_index
            self.stopwords_dim = stopwords_dim
            self.references = references
            self.references_index = references_index
            self.references_dim = references_dim
            self.classes = classes
            self.nr_labels = nr_labels
            self.class_weights = class_weights
            self.texts_submitted = texts_submitted
        else:
            keys_mapping = None
            if keys_mapping_file:
                keys_mapping = {}
                with open(keys_mapping_file) as f:
                    for line in f:
                        a1, a2 = line.strip().split(",")
                        keys_mapping[a2] = a1
            self._build_dataset(filedir, filedir_submitted, keys_mapping, vocabulary_file=vocabulary_file)
            self.classes, self.nr_labels, self.class_weights = self._build_class_index(
                max_classes=max_classes, min_class_freq=self.min_class_freq,
                tail_authors=tail_authors, middle_authors=middle_authors)
            self._build_references_index()
        self.article_indices = [i for i in self.authors if i in self.texts]

    def __len__(self):
        return len(self.article_indices)
    
    @classmethod
    def _collect_references(cls, filedir, keys_mapping=None):
        filedir = filedir + "/references/"
        filenames = os.listdir(filedir)
        references = {}
        for i, filename in enumerate(filenames):
            article_index = filename.split(".")[0]
            articles = {}

            with open(os.path.join(filedir, filename)) as infile:
                contents = infile.read()
                soup = BeautifulSoup(contents,'xml')
                article_nodes = soup.find_all('biblStruct')
                for article_node in article_nodes:
                    title = article_node.find_all("title")[0].get_text()
                    articles[title] = []
                    authors_nodes = article_node.find_all("author")
                    for author_node in authors_nodes:
                        names = []
                        for name_node in author_node.find_all("forename"):
                            names.append(name_node.get_text().strip())
                        for name_node in author_node.find_all("surname"):
                            names.append(name_node.get_text().strip())
                        if names:
                            articles[title].append(" ".join(names))
            if keys_mapping:
                references[keys_mapping[article_index]] = articles
            else:
                references[article_index] = articles
        return references
            
    @classmethod
    def _normalize_authors(cls, authors):
        normalized_authors = []
        for author in authors:
            author_names = author.split()
            # take inital of first name and full last name
            try:
                normalized_name = " ".join([author_names[0][0], author_names[-1]])
            except:
                sys.stderr.write("Could not normalize name %s\n" % author)
            normalized_authors.append(normalized_name)
        return normalized_authors

    @classmethod
    def _collect_authors(cls, filedir):
        filedir = filedir + "/headers/"
        filenames = os.listdir(filedir)
        for i, filename in enumerate(filenames):
            article_index = filename.split(".")[0]
            cls.authors[article_index] = {}
            with open(os.path.join(filedir, filename)) as infile:
                contents = infile.read()
                soup = BeautifulSoup(contents,'xml')
                title = soup.find_all('title')[0].get_text()
                cls.authors[article_index]['title'] = title
                authors_nodes = soup.find_all('author')
                cls.authors[article_index]['authors'] = []
                for author_node in authors_nodes:
                    names = []
                    for name_node in author_node.find_all('forename'):
                        name = name_node.get_text().strip()
                        # the first (not middle) name is 1-letter long
                        if len(name.split()[0]) == 1 and not names:
                            sys.stderr.write("Suspicious name: %s, article %s\n" % (name, article_index))
                        names.append(name)
                    for name_node in author_node.find_all('surname'):
                        name = name_node.get_text().strip()
                        # exponents
                        if name[-1] in '0123456789':
                            name = name[:-1]
                        names.append(name)
                    if len(names) > 3:
                        sys.stderr.write("Suspicious name: %s\n" % " ".join(names))
                    if names:
                        # TODO: test of name containing non-letter characters
                        cls.authors[article_index]['authors'].append(" ".join(names))

                
    def _collect_texts(self, filedir, keys_mapping=None):
        filenames = glob.glob(filedir + "/contents/*.xml")
        excluded_tags = {'ref'}
        context_window = 50
        ref_contexts = {}
        texts = {}
        for i, filename in enumerate(filenames):
            article_index = filename.split("/")[-1].split(".")[0]
            with open(filename) as infile:
                contents = infile.read()
                soup = BeautifulSoup(contents,'xml')
                body = soup.find_all('text')[0].find_all('body')[0]
                abstract = soup.find_all('abstract')[0]
                title = soup.find_all('title')[0]
                all_texts = []
                ref_contexts_local = []
                for refel in body.find_all('ref'): 
                    if refel.previous and type(refel.previous) is bs4.element.NavigableString: 
                        ref_contexts_local.append(refel.previous[-context_window:]) 
                    if refel.next_sibling and type(refel.next_sibling) is bs4.element.NavigableString:
                        ref_contexts_local.append(refel.next_sibling[:context_window]) 
                    refel.replace_with(" ")
                elements = title.find_all() + abstract.find_all()
                if not self.just_abstract:
                    elements.extend(body.find_all())
                for el in elements:
                    if el.name not in excluded_tags:
                        all_texts.append(el.get_text())
                text = " ".join(filter(None, all_texts))
                # fix syllabified words
                text = re.sub("- ", "", text) # TODO: check
                if keys_mapping:
                    texts[keys_mapping[article_index]] = text
                    ref_contexts[article_index] = " ".join(ref_contexts_local)
                else:
                    texts[article_index] = text
                    ref_contexts[article_index] = " ".join(ref_contexts_local)
        if self.just_contexts:
            return ref_contexts
        else:
            return texts
    
    def random_split(self, proportion1, proportion2=0, max_classes=None,
            min_class_freq=None, tail_authors=False, middle_authors=False):
        
        # Split all text indices
        indices1 = set()
        indices2 = set()
        indices3 = set()
        articles_for_author = {}
        
        for article, article_metadata in self.authors.items():
            author_list = article_metadata['authors']
            for a in author_list:
                if a not in self.classes:
                    continue
                if a not in articles_for_author:
                    articles_for_author[a] = []
                articles_for_author[a].append(article)

        sorted_articles_per_author = sorted(articles_for_author.items(), key=lambda x: len(x[1]), reverse=False)
        for author, articles in sorted_articles_per_author:
            # take first article that is not already in second set, add it to first
            for i in range(len(articles)):
                if articles[i] not in indices2 and articles[i] not in indices3:
                    indices1.add(articles[i])
                    break
            # take first article that is not already in first set, add it to second
            for i in range(len(articles)):
                if articles[i] not in indices1 and articles[i] not in indices3:
                    indices2.add(articles[i])
                    break
            if proportion2: # we want to split it 3-ways
                for i in range(len(articles)):
                    if articles[i] not in indices1 and articles[i] not in indices2:
                        indices3.add(articles[i])
                        break
            
        # what is left
        articles = set(self.texts.keys()).difference(indices2).difference(indices1).difference(indices3)
        n1 = int(proportion1*len(self.texts.keys()))
        if not proportion2:
            p2 = 1-proportion1
        else:
            p2 = proportion2
        n2 = int(proportion2*len(self.texts.keys()))
        n1_remaining = max(n1-len(indices1), 0)
        ia1 = set(random.sample(population=articles, k=min(n1_remaining, len(articles))))
        indices1 = indices1.union(ia1)
        n2_remaining = max(n2-len(indices2), 0)
        articles = articles.difference(indices1)
        ia2 = set(random.sample(population=articles, k=min(n2_remaining, len(articles))))
        indices2 = indices2.union(ia2)
        ia3 = articles.difference(indices1).difference(indices2)
        indices3 = indices3.union(ia3)
        
        # Filter out authors that don't occur in all 3 index sets
        valid_authors = set()
        min_occs = 1000000
        for aut, indices in articles_for_author.items():
            # keep the author if it occurs in both sets
            occs1, occs2, occs3 = len(set(indices).intersection(indices1)), len(set(indices).intersection(indices2)), len(set(indices).intersection(indices3))
            if occs1 and occs2 and (occs3 or not indices3):
                valid_authors.add(aut)
                if occs1+occs2+occs3 < min_occs:
                    min_occs = occs1+occs2+occs3

        authors1 = {i: self.authors[i] for i in indices1}
        authors2 = {i: self.authors[i] for i in indices2}
        authors3 = {i: self.authors[i] for i in indices3}
                    
        # Filter out texts which don't have any of the valid authors
        texts1 = {i: self.texts[i] for i in indices1 
                  if i in self.texts and [a for a in authors1[i]['authors'] if a in valid_authors]} # TODO: why are there indices which are not in self.texts?        
        texts2 = {i: self.texts[i] for i in indices2 
                  if i in self.texts and [a for a in authors2[i]['authors'] if a in valid_authors]}
        texts3 = {i: self.texts[i] for i in indices3 
                  if i in self.texts and [a for a in authors3[i]['authors'] if a in valid_authors]}
        texts_submitted = {i: self.texts_submitted[i] for i in texts3 if i in self.texts_submitted}
        

        print("minimum articles per author", min_occs)

        # Filter out authors which don't occur in any text anymore
        authors_in1 = set()
        authors_in2 = set()
        authors_in3 = set()
        for article, metadata in self.authors.items():
            authors = metadata['authors']
            if article not in texts1 and article not in texts2 and article not in texts3:
                continue
            if article in texts1:
                authors_in1 = authors_in1.union(set([a for a in authors if a in valid_authors]))
            if article in texts2:
                authors_in2 = authors_in2.union(set([a for a in authors if a in valid_authors]))
            if article in texts3:
                authors_in3 = authors_in3.union(set([a for a in authors if a in valid_authors]))
        valid_authors_with_texts = authors_in1.union(authors_in2).union(authors_in3)
        # Rebuild the class index using the authors that are left  
        classes, nr_labels, class_weights = self._build_class_index(
            valid_authors = valid_authors_with_texts, max_classes=max_classes,
            min_class_freq=min_class_freq, tail_authors=tail_authors, middle_authors=middle_authors)
      
        print("valid authors", len(valid_authors), "with texts", len(valid_authors_with_texts))
        data1 = ArticleTextData(authors=authors1, texts=texts1, vocabulary=self.vocabulary, 
                                max_classes=self.max_classes, stopwords_index=self.stopwords_index, 
                                stopwords_dim=self.stopwords_dim, references=self.references,
                               classes=classes, nr_labels=nr_labels, references_index=self.references_index,
                                references_dim=self.references_dim, class_weights=class_weights)
        data2 = ArticleTextData(authors=authors2, texts=texts2, vocabulary=self.vocabulary, 
                                max_classes=self.max_classes, stopwords_index=self.stopwords_index,
                               stopwords_dim=self.stopwords_dim, references=self.references,
                               classes=classes, nr_labels=nr_labels, references_index=self.references_index,
                                references_dim=self.references_dim, class_weights=class_weights)
    
        if proportion2:
            data3 = ArticleTextData(authors=authors3, texts=texts3, vocabulary=self.vocabulary, 
                                max_classes=self.max_classes, stopwords_index=self.stopwords_index,
                               stopwords_dim=self.stopwords_dim, references=self.references,
                               classes=classes, nr_labels=nr_labels, references_index=self.references_index,
                                references_dim=self.references_dim, class_weights=class_weights,
                                   texts_submitted=texts_submitted)
            
            data4 = ArticleTextData(authors=authors3, texts=texts_submitted, vocabulary=self.vocabulary, 
                    max_classes=self.max_classes, stopwords_index=self.stopwords_index,
                   stopwords_dim=self.stopwords_dim, references=self.references_submitted,
                   classes=classes, nr_labels=nr_labels, references_index=self.references_index,
                    references_dim=self.references_dim, class_weights=class_weights,
                       texts_submitted=texts_submitted)
            
            return data1, data2, data3, data4
        
        return data1, data2
    


  
    def _build_vocabulary(self, min_freq=50, max_words=50000, vocabulary_file=None):
        self.vocabulary = {}
        self.stopwords_index = {self.stop[i]: i for i in range(len(self.stop))}
        self.stopwords_dim = len(self.stopwords_index)
        if vocabulary_file:
            self.vocabulary = pickle.load(open(vocabulary_file, "rb"))
            self.vocabulary_size = len(self.vocabulary)
            return
        word_freqs = Counter() 
        for art, text in self.texts.items():
            tokenized_text = self._tokenize(text, filter_stopwords=True)
            word_freqs.update(tokenized_text)
        i = 1 # starts with 1, keeping 0 for padding
        for word, freq in word_freqs.most_common():
            if (min_freq and word_freqs[word] < min_freq) or i > max_words:
                print("min freq is", freq)
                break
            if re.match("^[0-9]+$", word):
                continue
            self.vocabulary[word] = i
            i += 1
        self.vocabulary_size = len(self.vocabulary) + 1 # to include unknowns
 

    @classmethod
    def _tokenize(cls, text, filter_stopwords=False, keep_only_stopwords=False):
        tokenized_text = cls.tokenizer.tokenize(text.lower())
        if filter_stopwords:
            filtered_text = list(filter(lambda w: w not in cls.stop, tokenized_text))
        else:
            filtered_text = tokenized_text
        if keep_only_stopwords:
            filtered_text = list(filter(lambda w: w in cls.stop, tokenized_text))
        else:
            filtered_text = list(filter(lambda w: len(w)>1, filtered_text))
        return filtered_text

    def _stopwords(cls, tokenized_text):
        return [w for w in tokenized_text if w in cls.stop]
    
    def _build_references_index(self):
        self.references_index = {}
        cnt = 0
        for article, references in self.references.items():
            if article not in self.authors:
                continue
            article_authors = self.authors[article]['authors']
            # only consider references in our dataset (covered by our limited classes)
            labels = [self.classes.get(a, -1) for a in article_authors]
            if all([label<0 for label in labels]):
                continue
            for title, authors in references.items():
                authors = self._normalize_authors(authors)
                for author in authors:
                    if author not in self.references_index:
                        self.references_index[author] = cnt
                        cnt += 1
        self.references_dim = len(self.references_index)
    
    @classmethod
    def _build_class_index(cls, max_classes=None, min_class_freq=None, valid_authors=None, 
                           top_authors=False, tail_authors=False, middle_authors=False):
        class_freq = Counter()
        if not tail_authors and not middle_authors and max_classes:
            top_authors=True
        for article, article_metadata in cls.authors.items():
            author_list = article_metadata['authors']
            class_freq.update(author_list)
        if top_authors:
            all_authors = class_freq.most_common(max_classes)
        if tail_authors:
            all_authors = []
            for a, f in sorted(class_freq.items(), key=lambda t:t[1], reverse=False):
                if f >= min_class_freq:
                    all_authors.append((a,f))
                if len(all_authors) > max_classes:
                    break
        middle_authors_i1 = 0
        middle_authors_i2 = len(class_freq)
        if middle_authors:
            all_authors = []
            for a, f in class_freq.most_common():
                if min_class_freq and (f < min_class_freq):
                    continue
                if valid_authors and a not in valid_authors:
                    continue
                all_authors.append((a,f))
            total_authors = len(all_authors)
            print("total authors", total_authors)
            middle_authors_i1 = max((total_authors - max_classes)/2, 0)
            middle_authors_i2 = min(middle_authors_i1 + max_classes, total_authors-1)
        classes = {}
        i = 0
        for cnt, (author, freq) in enumerate(all_authors):
            if min_class_freq and (freq < min_class_freq):
                continue
            if valid_authors and author not in valid_authors:
                continue
            if middle_authors and \
                ((cnt < middle_authors_i1) or (cnt > middle_authors_i2)):
                continue
            classes[author] = i
            i += 1
        nr_labels = len(classes)
        class_weights = np.zeros(len(classes))
        for l, i in classes.items():
            class_weights[i] = 1./class_freq[l]
#             print("weight of ", i, self.weights[i], authors_freq[l])
        return classes, nr_labels, class_weights
    
    
    def _build_dataset(self, filedir="/home/ana/code/research/acl_authorship/extracted", filedir_submitted=None,
                       keys_mapping=None,
                       min_word_freq=5, max_article_len=20000,
                      vocabulary_file=None):
        self._collect_authors(filedir)
        self.texts = self._collect_texts(filedir)
        self.references = self._collect_references(filedir)
        if filedir_submitted:
            self.texts_submitted = self._collect_texts(filedir_submitted, keys_mapping)
            self.references_submitted = self._collect_references(filedir_submitted, keys_mapping)
        else:
            self.texts_submitted = {}
            self.references_submitted = {}
        self._build_vocabulary(min_freq=min_word_freq, vocabulary_file=vocabulary_file)


In [None]:
class ArticleDataset(Dataset):
    def __init__(self, article_data, multilabel, segment_size=1500,
                 pad=True, pretrained_embeddings_file = None, embeddings_size = None, max_article_len=20000,
                valid_labels=None, extract_POS=True, context_window=5, context_max_words=1000):

        self.article_data = article_data
        self.multilabel = multilabel
        self.segment_size = segment_size
        self.pad = pad
        self.context_window = context_window
        self.context_max_words = context_max_words
        self.pretrained_embeddings = None
        self.nr_labels = self.article_data.nr_labels
        self.max_article_len = max_article_len
        self.parser = None#CoreNLPPOSTagger()
        self.extract_POS = extract_POS
        self.POS_index = {pos: i+1 for i, pos in enumerate(['VBZ', 'VBG', 'PRP', 'VBN', 'IN', 'WRB', 'JJR', 'SYM', 'NNS', 'RB', 'TO', 'WDT', 'WP$', 
                 'RP', 'CD', 'NNPS', 'RBS', 'PRP$', 'NN', 'PDT', 'EX', 'FW', 'UH', 'WP', 'DT', 'VBP', 
                 'MD', 'CC', 'JJS', 'RBR', 'VB', 'JJ', 'LS', 'VBD', 'NNP'])}
        if pretrained_embeddings_file:
            self.pretrained_embeddings = pickle.load(open(pretrained_embeddings_file, "rb"))
            self.embeddings_size = embeddings_size
        self.valid_labels = valid_labels
        self._encode_items()
    
    def _binarize_labels(self, n, labels):
        binarized_labels = np.zeros(n)
        for l in labels:
            if l >= 0:
                binarized_labels[l] = 1
        return binarized_labels
    
    def _encode_items(self):
        self.data = []
        self.weights = []
        self.source_articles = []
        context_window = self.context_window
        for article, text in self.article_data.texts.items():
            author_list = self.article_data.authors[article]['authors']
            labels = [self.article_data.classes.get(a, -1) for a in author_list]
            if self.valid_labels:
                labels = [l for l in labels if l in self.valid_labels]
            if not labels or all([label<0 for label in labels]):
                continue
            tokenized_text = self.article_data._tokenize(text, filter_stopwords=False)
            references = Counter()
            try:
                referenced_articles = self.article_data.references[article]
                for title, authors in referenced_articles.items():
                    authors = self.article_data._normalize_authors(authors)
                    references.update(authors)
            except KeyError:
                sys.stderr.write("No references for article %s\n" % article)
            if len(tokenized_text) <= 0 or len(tokenized_text) > self.max_article_len:
                continue

            for i in range(0, max(len(tokenized_text)-self.segment_size, len(tokenized_text)), self.segment_size):
                segment = tokenized_text[i: i+self.segment_size]
#                 print(segment)
                if len(segment) < self.segment_size:
                    if not self.pad:
                        continue
                if self.extract_POS:
                    try:
                        text_POS = [pos for (w, pos) in self.parser.tag(segment)]
                    except Exception:
#                         print("Could not extract POS for text " + text)
                        text_POS = []
                else:
                    text_POS = []
                text_stopwords = Counter()
                # Note: make sure you don't remove stopwords from segment/tokenized_text before doing this
                for w in segment:
                    if w in self.article_data.stopwords_index:
                        text_stopwords[w] += 1
                if self.multilabel:
                    binarized_labels = self._binarize_labels(self.nr_labels, labels)
                    self.data.append((segment, text_stopwords, text_POS, references, binarized_labels, article))
                else:
                    for label in labels:
                        if label < 0:
                            continue
                        self.data.append((segment, text_stopwords, text_POS, references, label, article))
                        self.weights.append(self.article_data.class_weights[label])
                        self.source_articles.append(article)

                    
    def __len__(self):
        return len(self.data)
    
    def _random_embedding(self, n):
        # normally distributed values between -1 and 1, mean 0
        return np.random.randn(n)/3.
    
    def __getitem__(self, i):
        segment, stopwords, POS, references, label, article = self.data[i]
        if not self.pretrained_embeddings:
            encoded_text = [self.article_data.vocabulary.get(word, self.article_data.vocabulary_size-1) 
                                    for word in segment]

            if len(segment) < self.segment_size and self.pad:
                encoded_text = encoded_text + [0] * (self.segment_size - len(encoded_text))

        encoded_stopwords = np.zeros(self.article_data.stopwords_dim)
        # Note: sometimes the tagger merges a few tokens into one, so the sequence gets shorter, so it
        # needs to be padded
        # Sometimes it splits words in 2 (can/not) and needs to be trimmed
        encoded_POS = np.zeros((max(self.POS_index.values())+1, self.segment_size))
        for i, pos in enumerate(POS):
            if i < encoded_POS.shape[1]:
                encoded_POS[self.POS_index[pos]][i] = 1
        for w, c in stopwords.items():
            encoded_stopwords[self.article_data.stopwords_index[w]] = c
        encoded_references = np.zeros(self.article_data.references_dim)
        for r, c in references.items():
            if r in self.article_data.references_index:
                encoded_references[self.article_data.references_index[r]] = c
        encoded_item = ((torch.tensor(encoded_text, dtype=torch.int64), 
                         torch.tensor(encoded_stopwords, dtype=torch.float32),
                         torch.tensor(encoded_POS, dtype=torch.float32),
                        torch.tensor(encoded_references, dtype=torch.float32)),
                        torch.tensor(label, dtype=torch.int64),
                       article)
        return encoded_item

In [None]:
%%time
text_dataset = ArticleTextData(filedir="acl_pre2014/", max_classes=None,min_class_freq=3,
                              vocabulary='vocabulary.pkl')
# Note: min_class_freq refers to number of articles not datapoints

In [None]:
training_articles, validation_articles, test_articles, test_articles_submitted = \
    text_dataset.random_split(0.8, 0.1, max_classes=200, middle_authors=True)

In [None]:
%%time

training_dataset = ArticleDataset(training_articles, multilabel=False, segment_size=sequence_length,
                                 extract_POS=False)


In [None]:
# Remove labels in validation set that don't occur in the training set

training_labels = []
for i, x in enumerate(training_dataset):
    features, labels, article = x
    print([f.shape for f in features], article)
    training_labels.append(labels.item())
    
training_labels_set = set(training_labels)


In [None]:
validation_dataset = ArticleDataset(validation_articles, multilabel=True, segment_size=sequence_length,
                                   valid_labels=training_labels_set, extract_POS=True) 
validation_dataset_single = ArticleDataset(validation_articles, multilabel=False, segment_size=sequence_length,
                                          valid_labels=training_labels_set, extract_POS=True) 

validation_labels = []
for i, x in enumerate(validation_dataset):
    features, labels, article = x
    labels_list = labels.nonzero().flatten().tolist()
    print(labels_list)
    validation_labels.extend(labels_list)
    
validation_labels2 = []
for i, x in enumerate(validation_dataset_single):
    features, labels, article = x
    validation_labels2.append(labels.item())


In [None]:
%%time
# Remove labels in training set that don't occur in the validation set

training_dataset = ArticleDataset(training_articles, multilabel=False, segment_size=sequence_length,
                                 extract_POS=True, valid_labels=set(validation_labels))
training_labels = []
for i, x in enumerate(training_dataset):
    features, labels, article = x
    training_labels.append(labels.item())
    
    
training_labels_set = set(training_labels)

In [None]:
test_dataset = ArticleDataset(test_articles, multilabel=True, segment_size=sequence_length,
                                   valid_labels=None, extract_POS=True)

In [None]:
test_labels = []
for i, x in enumerate(test_dataset):
    features, labels, article = x
    labels_list = labels.nonzero().flatten().tolist()
    print(labels_list)
    test_labels.extend(labels_list)

In [None]:
assert(not set(validation_labels).difference(set(training_labels)))
assert(training_dataset.article_data.references_index == validation_dataset.article_data.references_index)
assert(training_dataset.article_data.classes == validation_dataset.article_data.classes)
assert not set(list(training_articles.texts.keys())).intersection(set(list(validation_articles.texts.keys())))


In [None]:
assert not set(list(training_articles.texts.keys())).intersection(set(list(test_articles.texts.keys())))
# assert(not set(test_labels).difference(set(training_labels)))
assert(training_dataset.article_data.classes == test_dataset.article_data.classes)


In [None]:
assert not set(list(training_dataset.article_data.texts.keys())).intersection(set(list(articles_from_files.texts.keys())))
# assert(not set(test_labels).difference(set(training_labels)))
assert(training_dataset.article_data.classes == test_datapoint.article_data.classes)

In [None]:
pretrained_embeddings = pickle.load(open("datasets_serialized/acl_vocabulary_vectors6.pkl", "rb"))
vocabulary = pickle.load(open("datasets_serialized/acl_vocabulary6.pkl", "rb"))
vectors = pickle.load(open("datasets_serialized/acl_vocabulary_vectors6.pkl", "rb"))
embedding_size = 300
# random normal distribution with mean 0 from -1 to 1
pretrained_embedding_weights = np.zeros((training_dataset.article_data.vocabulary_size, embedding_size))
for w, i in training_dataset.article_data.vocabulary.items():
    pretrained_embedding_weights[i] = pretrained_embeddings.get(w, np.random.randn(embedding_size)/3.)
    
print(pretrained_embedding_weights.shape)



In [None]:
words = Counter()
for features in training_dataset.data:
    w, s, p, r, l, a = features
    words.update(w)

In [None]:
training_sampler = WeightedRandomSampler(training_dataset.weights, len(training_dataset))

dataloader = DataLoader(training_dataset, batch_size=50,
                        num_workers=4,
                        sampler=training_sampler)
all_labels_test = Counter()
texts_lengths = []
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch)
    features, labels, article = sample_batched
    words, stopwords, pos, references = features
    print(words.shape, stopwords.shape, references.shape, labels.shape)
    labels_list = labels.tolist()
    print(labels_list)
    all_labels_test.update(labels_list)
    if (i_batch>=2000):
        break

pd.Series(all_labels_test).plot(kind='hist')


In [None]:
class Net(nn.Module):

    def __init__(self, seq_length=1500, filter_size=4, nr_filters=100, nr_classes=1, hidden_layers=[], 
                 vocab_size=20001, embedding_size=30, convolutions=True, pretrained_embeddings_size=300,
                 stopwords_dim=179, references_dim=31889, references_hidden_dim=30, pos_dim=32,
                dropout=0, nr_filters_pos=20, filter_size_pos=4,
                 ignore_features=[False, False, False, False, False], pretrained_embedding_weights=None):
        super(Net, self).__init__()
        self.seq_length = seq_length
        self.seq_length_contexts = 1000
        self.filter_size = filter_size
        self.nr_filters = nr_filters
        self.nr_classes = nr_classes
        self.filter_size_pos = filter_size_pos
        self.nr_filters_pos = nr_filters_pos
        self.pos_dim = pos_dim
        self.convolutions = convolutions
        self.embeddings_size = embedding_size
        self.dropout = dropout
        self.ignore_features = ignore_features
        if self.dropout:
            self.dropout_layer1 = nn.Dropout(p=self.dropout)
            self.dropout_layer2 = nn.Dropout(p=self.dropout)
        prev_dim = 0

        if not ignore_features[0]:
            self.pretrained_embeddings_size = pretrained_embeddings_size
            if self.pretrained_embeddings_size:
                self.embeddings_size = self.pretrained_embeddings_size
            else: 
                self.embedding = nn.Embedding(vocab_size, embedding_size)
                self.embedding.weights = pretrained_embedding_weights
            prev_dim = self.embeddings_size
            if convolutions:
                self.conv1 = nn.Conv1d(in_channels=self.embeddings_size, out_channels=self.nr_filters, 
                                       kernel_size=self.filter_size)
                self.conv_output_dim = self.seq_length - self.filter_size + 1

                self.pool = nn.MaxPool1d(kernel_size=self.conv_output_dim)
                prev_dim = self.nr_filters

        
        if not ignore_features[2]:
            self.pos_conv = nn.Conv1d(in_channels=self.pos_dim, out_channels=self.nr_filters_pos, 
                                      kernel_size=self.filter_size_pos)
            self.pos_conv_output_dim = self.seq_length - self.filter_size_pos + 1
            self.pos_pool = nn.MaxPool1d(kernel_size=self.pos_conv_output_dim)
            prev_dim += self.nr_filters_pos
            
        if not ignore_features[3]:
            if references_hidden_dim:
                self.references_hidden_layer = nn.Linear(references_dim, references_hidden_dim)
            else:
                self.references_hidden_layer = None
                references_hidden_dim = references_dim
            prev_dim +=  references_hidden_dim

        self.hidden = nn.ModuleList([])
        if not ignore_features[1]:
            prev_dim += stopwords_dim
        for hidden_layer in hidden_layers:
            if not hidden_layer:
                continue
            self.hidden.append(nn.Linear(prev_dim, hidden_layer))
            prev_dim = hidden_layer
        # output layer
        self.fc = nn.Linear(prev_dim, nr_classes, bias=True)
    

    def forward(self, features):

        if not self.ignore_features[0]:
            if not self.pretrained_embeddings_size:
                features[0] = self.embedding(features[0])
                if self.dropout:
                    features[0] = self.dropout_layer1(features[0])

            if self.convolutions:
                features[0] = features[0].transpose(1,2)
                features[0] = F.relu(self.conv1(features[0]))
                features[0] = self.pool(features[0])
                # flatten
                features[0] = features[0].view(-1, self.num_flat_features(features[0]))

            else:
                features[0] = features[0].sum(dim=1)
        if not self.ignore_features[2]:
            features[2] = F.relu(self.pos_conv(features[2]))
            features[2] = self.pos_pool(features[2])
            features[2] = features[2].view(-1, self.num_flat_features(features[2]))

        if not self.ignore_features[3]:
            if self.references_hidden_layer:
                features[3] = F.relu(self.references_hidden_layer(features[3]))

        x = torch.cat([f for i, f in enumerate(features) if not self.ignore_features[i]], dim=1)
        for i, hidden in enumerate(self.hidden):
            x = F.relu(hidden(x))

        if self.dropout:
            x = self.dropout_layer2(x)
        x = self.fc(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    


In [None]:
def average_precision_recall(predicted_ranking, true_authors):
    '''Computes average precision and recall for a given article.'''
    precisions = []
    recalls = []
    reciprocal_rank = 0
    for rank, _ in enumerate(predicted_ranking):
        predicted_authors = predicted_ranking[:rank+1]
        correct_predictions = set(predicted_authors).intersection(set(true_authors))
        precision = float(len(correct_predictions))/(rank+1)
        recall = float(len(correct_predictions))/len(true_authors)
        recalls.append(recall)
        precisions.append(precision)

        # find first rank where at least one prediction is correct
        if len(correct_predictions) > 0:
            if reciprocal_rank == 0:
                reciprocal_rank = 1./(rank+1)
    return sum(precisions)/len(precisions), sum(recalls)/len(recalls), reciprocal_rank, precisions, recalls


In [None]:
def evaluate(net, data, topn=10, loss=None,  gpu=False, max_batches=20):
    loader = DataLoader(data, batch_size=50,
                        shuffle=True, num_workers=1)
    total_correct = 0
    total = 0
    losses = []
    precisions = []
    recalls = []
    RRs = []
    net.eval()
    for i, batch in enumerate(loader):
        features, labels, article = batch
        print(features[0], features[0].shape)
        if gpu:
            outputs = net([f.cuda() for f in features])
        else:
            outputs = net(features)
        if loss:
            if gpu:
                loss_size = loss(outputs, labels.cuda())
            else:
                loss_size = loss(outputs, labels)
            losses.append(loss_size.item())
        predictions = torch.sort(outputs, descending=True, dim=1)[1][:,:topn].cpu()
        if not loss: # it means these are multi labels
            labels_list = [l.nonzero().flatten() for l in labels]
            for i, prediction in enumerate(predictions):
                precision, recall, RR, _, _ = average_precision_recall(predictions[i].tolist(), labels_list[i].tolist())
                precisions.append(precision)
                recalls.append(recall)
                RRs.append(RR)
        else:
            labels_list = labels.tolist()
        correct = [len(np.intersect1d(predictions[j], labels_list[j]))>0 for j in range(len(predictions))]
        total_correct += sum(correct)
        total += len(correct)
        if i > max_batches:
            break
    MAP = float(sum(precisions))/len(precisions) if precisions else None
    MAR = float(sum(recalls))/len(recalls) if recalls else None
    MRR = float(sum(RRs))/len(RRs) if RRs else None
    accuracy = float(total_correct)/total
    return {'accuracy': accuracy, 'loss': losses, 'MAP': MAP, 'MAR': MAR, 'MRR': MRR}

In [None]:
hyperparams = {"epochs": 150,
              "learning_rate": 0.0001,
              "batch_size": 128,
              'hidden_layers': 0,
              'weight_decay': 0.0001,
              'convolutions': True,
               'filters': 300,
               'filter_width': 9,
               'nr_filters_pos': 50,
               'references_hidden_dim': 1000,
              'lr_decrease_step': 5,
              'lr_decrease_rate': 10,
               'dropout': 0.5,
              'optimizer': 'adam',
               'momentum': 0.9,
               'just_abstract': False,
               'just_context': False,
              'ignore_features': [False, False, False, False]}

In [None]:
losses = []
accuracies = []
training_accuracies = []
validation_losses = []
max_batches = 1000000
all_predictions = Counter()
all_labels = Counter()
eval_freq = 200

def train(hyperparams, training_data, validation_data, validation_data_single, training_validation_data=None, 
          gpu=False, log_experiment=False, experiment_tags=[], cometml_api_key=""):
    net = Net(filter_size=hyperparams['filter_width'], nr_filters=hyperparams['filters'], hidden_layers=[hyperparams['hidden_layers']], 
              nr_classes=training_dataset.nr_labels, nr_filters_pos=hyperparams['nr_filters_pos'],
              vocab_size=training_dataset.article_data.vocabulary_size, convolutions=hyperparams['convolutions'],
              seq_length=sequence_length, stopwords_dim=training_data.article_data.stopwords_dim,
              references_dim=training_data.article_data.references_dim,
              references_hidden_dim=hyperparams['references_hidden_dim'], pos_dim=max(training_dataset.POS_index.values())+1,
             embedding_size=300, pretrained_embeddings_size=0, dropout=hyperparams['dropout'],
             ignore_features=hyperparams['ignore_features'])
    experiment = Experiment(api_key=cometml_api_key,
                        project_name="general", workspace="", disabled=not log_experiment)
    print(net)
            
    experiment.add_tags(experiment_tags + ["correct", "correct_pos"])
    experiment.log_parameters(hyperparams)
    experiment.log_parameters({"seq_length": sequence_length, 
                               "vocab_size": training_dataset.article_data.vocabulary_size,
                              "nr_classes": training_data.nr_labels})
    if gpu:
        net = net.cuda()
    loss = torch.nn.CrossEntropyLoss()
    #Optimizer
    optimizer_names = {
        'adam': torch.optim.Adam,
        'adagrad': torch.optim.Adagrad,
        'sgd': torch.optim.SGD,
        'adadelta': torch.optim.Adadelta
    }
    optimizer_class = optimizer_names[hyperparams['optimizer']]
    optimizer = optimizer_class(net.parameters(), lr=hyperparams['learning_rate'], 
                                weight_decay=hyperparams['weight_decay'])
    max_acc = 0

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
    for e in range(hyperparams['epochs']):
        train_loader = DataLoader(training_data, batch_size=hyperparams['batch_size'],
                        sampler=training_sampler, num_workers=4)
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()
            features, labels, article = batch

            if gpu:
                outputs = net([f.cuda() for f in features])
                loss_size = loss(outputs, labels.cuda())
            else:
                outputs = net(features)
                loss_size = loss(outputs, labels)
            print("Loss:", loss_size.data.item())
            losses.append(loss_size.data.item())
            loss_size.backward()
            optimizer.step()
            
            predictions = outputs.cpu().sort(descending=True)[1][:,:10]
            labels_list = labels.tolist()

            correct = [len(np.intersect1d(predictions[j], labels_list[j])) for j in range(len(predictions))]
            training_accuracy = sum(correct)/float(hyperparams['batch_size'])
            training_accuracies.append(training_accuracy)

            all_predictions.update(sum([p.tolist()[:10] for p in predictions], []))
            all_labels.update(labels_list)

            experiment.log_metric("train_acc", float(training_accuracy))
            experiment.log_metric("loss", float(loss_size.item()))
            
            if i%eval_freq==0:
                metrics = evaluate(net, validation_data, topn=10, gpu=gpu)
                metrics2 = evaluate(net, validation_data_single, loss=loss, topn=10, gpu=gpu)

                validation_losses.extend(metrics2['loss'])
                print(metrics)
                print(metrics2)
                experiment.log_metric("acc", float(metrics['accuracy']))
                avg_eval_loss = float(sum(metrics2['loss']))/len(metrics2['loss'])
                experiment.log_metric("eval_loss", avg_eval_loss)
                experiment.log_metric("MAP", float(metrics['MAP']))
                experiment.log_metric("MAR", float(metrics['MAR']))
                experiment.log_metric("MRR", float(metrics['MRR']))
                
                if float(metrics['accuracy']) > max_acc:
                    max_acc = float(metrics['accuracy'])
                    with open("best_modelx.pkl", "wb+") as f:
                        pickle.dump(net, f)
                
                scheduler.step()
                net.train()
            accuracies.append(metrics['accuracy'])

            
            if i > max_batches:
                break
            print("batch", i, "epoch", e, "learning rate", optimizer.param_groups[0]['lr'])

    if log_experiment:
        return experiment, net
    else:
        return net

net=train(hyperparams, training_dataset, validation_dataset, validation_dataset_single, gpu=True, 
      log_experiment=False)

In [None]:
def evaluate_per_article(net, data, topn=10, loss=None,  gpu=False, max_batches=1000, cutoff=10):
    loader = DataLoader(data, batch_size=50,
                        shuffle=False, num_workers=0)
    outputs_for_article = {}
    labels_for_article = {}
    total_correct = 0
    total = 0
    losses = []
    precisions = []
    recalls = []
    RRs = []
    total_correct = 0
    total = 0
    for i, batch in enumerate(loader):
        features, labels, articles = batch

        if gpu:
            outputs = net([f.cuda() for f in features])
        else:
            outputs = net(features)
        if loss:
            if gpu:
                loss_size = loss(outputs, labels.cuda())
            else:
                loss_size = loss(outputs, labels)
        for i, article in enumerate(articles):
            output = outputs[i]
            if article not in outputs_for_article:
                outputs_for_article[article] = output
            else:
                outputs_for_article[article] = torch.cat((outputs_for_article[article], output))
            labels_for_article[article] = [l[0] for l in labels[i].nonzero().tolist()]
            
  
    for article in outputs_for_article:
        outputs = outputs_for_article[article]
        labels = labels_for_article[article]
        probabilities, predictions = outputs.cpu().sort(descending=True)
        # readjust indices to correspond to classes. the vector got longer than the number
        # of classes when you concatenated everything
        predictions = [p.item()%data.article_data.nr_labels for p in predictions]
        top_predictions = []
        for p in predictions:
            if p not in top_predictions:
                top_predictions.append(p)
            if len(top_predictions) >= topn:
                break

                precision, recall, RR, _, _ = average_precision_recall(top_predictions, labels)
        precisions.append(precision)
        recalls.append(recall)
        RRs.append(RR)

        correct = np.intersect1d(top_predictions, labels)
        if len(correct)>0:
            total_correct += 1
        total += 1

        MAP = float(sum(precisions))/len(precisions) if precisions else None
    MAR = float(sum(recalls))/len(recalls) if recalls else None
    MRR = float(sum(RRs))/len(RRs) if RRs else None
    accuracy = float(total_correct)/total
    return {'accuracy': accuracy, 'loss': losses, 'MAP': MAP, 'MAR': MAR, 'MRR': MRR}

In [None]:
trained_net = pickle.load(open("datasets_serialized/model_best_922.pkl", "rb"))

In [None]:
# Test on the test set
print(evaluate_per_article(trained_net.cuda(), test_dataset, gpu=True, max_batches=100000))
print(evaluate(trained_net.cuda(), test_dataset, gpu=True, max_batches=100000))

In [None]:
from comet_ml import Optimizer
optimizer = Optimizer("") # key here
# Declare your hyper-parameters:
params = """
learning_rate real [0.00001, 0.01] [0.0001]
batch_size integer [100, 512] [256]
hidden_layers categorical {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} [500]
weight_decay real [0.001, 0.1] [0.01]
lr_decrease_step integer [1, 20] [5]
lr_decrease_rate integer [2, 10] [10]
filters categorical {50, 70, 100, 150, 175, 200} [200]
references_hidden_dim categorical {10, 50, 100, 200, 300, 500, 700, 1000, 1200, 1500, 1700, 2000} [500]
dropout real [0.0, 0.5] [0.4]
optimizer categorical {adam, adagrad, sgd} [adam]
"""
"""
ignore_features categorical {[False, False, False, False], [False, True, True, True], [True, False, True, True], [True, True, False, True], [True, True, True, False]} [[False, False, False, False]]
"""
optimizer.set_params(params)

all_losses = []
all_accuracies = []
all_training_accuracies = []
all_validation_losses = []
all_hyperparams = []

while True:
    losses = []
    accuracies = []
    training_accuracies = []
    validation_losses = []
    # Get a suggestion
    suggestion = optimizer.get_suggestion()


    # Test the model
    hyperparams["learning_rate"] = suggestion["learning_rate"]
    hyperparams["batch_size"] = suggestion["batch_size"]
    hyperparams["hidden_layers"] = suggestion["hidden_layers"]
    hyperparams["weight_decay"] = suggestion["weight_decay"]
    hyperparams["lr_decrease_step"] = suggestion["lr_decrease_step"]
    hyperparams["lr_decrease_rate"] = suggestion["lr_decrease_rate"]  
    hyperparams["optimizer"] = suggestion["optimizer"]
    hyperparams["references_hidden_dim"] = suggestion["references_hidden_dim"]
    hyperparams["dropout"] = suggestion["dropout"]
    experiment, net = train(hyperparams, training_dataset, validation_dataset, validation_dataset_single, 
                       gpu=True, experiment_tags=["tune", "correct_conv", "ablation"])
    score = sum(accuracies[-100:])/len(accuracies[-100:])
    # Report the score back
    suggestion.report_score("avg_accuracy",score)
    
    all_losses.append(losses)
    all_accuracies.append(accuracies)
    all_validation_losses.append(validation_losses)
    all_training_accuracies.append(training_accuracies)
    all_hyperparams.append(hyperparams)
    with open("experiments/experiment" + experiment.id, "wb+") as log: 
        pickle.dump({'losses': losses, 'accuracies': accuracies, 
                     'training_accuracies': training_accuracies, 'hyperparams': hyperparams}, log)

### Author statistics

In [None]:
def get_authors_list(dataset):
    # TODO: issue: this doesn't remove occurrences of authors from articles
    # where all the co-authors were removed so the article is no longer in
    # the dataset?
    # Actually not a problem cause we keep articles authored by any number of
    # authors as long as there is at least one.
    return Counter(sum([[a for a in p['authors'] if a in dataset.article_data.classes]
                        for p in dataset.article_data.authors.values()], 
                  []))

def get_references(dataset):
    # TODO: issue: this doesn't remove occurrences of authors from articles
    # where all the co-authors were removed so the article is no longer in
    # the dataset?
    # Actually not a problem cause we keep articles authored by any number of
    # authors as long as there is at least one.
    references_inexp = {p: v for (p, v) in dataset.article_data.references.items()
                        if p in dataset.article_data.article_indices}
    return Counter(ArticleTextData._normalize_authors([a for a in sum(sum([(list(p.values())) 
                            for p in references_inexp.values()], []), 
                       []) if a in dataset.article_data.classes]))


training_authors_freq = get_authors_list(training_dataset)
test_authors_freq = get_authors_list(test_dataset)
validation_authors_freq = get_authors_list(validation_dataset)

training_authors_freq.update(test_authors_freq)
training_authors_freq.update(validation_authors_freq)
training_authors_freq.most_common()

training_references = get_references(training_dataset)
test_references = get_references(test_dataset)
validation_references = get_references(validation_dataset)

training_references.update(test_references)
training_references.update(validation_references)
training_references.most_common()


In [None]:
all_authors_list = list(training_authors_freq.keys())
print(len(all_authors_list))

# Removing middle names
all_authors_seminormalized = [" ".join([a.split(" ")[0], a.split(" ")[-1]]) 
                              for a in all_authors_list]
all_authors_normalized = ArticleTextData._normalize_authors(all_authors_seminormalized)
print(len(all_authors_normalized))

collisions_counter = Counter(all_authors_normalized)
collisions_counter.most_common()
# [a for a,f in collisions_counter.items() if f > 1]
collisions_emnlp = [a for a in all_authors_list
              if collisions_counter[ArticleTextData._normalize_authors([a])[0]] > 1]
collisions_emnlp = sorted(collisions_emnlp, key=lambda n: n.split()[-1])
# open("collisions_emnlp", "w+").write("\n".join(collisions_emnlp))

In [None]:
with open("collisions_emnlp", "r") as f:
    collisions_emnlp_full = f.read().split("\n")
[a for a in collisions_emnlp_full if a in training_dataset.article_data.classes]

In [None]:
authors_freq_acl = training_authors_freq

In [None]:
authors_freq_emnlp = training_authors_freq

In [None]:
fig, ax = plt.subplots(1,2)
fig.set_size_inches(8,4)
ax[0].bar(range(len(authors_freq_acl)), 
        sorted(authors_freq_acl.values(), reverse=True), 
        width=2.75, log=False)
ax[0].set_xlabel('ACL authors')
ax[0].set_ylabel('Articles')

ax[1].bar(range(len(authors_freq_emnlp)), 
        sorted(authors_freq_emnlp.values(), reverse=True), 
        width=2.75, log=False)
ax[1].set_xlabel('EMNLP authors')


In [None]:
with open("acl_author_freqs.csv", "w+") as fa:
    for a, f in authors_freq_acl.most_common():
        fa.write(a + "," + str(f) + "\n")

with open("emnlp_author_freqs.csv", "w+") as fa:
    for a, f in authors_freq_emnlp.most_common():
        fa.write(a + "," + str(f) + "\n")