<a href="https://colab.research.google.com/github/Crisitunity-Lab/ARDC-Project/blob/main/Sandbox/Prototype_Falcon_7B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Falcon 7B
Falcon 7B is a causal decoder-only model built by [TII](https://www.tii.ae/). It has 7 billion parameters. It is available under the Apache 2.0 licence.

This code is designed to pass in a query and receive a set of attributes of the text in return. The aim is to be able to bring structure to unstructured data.

In [1]:
# Need to check this is running in a hugh RAM environment
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 54.8 gigabytes of available RAM

You are using a high-RAM runtime!


In [2]:
# Install libraries required for project
!pip install accelerate
!pip install bitsandbytes
!pip install einops
!pip install peft
!pip install transformers

Collecting transformers
  Downloading transformers-4.32.1-py3-none-any.whl (7.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m39.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m43.3 MB/s[0m eta [36m0:00:0

In [4]:
import bitsandbytes
from peft import PeftConfig, PeftModel
import torch
from transformers import AutoModelForCausalLM, GenerationConfig, AutoTokenizer, BitsAndBytesConfig

In [53]:
# Set parameters to be used throughout the project
max_output_len = 120 # maximum length of output from model
model_path="tiiuae/falcon-7b-instruct" # path to model
compute_dtype = getattr(torch, "float16")

In [45]:
def get_bnb_config(compute_dtype, bit=4, bit_qtype="nf4",):
  if bit==4:
    config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type=bit_qtype,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
    )
  elif bit==8:
    config = BitsAndBytesConfig(
        load_in_8bit=True
    )

  return config

In [47]:
def get_model(model_path, config):
  model = AutoModelForCausalLM.from_pretrained(
      model_path,
      quantization_config=config,
      device_map="auto",
      trust_remote_code=True,
  )

  return model

In [48]:
bnb_config = get_bnb_config(compute_dtype)
model = get_model(model_path, bnb_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [49]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [55]:
def get_response(model, query, max_response_len=120):

  if len(query) < 1:
    raise Exception("No Query Supplied")

  input_ids = tokenizer(query, return_tensors="pt").input_ids.to("cuda")
  next_input = input_ids

  # Change to required length, too long and may cause OOM issue
  max_length = max_response_len
  response = str()
  current_length = input_ids.shape[1]

  while True:
    # Check if we've reached the length limit
    if current_length >= max_length:
      return response
      break

    output = model(next_input)
    next_token_logits = output.logits[:, -1, :]
    next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
    # Concatenate response values into one string
    response = response + tokenizer.decode(next_token_id[0].cpu().tolist(), skip_special_tokens=True)
    next_input = torch.cat([next_input, next_token_id.to("cuda")], dim=-1)

    current_length += 1

    if next_token_id[0].item() == tokenizer.eos_token_id:
      return response
      break

In [51]:
# Try with a single tweet as an input query
input_text = "In a CSV format return, the sentiment, topic, subtopic, location and the crisis for which this tweet relates: RT @CBCAlerts: Canmore, Alta. declares state of emergency due to flooding  - with some residents being moved to community centre #Alberta"

In [70]:
# Get response and print it out
response = get_response(model,
                        query=input_text,
                        max_response_len=max_output_len)
print(response)


Sentiment: Negative
Topic: Weather
Subtopic: Flooding
Location: Canmore, Alberta
Crisis: Flooding


In [95]:
input_0 = "RT @CBCAlerts: Canmore, Alta. declares state of emergency due to flooding  - with some residents being moved to community centre #Alberta"
input_1 = "RT @GlobalCalgary: If you are in #Canmore and need help, an emergency line has been set up. Pls call 403-678-1551. #abstorm #abflood"
input_2 = "RT @metrocalgary: UPDATE: Latest from the @cityofcalgary and surrounding area on #abflood: http://t.co/lkA9L9zSvT #yyc"
input_3 = "RT @nenshi: Major risk of flooding in Calgary. Follow directions here: http://t.co/7dLx8aZptf and stay tuned. Please RT widely. #YYC"

In [96]:
input_list = [input_0, input_1, input_2, input_3]

In [98]:
attributes = ["social sentiment", "topic", "seconadry topic", "location", "country", "crisis type"]

for input in input_list:

  print(input)

  for attribute in attributes:
    input_prefix = "In 3 words, or less, return the {0} of the tweet: ".format(attribute)
    input_text = input_prefix + input

    response = get_response(model,
                          query=input_text,
                          max_response_len=max_output_len)
    print(response)

RT @GlobalCalgary: If you are in #Canmore and need help, an emergency line has been set up. Pls call 403-678-1551. #abstorm #abflood

Helpful

Emergency

Flooding

Canmore

Canada

Emergency
RT @metrocalgary: UPDATE: Latest from the @cityofcalgary and surrounding area on #abflood: http://t.co/lkA9L9zSvT #yyc

Neutral

Flooding

Flooding

Calgary

Canada
 #abflood
Flooding
RT @nenshi: Major risk of flooding in Calgary. Follow directions here: http://t.co/7dLx8aZptf and stay tuned. Please RT widely. #YYC

Alert

Flooding

Flooding

Calgary

Canada

Flooding
