In [1]:
import streamlit as st
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

In [None]:
# Function to reset the state
def reset_state():
    for key in st.session_state:
        del st.session_state[key]

In [None]:
DEVICE: str = "cuda"
# MODEL_NAME: str = "mistralai/Mistral-7B-Instruct-v0.2"
MODEL_NAME: str = "/disk2/elvys/Mistral-7B-Instruct-v0.2"

In [None]:
class LLM:
    def __init__(self) -> None:
        # self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
        self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(DEVICE)
        self.model.eval()
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    def __call__(self, messages: list) -> str:
        # Tokenize messages
        model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(DEVICE)
        # Generate answer for the given input
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            generated_ids = self.model.generate(model_inputs, max_new_tokens=1100, do_sample=False)
            decoded = self.tokenizer.batch_decode(generated_ids)
        return decoded[0]

In [None]:
st.set_page_config(
    layout="wide",
)

In [None]:
@st.cache_resource
def load_llm_model() -> LLM:
    return LLM()
llm_model = load_llm_model()

In [None]:
# Store the initial value of widgets in session state
if "disabled" not in st.session_state:
    st.session_state.disabled = False
    st.session_state.messages = []

with st.sidebar:
    st.title('Rene: Investment assistant')
    assistant_type = st.selectbox('Select assistant type:', ["AI Technical analysis", "Chatbot"], index=0, disabled=st.session_state.disabled)
    analysis_type = st.selectbox('AI Analysis Style:', ["Analytical", "Advisory"], index=0, disabled=st.session_state.disabled)
    experience_user = st.selectbox('Knowledge Level:', ["Novice", "Specialist"], index=0, disabled=st.session_state.disabled)

In [None]:
if assistant_type == "AI Technical analysis":
    with st.sidebar:
        button = st.button('Generate AI Technical analysis', disabled=st.session_state.disabled)
    if button and assistant_type == "AI Technical analysis":
        st.session_state.disabled = True
        if selected_asset == "General":
            st.markdown(utils.convert_str_to_markdown("Select an asset to get the AI Technical analysis."))
        else:
            start = time.time()
            prompt = prompts.get_prompt(
                selected_asset, 
                experience_user,
                analysis_type,
                f"Provide the financial analysis of {selected_asset}", 
                prices[selected_asset], 
                smas[selected_asset], 
                support_resistances[selected_asset], 
                oscillators_values[selected_asset],
                active_patterns[selected_asset], 
                anticipated_patterns[selected_asset],
            )
            print("PROMPT:", prompt)
            chat_model_response = llm_model([{"role": "user", "content": prompt}])
            chat_model_response = chat_model_response.split("[/INST]")[-1].split("</s>")[0]
            st.markdown(utils.convert_str_to_markdown(chat_model_response))
            print("Inference duration:", time.time() - start)
        st.session_state.disabled = False


In [None]:
!streamlit run streamlit101.ipynb
