In [1]:
# Install the required libraries
!pip install  bitsandbytes  accelerate gradio

Collecting bitsandbytes
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.32.1-py3-none-any.whl (314 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m314.1/314.1 kB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->bitsandbytes)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->bitsandbytes)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->bitsandbytes)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->bitsandbytes)
  Using cached nvid

# Setup The Environment

In [1]:
import io
import tokenize
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gradio as gr
import torch.nn.functional as F
import bitsandbytes as bnb



from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score,accuracy_score
from tqdm import tqdm
from typing import List, Tuple, Dict
from datetime import datetime



# _fn = "final"
LENGTH = 512
BATCH_SIZE = 16
STEP_SIZE=1
NUM_LABELS = 14

RANDOM_SEED = 42

EMBEDDING_SIZE = 768
HIDDEN_SIZE = 512
CNN_FILTERS=128

BERT_MODEL_ID = "microsoft/codebert-base"
# UAE_MODEL_ID = "WhereIsAI/UAE-Code-Large-V1"

# you must use cuda to run this code. if this returns false, you can not proceed.
USE_CUDA = torch.cuda.is_available()

if USE_CUDA:
    print("You are using cuda. Good!")
else:
    print('You are NOT using cuda! Some problems may occur.')

torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

You are using cuda. Good!


# Load Our Classifier (AEGIS)

In [2]:
# Define BERT layer
class BertLayer(nn.Module):
    def __init__(self):
        super(BertLayer, self).__init__()
        self.bert_model = AutoModel.from_pretrained(BERT_MODEL_ID)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, None]:
        bert_encodings = self.bert_model(input_ids, attention_mask)
        bert_last_hidden = bert_encodings['last_hidden_state']
        return bert_last_hidden, None

# Define CNN Encoder
class CNNEncoder(nn.Module):
    def __init__(self):
        super(CNNEncoder, self).__init__()
        self.embedding_size = EMBEDDING_SIZE
        self.filter_number = CNN_FILTERS
        self.activation = nn.ReLU()

        self.conv1 = nn.Conv1d(in_channels=self.embedding_size, out_channels=self.filter_number, kernel_size=2, padding="same")
        self.conv2 = nn.Conv1d(in_channels=self.embedding_size, out_channels=self.filter_number, kernel_size=3, padding="same")
        self.conv3 = nn.Conv1d(in_channels=self.embedding_size, out_channels=self.filter_number, kernel_size=5, padding="same")
        self.conv4 = nn.Conv1d(in_channels=self.embedding_size, out_channels=self.filter_number, kernel_size=1, padding="same")

    def forward(self, bert_last_hidden: torch.Tensor) -> torch.Tensor:
        trans_embedded = torch.transpose(bert_last_hidden, dim0=1, dim1=2)

        convolve1 = self.activation(self.conv1(trans_embedded))
        convolve2 = self.activation(self.conv2(trans_embedded))
        convolve3 = self.activation(self.conv3(trans_embedded))
        convolve4 = self.activation(self.conv4(trans_embedded))

        convolve1 = torch.transpose(convolve1, dim0=1, dim1=2)
        convolve2 = torch.transpose(convolve2, dim0=1, dim1=2)
        convolve3 = torch.transpose(convolve3, dim0=1, dim1=2)
        convolve4 = torch.transpose(convolve4, dim0=1, dim1=2)

        output = torch.cat((convolve4, convolve1, convolve2, convolve3), dim=2)
        return output

# Define Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Define Transformer-based Middle layer
class Middle(nn.Module):
    def __init__(self):
        super(Middle, self).__init__()
        self.activation = nn.ReLU()
        self.pos_encoder = PositionalEncoding(HIDDEN_SIZE, dropout=0.1)
        encoder_layers = nn.TransformerEncoderLayer(HIDDEN_SIZE, nhead=2, batch_first=True, dim_feedforward=2048, activation="relu", dropout=0.1)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=2, enable_nested_tensor=False)

    def forward(self, encodings: torch.Tensor, input_mask: torch.Tensor) -> torch.Tensor:
        src = encodings * math.sqrt(HIDDEN_SIZE)
        src = self.pos_encoder(src)
        out = self.transformer_encoder(src, src_key_padding_mask=input_mask)
        return out

