In [1]:
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os

In [3]:
tokenizer: AutoTokenizer = None
model: PeftModel = None
model_loaded: bool = False

In [4]:
def load_model():
    global tokenizer, model, model_loaded, load_error
    
    if model_loaded:
        return 
    
    try:
        model_path = "./joke_model_output/final_model"
        base_model_name = "meta-llama/Llama-3.2-3B-Instruct"

        if not os.path.exists(model_path):
            load_error = f"Nie znaleziono modelu w {model_path}"
            return load_error
        

        tokenizer = AutoTokenizer.from_pretrained(model_path)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token


        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3  # GB
            print(f"Dostępna pamięć GPU: {gpu_memory:.1f} GB")
            
            if gpu_memory < 8:  
                device_map = "cpu"
                torch_dtype = torch.float32
                offload_folder = None
            elif gpu_memory < 16:  
                device_map = "auto"
                torch_dtype = torch.float16
                offload_folder = "./model_offload"
            else: 

                device_map = "auto"
                torch_dtype = torch.float16
                offload_folder = None
        else:
            device_map = "cpu"
            torch_dtype = torch.float32
            offload_folder = None
        
        
        load_kwargs = {
            "torch_dtype": torch_dtype,
            "device_map": device_map,
            "trust_remote_code": True,
            "low_cpu_mem_usage": True
        }
        
        if offload_folder:
            load_kwargs["offload_folder"] = offload_folder
            load_kwargs["offload_state_dict"] = True
        
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            **load_kwargs
        )

        model = PeftModel.from_pretrained(base_model, model_path)

        
    except Exception as e:
        load_error = f"Błąd ładowania modelu: {str(e)}"
        model = None
        tokenizer = None
        model_loaded = False
        print(load_error)

In [5]:
def generate_joke(prompt: str, temperature: float = 0.7, max_length: int = 150) -> str:
    global tokenizer, model, model_loaded
    if not model_loaded:
        print("Generating...")
        load_model()
    if not prompt.strip():
        return "Send prompt. Please."
    
    '''
    message = [
        {
            "role" : "system",
            "content" : "Jesteś pomocnym asystentem, który opowiada śmieszne polskie dowcipy. Odpowiadasz tylko żarty bez dodatkowych komentarzy. Nie uzywaj wulgsyzmow."
        },
        {
            "role" : "user",
            "content" : prompt
        }
    ]'''
    system_msg = "Jesteś pomocnym asystentem, który opowiada śmieszne polskie dowcipy. Odpowiadasz tylko dowcipem, bez dodatkowych komentarzy. Nie używaj wulgaryzmów"
        
    formatted_prompt = f"System: {system_msg}\nUser: {prompt}\nAssistant:"

    '''formatted_prompt = tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt = True 
    )'''
    inputs = tokenizer(
            formatted_prompt, 
            return_tensors="pt", 
            truncation=True, 
            max_length=200,  
            add_special_tokens=True
        )
    device = model.device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask'),
                max_new_tokens=min(max_length, 100),  
                temperature=max(0.3, min(temperature, 1.0)),  
                do_sample=True,
                top_p=0.8,  
                top_k=40,   
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.2,  
                num_return_sequences=1,
                early_stopping=True,  
                no_repeat_ngram_size=2  
            )

    input_length = inputs['input_ids'].shape[1]
    generated_tokens = outputs[0][input_length:]
    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    result = generated_text.strip()
        
    prefixes_to_remove = ["Assistant:", "Asystent:", "System:", "User:", "Użytkownik:"]
    for prefix in prefixes_to_remove:
        if result.startswith(prefix):
            result = result[len(prefix):].strip()
        
    end_markers = ['\nUser:', '\nSystem:', '\nAssistant:', '\nUżytkownik:']
    for marker in end_markers:
        if marker in result:
            result = result.split(marker)[0].strip()
    if len(result.split('\n')) > 3: 
        lines = result.split('\n')[:2] 
        result = '\n'.join(lines)
        
    if '.' in result:
        sentences = result.split('.')
        if len(sentences) > 1:
            result = '. '.join(sentences[:2]) + '.'
        

    result = result.strip()
    return result
    

In [11]:
def generate_interface() :
    with gr.Blocks(title="Generator Dowcipów", theme=gr.themes.Soft()) as app:
        gr.Markdown("# Generator Polskich Dowcipów")
        gr.Markdown("*Wytrenowany model do generowania śmiesznych polskich dowcipów*")
        
        
        load_model_btn = gr.Button("Załaduj Model", variant="secondary")
        
        with gr.Row():
            with gr.Column(scale=1):
                prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Opowiedz mi polski dowcip...",
                    value="Opowiedz mi polski dowcip",
                    lines=3
                )
                
                with gr.Row():
                    temperature = gr.Slider(
                        minimum=0.1, 
                        maximum=2.0, 
                        value=0.7, 
                        step=0.1,
                        label="Temperature (kreatywność)"
                    )
                    max_length = gr.Slider(
                        minimum=50, 
                        maximum=300, 
                        value=150, 
                        step=10,
                        label="Max Length"
                    )
                
                with gr.Row():
                    generate_btn = gr.Button("Generuj Dowcip", variant="primary", size="lg")
                    clear_btn = gr.Button("Wyczyść", size="lg")
            
            with gr.Column(scale=1):
                output = gr.Textbox(
                    label="Dowcip",
                    lines=12,
                    interactive=False,
                    placeholder="Tutaj pojawi się dowcip..."
                )
        
        # Przykłady
        gr.Examples(
            examples=[
                ["Opowiedz mi polski dowcip", 0.7, 150],
                ["Znasz jakiś dobry żart?", 0.8, 180], 
                ["Powiedz dowcip o programistach", 0.6, 120],
                ["Opowiedz coś śmiesznego o kotach", 0.7, 200],
                ["Masz jakiś żart o pracy?", 0.7, 150]
            ],
            inputs=[prompt, temperature, max_length]
        )
        
        # Przypisz funkcje
        load_model_btn.click(
            fn=load_model,
            outputs=status_display
        )
        
        generate_btn.click(
            fn=generate_joke,
            inputs=[prompt, temperature, max_length],
            outputs=output
        )
        
        clear_btn.click(
            fn=lambda: ("", ""),
            outputs=[prompt, output]
        )
    
    return app

In [7]:
from huggingface_hub import login
from dotenv import load_dotenv
load_dotenv()
login(token=os.getenv("HF_TOKEN"))

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [10]:
app = generate_interface()
app.launch()

NameError: name 'check_model_status' is not defined