In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model_path = '/Users/anthonydavid/Documents/Etudes/alternance_ML_engineer/OpenClassrooms/projets/projet_5/2024-08-19_15-10-53-bert-model'

# Choix automatique du device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
    
print(device)

mps


In [6]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained(
    model_path
)
model.to(device);

In [7]:
from transformers import TextClassificationPipeline
pipe = TextClassificationPipeline(
    model=model, 
    tokenizer=tokenizer, 
    return_all_scores=False, 
    device=device,
    task="multi_label_classification",
    function_to_apply='sigmoid'
)

In [8]:
example_text = 'i read in tcp ip network administration by o that typing the route n command should bring up a routing table when i typed it into the terminal on a mac it returned the following usage route dnqtv command modifiers args what is the correct command to use to see the routing table in my terminal'

In [9]:
# Apparence des resultats du pipe
pipe(example_text)

[{'label': 'javascript', 'score': 0.10410106927156448}]

In [10]:
%%time
# Apparence des resultats du pipe lorsque l'on demande 5 reponses
pipe('how do i install python on arch linux ? i cant understand the docs', top_k=5)

CPU times: user 87.5 ms, sys: 49.3 ms, total: 137 ms
Wall time: 427 ms


[{'label': 'python', 'score': 0.5512714385986328},
 {'label': 'r', 'score': 0.18645761907100677},
 {'label': 'pandas', 'score': 0.15008153021335602},
 {'label': 'dataframe', 'score': 0.11528932303190231},
 {'label': 'numpy', 'score': 0.06310949474573135}]

# Wrapper qui recommandera les resultats les plus pertinents

In [11]:
def pred_fn(text, pipeline, thresh=0.5, max_answers=10):
    pipe_output = pipeline(text, top_k=max_answers)
    recommended_tags = [
        dict_output['label'] for dict_output in pipe_output if dict_output['score'] > thresh
    ]
    
    return recommended_tags

In [12]:
pred_fn('how do i install python on arch linux ? i cant understand the docs', pipe)

['python']

In [13]:
pred_fn(example_text, pipe)

[]

## Entry Point

In [14]:
def pred_ep(text):
    """
    Predicts the most relevant tags for a given text.

    Args:
    - text (str): The input text.

    Returns:
    - list: A list of recommended tags.
    """
    # Set the threshold and maximum number of answers
    thresh = 0.5
    max_answers = 10

    # Use the pipeline to get predictions
    pipe_output = pipe(text, top_k=max_answers)

    # Filter the results based on the threshold
    recommended_tags = [
        dict_output['label'] for dict_output in pipe_output if dict_output['score'] > thresh
    ]
    
    return recommended_tags

In [15]:
# Test the pred_ep with example inputs
print(pred_ep('how do i install python on arch linux? I can’t understand the docs'))
print(pred_ep('i read in tcp ip network administration by o that typing the route n command should bring up a routing table when i typed it into the terminal on a mac it returned the following usage route dnqtv command modifiers args what is the correct command to use to see the routing table in my terminal'))

['python']
[]
