In [1]:
from abc import ABC, abstractmethod
from collections import Counter

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import make_pipeline
from sklearn.metrics import accuracy_score
from pickle import dump, load
from typing import List

#### Classifier learning

In [2]:
categories = fetch_20newsgroups().target_names
twenty_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)
twenty_test = fetch_20newsgroups(subset='test', categories=categories, shuffle=True, random_state=42)

pipe = make_pipeline(
    CountVectorizer(),
    TfidfTransformer()
)

clf = SGDClassifier(loss='modified_huber', penalty='l2', alpha=1e-4, random_state=42, max_iter=10, tol=None)

data = pipe.fit_transform(twenty_train.data)
clf.fit(data, twenty_train.target)

data_test = pipe.transform(twenty_test.data)

print('accuracy train:', accuracy_score(clf.predict(data), twenty_train.target))
print('accuracy test:', accuracy_score(clf.predict(data_test), twenty_test.target))

dump(pipe, open('pipe.pkl', 'wb'))
dump(clf, open('clf.pkl', 'wb'))

accuracy train: 0.9990277532260916
accuracy test: 0.852761550716941


#### Stop words & common words

In [6]:
# If on colab, uncomment the following line:
# ! wget https://raw.githubusercontent.com/igorbrigadir/stopwords/master/en/alir3z4.txt --output-document=stopwords.txt
with open('stopwords.txt') as stop_words_file:
    STOP_WORDS_ALIR3Z4 = stop_words_file.read().split('\n')

# If on colab, uncomment the following line:
# ! wget https://raw.githubusercontent.com/first20hours/google-10000-english/master/google-10000-english-no-swears.txt --output-document=popular-words.txt

with open('popular-words.txt') as popular_words_file:
    POPULAR_WORDS = popular_words_file.read().split('\n')

POPULAR_TAGS = list(set(POPULAR_WORDS) - set(STOP_WORDS_ALIR3Z4))

--2022-05-16 09:57:31--  https://raw.githubusercontent.com/igorbrigadir/stopwords/master/en/alir3z4.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7678 (7.5K) [text/plain]
Saving to: ‘stopwords.txt’


2022-05-16 09:57:31 (74.0 MB/s) - ‘stopwords.txt’ saved [7678/7678]

--2022-05-16 09:57:31--  https://raw.githubusercontent.com/first20hours/google-10000-english/master/google-10000-english-no-swears.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 75153 (73K) [text/plain]
Saving to: ‘popular-words.txt’


2022-05-16 09:57:31 

#### Tagging implementation

In [7]:
class BaseTagger(ABC):

    def get_tags(self, texts: List[str]) -> List[List[str]]:
        """['Text1', 'Text2', ...] -> [['text1_tag1', text1_tag2], ...]"""
        result = []
        for text in texts:
            tags = self.get_tags_from_text(text)
            result.append(tags)
        return result
    
    def text_to_words(self, text: str) -> List[str]:
        words = ''.join(
            (c if c in self.words_alphabet else ' ')
            for c in text.lower()
        ).split()
        return words 


class BasePredefinedTagsTagger(BaseTagger, ABC):
    def __init__(self, tags: List[str]):
        self.tags = tags

In [8]:
class MostFrequentWordsTagger(BaseTagger):
    default_stop_words = STOP_WORDS_ALIR3Z4
    words_alphabet = 'abcdefghijklmnopqrstuvwxyz-\''

    def __init__(self, stop_words: list = None, max_tags_per_text: int = 5):
        self.stop_words = stop_words or self.default_stop_words
        self.max_tags_per_text = max_tags_per_text

    def get_tags_from_text(self, text: str) -> List[str]:
        words = self.text_to_words(text)
        words = [
            word
            for word in words
            if word not in self.stop_words and len(word) > 2
        ]

        words_counter = Counter(words)

        # TODO improve heuristics
        tags = []
        result = words_counter.most_common()
        if len(result) == 0:
            return []

        word, max_count = result[0]
        i = 0
        while result[i][1] == max_count:
            tags.append(result[i][0])
            i += 1

        return tags[:self.max_tags_per_text]

