In [19]:
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, Trainer, TrainingArguments, BitsAndBytesConfig, \
    DataCollatorForLanguageModeling, Trainer, TrainingArguments, logging
from torch import cuda, bfloat16
import transformers
from sklearn.metrics import precision_score, recall_score, f1_score
from metrics import calc_mets_my, calculate_metrics2
from datasets import Dataset, load_dataset


In [2]:
PROJECT = "Llama3-8B-QLora-Omni"
MODEL_NAME = 'meta-llama/Meta-Llama-3-8B-Instruct'


In [3]:
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
device

'cuda:0'

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=bfloat16
)


model_config = transformers.AutoConfig.from_pretrained(
    MODEL_NAME,
    token=True
)

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    device_map='auto',
    token=True
)
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    MODEL_NAME,
    token=True
)
PAD_TOKEN = tokenizer.eos_token
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
tokenizer.padding_side = "right"

In [7]:
memory_used = model.get_memory_footprint()
print("Memory footprint: {} GB".format(memory_used/1024/1024/1024))

Memory footprint: 5.207535028457642 GB


In [8]:
def prepare_prompt_simple(row):
    """
    Prepares a natural language prompt for the entity resolution task.
    :param row: A tuple with two entities and the expected result.
    :return: A formatted prompt string.
    """
    
    entity_1, entity_2 = row['e1'], row['e2']
    # print(entity_1)
    prompt = f"""Do the two place descriptions refer to the same real-world place? Answer with 'Yes' if they do and 'No' if they do not.
    Place 1: {entity_1}
    Place 2: {entity_2}
    Answer: """
    return prompt

In [9]:
def prepare_prompt_attribute_value_distance(row):
    """
    Prepares a natural language prompt for the entity resolution task.
    :param row: A tuple with two entities and the expected result.
    :return: A formatted prompt string.
    """
    
    entity_1, entity_2, distance = row['e1'], row['e2'], row['distance']
    # print(entity_1)
    prompt = f"""Two place descriptions and the geographic distance between them is provided. Do the two place descriptions refer to the same real-world place? Answer with 'Yes' if they do and 'No' if they do not.
    Place 1: {entity_1}
    Place 2: {entity_2}
    Distance: {distance}
    Answer: """
    return prompt

In [10]:
logging.set_verbosity_error()

In [11]:
def zero_shot_inference(model, tokenizer, prompts, max_new_tokens):
    """
    Performs zero-shot inference using the model.
    :param model: The loaded quantized model.
    :param tokenizer: Tokenizer for the model.
    :param prompts: List of input prompts.
    :return: Model predictions (Yes/No).
    """
    results = []
    
    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        # outputs = model.pipeline(inputs.input_ids)
        outputs = model.generate(
            inputs.input_ids, 
            max_length=100,  # Maximum length of the generated text
            max_new_tokens= max_new_tokens,
            num_return_sequences=1,  # Number of sequences to generate
            no_repeat_ngram_size=2,  # Avoid repeating phrases
            temperature=0.01,  # Controls randomness; lower is less random
            top_k=50,  # Top-k sampling
        )
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # prediction = tokenizer.decode(outputs[:, inputs.shape[1]:])
        results.append(prediction.strip())
        
    return results

In [21]:
def calculate_metrics(predictions, labels):
 
    # Convert "Yes" to 1 and "No" to 0 for predicted labels
    predicted = [1 if label == "Yes" else 0 if label == "No" else 3 for label in predictions]
    
    # Ensure ground truth is already in binary format
    ground_truth = [int(x) for x in labels]
    # Calculate metrics
    precision = precision_score(ground_truth, predicted)
    recall = recall_score(ground_truth, predicted)
    f1 = f1_score(ground_truth, predicted)
    
    return {
        "Precision": precision,
        "Recall": recall,
        "F1 Score": f1
    }

In [22]:
# Select prompt to test zero shot. Select between "simple", "attribute_value", "plm" and "attribute_value_dist"
test_prompt = "plm"

