### Automated data flagging

Copyright &copy; 2024 Praneeth Vadlapati

In [1]:
import io
import os
import time
from urllib.parse import urlparse

from datasets import load_dataset
import pandas as pd
import replicate
# import requests
from tavily import TavilyClient

for key in ['TAVILY_API_KEY', 'GROQ_API_KEY']:
	if not os.getenv(key):
		raise Exception(f'{key} not found in environment or .env file')

from common_functions import get_filename, dataset_name, latest_dump_name, \
				flags_list, safe_flag, harm_categories, \
				print_progress, print_error, get_bot_response, \
				is_not_na, print_modelname

# Whether to process existing file or fetch data into a new file
PROCESS_EXISTING_FILE = True

### Collect the data

In [2]:
new_data_filename = None
skip_index = 0
last_existing_file_index = -1
for index in range(1000):
	new_data_filename = get_filename(index)
	if os.path.exists(new_data_filename) and os.stat(new_data_filename).st_size > 0:
		last_existing_file_index = index
		skip_index += len(pd.read_csv(new_data_filename))
	else:
		flagged_data_filename = get_filename(index, 'flagged')
		filtered_data_filename = get_filename(index, 'filtered')
		short_text_filename = get_filename(index, 'shortened')
		break  # found file that doesn't exist or is empty

if PROCESS_EXISTING_FILE and last_existing_file_index != -1:
	index = last_existing_file_index
	new_data_filename = get_filename(index)
	flagged_data_filename = get_filename(index, 'flagged')
	filtered_data_filename = get_filename(index, 'filtered')
	short_text_filename = get_filename(index, 'shortened')
	full_df = pd.read_csv(new_data_filename)
else:  # if the file exists, load it
	start_time = time.time()
	dataset = load_dataset(dataset_name, name=latest_dump_name, split='train', streaming=True)
	if skip_index:
		dataset = dataset.skip(skip_index)

	limit = 100
	data = []
	curr_size = 0
	for i, item in enumerate(dataset):
		# print(i)
		if curr_size >= limit:
			break
		data.append(item)
		print_progress()
		curr_size += 1

	time_min, time_sec = divmod(time.time() - start_time, 60)
	print(f'Time taken: {time_min:.0f} min {time_sec:.0f} sec')

	full_df = pd.DataFrame(data)
	full_df = full_df[full_df['language'] == 'en']  # drop rows with language not 'en'
	if not full_df.empty:
		full_df.to_csv(new_data_filename, index=False)

print('New data file:', os.path.basename(new_data_filename))
print('Flagged data file:', os.path.basename(flagged_data_filename))
full_df.head(2)

New data file: New_data - 0.full.csv
Flagged data file: New_data - 0.flagged.csv


Unnamed: 0,text,id,url
0,We want to know how to best serve you. Please ...,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback
1,Architectural Control Committee Policies and F...,<urn:uuid:77695799-0774-42a1-8eaa-5efbe154c4e0>,http://aberdeencreekfl.com/ACCBusiness/Procedu...


### Flagging the data:

#### Essential functions

In [3]:
# Emulating LlamaGuard using LLM
categories_str = '\n'.join([f'{code}: {category_label}' for code, category_label in harm_categories.items()])
categories_codes = list(harm_categories.keys())

def check_unsafe_LLM(text, user_prompt='', max_retries=3):
	instruction_prompt = (
		'Classify the following text based on predefined categories related to content safety. \n'
		'Respond with `safe` if the text does not fall into any unsafe categories. \n'
		'If the text is unsafe, respond with `unsafe` followed by the relevant category code(s) like S1.\n\n'
		f'Categories:\n{categories_str}\n\n'
		f'Text to classify:\n{text}\n{user_prompt} \n'
		f'Respond in `backticks` like `safe`, `unsafe\n{categories_codes[0]}`. \n'
	)
	messages = [{'role': 'user', 'content': instruction_prompt}]

	for _ in range(max_retries):
		try:
			response = get_bot_response(messages, process_backticks=True)
			if 'unsafe' in response:
				response = response.replace('unsafe', '').strip()  # Extract code if unsafe
				return response or True
			else:
				if not response:
					raise Exception('No response')
				return safe_flag
		except Exception as e:
			print_error(e)
			return None

In [4]:
replicate_not_working = True

