In [102]:
import torch
from transformers import AutoTokenizer, MambaForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import pandas as pd
import os

In [103]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [104]:
data = pd.read_csv('data/disaster_tweets/train.csv')

In [105]:
data.head()

Unnamed: 0,id,keyword,location,text,target
0,1,,,Our Deeds are the Reason of this #earthquake M...,1
1,4,,,Forest fire near La Ronge Sask. Canada,1
2,5,,,All residents asked to 'shelter in place' are ...,1
3,6,,,"13,000 people receive #wildfires evacuation or...",1
4,7,,,Just got sent this photo from Ruby #Alaska as ...,1


In [106]:
data = data[data["id"].between(1, 700)]

In [107]:
data.shape

(485, 5)

In [108]:
strategy = "zero_shot"

In [109]:
model_size = "130m"
# model_size = "370m"
# model_size = "790m"
# model_size = "1.4b"
# model_size = "2.8b"

In [110]:
# model = MambaForCausalLM.from_pretrained(f"state-spaces/mamba-{model_size}-hf").to(device)
model = MambaLMHeadModel.from_pretrained(os.path.expanduser("state-spaces/mamba-{model_size}"), device="cuda", dtype=torch.bfloat16)



In [111]:
tokenizer = AutoTokenizer.from_pretrained(f"state-spaces/mamba-{model_size}-hf")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [112]:
prompt_template_zero_shot = """
Instructions:

You have to analyze the following tweet and to determine if it speaks about a real desaster or not. Answer with "1" if the tweet speaks about a real disaster and with "0" if not. Don't add any other information in your answer.

--------------------------
Tweet:

{text}
--------------------------
Your answer (only a "1" or a "0"):
"""

In [113]:
prompt_template_few_shot = """
Instructions:
Your task is to analyze the following tweet and determine if it is talking about a real disaster. A real disaster can include, but is not limited to, events such as earthquakes, hurricanes, fires, floods, major accidents, etc. If the tweet refers to a real disaster, respond with 1. If not, respond with 0.

Your response should only be the number 1 or 0.

Considerations:
Real Disasters: Significant events that impact people, property, or the environment.
Not Disasters: Personal opinions, jokes, fake news, or events that do not qualify as a disaster.

Examples:
Tweet: "A 7.5 magnitude earthquake has shaken the city, causing significant damage and injuries."
Expected Response: 1

Tweet: "I'm so tired that my house looks like a disaster after last night's party!"
Expected Response: 0

Tweet: "Uncontrolled wildfire in the north of the country. Evacuate immediately."
Expected Response: 1

Tweet: "It rained a lot yesterday, but today is sunny and beautiful."
Expected Response: 0

Tweet to Analyze:
Tweet: "{text}"

Response:
Result (1 or 0):
"""

In [114]:
prompt_template = prompt_template_zero_shot if strategy == "zero_shot" else prompt_template_few_shot

In [115]:
predictions = []
with torch.no_grad():
    for index, row in data.iterrows():
        prompt = prompt_template.format(text=row['text'])
        encodings = tokenizer(prompt, return_tensors="pt")
        input_ids = encodings.to(device)
        #outputs = model(**input_ids, max_new_tokens=1)
        outputs = model.generate(input_ids=input_ids, max_length=1, return_dict_in_generate=True, output_scores=True, temperature=0.1, top_k=10, top_p=0.1)()
        #p = tokenizer.decode(outputs.logits.argmax(dim=-1)[0], skip_special_tokens=True)
        p = tokenizer.decode(outputs[0][0])
        # predictions.append(p)
        try:
            predictions.append(int(p))
        except:
            print(f"{index}: {p}")
        if index % 50:
            print(index)



AttributeError: 

In [None]:
data["predictions"] = predictions

In [None]:
f1_score(data["target"], data["predictions"])

0.0