In [1]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

os.environ["WANDB_PROJECT"] = "mscft_ner"
# os.environ["WANDB_LOG_MODEL"] = "true"
# os.environ["WANDB_WATCH"] = "none"

import json
import random
from glob import glob

import pandas as pd
from gliner import GLiNER

In [2]:
PATH = "models/checkpoint-1848" # "urchade/gliner_small-v2.1"

model = GLiNER.from_pretrained(PATH,).cuda() #  load_tokenizer=True

config.json not found in /data/home/eak/learning/zindi_challenge/micro_rec/models/checkpoint-1848


In [None]:
raw_test = json.load(open("data/accepted_data/TestCleaned.json"))
location_name = "location" # "disaster related location"
for i in raw_test:
	# i["label"] = labels
	for row in i["ner"]:
		row[-1] = location_name

texts = [
    " ".join(i["tokenized_text"]) for i in raw_test
]

def get_expected(raw):
    entities = [
        " ".join(raw["tokenized_text"][i[0]: i[1] + 1]) for i in raw["ner"]
    ]
    return " ".join(sorted(entities))
ners = [
    get_expected(i) for i in raw_test
]
test_data = pd.DataFrame(zip(texts, ners), columns=["text", "location"])
print(len(test_data))

test_data.sample(5)

1645


Unnamed: 0,text,location
1194,The best thing central Floridians can do with ...,
637,Largest evacuation in 25 years underway in # F...,Florida
1035,"Our @ ArtofLiving 2nd Truck , with 4 Tons of F...",Ernakulam
746,"Verizon customers in # SouthTexas , our data r...",
565,Venezuela to send humanitarian aid to hurrican...,Haiti Venezuela


In [4]:
from tqdm import tqdm

texts = test_data["text"].tolist()
bsize = 256

predictions = [
	model.batch_predict_entities(texts[i: i+bsize], ["disaster related location"], threshold=.05) for i in tqdm(range(0, len(texts), bsize))
]

len(predictions)

  0%|          | 0/7 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


100%|██████████| 7/7 [00:08<00:00,  1.26s/it]


7

In [5]:
thr = .05

all_predictions = sum(predictions, start=[])

all_predictions = [
	[i for i in raws if i["score"] >= thr] for raws in all_predictions
]

for i, raw in enumerate(all_predictions):
	for j in raw:
		j["raw_id"] = i

all_predictions = sum(all_predictions, start=[])
all_predictions[:5]

[{'start': 60,
  'end': 65,
  'text': 'Texas',
  'label': 'disaster related location',
  'score': 0.9944333434104919,
  'raw_id': 0},
 {'start': 83,
  'end': 92,
  'text': 'Islamabad',
  'label': 'disaster related location',
  'score': 0.9199880957603455,
  'raw_id': 1},
 {'start': 206,
  'end': 210,
  'text': 'NDMA',
  'label': 'disaster related location',
  'score': 0.17284485697746277,
  'raw_id': 1},
 {'start': 38,
  'end': 49,
  'text': 'Californias',
  'label': 'disaster related location',
  'score': 0.4454416036605835,
  'raw_id': 2},
 {'start': 112,
  'end': 119,
  'text': 'U . S .',
  'label': 'disaster related location',
  'score': 0.7139996886253357,
  'raw_id': 2}]

In [6]:
preds = pd.DataFrame(all_predictions)

preds.sample(10)

Unnamed: 0,start,end,text,label,score,raw_id
2428,30,43,Triangle Area,disaster related location,0.606575,1428
1380,185,193,Zimbabwe,disaster related location,0.922686,810
242,188,198,Mozambique,disaster related location,0.991964,137
900,37,45,Nebraska,disaster related location,0.106825,514
2607,4,13,Johnstown,disaster related location,0.829947,1531
198,51,74,Cape Canaveral Hospital,disaster related location,0.630374,111
2642,0,8,FLOODING,disaster related location,0.079737,1555
2235,33,43,California,disaster related location,0.846981,1312
1588,5,22,automotive plants,disaster related location,0.099825,917
1944,110,116,Mexico,disaster related location,0.997574,1136


In [7]:
preds["label"].value_counts()

label
disaster related location    2794
Name: count, dtype: int64