def is_unsafe(bot_response, user_prompt='', max_retries=3):
	'Ask LlamaGuard if a text is safe'
	global replicate_not_working
	if replicate_not_working:
		check_unsafe_LLM(bot_response, user_prompt, max_retries=max_retries)
	for _ in range(max_retries):
		try:
			if not len(bot_response):
				return True

			output = replicate.run(
				'meta/meta-llama-guard-2-8b:b063023ee937f28e922982abdbf97b041ffe34ad3b35a53d33e1d74bb19b36c4',
				input={ 'prompt': user_prompt, 'assistant': bot_response },
			)
			if isinstance(output, str):
				if output == safe_flag:
					return False
				else:
					output = output.replace('unsafe', '').strip()
					return output  # returns reason like 'S5', 'S6', etc.
			else:
				raise Exception('Invalid output')
		except Exception as e:
			replicate_not_working = True
			print(f'Error: {e}. Retrying')

flagged_df = None

# if the file exists, load it
if os.path.exists(flagged_data_filename):
	flagged_df = pd.read_csv(flagged_data_filename)


if flagged_df is None or flagged_df.empty:
	# if flagged_df is None or flagged_df.empty:
	flagged_df = full_df.copy()
	flagged_df = flagged_df.dropna(subset=['text'])  # Filter out rows with empty text
	# columns_to_keep = ['id', 'url', 'date', 'text', 
	# 					'text_unsafe', 'domain_unsafe',
	# 					'domain_unindexed', 'flags']
	# # Keep only the columns in `columns_to_keep` that are also in `flagged_df`
	# available_columns = [col for col in columns_to_keep if col in flagged_df.columns]
	# flagged_df = flagged_df[available_columns]


def save_flagged_df(df=flagged_df):
	df.to_csv(flagged_data_filename, index=False)


tavily = TavilyClient(api_key=os.getenv('TAVILY_API_KEY'))

def is_indexed(domain, max_retries=3):
	'Search online to check if a domain is indexed'
	for _ in range(max_retries):
		try:
			result = tavily.search(query=f'site:{domain}')
			if not result or 'results' not in result:
				raise Exception('No results')
			indexed = len(result['results']) > 0
			if indexed is not None:
				return True if indexed else False
		except Exception as e:
			print(f'Error: {e}. Retrying')

#### Flagging unsafe text

(Typically takes 8 minutes for 100 rows)

In [5]:
def flag_unsafe_text(flagged_df):
	if 'text_unsafe' not in flagged_df.columns:
		flagged_df['text_unsafe'] = None

	# Identify rows with empty 'text_unsafe' value
	indices = flagged_df[flagged_df['text_unsafe'].isna()].index
	for i in indices:
		flagged_df.loc[i, 'text_unsafe'] = is_unsafe(flagged_df.loc[i, 'text']) or safe_flag
		print_progress()

	if not indices.empty:
		save_flagged_df()

start_time = time.time()
flag_unsafe_text(flagged_df)
print('Time taken:', time.time() - start_time)
print(f'Flagged data size: {flagged_df.shape}')
flagged_df.head(2)

Time taken: 0.0005154609680175781
Flagged data size: (100, 9)


Unnamed: 0,id,url,text,text_unsafe,domain_unsafe,domain_unindexed,flags,flag_reason,flags_singlerow
0,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback,We want to know how to best serve you. Please ...,safe,,,"sensitive_topic,unusable","Report discrimination, harassment, and sexual ...",safe
1,<urn:uuid:77695799-0774-42a1-8eaa-5efbe154c4e0>,http://aberdeencreekfl.com/ACCBusiness/Procedu...,Architectural Control Committee Policies and F...,safe,True,,safe,Useful knowledge,safe


#### Flagging unsafe domains

(Typically takes 5 minutes for 100 rows)

In [6]:
# use tavily api to find whether google indexes the page or not

def get_main_domain(domain):
	'Sometimes, only subdomains are provided. This function returns the main domain.'
	'Example: news.example.com -> example.com, www.example.co.uk -> example.co.uk, www.example.ac.ir -> example.ac.ir'
	parts = domain.split('.')
	# Handle cases with ccTLDs or multi-part TLDs (like co.uk, ac.ir)
	if len(parts) >= 3 and (parts[-2] in ['co', 'ac'] or len(parts[-1]) == 2):
		return '.'.join(parts[-3:])
	else:
		return '.'.join(parts[-2:])