In [9]:
class FindSpecialWordsTagger(BasePredefinedTagsTagger):

    default_tags_candidates = POPULAR_TAGS
    words_alphabet = 'abcdefghijklmnopqrstuvwxyz-\''

    def __init__(self, tags: List[str] = None, max_tags_per_text: int = 5):
        super().__init__(tags=tags or self.default_tags_candidates)
        self.max_tags_per_text = max_tags_per_text

    def get_tags_from_text(self, text: str) -> List[str]:
        words = self.text_to_words(text)
        words = [word for word in words if len(word) > 2]

        found_tags = []
        for tag in self.tags:
            found_tags.append((tag, words.count(tag)))

        found_tags.sort(key=lambda o: o[1], reverse=True)
        found_tags = found_tags[:self.max_tags_per_text]

        return [tag for tag, count in found_tags]

In [10]:
class SGDClassifierTagger(BasePredefinedTagsTagger):
    default_tags_candidates = fetch_20newsgroups().target_names

    def __init__(self, tags: List[str] = None):

        super().__init__(tags=tags or self.default_tags_candidates)
        
        self.pipe = load(open('pipe.pkl', 'rb'))
        self.clf = load(open('clf.pkl', 'rb'))

    def get_tags(self, texts: List[str], threshold=0.0,  verbose=False) -> List[List[str]]:
                
        X_new = self.pipe.transform(texts)

        predicted = self.clf.predict(X_new)
        probas = self.clf.predict_proba(X_new)
        
        tags = [[self.default_tags_candidates[j]] if probas[i][predicted[i]] > threshold else ['no suitable tag'] for i, j in enumerate(predicted)]
        
        if verbose==True:
            print('First choice tags with probas:')
            tags_verbose = [[self.default_tags_candidates[i]] for i in predicted]
            probas = [probas[i][predicted[i]] for i, _ in enumerate(predicted)]
            for i in range(len(texts)):
                print(f'{i}. tag: {tags_verbose[i]}, proba: %0.3f' % probas[i])   
            print()
                
        return tags

#### Demonstrating tests

In [11]:
example = '''
In software engineering, a software design pattern is a general, reusable solution to a commonly occurring problem 
within a given context in software design. It is not a finished design that can be transformed directly into source 
or machine code. Rather, it is a description or template for how to solve a problem that can be used in many different 
situations. Design patterns are formalized best practices that the programmer can use to solve common problems 
when designing an application or system.

Object-oriented design patterns typically show relationships and interactions between classes or objects, 
without specifying the final application classes or objects that are involved. Patterns that imply mutable state may be 
unsuited for functional programming languages. Some patterns can be rendered unnecessary in languages that have built-in 
support for solving the problem they are trying to solve, and object-oriented patterns are not necessarily suitable 
for non-object-oriented languages.

Design patterns may be viewed as a structured approach to computer programming intermediate between the levels 
of a programming paradigm and a concrete algorithm.
'''

In [12]:
print(MostFrequentWordsTagger().get_tags([example]))

[['design', 'patterns']]


In [13]:
print(FindSpecialWordsTagger().get_tags([example]))

[['patterns', 'design', 'software', 'languages', 'programming']]


In [14]:
print(SGDClassifierTagger().get_tags(['God is love', 'OpenGL on the GPU is fast']))

[['soc.religion.christian'], ['rec.autos']]


In [15]:
print(SGDClassifierTagger().get_tags(['God is love', 'OpenGL on the GPU is fast'], threshold=0.5))

[['soc.religion.christian'], ['no suitable tag']]


In [16]:
print(SGDClassifierTagger().get_tags(['God is love', 'OpenGL on the GPU is fast'], threshold=0.5, verbose=True))

First choice tags with probas:
0. tag: ['soc.religion.christian'], proba: 0.620
1. tag: ['rec.autos'], proba: 0.240

[['soc.religion.christian'], ['no suitable tag']]


In [17]:
print(SGDClassifierTagger().get_tags([example]))

[['sci.electronics']]


In [18]:
print(SGDClassifierTagger().get_tags([example], threshold=0.5))

[['no suitable tag']]


In [19]:
print(SGDClassifierTagger().get_tags([example], threshold=0.5, verbose=True))

First choice tags with probas:
0. tag: ['sci.electronics'], proba: 0.259

[['no suitable tag']]