In [8]:
def create_predictions(raw: pd.DataFrame):
	entities = raw.to_dict("records")
	filtered_entities = []
	seen = {}

	for entity in entities:
		label = entity['label']
		text = entity['text']
		
		# Check if we have already seen this label-text pair
		if (label, text) not in seen:
			# Add to seen dictionary with the entity itself
			seen[(label, text)] = entity
		else:
			# If the entity already exists, keep the one with the higher score
			if entity['score'] > seen[(label, text)]['score']:
				seen[(label, text)] = entity

	# Add unique/high-score entities to filtered list
	filtered_entities = list(seen.values())

	# Step 2: Sort the entities first by label hierarchy and then by start index for duplicates
	sorted_entities = sorted(filtered_entities, key=lambda x: x['text'])
	return sorted_entities

structured_preds = preds.groupby("raw_id")[preds.columns].apply(create_predictions)

In [9]:
structured_preds.sample(10)

raw_id
360     [{'start': 59, 'end': 64, 'text': 'Haiti', 'la...
804     [{'start': 87, 'end': 98, 'text': 'CicloneIdai...
691     [{'start': 38, 'end': 46, 'text': 'Chipinge', ...
728     [{'start': 57, 'end': 66, 'text': 'Sri Lanka',...
563     [{'start': 19, 'end': 32, 'text': 'Fort McMurr...
131     [{'start': 2, 'end': 17, 'text': 'HurricaneHar...
1357    [{'start': 165, 'end': 171, 'text': 'Grecia', ...
278     [{'start': 75, 'end': 83, 'text': 'Pakistan', ...
1625    [{'start': 59, 'end': 65, 'text': 'Greece', 'l...
1191    [{'start': 14, 'end': 22, 'text': 'Nebraska', ...
dtype: object

In [10]:
structured_preds.loc[0]

[{'start': 60,
  'end': 65,
  'text': 'Texas',
  'label': 'disaster related location',
  'score': 0.9944333434104919,
  'raw_id': 0}]

In [11]:
structured_preds.loc[912]

[{'start': 32,
  'end': 44,
  'text': 'Cyclone Idai',
  'label': 'disaster related location',
  'score': 0.09712756425142288,
  'raw_id': 912},
 {'start': 148,
  'end': 158,
  'text': 'Mozambique',
  'label': 'disaster related location',
  'score': 0.9948888421058655,
  'raw_id': 912}]

In [12]:
test_data.iloc[912].values

array(['As the official death toll from Cyclone Idai rises sharply , the BBCs @ Pumza_Fihlani witnesses the trail of destruction that has left thousands in Mozambique stranded and homeless [ tap to expand ]',
       'Mozambique'], dtype=object)

In [13]:
import numpy as np

test_data.reset_index(drop=True, inplace=True)

n_position = test_data["location"].apply(lambda x: len(str(x).split()))

n_position.describe()

count    1645.000000
mean        1.316717
std         1.415931
min         0.000000
25%         0.000000
50%         1.000000
75%         2.000000
max        19.000000
Name: location, dtype: float64

In [14]:
test_data[n_position > 5]

Unnamed: 0,text,location
2,# BREAKING : Death toll from Northern Californ...,Butte County Northern Californias U . S .
167,Our relief efforts continue to help those impa...,Nebraska Nebraska North Carolina to Omaha
169,# CycloneIdai yet another alarm bell about the...,Boston Chicago D . C . Malawi Mozambique NYC Z...
209,Why has the media only focused on Irmas damage...,Florida U . S . Virgin Islands
261,FLOOD WATCH REMAINS IN EFFECT THROUGH TUESDAY ...,Fayette Maryland Maryland Pennsylvania Pennsyl...
283,"# KodaguFloodRelief Velayudhan , along with hi...",Karnatakas Kerala Kodagu Kodagu Kushalnagar Ma...
358,"Kudumbasree needs 26750 buckets , mugs , mops ...",Alappuzha Ernakulam Kottayam Pathanamthitta Th...
378,"RT @ TheREALTMGIJane : Lyall Bay , Wellington ...",Lyall Bay New Zealand NewZealand Wellington
411,Kerala Chief Minister Pinarayi Vijayan will to...,Alappuzha Chalakudy Chengannur Kerala Kozhench...
464,# CycloneIdai : Marowanyati Dam in Murambinda ...,Chimanimani Chimanimani Chipinge Murambinda Mw...


In [15]:
test_data.loc[1024].values

array(['Current # GFS # hurricane models for # Dorian show it raking along the coast from Florida , Georgia , South Carolina , North Carolina . and Virginia . Depending on how offshore the eye sits , could be a quite a task to handle such a span of damage across multiple states .',
       'Florida Georgia North Carolina South Carolina Virginia'],
      dtype=object)

In [16]:
structured_preds.loc[1024]

[{'start': 39,
  'end': 45,
  'text': 'Dorian',
  'label': 'disaster related location',
  'score': 0.27948373556137085,
  'raw_id': 1024},
 {'start': 82,
  'end': 89,
  'text': 'Florida',
  'label': 'disaster related location',
  'score': 0.7103641629219055,
  'raw_id': 1024},
 {'start': 92,
  'end': 99,
  'text': 'Georgia',
  'label': 'disaster related location',
  'score': 0.700630784034729,
  'raw_id': 1024},
 {'start': 119,
  'end': 133,
  'text': 'North Carolina',
  'label': 'disaster related location',
  'score': 0.6223258972167969,
  'raw_id': 1024},
 {'start': 102,
  'end': 116,
  'text': 'South Carolina',
  'label': 'disaster related location',
  'score': 0.6113390326499939,
  'raw_id': 1024},
 {'start': 140,
  'end': 148,
  'text': 'Virginia',
  'label': 'disaster related location',
  'score': 0.7333858013153076,
  'raw_id': 1024}]

In [17]:
preds.sample(5)

Unnamed: 0,start,end,text,label,score,raw_id
1984,127,135,Nebraska,disaster related location,0.991057,1164
1965,54,68,Princess Nokia,disaster related location,0.067721,1150
527,38,49,Puerto Rico,disaster related location,0.976933,298
868,82,98,Hurricane Marias,disaster related location,0.242108,496
1883,101,110,Camp Fire,disaster related location,0.506868,1093


In [18]:
test_data.sample(5)

Unnamed: 0,text,location
1087,RT @ sos_children : SOS Childrens Villages has...,Ecuador
28,"In the meantime , the death toll from Northern...",California
1064,# thisweeksdonation is to globalgivings Mexico...,Mexico
1504,CANADA ALBERTA : # FortMacFire damage : oil fi...,ALBERTA CANADA
567,RT @ RachelNBanfield : Northlanders : Civil De...,Northland


In [19]:
structured_preds.sample(5)

raw_id
1069    [{'start': 16, 'end': 23, 'text': 'Chinese', '...
82      [{'start': 41, 'end': 55, 'text': 'Santa Fe Ri...
1396    [{'start': 10, 'end': 15, 'text': 'Haiti', 'la...
447     [{'start': 38, 'end': 43, 'text': 'Haiti', 'la...
144     [{'start': 58, 'end': 64, 'text': 'Athens', 'l...
dtype: object

In [20]:
prediction_list = structured_preds.tolist()
expected = [i or "@" for i in test_data["location"]]

len(expected), len(prediction_list)

(1645, 1495)

In [21]:
len([i for i in expected if not i])

0

In [22]:
from evaluate import load

wer = load("wer")

def extract_predictions(thr = .5):
    order_element = lambda line: sorted(
        line, key=lambda x:  x["text"]
    )
    raws = [
        [j for j in i if j["score"] > thr]
        for i in prediction_list
    ]
    raws = [order_element(i) for i in raws]

    raws = {i[0]["raw_id"]: " ".join(j["text"] for j in i) for i in raws if i}
    preds = [raws.get(i, "@") for i in range(len(expected))]
    # refs = references["location"].fillna(" ").tolist()
    return wer.compute(predictions=preds, references=expected)

extract_predictions()

0.26189569851541683

In [23]:
best_score, best_thr = 1000, 0
for thr in np.linspace(0.1, 1, 50):
    score = extract_predictions(thr)
    if score < best_score:
        best_score = score
        best_thr = thr
        print(score, thr)

0.4815378759040731 0.1
0.45260753711457935 0.11836734693877551
0.42177388656261894 0.13673469387755102
0.39665017129805863 0.15510204081632656
0.3844689760182718 0.17346938775510207
0.37381043014845833 0.19183673469387758
0.3490673772363913 0.21020408163265308
0.3338408831366578 0.2285714285714286
0.3193757137419109 0.2469387755102041
0.30376855728968405 0.2653061224489796
0.2953939855348306 0.2836734693877551
0.29425199847735056 0.3020408163265306
0.2835934526075371 0.3204081632653062
0.2805481537875904 0.3387755102040817
0.27864484202512374 0.3571428571428572
0.2775028549676437 0.3755102040816327
0.273315569090217 0.3938775510204082
0.2676056338028169 0.41224489795918373
0.26532165968785687 0.43061224489795924
0.2626570232204035 0.44897959183673475
0.26189569851541683 0.5040816326530613


0.15569823434991975 0.5408163265306123

0.18576322801674913 0.5224489795918368