# Define Decoder
class Decoder(nn.Module):
    def __init__(self, num_labels: int = NUM_LABELS, dropout_p: float = 0.5):
        super(Decoder, self).__init__()
        self.num_labels = num_labels
        self.self_attention = nn.MultiheadAttention(embed_dim=HIDDEN_SIZE, num_heads=8, dropout=0.1, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)
        self.layer_norm = nn.LayerNorm(HIDDEN_SIZE)
        self.score = nn.Linear(HIDDEN_SIZE, self.num_labels)

    def forward(self, encodings: torch.Tensor, input_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encodings2 = encodings.clone()
        context, attention_weights = self.self_attention(encodings2, encodings2, encodings2, key_padding_mask=input_mask, need_weights=True)
        encodings2 = self.layer_norm(self.dropout(context)) + encodings2
        sum_mask = (~input_mask).sum(1).unsqueeze(1)
        sum_encoder = ((encodings2) * ((~input_mask).unsqueeze(2))).sum(1)
        score = self.score(self.dropout(sum_encoder / sum_mask))
        return score, attention_weights

# Define the AEGIS model
class AEGIS(nn.Module):
    def __init__(self, num_labels: int = NUM_LABELS, dropout_p: float = 0.5):
        super(AEGIS, self).__init__()
        self.bert_layer = BertLayer()
        self.cnn_encoder = CNNEncoder()
        self.middle = Middle()
        self.decoder = Decoder(num_labels=num_labels, dropout_p=dropout_p)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        bert_hidden, _ = self.bert_layer(input_ids, attention_mask)
        encoder_output = self.cnn_encoder(bert_hidden)
        output = self.middle(encoder_output, attention_mask == 0)
        score, attention_weights = self.decoder(output, attention_mask == 0)
        return score, attention_weights

In [3]:
from google.colab import drive

# Define our mapping from/to CWEs IDs
index2target = {0: 'SAFE',
  1: 'CWE-78',
  2: 'CWE-605',
  3: 'CWE-502',
  4: 'CWE-377',
  5: 'CWE-20',
  6: 'CWE-259',
  7: 'CWE-330',
  8: 'CWE-703',
  9: 'CWE-319',
  10: 'CWE-400',
  11: 'CWE-89',
  12: 'CWE-327',
  13: 'CWE-22'}
target2index = {v:k for k,v in index2target.items()}


# Load our pretrained weights
drive.mount('/content/drive')

aegis = AEGIS(num_labels=len(target2index))
aegis.load_state_dict(torch.load("/content/drive/MyDrive/AEGIS/aegis-scl_v4.pkl").state_dict())
if USE_CUDA:
    aegis.cuda()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


# Load our inspecter LLM (Acts like a cybersecurity expert)

In [4]:
# Define the model and tokenizer
model_name = "Qwen/CodeQwen1.5-7B-Chat"

# Load the tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(model_name)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model_4bit = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=quantization_config,
    trust_remote_code=True
)

text_generation_pipeline = pipeline(
    "text-generation",
    model=model_4bit,
    tokenizer=llm_tokenizer,
    use_cache=True,
    device_map="auto",
)

def your_query(query: str) -> str:
    """
    Generate a response for the given query using the text generation pipeline.

    Args:
        query (str): The input query for the model.

    Returns:
        str: The generated response.
    """
    sequences = text_generation_pipeline(
        query,
        max_length=512,
        eos_token_id=llm_tokenizer.eos_token_id,
    )
    output = sequences[0]['generated_text']
    return output


def format_messages(messages: List[Dict[str, str]]) -> str:
    """
    Format messages for the Llama-2 chat models.

    Args:
        messages (List[Dict[str, str]]): List of messages with 'role' and 'content'.

    Returns:
        str: The formatted prompt.
    """
    prompt: List[str] = []

    if messages[0]["role"] == "system":
        content = "".join(["<S>\n", messages[0]["content"], "\n</S>\n\n", messages[1]["content"]])
        messages = [{"role": messages[1]["role"], "content": content}] + messages[2:]

    for user, answer in zip(messages[::2], messages[1::2]):
        prompt.extend(["<s>", "[INST] ", (user["content"]).strip(), " [/INST] ", (answer["content"]).strip(), "</s>"])

    prompt.extend(["<s>", "<|im_start|> ", (messages[-1]["content"]).strip(), " <|im_end|>"])

    return "".join(prompt)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