def flag_unsafe_domains(flagged_df):
	'Flag domains that are unsafe or unindexed, as unsafe domains ' \
    'are less likely to be indexed by a search engine'
	urls = flagged_df['url'].dropna().unique()
	domains = [urlparse(url).netloc for url in urls]
	# get main domains from subdomains
	domains = [get_main_domain(domain) for domain in domains]
	domains = list(set(domains))  # remove duplicates

	if 'domain_unsafe' not in flagged_df.columns:
		flagged_df['domain_unsafe'] = None

		unsafe_domains = set()
		for domain in domains:
			if is_unsafe(domain):
				unsafe_domains.add(domain)
				print_progress()
		print(f'Unsafe domains: {unsafe_domains}')

		for i, row in flagged_df.iterrows():
			if row['domain_unsafe'] is None:
				domain = urlparse(row['url']).netloc
				flagged_df.loc[i, 'domain_unsafe'] = domain in unsafe_domains

		save_flagged_df()

	if 'domain_unindexed' not in flagged_df.columns:
		flagged_df['domain_unindexed'] = None

		unindexed_domains = set()
		for domain in domains:
			if not is_indexed(domain):
				unindexed_domains.add(domain)
				print_progress()
		print(f'Unindexed domains: {unindexed_domains}')

		for i, row in flagged_df.iterrows():
			if row['domain_unindexed'] is None:
				domain = urlparse(row['url']).netloc
				flagged_df.loc[i, 'domain_unindexed'] = domain in unindexed_domains

		save_flagged_df()


start_time = time.time()
flag_unsafe_domains(flagged_df)
print('Time taken:', time.time() - start_time)
print(f'Flagged data size: {flagged_df.shape}')
flagged_df.head(2)

Time taken: 0.0010066032409667969
Flagged data size: (100, 9)


Unnamed: 0,id,url,text,text_unsafe,domain_unsafe,domain_unindexed,flags,flag_reason,flags_singlerow
0,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback,We want to know how to best serve you. Please ...,safe,,,"sensitive_topic,unusable","Report discrimination, harassment, and sexual ...",safe
1,<urn:uuid:77695799-0774-42a1-8eaa-5efbe154c4e0>,http://aberdeencreekfl.com/ACCBusiness/Procedu...,Architectural Control Committee Policies and F...,safe,True,,safe,Useful knowledge,safe


#### Flagging unwanted rows like Sensitive Topics and Biased data

(Typically takes 5-10 minutes for 100 rows)

In [7]:
flags_to_detect = ', '.join(flags_list[:-1])
flag_df_columns = ['id', 'flags', 'flag_reason']
flag_df_columns_str = ','.join(flag_df_columns)
unusable_df_columns = ['id', 'unusable_flag', 'unusable_flag_reason']
unusable_df_columns_str = ','.join(unusable_df_columns)

def flag_chunk_using_LLM(df, indices_list, max_retries=3):
	csv_text = df[['id', 'text']].iloc[indices_list].to_csv(index=False)
	prompt = (
		'You are a content moderator. The text below will be used to fine-tune LLMs. '
		f'Fill `flags` column that contains one or more of flags to detect: `{flags_to_detect}`. \n'
		'If you flag a row, fill `flag_reason` column with a very short reason for flag choices. \n'
		f'Return only error-free CSV text back in triple backticks and no other text, like '
		f'```csv\n{flag_df_columns_str}\n<urn:uuid:aaaa>,"safe","safe"\n'
  		'<urn:uuid:aaab>,"scam,spam","Suggests potential manipulation, and crimes"\n``` \n'
		f'Input data: ```csv\n{csv_text}\n```. \n '
		f'Columns: `{flag_df_columns_str}`'
	)
	flag_result_df = None
	for _ in range(max_retries):
		try:
			response = get_bot_response(messages=[{ 'role': 'user', 'content': prompt }],
										process_backticks=True)
			df = pd.read_csv(io.StringIO(response))
			df = df[flag_df_columns]
			if df.empty:
				flag_result_df = None  # retry
			else:
				flag_result_df = df
				break
		except Exception as e:
			# print('Exception in flagging chunk using LLM:', e)
			print_error(e)

	unusable_flagging_prompt = (
		'You are a content moderator. The text below will be used to fine-tune LLMs. '
		'Use "unusable" as the flag if text does not convey new information/knowledge, or mark as "safe". \n'
		'Return csv text back in triple backticks. \n '
		f'Output columns (to be strictly followed): `{unusable_df_columns_str}` \n'
		f'Return only csv response and no other text like ```csv\n{unusable_df_columns_str}\n'
  		'<urn:uuid:aaaa>,"unusable","No useful/new info"\n<urn:uuid:aaab>,"safe","Useful knowledge"\n``` \n'
		f'Input data: ```\n{csv_text}\n```. \n '
		f'Columns: `{unusable_df_columns_str}`'
	)
	unusable_flag_df = None
	for _ in range(max_retries):
		try:
			response = get_bot_response(messages=[{ 'role': 'user', 'content': unusable_flagging_prompt }],
										process_backticks=True)
			df = pd.read_csv(io.StringIO(response))
			df = df[unusable_df_columns]
			if df.empty:
				unusable_flag_df = None
			else:
				unusable_flag_df = df
				break
		except Exception as e:
			print_error(e)
			# print('Exception in unusable-flagging', e)

	# Merge the two dataframes
	if flag_result_df is not None and unusable_flag_df is not None:
		df_combined = pd.merge(flag_result_df, unusable_flag_df, on='id', how='left')
		# Combine unusable_flag value using a comma
		df_combined['flags'] = df_combined.apply(
			lambda row: f"{row['flags']},{row['unusable_flag']}"
				if is_not_na(row['flags']) and is_not_na(row['unusable_flag'])
				else row['unusable_flag'] if is_not_na(row['unusable_flag'])
				else row['flags'] if is_not_na(row['flags'])
				else None,
			axis=1
		)
		df_combined['flag_reason'] = df_combined.apply(
			lambda row: f"{row['flag_reason']}. {row['unusable_flag_reason']}"
				if is_not_na(row['flag_reason']) and is_not_na(row['unusable_flag_reason'])
				else row['flag_reason'] if is_not_na(row['flag_reason'])
				else row['unusable_flag_reason'] if is_not_na(row['unusable_flag_reason'])
				else None,
			axis=1
		)
	elif flag_result_df is not None:
		df_combined = flag_result_df
	elif unusable_flag_df is not None:
		df_combined = unusable_flag_df.rename({'unusable_flag': 'flags'})
		df_combined['flag_reason'] = None
	else:
		raise Exception('Bot response failed for both undesirable flagging and unusable flagging')

	return df_combined[flag_df_columns]


