In [1]:
import os
import pandas as pd
import numpy as np
import torch
import pytorch_lightning as pl

In [2]:
import gensim.downloader as dl
from gensim.models import KeyedVectors

In [3]:
pretrained_weights_name = "word2vec-google-news-300"
model_dl_path = os.path.join(
    dl.BASE_DIR, pretrained_weights_name, f"{pretrained_weights_name}.gz")


if os.path.exists(model_dl_path):
    # load model
    print(f"Loading model from {model_dl_path}")
    gnews_embeddings = dl.load(pretrained_weights_name)
else:
    # download
    print(f"Model will be downloaded at {model_dl_path}")
    gnews_embeddings = dl.load("word2vec-google-news-300")


Loading model from /home/shawon/gensim-data/word2vec-google-news-300/word2vec-google-news-300.gz


In [4]:
vocabulary = gnews_embeddings.index_to_key
vocab_len = len(vocabulary)
vocab_len

3000000

In [5]:
# https://github.com/Oneplus/Tweebank

train_file = os.path.join(
    "/mnt/Others/experiments/datasets/Tweebank-dev/converted/"
    "en-ud-tweet-train.fixed.conllu")

# assert os.path.exists(train_file)

with open(train_file) as f:
    data = f.readlines()

# data

In [6]:
# break line at every "\n"
tweets = list()
buffer = list()
for idx, tw in enumerate(data):
    if tw == "\n":
        # one partition here
        tweets.append(buffer)
        buffer = []
    else:
        # keep appending
        buffer.append(tw)
        
tweets[0]