# Utils

In [7]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_ID)

def remove_comments_and_docstrings(source: str) -> str:
    """
    Remove comments and docstrings from the provided Python source code.

    Args:
        source (str): The source code to process.

    Returns:
        str: The source code with comments and docstrings removed.
    """
    try:
        io_obj = io.StringIO(source)
        out = ""
        prev_toktype = tokenize.INDENT
        last_lineno = -1
        last_col = 0
        for tok in tokenize.generate_tokens(io_obj.readline):
            token_type = tok[0]
            token_string = tok[1]
            start_line, start_col = tok[2]
            end_line, end_col = tok[3]
            if start_line > last_lineno:
                last_col = 0
            if start_col > last_col:
                out += (" " * (start_col - last_col))
            if token_type == tokenize.COMMENT:
                pass
            elif token_type == tokenize.STRING:
                if prev_toktype != tokenize.INDENT:
                    if prev_toktype != tokenize.NEWLINE:
                        if start_col > 0:
                            out += token_string
            else:
                out += token_string
            prev_toktype = token_type
            last_col = end_col
            last_lineno = end_line
        out = '\n'.join(line for line in out.splitlines() if line.strip())
        return out

    except Exception as e:
        print('Error:', e)
        return ""

def process_input(code: str, tokenizer) -> Tuple[str, torch.Tensor, torch.Tensor]:
    """
    Process the input code by removing comments and docstrings, and tokenize it for the model.

    Args:
        code (str): The source code to process.
        tokenizer: The tokenizer to use for processing the code.

    Returns:
        Tuple[str, torch.Tensor, torch.Tensor]: The processed code, input IDs, and attention mask.
    """
    code = remove_comments_and_docstrings(code)

    tok_in = tokenizer.encode_plus(code, add_special_tokens=True, max_length=LENGTH,
                                   return_attention_mask=True, padding='max_length',
                                   truncation=True, return_tensors='pt')

    input_ids, attention_mask = tok_in.input_ids.cuda(), tok_in.attention_mask.cuda()

    return code, input_ids, attention_mask



# Define the function to get line scores based on attention weights
def get_line_scores(code: str, attention_mask: List[int], attention_weights: torch.Tensor) -> List[float]:
    """
    Calculate the attention score for each line of code.

    Args:
    - code (str): The input code as a string.
    - attention_mask (List[int]): The attention mask as a list of integers.
    - attention_weights (torch.Tensor): The attention weights tensor.

    Returns:
    - List[float]: A list of attention scores for each line of code.
    """

    # Remove padding from attention weights using the mask
    masked_attention_weights = attention_weights[:, attention_mask.astype(bool)]
    masked_attention_weights = masked_attention_weights[attention_mask.astype(bool), :]

    # Sum the attention weights across the tokens
    token_attention = masked_attention_weights.sum(dim=0)

    # Calculate line attention scores
    line_scores = []
    start = 0
    end = 0

    lines = code.split("\n")
    for idx, line in enumerate(lines):
        if idx != len(lines) - 1:
            token_ids = tokenizer.encode(line + "\n")
        else:
            token_ids = tokenizer.encode(line)

        token_encodings = tokenizer.convert_ids_to_tokens(token_ids)[1:-1]
        line_length = len(token_encodings)

        if idx == 0:
            start += 1
            end = line_length + 1
        else:
            start = end
            end += line_length

        line_score = float(sum(token_attention[start:end]))
        line_scores.append(line_score)
    return line_scores

# Define the function to normalize scores
def normalize_scores(danger_scores: List[float]) -> List[float]:
    """
    Normalize a list of danger scores to a range of [0, 1].

    Args:
        danger_scores (List[float]): List of danger scores to normalize.

    Returns:
        List[float]: Normalized danger scores.
    """
    min_score = min(danger_scores)
    max_score = max(danger_scores)
    if min_score == max_score:
        return [0.0 for _ in danger_scores]
    return [(score - min_score) / (max_score - min_score) for score in danger_scores]

