In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = '/Users/anthonydavid/Workspace/Openclassrooms/projet_5/2024-08-27_23-50-35-bert-model'

# Choix automatique du device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
    
print(device)

mps


In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained(
    model_path
)
model.to(device);

In [4]:
from transformers import TextClassificationPipeline
pipe = TextClassificationPipeline(
    model=model, 
    tokenizer=tokenizer, 
    return_all_scores=False, 
    device=device,
    task="multi_label_classification",
    function_to_apply='sigmoid'
)

In [5]:
example_text = 'i read in tcp ip network administration by o that typing the route n command should bring up a routing table when i typed it into the terminal on a mac it returned the following usage route dnqtv command modifiers args what is the correct command to use to see the routing table in my terminal'

In [6]:
# Apparence des resultats du pipe
pipe(example_text)

[{'label': 'angular', 'score': 0.38646113872528076}]

In [7]:
%%time
# Apparence des resultats du pipe lorsque l'on demande 5 reponses
pipe('how do i install python on arch linux ? i cant understand the docs', top_k=5)

CPU times: user 85.1 ms, sys: 14.1 ms, total: 99.1 ms
Wall time: 349 ms


[{'label': 'python', 'score': 0.9854111671447754},
 {'label': 'visual-studio-code', 'score': 0.15291842818260193},
 {'label': 'python-3.x', 'score': 0.1475229412317276},
 {'label': 'django', 'score': 0.05575046315789223},
 {'label': 'docker', 'score': 0.012701679021120071}]

# Wrapper qui recommandera les resultats les plus pertinents

In [8]:
def pred_fn(text, pipeline, thresh=0.5, max_answers=10):
    pipe_output = pipeline(text, top_k=max_answers)
    recommended_tags = [
        dict_output['label'] for dict_output in pipe_output if dict_output['score'] > thresh
    ]
    
    return recommended_tags

In [9]:
pred_fn('how do i install python on arch linux ? i cant understand the docs', pipe)

['python']

In [10]:
pred_fn(example_text, pipe)

[]

## Entry Point

In [11]:
def pred_ep(text):
    """
    Predicts the most relevant tags for a given text.

    Args:
    - text (str): The input text.

    Returns:
    - list: A list of recommended tags.
    """
    # Set the threshold and maximum number of answers
    thresh = 0.5
    max_answers = 10

    # Use the pipeline to get predictions
    pipe_output = pipe(text, top_k=max_answers)

    # Filter the results based on the threshold
    recommended_tags = [
        dict_output['label'] for dict_output in pipe_output if dict_output['score'] > thresh
    ]
    
    return recommended_tags

In [12]:
# Test the pred_ep with example inputs
print(pred_ep('how do i install python on arch linux? I canâ€™t understand the docs'))
print(pred_ep('i read in tcp ip network administration by o that typing the route n command should bring up a routing table when i typed it into the terminal on a mac it returned the following usage route dnqtv command modifiers args what is the correct command to use to see the routing table in my terminal'))

['python']
[]


In [13]:
print(pred_ep('how do i install python in a docker container?'))

['python', 'docker']