['# tweet_id = feb_jul_16.1463316480\n',
 "# text = RT @USER991: Dear diary,       I've been rapping in 3 accents and no longer know which one is truly mine. I am a sadting - Drake URL217…\n",
 '1\tRT\trt\tX\t_\t_\t10\tdiscourse\t_\t_\n',
 '2\t@USER991\t@USER\tX\t_\t_\t1\tdiscourse\t_\tSpaceAfter=No\n',
 '3\t:\t:\tPUNCT\t_\t_\t1\tpunct\t_\t_\n',
 '4\tDear\tdear\tADJ\t_\t_\t5\tamod\t_\t_\n',
 '5\tdiary\tdiary\tNOUN\t_\t_\t10\tvocative\t_\tSpaceAfter=No\n',
 '6\t,\t,\tPUNCT\t_\t_\t10\tpunct\t_\t_\n',
 '7\tI\ti\tPRON\t_\t_\t10\tnsubj\t_\tSpaceAfter=No\n',
 "8\t've\t've\tAUX\t_\t_\t10\taux\t_\t_\n",
 '9\tbeen\tbe\tAUX\t_\t_\t10\taux\t_\t_\n',
 '10\trapping\trap\tVERB\t_\t_\t0\troot\t_\t_\n',
 '11\tin\tin\tADP\t_\t_\t13\tcase\t_\t_\n',
 '12\t3\tNUMBER\tNUM\t_\t_\t13\tnummod\t_\t_\n',
 '13\taccents\taccent\tNOUN\t_\t_\t10\tobl\t_\t_\n',
 '14\tand\tand\tCCONJ\t_\t_\t17\tcc\t_\t_\n',
 '15\tno\tno\tADV\t_\t_\t16\tadvmod\t_\t_\n',
 '16\tlonger\tlonger\tADV\t_\t_\t17\tadvmod\t_\t_\n',
 '17\tknow\

In [7]:
# format for tokens
# number - word - lemma - pos - _ - _ - id - role, -, - 

'4\tDear\tdear\tADJ\t_\t_\t5\tamod\t_\t_\n'.split("\t")


['4', 'Dear', 'dear', 'ADJ', '_', '_', '5', 'amod', '_', '_\n']

In [24]:
# need idx 1, 2,3 : word, lemma and pos

class ConlluRowInfo:
    word: str
    lemma: str
    pos: str
    
    def __init__(self, word: str, lemma: str, pos: str) -> None:
        self.word = word
        self.lemma = lemma
        self.pos = pos
        
    def __str__(self) -> str:
        rep = {
            "word": self.word,
            "lemma": self.lemma,
            "pos": self.pos
        }
        return str(rep)

In [9]:
from typing import List

class ConlluRow:
    info: List[ConlluRowInfo]
    # text: str
    
    def __init__(self, infos: List[ConlluRowInfo]) -> None:
        self.info = infos
        
    def __str__(self) -> str:
        return f"info : {self.info}"

In [10]:
structured_tweets = list()

for tweet in tweets:
    # text = tweet[1].replace("# text = ", "")
    info_in_tweet = list()
    for infos in tweet[2:]:
        buffer = infos.split("\t")
        try:
            word = buffer[1]
            lemma = buffer[2]
            tag = buffer[3]
            info_in_tweet.append(ConlluRowInfo(word, lemma, tag))
        except IndexError:
            print(buffer)
        except AttributeError as e:
            print(e.name)
    structured_tweets.append(ConlluRow(info_in_tweet))  

In [47]:
# time to define the torch dataset

from torch.utils.data import Dataset
from tqdm._tqdm import trange, tqdm

class TweebankDataset(Dataset):
    def __init__(self, filename, w2v_weights=gnews_embeddings) -> None:
        self.filename = filename
        
        self.w2v = w2v_weights
        self.data = None
        self.__read_data()
        
        # self.max_seq_len = 0
        
        self.UNIQUE_TAGS = ['PRON', 'NUM', 'NOUN', 'CCONJ', 'ADV', 'SCONJ', 
                               'ADP', 'AUX', 'PROPN', 'SYM', 'DET', 
                               'INTJ', 'PUNCT', 'X', 'ADJ', 'VERB', 'PART']
        self.tag_dict = dict()
        self.__encode_tags()
        
        self.number_tags = len(self.UNIQUE_TAGS)
        
        self.vocabulary = self.w2v.index_to_key  # type: ignore
            
    
    def __len__(self) ->  int:
        return len(self.data)
    
    def __getitem__(self, idx):
        # ============== collect ===================
        words = [i.word for i in self.data[idx].info]
        lemmas = [i.lemma for i in self.data[idx].info]
        tags = [i.pos for i in self.data[idx].info]
                
        
        # =================== convert using word2vec weights ==========
        for idx in range(len(words)):
            try:
                w2v_idx = self.w2v.key_to_index[words[idx]]
            except KeyError:
                w2v_idx = 0
            words[idx] = w2v_idx
            tags[idx] = self.tag_dict[tags[idx]]
        
        return {
            "words": words,
            "lemmas": lemmas,
            "tags": tags
        }
        
    def __encode_tags(self):
        for idx, tag in enumerate(self.UNIQUE_TAGS):
            self.tag_dict[tag] = idx
        
    def __read_data(self):
        with open(self.filename, "r") as f:
            data = f.readlines()
            
            # ============ read the text file =============
            lines = list()
            buffer = list()
            for _, line in tqdm(enumerate(data)):
                if line == "\n":
                    lines.append(buffer)
                    buffer = []
                else:
                    buffer.append(line)
                    
            # ============== organize in objects ==============
            for idx, line in tqdm(enumerate(lines)):
                # from line index 2 and onwards
                line_info = list()
                for info in line[2:]:
                    buffer = info.split("\t")
                
                    try:
                        word = buffer[1]
                        lemma = buffer[2]
                        tag = buffer[3]
                        
                        line_info.append(ConlluRowInfo(word, lemma, tag))
                        
                    except IndexError:
                        print(buffer)
                        
                
                lines[idx] = ConlluRow(line_info)    

            self.data = lines


In [48]:
dataset = TweebankDataset(train_file)
print(dataset[0])

29670it [00:00, 4080163.92it/s]
1639it [00:00, 97075.01it/s]

{'words': [31905, 0, 0, 12654, 14263, 0, 20, 190, 42, 40105, 1, 234, 22860, 0, 86, 951, 177, 48, 45, 4, 2604, 2747, 0, 20, 248, 0, 0, 0, 10297, 0, 0], 'lemmas': ['rt', '@USER', ':', 'dear', 'diary', ',', 'i', "'ve", 'be', 'rap', 'in', 'NUMBER', 'accent', 'and', 'no', 'longer', 'know', 'which', 'one', 'be', 'truly', 'mine', '.', 'i', 'be', 'a', 'sadting', '-', 'drake', 'URL', '…'], 'tags': [13, 13, 12, 14, 2, 12, 0, 7, 7, 15, 6, 1, 2, 3, 4, 4, 15, 10, 1, 7, 4, 0, 12, 0, 7, 10, 2, 12, 8, 13, 12]}





In [46]:
# https://stackabuse.com/python-how-to-flatten-list-of-lists/


# import itertools

# all_tags = [data["tags"] for data in dataset]
# all_tags = list(itertools.chain(*all_tags))
# unique_tags = set(all_tags)
# print(list(unique_tags))

['PRON', 'NUM', 'NOUN', 'CCONJ', 'ADV', 'SCONJ', 'ADP', 'AUX', 'PROPN', 'SYM', 'DET', 'INTJ', 'PUNCT', 'X', 'ADJ', 'VERB', 'PART']