# Define the function to color text based on danger scores
def color_text_by_danger(text_lines: List[str], danger_scores: List[float]) -> str:
    """
    Colorize text lines based on normalized danger scores.

    Args:
        text_lines (List[str]): List of text lines to color.
        danger_scores (List[float]): List of normalized danger scores corresponding to each text line.

    Returns:
        str: HTML content with colored text lines.
    """
    normalized_scores = normalize_scores(danger_scores)

    def get_color(score: float) -> str:
        """
        Get an RGB color string based on the given score.

        Args:
            score (float): Normalized danger score.

        Returns:
            str: RGB color string in the format 'rgb(red,green,blue)'.
        """
        red = int(255 * score)
        green = int(255 * (1 - score))
        return f'rgb({red},{green},0)'

    html_content = ""
    for line, score in zip(text_lines, normalized_scores):
        color = get_color(score)
        html_content += f'<pre style="color:{color}">{line}</pre>'
    return html_content

# Define the text classification function
def classify_text(code: str) -> Tuple[Dict[str, float], str]:
    """
    Perform text classification on the given code snippet.

    Args:
        code (str): Input code snippet to classify.

    Returns:
        Tuple[Dict[str, float], str]: Tuple containing a dictionary of class probabilities and colored HTML code.
    """
    code, input_ids, attention_mask = process_input(code, tokenizer)
    score, attention_weights = aegis(input_ids, attention_mask)

    probs = F.softmax(score, dim=1).squeeze().tolist()
    class_proba_dict = {index2target[i]: prob for i, prob in enumerate(probs)}
    most_probable_class = max(class_proba_dict, key=class_proba_dict.get)

    line_scores = get_line_scores(code, attention_mask.cpu().squeeze().numpy(), attention_weights.squeeze().detach().cpu())
    if most_probable_class == "SAFE":
        line_scores = [1 for _ in line_scores]

    colored_code = color_text_by_danger(code.split('\n'), line_scores)
    return class_proba_dict, colored_code

# Define the function to fix text based on probabilities
def fix_text(code: str, normalized_probabilities: Dict[str, float]) -> str:
    """
    Fix the input code snippet based on normalized vulnerability probabilities.

    Args:
        code (str): Input code snippet to fix.
        normalized_probabilities (Dict[str, float]): Dictionary of normalized vulnerability probabilities.

    Returns:
        str: Fixed code snippet or explanation.
    """
    most_probable_class = max(normalized_probabilities, key=normalized_probabilities.get)
    if most_probable_class == "SAFE":
        prompt = f"Why is this code safe from any vulnerability?\nCode:{code}"
    else:
        prompt = f"If the code below is vulnerable to {most_probable_class}, explain why and suggest a fix.\nCode:{code}"

    dialog = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    formatted_prompt = format_messages(dialog)
    generated_text = your_query(formatted_prompt)
    return generated_text[len(formatted_prompt):].strip()

