<a href="https://colab.research.google.com/github/Crisitunity-Lab/ARDC-Project/blob/main/Sandbox/Prototype_Falcon_7B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Falcon 7B
Falcon 7B is a causal decoder-only model built by [TII](https://www.tii.ae/). It has 7 billion parameters. It is available under the Apache 2.0 licence.

This code is designed to pass in a query and receive a set of attributes of the text in return. The aim is to be able to bring structure to unstructured data.

In [1]:
# Need to check this is running in a hugh RAM environment
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 54.8 gigabytes of available RAM

You are using a high-RAM runtime!


In [2]:
# Install libraries required for project
!pip install accelerate
!pip install bitsandbytes
!pip install einops
!pip install peft
!pip install transformers

Collecting accelerate
  Downloading accelerate-0.22.0-py3-none-any.whl (251 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/251.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m112.6/251.2 kB[0m [31m3.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.2/251.2 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.22.0
Collecting bitsandbytes
  Downloading bitsandbytes-0.41.1-py3-none-any.whl (92.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.41.1
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.6 MB/

In [3]:
import bitsandbytes
from peft import PeftConfig, PeftModel
import torch
from transformers import AutoModelForCausalLM, GenerationConfig, AutoTokenizer, BitsAndBytesConfig

In [4]:
# Set parameters to be used throughout the project
max_output_len = 120 # maximum length of output from model
model_path="tiiuae/falcon-7b-instruct" # path to model
compute_dtype = getattr(torch, "float16")

In [5]:
def get_bnb_config(compute_dtype, bit=4, bit_qtype="nf4",):
  if bit==4:
    config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type=bit_qtype,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
    )
  elif bit==8:
    config = BitsAndBytesConfig(
        load_in_8bit=True
    )

  return config

In [6]:
def get_model(model_path, config):
  model = AutoModelForCausalLM.from_pretrained(
      model_path,
      quantization_config=config,
      device_map="auto",
      trust_remote_code=True,
  )

  return model

In [7]:
bnb_config = get_bnb_config(compute_dtype)
model = get_model(model_path, bnb_config)

Downloading (…)lve/main/config.json:   0%|          | 0.00/667 [00:00<?, ?B/s]

Downloading (…)/configuration_RW.py:   0%|          | 0.00/2.61k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/tiiuae/falcon-7b-instruct:
- configuration_RW.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)main/modelling_RW.py:   0%|          | 0.00/47.5k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/tiiuae/falcon-7b-instruct:
- modelling_RW.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)model.bin.index.json:   0%|          | 0.00/16.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/4.48G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

Downloading (…)okenizer_config.json:   0%|          | 0.00/220 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.73M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/281 [00:00<?, ?B/s]

In [9]:
def get_response(model, query, max_response_len=120):

  if len(query) < 1:
    raise Exception("No Query Supplied")

  input_ids = tokenizer(query, return_tensors="pt").input_ids.to("cuda")
  next_input = input_ids

  # Change to required length, too long and may cause OOM issue
  max_length = max_response_len
  response = str()
  current_length = input_ids.shape[1]

  while True:
    # Check if we've reached the length limit
    if current_length >= max_length:
      return response
      break

    output = model(next_input)
    next_token_logits = output.logits[:, -1, :]
    next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
    # Concatenate response values into one string
    response = response + tokenizer.decode(next_token_id[0].cpu().tolist(), skip_special_tokens=True)
    next_input = torch.cat([next_input, next_token_id.to("cuda")], dim=-1)

    current_length += 1

    if next_token_id[0].item() == tokenizer.eos_token_id:
      return response
      break

In [None]:
# Try with a single tweet as an input query
input_text = "In a CSV format return, the sentiment, topic, subtopic, location and the crisis for which this tweet relates: RT @CBCAlerts: Canmore, Alta. declares state of emergency due to flooding  - with some residents being moved to community centre #Alberta"

In [None]:
# Get response and print it out
response = get_response(model,
                        query=input_text,
                        max_response_len=max_output_len)
print(response)


Sentiment: Negative
Topic: Weather
Subtopic: Flooding
Location: Canmore, Alberta
Crisis: Flooding


In [17]:
input_0 = "RT @CBCAlerts: Canmore, Alta. declares state of emergency due to flooding  - with some residents being moved to community centre #Alberta"
input_1 = "RT @GlobalCalgary: If you are in #Canmore and need help, an emergency line has been set up. Pls call 403-678-1551. #abstorm #abflood"
input_2 = "RT @metrocalgary: UPDATE: Latest from the @cityofcalgary and surrounding area on #abflood: http://t.co/lkA9L9zSvT #yyc"
input_3 = "RT @nenshi: Major risk of flooding in Calgary. Follow directions here: http://t.co/7dLx8aZptf and stay tuned. Please RT widely. #YYC"

In [18]:
input_list = [input_0, input_1, input_2, input_3]

In [19]:
attributes = ["social sentiment", "topic", "seconadry topic", "location", "country", "crisis type"]

for input in input_list:

  print(input)

  for attribute in attributes:
    input_prefix = "In 3 words, or less, return the {0} of the tweet: ".format(attribute)
    input_text = input_prefix + input

    response = get_response(model,
                          query=input_text,
                          max_response_len=max_output_len)
    print(response)

RT @CBCAlerts: Canmore, Alta. declares state of emergency due to flooding  - with some residents being moved to community centre #Alberta
 #Flooding
Alert

Flooding

Flooding

Canmore

Canada
 #Flooding
Flooding
RT @GlobalCalgary: If you are in #Canmore and need help, an emergency line has been set up. Pls call 403-678-1551. #abstorm #abflood

Helpful

Emergency

Flooding

Canmore

Canada

Emergency
RT @metrocalgary: UPDATE: Latest from the @cityofcalgary and surrounding area on #abflood: http://t.co/lkA9L9zSvT #yyc

Neutral

Flooding

Flooding

Calgary

Canada
 #abflood
Flooding
RT @nenshi: Major risk of flooding in Calgary. Follow directions here: http://t.co/7dLx8aZptf and stay tuned. Please RT widely. #YYC

Alert

Flooding

Flooding

Calgary

Canada

Flooding


In [10]:
import pandas as pd
df = pd.DataFrame({"id":[1,2,3,4],
                   "text":["RT @CBCAlerts: Canmore, Alta. declares state of emergency due to flooding  - with some residents being moved to community centre #Alberta",
                           "RT @GlobalCalgary: If you are in #Canmore and need help, an emergency line has been set up. Pls call 403-678-1551. #abstorm #abflood",
                           "RT @metrocalgary: UPDATE: Latest from the @cityofcalgary and surrounding area on #abflood: http://t.co/lkA9L9zSvT #yyc",
                           "RT @nenshi: Major risk of flooding in Calgary. Follow directions here: http://t.co/7dLx8aZptf and stay tuned. Please RT widely. #YYC"]})

In [13]:
df_str = df.to_string(header=True,
                      index=False,
                      index_names=False,
                      justify="right")

question = "For each row in the below dataset return the row id and the sentiment of the text in 3 words or less. \n {}".format(df_str)

In [14]:
# Get response and print it out
response = get_response(model,
                        query=question,
                        max_response_len=500)
print(response)


  5                    RT @cityofcalgary: #abflood: Calgary is currently experiencing a significant flood event. Please stay safe and follow the instructions of your local authorities. #yyc
  6      RT @CBCAlerts: Calgary is currently experiencing a significant flood event. Please stay safe and follow the instructions of your local authorities. #Alberta
  7                    RT @CBCAlerts: Calgary is currently experiencing a significant flood event. Please stay safe and follow the instructions of your local authorities. #Alberta
  8                    RT @CBCAlerts: Calgary is currently experiencing a significant flood event. Please stay safe and follow the instructions of your local authorities. #Alberta
  9                    RT @CBCAlerts: Calgary is currently experiencing a significant flood event. Please stay safe and follow the instructions of your local authorities. #Alberta
  10                    RT @CBCAlerts: Calgary is currently experiencing a significant flood event. Ple