In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from multiprocessing import Pool
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
from TGA.utils import Dataset
from segtok import segmenter as sg_sgm
from segtok import tokenizer as sg_tkn
from string import punctuation

In [2]:
dataset = Dataset('/home/mangaravite/Documentos/datasets/classification/datasets/20ng/')
g = dataset.get_fold_instances(5, with_val=True)
fold = next(g)
fold._fields, len(fold.X_train)

(('X_train', 'y_train', 'X_test', 'y_test', 'X_val', 'y_val'), 11296)

In [3]:
pont = set(punctuation)
pont.add('≤')
pont.add('…')
pont.add('®')
pont.add('@')
pont.add('•')
pont.add('')
_ = list(map(pont.add, "]-------$ ]-------$> >"))


In [4]:
def preprocess(text):
    text = [ [ preprocess_term(term) for term in sg_tkn.web_tokenizer(sentence) ] for sentence in sg_sgm.split_multi(text) ]
    ntext = []
    for sentence in text:
        buffer = []
        for word in sentence:
            if word is None:
                word = [ ]
            if isinstance(word, str):
                word = [ word ]
            buffer.extend(list(word))
        ntext.append(buffer)
    return ntext

def preprocess_term(term):
    term  = term.lower()
    len_t = len(term)
    if not len_t:
        return None
    if term.isalpha():
        return term
    if all([ c in pont for c in term ]):
        if len_t > 1:
            if len_t > 2:
                term = Counter(term).most_common()[0]
            term = term[0]
        return term
    if is_number(term):
        term = term.replace(',', '')
        term = re.sub('\d', 'd', term)
        return term
    if re.search(".+'(s|ll|t|ve|d|re|m|l|r|v)", term):
        _term,part = term.rsplit("'", 1)
        term = preprocess_term(term[:-len(part)+1])
        term = (_term, "'"+part)
        return term
    if term.endswith("s'"):
        term = preprocess_term(term[:-2])
        return (term, "'s")
    if all([ c.isalpha() or c == '-' for c in term ]):
        term = tuple( [ preprocess_term(subword) for subword in term.split('-') ] )
        return term
    if any( [ term.endswith(t) for t in pont  ] ):
        term = preprocess_term(term[:-1])
        return term
    term = re.sub('\d', 'D', term)
    return term

def is_number(term):
    return all([ c.isnumeric() or c in ('.', ',') for c in term] ) and any([ c.isnumeric() for c in term] )

In [52]:
def fit_tokenizer(texts):
    vocab = {}
    counter = Counter()
    for doc in map(preprocess, tqdm(texts)):
        for sent in doc:
            terms_ids = [ vocab.setdefault( term, len(vocab)+1 ) for term in sent ]
            counter.update( set(terms_ids) )
            
    vocab['<UNK>'] = len(vocab)+1
    return vocab, counter
def tokenizer(texts, vocab):
    for doc in map(preprocess, tqdm(texts)):
        yield [ [ vocab.get(term, vocab['<UNK>'] ) for term in sent ] for sent in doc if len(sent) > 0 ]
vocab, counter = fit_tokenizer(fold.X_train)

  0%|          | 0/11296 [00:00<?, ?it/s]

In [79]:
def collate(X):
    doc_tids  = list(tokenizer(X, vocab))
    sent_len  = [ [ len(sent) for sent in doc ] for doc in doc_tids ]
    sent_tids = [ sent for doc in doc_tids for sent in doc  ]
    sent_tids = pad_sequence(list(map(torch.LongTensor, sent_tids)), batch_first=True, padding_value=0)
    return sent_tids, sent_len
def collate_train(param):
    X, y = zip(*param)
    return collate(X), torch.LongTensor(y)

In [None]:
ab = ClassifierSentence(len(vocab_size), 300, ,
                       initrange=0.3, drop=drop).to( device )
optimizer = optim.AdamW( ab.parameters(), lr=5e-3, weight_decay=5e-3)
loss_func_cel = nn.CrossEntropyLoss().to( device )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.95,
                                                       patience=3, verbose=True)

In [63]:
doc_tids

[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 5, 14, 15, 16, 17, 18, 19, 20],
  [21, 22, 23, 24, 25, 26, 27, 28, 5, 29, 30, 31, 32, 33, 20],
  [34, 35, 36, 37, 18, 38, 39, 20],
  [5, 5, 8, 5, 40, 41, 24, 42, 43, 44, 45, 46, 47, 48, 49, 50, 9, 51, 23, 20],
  [52,
   53,
   54,
   55,
   56,
   57,
   34,
   7,
   58,
   59,
   24,
   60,
   61,
   15,
   24,
   62,
   63,
   64,
   65,
   66,
   67,
   68,
   69,
   28,
   29,
   70,
   71,
   72,
   24,
   61,
   73,
   31,
   74,
   75,
   76,
   11,
   20],
  [24, 77, 63, 41, 9, 48, 78],
  [79, 21, 80, 69, 81, 82, 20],
  [58, 83, 84, 74, 85, 86, 69, 87, 88, 89, 1, 38, 90, 91, 25, 20],
  [92,
   43,
   93,
   94,
   95,
   14,
   64,
   96,
   97,
   24,
   98,
   99,
   100,
   101,
   102,
   103,
   69,
   24,
   104,
   105,
   106,
   107,
   20],
  [24, 108, 15, 84, 109, 110, 36, 111, 112, 113, 41, 96, 20]],
 [[114,
   115,
   7,
   116,
   11,
   117,
   13,
   13,
   34,
   70,
   118,
   24,
   119,
   120,
   121,
   24,
   