In [23]:
dataset_folder_path = ['datasets2\\NZER_'+ test_prompt+ '\\auck\\', 
                       'datasets2\\NZER_'+ test_prompt+ '\\hope\\', 
                       'datasets2\\NZER_'+ test_prompt+ '\\norse\\',
                       'datasets2\\NZER_'+ test_prompt+ '\\north\\', 
                       'datasets2\\NZER_'+ test_prompt+ '\\palm\\', 
                       'datasets2\\GEOD_OSM_FSQ_'+ test_prompt+ '\\edi\\', 
                       'datasets\\GEOD_OSM_FSQ_'+ test_prompt+ '\\pit\\', 
                       'datasets\\GEOD_OSM_FSQ_'+ test_prompt+ '\\sin\\', 
                       'datasets\\GEOD_OSM_FSQ_'+ test_prompt+ '\\tor\\', 
                       'datasets\\GEOD_OSM_YELP_'+ test_prompt+ '\\edi\\', 
                       'datasets\\GEOD_OSM_YELP_'+ test_prompt+ '\\pit\\', 
                       'datasets\\GEOD_OSM_YELP_'+ test_prompt+ '\\sin\\', 
                       'datasets\\GEOD_OSM_YELP_'+ test_prompt+ '\\tor\\', 
                       'datasets\\SGN_'+test_prompt+'\\swiss\\']

In [26]:
# File path to the input data
for dataset_folder in dataset_folder_path:
    
    # print(dataset_folder.split('\\')[-3:])
    
    print(dataset_folder.split("\\"))
    dataset_output_path_1, dataset_output_path_2 = dataset_folder.split("\\")[-3], dataset_folder.split("\\")[-2]
        
    dataset = load_dataset(
        "json",
        data_files={"train": dataset_folder+"train.json", "valid": dataset_folder+"valid.json", "test": dataset_folder+"test.json"},
    )
    
    # data, labels = parse_file(dataset_folder)
    labels = [x['answer'] for x in dataset["test"]]
    if test_prompt=="attribute_value_dist":
        prompts = [prepare_prompt_attribute_value_distance(row) for row in dataset['test']]
    else:
        prompts = [prepare_prompt_simple(row) for row in dataset['test']]
    print(prompts[0])
    print(labels[0])
    predictions = zero_shot_inference(model, tokenizer, prompts, 1)
    predictions = [x.split(" ")[-1].strip() for x in predictions] 
    # predictions = [1 if label in ["Yes", "yes"] else 2 if label in ["no", "No"] else 3 for label in predictions]
    print(len(predictions), len(labels))
    print(calculate_metrics(predictions, [1 if lbl =="Yes" else 0 if lbl =="No" else 3 for lbl in labels]))

['datasets2', 'NZER_plm', 'auck', '']
Do the two place descriptions refer to the same real-world place? Answer with 'Yes' if they do and 'No' if they do not.
    Place 1: COL name VAL Tautini COL type VAL farmstead COL latitude VAL -40.18825 COL longitude VAL 176.14021 
    Place 2: COL name VAL Waikoukou Stream COL type VAL Stream COL latitude VAL -40.08742665596005 COL longitude VAL 176.28878276483226 
    Answer: 
No
601 601
{'Precision': 0.7142857142857143, 'Recall': 0.25, 'F1 Score': 0.37037037037037035}
['datasets2', 'NZER_plm', 'hope', '']
Do the two place descriptions refer to the same real-world place? Answer with 'Yes' if they do and 'No' if they do not.
    Place 1: COL name VAL Silver Stream COL type VAL stream COL latitude VAL -44.04833 COL longitude VAL 168.67008 
    Place 2: COL name VAL Deep Dale COL type VAL Valley COL latitude VAL -44.025083 COL longitude VAL 168.672056 
    Answer: 
