# Text generation model comparisons

There are 3 candidate text generation models to utilise:
- GPT-2 (Downloaded for local inference)
- Gemma-2b-it (Downloaded for local inference)
- Mixtral-8x7B-Instruct-v0.1 (API call)

In [1]:
prompt = "You are an analyst from GXS Bank. Help me describe what you see in terms of trend with this JSON format: {\"Positive Insights\": <Paragraph>, \"Negative Insights\": <Paragraph>, \"Topic Insights\": <Paragraph>}. Do not put extra words like 'Based on...'. Output STRICTLY in JSON ONLY. Do not talk about null data.The following is the overall data acquired from our banking application: {}'Jan 2024': 4,'Feb 2024': 3.6,'Mar 2024': 3.2,'Apr 2024': 2,'May 2024': 3.8,'Jun 2024': 4.5,'July 2024': 4.5,'Aug 2024': 4.2, 'Sep 2024': 3,'Oct 2024': 4,'Nov 2024': 3.7,'Dec 2024': 4.1}"

## GPT-2

In [11]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")

def generate_gpt2(prompt):
    input_ids = gpt2_tokenizer.encode(prompt, return_tensors="pt", truncation=True)
    output = gpt2_model.generate(input_ids, max_length=1024, num_return_sequences=1)

    generated_text = gpt2_tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text[len(prompt):].strip()

In [12]:
gpt2_output = generate_gpt2(prompt)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


## Gemma-2b-it

In [4]:
from huggingface_hub import login, logout
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
import os

In [5]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [13]:
gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
gemma_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")

def generate_gemma(prompt):
    input_ids = gemma_tokenizer(prompt, return_tensors="pt")
    output = gemma_model.generate(**input_ids, max_new_tokens=256)

    generated_text = gemma_tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text[len(prompt):].strip()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
gemma_output = generate_gemma(prompt)

## Mixtral

In [17]:
from h2ogpte import H2OGPTE

In [22]:
key = "REDACTED"
client = H2OGPTE(address='https://h2ogpte.genai.h2o.ai', api_key=key)
session_id = client.create_chat_session()

def generate_mixtral(prompt):
    with client.connect(session_id) as session:
        response = session.query(prompt, timeout=70, rag_config={"rag_type": "llm_only", "llm": "mistral-large-latest", "max_new_tokens": 1024})
        return response.content

In [23]:
mixtral_output = generate_mixtral(prompt)

## Evaluation

### RUNTIME

GPT-2: 30s
Gemma-2b-it: 16mins 32s
Mixtral-8x7B-Instruct-v0.1 (API Call): 14s

Comparing runtime, using API for cloud inferencing is clearly the fastest.

### Quality of output

In [25]:
gpt2_output

"The following is the overall data acquired from our banking application: {}'Jan 2024': 4,'Feb 2024': 3.6,'Mar 2024': 3.2,'Apr 2024': 2,'May 2024': 3.8,'Jun 2024': 4.5,'July 2024': 4.5,'Aug 2024': 4.2, 'Sep 2024': 3,'Oct 2024': 4,'Nov 2024': 3.7,'Dec 2024': 4.1}\n\nThe following is the overall data acquired from our banking application: {}'Jan 2024': 4,'Feb 2024': 3.6,'Mar 2024': 3.2,'Apr 2024': 2,'May 2024': 3.8,'Jun 2024': 4.5,'July 2024': 4.5,'Aug 2024': 4.2, 'Sep 2024': 3,'Oct 2024': 4,'Nov 2024': 3.7,'Dec 2024': 4.1}\n\nThe following is the overall data acquired from our banking application: {}'Jan 2024': 4,'Feb 2024': 3.6,'Mar 2024': 3.2,'Apr 2024': 2,'May 2024': 3.8,'Jun 2024': 4.5,'July 2024': 4.5,'Aug 2024': 4.2, 'Sep 2024': 3,'Oct 2024': 4,'Nov 2024': 3.7,'Dec 2024': 4.1}\n\nThe following is the overall data acquired from our banking application: {}'Jan 2024': 4,'Feb 2024': 3.6,'Mar 2024': 3.2,'Apr 2024': 2,'May 2024': 3.8,'Jun 2024': 4.5,'July 2024': 4.5,'Aug 2024': 4.2, 'Se

In [26]:
gemma_output

'**Output:**\n```json\n{\n  "Positive Insights": null,\n  "Negative Insights": null,\n  "Topic Insights": null\n}\n```'

In [24]:
mixtral_output

'{\n"Positive Insights": "There is a positive trend in bank application usage from April 2024 to June 2024, with a peak in June at 4.5. Additionally, there is a consistent level of usage from July to December 2024, ranging from 3.7 to 4.5.",\n"Negative Insights": "There is a decrease in bank application usage from March 2024 to May 2024, with a low of 3.2 in March and 3.8 in May.",\n"Topic Insights": "Overall, the bank application usage shows a fluctuating trend throughout the year, with a general increase in usage from January to June 2024 and a slight decrease in the second half of the year."\n}'

It is evident that mixtral has the most coherent output. GPT-2 simply repeats the prompt, and Gemma could not do proper analysis.