def flag_with_LLM(flagged_df):
	if 'flags' not in flagged_df.columns:
		flagged_df['flags'] = None
	if 'flag_reason' not in flagged_df.columns:
		flagged_df['flag_reason'] = None

	# split indices into chunks
	chunk_size = 3
	noflags_indices = flagged_df[flagged_df['flags'].isna() |
							  flagged_df['flag_reason'].isna()].index
	chunks = [noflags_indices[i:i + chunk_size] for i in range(0, len(noflags_indices), chunk_size)]
	print(f'Total chunks: {len(chunks)}')

	for chunk in chunks:
		try:
			new_flags_df = flag_chunk_using_LLM(flagged_df, indices_list=chunk)
			new_flags_df['flags'] = new_flags_df['flags'].fillna(safe_flag)
			# for each id in flags_df, update the flags column in flagged_df
			# for all listed indices, mark as safe
			for i in chunk:
				flagged_df.loc[i, 'flags'] = safe_flag
			for i, row in new_flags_df.iterrows():
				flagged_df.loc[flagged_df['id'] == row['id'], 'flags'] = row['flags']
				flagged_df.loc[flagged_df['id'] == row['id'], 'flag_reason'] = row['flag_reason']
			print_progress()
			save_flagged_df()
		except Exception as e:
			print('Exception in flag with LLM:', e)
			print_error(e)

	# Experimenting by flagging only 1 row at a time
	# if 'flags_singlerow' not in flagged_df.columns:
	# 	flagged_df['flags_singlerow'] = None
	# 	print('Created column `flags_singlerow`')
	# print('Processing single row flags for:', flagged_df['flags_singlerow'].isna().sum())
	# for i, row in flagged_df.iterrows():
	# 	try:
	# 		if pd.isna(row['flags_singlerow']):
	# 			new_flags_df = flag_chunk_using_LLM(flagged_df, indices_list=[i])
	# 			if not new_flags_df.empty:
	# 				new_flags_df['flags'] = new_flags_df['flags'].fillna(safe_flag)
	# 				flagged_df.loc[i, 'flags_singlerow'] = new_flags_df.loc[0, 'flags']
	# 			print_progress()
	# 			save_flagged_df()
	# 	except Exception as e:
	# 		print_error(e)

print_modelname()
start_time = time.time()
flag_with_LLM(flagged_df)
print('Time taken:', time.time() - start_time)
flagged_df.head(2)

Model: None
Total chunks: 0
Time taken: 0.0007460117340087891


Unnamed: 0,id,url,text,text_unsafe,domain_unsafe,domain_unindexed,flags,flag_reason,flags_singlerow
0,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback,We want to know how to best serve you. Please ...,safe,,,"sensitive_topic,unusable","Report discrimination, harassment, and sexual ...",safe
1,<urn:uuid:77695799-0774-42a1-8eaa-5efbe154c4e0>,http://aberdeencreekfl.com/ACCBusiness/Procedu...,Architectural Control Committee Policies and F...,safe,True,,safe,Useful knowledge,safe