No
2907 2907
{'Precision': 0.9, 'Recall': 0.08035714285714286, 'F1 Score': 0.1475409

Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Do the two place descriptions refer to the same real-world place? Answer with 'Yes' if they do and 'No' if they do not.
    Place 1: COL name VAL Mangatewai River Scenic Reserve COL type VAL Scenic Reserve COL latitude VAL -39.985556 COL longitude VAL 176.254722 
    Place 2: COL name VAL Ōtāwhao COL type VAL locality COL latitude VAL -40.0501529 COL longitude VAL 176.2682748 
    Answer: 
No
1783 1783
{'Precision': 1.0, 'Recall': 0.06666666666666667, 'F1 Score': 0.125}
['datasets2', 'NZER_plm', 'north', '']


Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Do the two place descriptions refer to the same real-world place? Answer with 'Yes' if they do and 'No' if they do not.
    Place 1: COL name VAL Waitiki Channel COL type VAL stream COL latitude VAL -34.51818 COL longitude VAL 172.88018 
    Place 2: COL name VAL Matapia COL type VAL Island COL latitude VAL -34.606333571762384 COL longitude VAL 172.79844153380557 
    Answer: 
No


KeyboardInterrupt: 

In [30]:
def prepare_prompt_gtminer_simple(row):
    """
    Prepares a natural language prompt for the entity resolution task.
    :param row: A tuple with two entities and the expected result.
    :return: A formatted prompt string.
    """
    
    entity_1, entity_2 = row['e1'], row['e2']
    # print(entity_1)
    prompt = f"""Two place descriptions are provided. Predict the relation between them. Answer only with ‘same_as’, ‘part_of’, ‘serves’ or ‘unknown’..
    Place 1: {entity_1}
    Place 2: {entity_2}
    Answer: """
    
    return prompt

In [31]:
def prepare_prompt_gtminer(row):
    """
    Prepares a natural language prompt for the entity resolution task.
    :param row: A tuple with two entities and the expected result.
    :return: A formatted prompt string.
    """
    
    entity_1, entity_2 = row['e1'], row['e2']
    # print(entity_1)
    prompt = f"""Two place descriptions are provided. Answer with 'same_as' if the first place is the same as the second place. Answer with 'part_of' if the first place is a part of the second place and is located inside the second place. Answer with 'serves' if the first place provides a service to the second place in terms of human mobility, assistance, etc. Answer with 'unknown' if the two places show none of these relations.
    Place 1: {entity_1}
    Place 2: {entity_2}
    Answer: """
    
    return prompt

In [32]:
def prepare_prompt_gtminer_distance(row):
    """
    Prepares a natural language prompt for the entity resolution task.
    :param row: A tuple with two entities and the expected result.
    :return: A formatted prompt string.
    """
    
    entity_1, entity_2, dist = row['e1'], row['e2'], row['distance']
    # print(entity_1)
    prompt = f"""Two place descriptions and the geographic distance between them are provided. Answer with 'same_as' if the first place is the same as the second place. Answer with 'part_of' if the first place is a part of the second place and is located inside the second place. Answer with 'serves' if the first place provides a service to the second place in terms of human mobility, assistance, etc. Answer with 'unknown' if the two places show none of these relations..
    Place 1: {entity_1}
    Place 2: {entity_2}
    Distance: {dist}
    Answer: """
    
    return prompt

In [33]:
# Select prompt to test zero shot for Geospatial relation mining task. Select between "simple", "attribute_value", "plm" and "attribute_value_dist"
test_prompt = "plm"

In [34]:
dataset_folder_path = ['datasets2\\GTMD_'+ test_prompt+ '\\mel\\', 
                       'datasets2\\GTMD_'+ test_prompt+ '\\sea\\', 
                       'datasets2\\GTMD_'+ test_prompt+ '\\sin\\',
                       'datasets2\\GTMD_'+ test_prompt+ '\\tor\\']

In [35]:


for dataset_folder in dataset_folder_path:
    
    
    
    print(dataset_folder.split("\\"))
    dataset_output_path_1, dataset_output_path_2 = dataset_folder.split("\\")[-3], dataset_folder.split("\\")[-2]
        
    dataset = load_dataset(
        "json",
        data_files={"train": dataset_folder+"train.json", "valid": dataset_folder+"valid.json", "test": dataset_folder+"test.json"},
    )
    
    # data, labels = parse_file(dataset_folder)
    labels = [1 if label == "same_as" else 2 if label == "part_of" else 3 if label == "serves" else 0 if label == "unknown" else 5 for label in dataset['test']['answer']]
    if test_prompt=="attribute_value_dist":
        prompts = [prepare_prompt_gtminer_distance(row) for row in dataset['test']]
    elif test_prompt=="simple":
        prompts = [prepare_prompt_gtminer_simple(row) for row in dataset['test']]
    else:
        prompts = [prepare_prompt_gtminer(row) for row in dataset['test']]
    print(prompts[0])
    print(dataset['test'][0]['answer'])
    predictions = zero_shot_inference(model, tokenizer, prompts, 2)
    predictions = [x.split(" ")[-1].strip() for x in predictions] 
    predictions = [1 if label in ["same_as", "same", "same-as"] else 2 if label in ["part_of", "part-of", "partof"] else 3 if label in ["serves", "served"] else 0 if label in ["unknown"] else 4 for label in predictions]
    print(len(predictions), len(labels))
    print(calculate_metrics2(predictions, labels))
    

['datasets2', 'GTMD_plm', 'mel', '']
Two place descriptions are provided. Answer with 'same_as' if the first place is the same as the second place. Answer with 'part_of' if the first place is a part of the second place and is located inside the second place. Answer with 'serves' if the first place provides a service to the second place in terms of human mobility, assistance, etc. Answer with 'unknown' if the two places show none of these relations.
    Place 1: COL name VAL JB Hi-Fi COL type VAL electronics COL address VAL nan COL latitude VAL -37.7681204 COL longitude VAL 145.304855 
    Place 2: COL name VAL Chirnside Homemaker Centre COL type VAL mall COL address VAL 282 Maroondah Highway 3116 COL latitude VAL -37.7663845 COL longitude VAL 145.3058855 
    Answer: 
part_of
1839 1839
P: 0.4786  |  R: 0.6586  |  F1: 0.5544
{'precision': 0.4786269430051813, 'recall': 0.6586452762923352, 'f1': 0.5543885971492873}
['datasets2', 'GTMD_plm', 'sea', '']


Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Two place descriptions are provided. Answer with 'same_as' if the first place is the same as the second place. Answer with 'part_of' if the first place is a part of the second place and is located inside the second place. Answer with 'serves' if the first place provides a service to the second place in terms of human mobility, assistance, etc. Answer with 'unknown' if the two places show none of these relations.
    Place 1: COL name VAL 32 Bar & Grill COL type VAL Sports Bars; American (Traditional) COL address VAL nan COL latitude VAL 47.70645953432845 COL longitude VAL -122.3245641011371 
    Place 2: COL name VAL Kraken Community Iceplex COL type VAL leisure COL address VAL 10601 5th Avenue Northeast 98125 COL latitude VAL 47.7062215 COL longitude VAL -122.3251917 
    Answer: 
part_of
4747 4747
P: 0.3026  |  R: 0.6159  |  F1: 0.4058
{'precision': 0.3025885900980146, 'recall': 0.6158567774936061, 'f1': 0.4057971014492754}
['datasets2', 'GTMD_plm', 'sin', '']


Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Two place descriptions are provided. Answer with 'same_as' if the first place is the same as the second place. Answer with 'part_of' if the first place is a part of the second place and is located inside the second place. Answer with 'serves' if the first place provides a service to the second place in terms of human mobility, assistance, etc. Answer with 'unknown' if the two places show none of these relations.
    Place 1: COL name VAL Garrett Gourmet Popcorn COL type VAL Candy Stores COL address VAL 541 Orchard Rd #01-K1 Liat Towers 238881 COL latitude VAL 1.3053032 COL longitude VAL 103.8305236 
    Place 2: COL name VAL Liat Towers COL type VAL Shopping Centers COL address VAL 541 Orch Rd 238881 COL latitude VAL 1.3051056 COL longitude VAL 103.8307274 
    Answer: 
part_of
7852 7852
P: 0.3356  |  R: 0.5645  |  F1: 0.421
{'precision': 0.33562071116656267, 'recall': 0.5645330535152151, 'f1': 0.42097026604068855}
['datasets2', 'GTMD_plm', 'tor', '']


Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Two place descriptions are provided. Answer with 'same_as' if the first place is the same as the second place. Answer with 'part_of' if the first place is a part of the second place and is located inside the second place. Answer with 'serves' if the first place provides a service to the second place in terms of human mobility, assistance, etc. Answer with 'unknown' if the two places show none of these relations.
    Place 1: COL name VAL New College Library COL type VAL library COL address VAL 20 Willcocks Street COL latitude VAL 43.6617897 COL longitude VAL -79.4001637 
    Place 2: COL name VAL Engineering Library COL type VAL library COL address VAL nan COL latitude VAL 43.6601686 COL longitude VAL -79.3949804 
    Answer: 
unknown
5101 5101
P: 0.3712  |  R: 0.6408  |  F1: 0.4701
{'precision': 0.37122687439143137, 'recall': 0.6407563025210085, 'f1': 0.4700986436498151}