# Define the function to update chat history
def update_chatbox(fixed_text: str, chat_history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    """
    Update the chat history with the fixed text.

    Args:
        fixed_text (str): Fixed text or response to add to chat history.
        chat_history (List[Tuple[str, str]]): List of tuples representing chat history entries.

    Returns:
        List[Tuple[str, str]]: Updated chat history with the new entry.
    """
    now = datetime.now()
    chat_history.append((now.strftime("%B %d, %Y %H:%M:%S"), fixed_text))
    return chat_history

# Define the function to regenerate fixed text and update chat history
def regenerate_fixed_text(code: str, normalized_probabilities: Dict[str, float], chat_history: List[Tuple[str, str]] = []) -> Tuple[str, List[Tuple[str, str]]]:
    """
    Regenerate the fixed text based on vulnerability probabilities and update chat history.

    Args:
        code (str): Input code snippet to fix.
        normalized_probabilities (Dict[str, float]): Dictionary of normalized vulnerability probabilities.
        chat_history (List[Tuple[str, str]], optional): List of tuples representing chat history entries. Defaults to [].

    Returns:
        Tuple[str, List[Tuple[str, str]]]: Tuple containing the fixed code snippet or explanation and updated chat history.
    """
    fixed_text = fix_text(code, normalized_probabilities)
    chat_history = update_chatbox(fixed_text, chat_history)
    return fixed_text, chat_history

# Define the function to get HTML content
def get_html(html: str) -> str:
    return html

# Setup  Our Framework Interface (with the help of Gradio)

In [8]:
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    with gr.Row():
        code_input = gr.TextArea(placeholder="Your Python Code?", label="Input")
        class_output = gr.Label()
        hidden_output = gr.Text(visible=False)
    analyze_button = gr.Button("Analyze")
    analyze_button.click(fn=classify_text, inputs=code_input, outputs=[class_output, hidden_output])

    gr.Examples(
        examples=[["print('Hello, World!')"]],
        inputs=code_input
    )

    with gr.Row():
        line_level_output = gr.Markdown(label="Line-level Analysis")

    line_level_button = gr.Button("Line-level Detection")
    line_level_button.click(fn=get_html, inputs=hidden_output, outputs=line_level_output)

    with gr.Row():
        insight_output = gr.Markdown(label="Suggestions")

    insight_button = gr.Button("Insights")
    chat_history_box = gr.Chatbot(label="Chat History")
    insight_button.click(fn=regenerate_fixed_text, inputs=[code_input, class_output], outputs=[insight_output, chat_history_box])

demo.launch(debug=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://019081c82251b8e66e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://019081c82251b8e66e.gradio.live




In [None]:
!pip install gradio

Collecting gradio
  Downloading gradio-4.36.1-py3-none-any.whl (12.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.111.0-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpy (from gradio)
  Downloading ffmpy-0.3.2.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gradio-client==1.0.1 (from gradio)
  Downloading gradio_client-1.0.1-py3-none-any.whl (318 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.1/318.1 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting httpx>=0.24.1 (from gradio)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━

In [None]:
# yes
import gradio as gr
import random

# Define the text classification model
def classify_text(text):
    probabilities = {
        "Positive": random.random(),
        "Negative": random.random(),
        "Neutral": random.random(),
    }
    total = sum(probabilities.values())
    normalized_probabilities = {k: v / total for k, v in probabilities.items()}
    return normalized_probabilities

# Define the text fixing model
def fix_text(text, most_probable_class):
    return f"before: \n```python{text}```\nafter: This is the fixed text for {most_probable_class} ..."

# Define a function to handle the complete workflow
def classify_and_fix(text):
    # Classify the input text
    classification_probs = classify_text(text)
    # Find the most probable classification
    most_probable_class = max(classification_probs, key=classification_probs.get)
    # Fix the text based on the most probable classification
    fixed_text = fix_text(text, most_probable_class)
    # Return all relevant information
    return classification_probs, most_probable_class, fixed_text



# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# with gr.Blocks() as demo:
    with gr.Row():
        input_text = gr.Textbox(label="Input Text")
    with gr.Row():
        classify_button = gr.Button("Classify Text")
    with gr.Row():
        ll = gr.Label(label="Classification Probabilities")
    with gr.Row():
        fixed_text_output = gr.Markdown(label="Fixed Text")
    with gr.Row():
        regenerate_button = gr.Button("Regenerate Text")

    chatbox = gr.Chatbot(label="Chat History")

    def update_chatbox(fixed_text, chat_history):
        # Update the chat history with the new fixed text only
        chat_history.append(("Fixed Text", fixed_text))
        return chat_history

    # Initial classification and text fixing
    def handle_classification(text, chat_history=[]):
        classification_probs, most_probable_class, fixed_text = classify_and_fix(text)
        chat_history = update_chatbox(fixed_text, chat_history)
        return classification_probs, most_probable_class, fixed_text, chat_history

    # Regenerate the fixed text based on the most probable class
    def regenerate_fixed_text(text, most_probable_class, chat_history=[]):
        fixed_text = fix_text(text, most_probable_class)
        chat_history = update_chatbox(fixed_text, chat_history)
        return fixed_text, chat_history

    # Store the most probable class in a state variable to be reused by the regenerate button
    most_probable_class_state = gr.State()

    classify_button.click(fn=handle_classification, inputs=[input_text], outputs=[ll, most_probable_class_state, fixed_text_output, chatbox])
    regenerate_button.click(fn=regenerate_fixed_text, inputs=[input_text, most_probable_class_state], outputs=[fixed_text_output, chatbox])

# Launch the Gradio app
demo.launch()


Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://7396c9ec6ae3994f6c.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


