In [1]:
from typing import List, Optional, Tuple,Dict
import torch
import outlines
from enum import Enum
from pydantic import BaseModel, constr, conint
from dotenv import load_dotenv
import os

load_dotenv()
model = outlines.models.transformers(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    model_kwargs={"torch_dtype": torch.bfloat16, "token": os.getenv("HF_TOKEN")},
    device="cuda",
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:

class Severity(str, Enum):
    minimal = "minimal"
    significant = "significant"


class ProblematicSnippet(BaseModel):
    original_text: str
    severity: Severity
    start: Optional[conint(ge=0)] = None
    end: Optional[conint(ge=0)] = None


class TextAnalysis(BaseModel):
    problematic_snippets: List[ProblematicSnippet]


class Token(BaseModel):
    text: str
    start: conint(ge=0)
    end: conint(ge=0)
    severity: Optional[Severity] = None


generator = outlines.generate.json(
    model, TextAnalysis, sampler=outlines.samplers.multinomial(temperature=0.1)
)

prompt_template = """
{attribute_definition}
You are provided with a comment which contains {attribute}. Output the snippets which lead to this classification.
----------
Comment in question:
{comment}
"""


def generate_text_analysis(attribute_definition: str, attribute: str, comment: str):
    return generator(
        prompt_template.format(attribute_definition=attribute_definition, attribute=attribute, comment=comment))

In [3]:
from datasets import load_dataset

ds = load_dataset("timonziegenbein/appropriateness-corpus")["test"]

In [4]:
import re
from constants import ATTRIBUTE_DEFINITIONS

def analyse_post(post_text: str) -> Dict[str,List[Tuple[str, float]]]:
    results = {}
    for attribute_name, attribute_definition in ATTRIBUTE_DEFINITIONS.items():
        results[attribute_name] = analyse_post_for_attribute(post_text, attribute_name, attribute_definition)
    return results

def analyse_post_for_attribute(post_text: str, attribute_name: str, attribute_definition: str) -> List[Tuple[str, float]]:
    analysis_result = generate_text_analysis(attribute_name, attribute_definition, post_text)
    update_snippet_positions(analysis_result.problematic_snippets, post_text)
    token_list = tokenize_post_text(post_text)
    assign_severity_to_tokens(token_list, analysis_result.problematic_snippets)
    return convert_tokens_to_output(token_list)


def update_snippet_positions(snippets: List[ProblematicSnippet], post_text: str) -> None:
    for snippet in snippets:
        re_result = re.search(snippet.original_text, post_text)
        snippet.start = re_result.start()
        snippet.end = re_result.end()


def tokenize_post_text(post_text: str) -> List[Token]:
    tokenized_text = model.tokenizer.encode(post_text)[0][0]
    seek_start = 0
    token_list = []
    for token_id in tokenized_text:
        text = model.tokenizer.decode([token_id])[0]
        if text:
            token = Token(text=text, start=seek_start, end=seek_start + len(text))
            seek_start += len(text)
            token_list.append(token)
    return token_list


def assign_severity_to_tokens(token_list: List[Token], snippets: List[ProblematicSnippet]) -> None:
    for token in token_list:
        for snippet in snippets:
            if snippet.start <= token.start and token.end <= snippet.end:
                token.severity = snippet.severity


def convert_tokens_to_output(token_list: List[Token]) -> List[Tuple[str, float]]:
    severity_mapping = {Severity.minimal: 0.5, Severity.significant: 1.0}
    return [(token.text, severity_mapping.get(token.severity, 0.0)) for token in token_list]


[('students', 0.0),
 (' should', 0.0),
 (' wear', 0.0),
 (' what', 0.5),
 (' they', 0.5),
 (' like', 0.5),
 (' and', 0.0),
 (' feel', 0.0),
 (' free', 0.5),
 (' about', 0.5),
 (' their', 0.5),
 (' clothes', 0.5)]