# Lockstep, the full recipe, Part 2: Training

In [1]:
import pandas as pd
import gc
from pathlib import Path
import os
from collections import defaultdict
from fastparquet import ParquetFile

verbosity = 5

twibot_path = r"/dataset/twibot22"
twibot_user = r"/dataset/twibot22/user.json"
twibot_label = r"/dataset/twibot22/label.csv"
twibot_graph_file = f"{twibot_path}/edge.csv"

generated_data_output = r"/dataset/twibot22/generated_data" # output is saved in this directory
ls_selected_indices_output = os.path.join(generated_data_output, "selected.csv")
ls_userdata_output_parquet = "assembled_user_details.parquet"
url_model_output = os.path.join(generated_data_output,"url_model.pkl")
url_model_prediction_output = os.path.join(generated_data_output,"url_predict.csv")


gnn_output_folder = os.path.join(generated_data_output,"gs_output")
gnn_output_dataset_name = os.path.join(gnn_output_folder,"gnn_graph_data.pt")
gnn_output_checkpoint_name = os.path.join(gnn_output_folder,"graphsage_model.ckpt")
gnn_output_relmap_name = os.path.join(gnn_output_folder,"relationship_map.pkl")
gnn_output_map_name = os.path.join(gnn_output_folder,"index_map.json")
gnn_output_predictions_name =  os.path.join(gnn_output_folder,"gs_predictions.csv")
gnn_output_model_path = os.path.join(gnn_output_folder,"graphsage_final.pt")

# Files in the path specified by twibot_path, that begin with %twibot_node_identifier_str%, will be assumed as node files and converted if needed.
twibot_node_identifier_str = "tweet_" 
NODE_FILE_LIST = list(filter(lambda fileName: twibot_node_identifier_str in fileName, 
                                        [child.name for child in Path(generated_data_output).iterdir()]))



concurrent_max_workers = 2

sample_set_constrain_minimum_posts = 25
sample_set_save_selected_indices = True
sample_set_size_per_label = 50000 # per label, sample this many users
sample_set_stratification = True # If at any point during selection our set becomes unbalanced, should we stratify?
sample_set_randomization = True # Should the _shuffle method be run on rows before the data is used on anything more?
sample_set_sampling_strategy = (1,0) # (1over/0under, 1major/0minor)

gnn_dataset_build_num_pass = 3
gnn_doing_experiment = False
gnn_data_load_from_file = True
gnn_load_from_checkpoint = True
gnn_model_type = "graphsage"  # Options: "dropedge_gcn", "gcn", or "graphsage"

scores = {}

if not Path(gnn_output_folder).exists:
    os.mkdir(gnn_output_folder)
    
def debug_print(m, level=5, r=None):
    if level <= verbosity:
        print(m)
        if r:
            raise r
    
def is_data(name, _dir=generated_data_output):
    file_path = os.path.join(_dir, f"{name}.parquet")
    return os.path.exists(file_path)
    
def get_data(name, _dir=generated_data_output, pqargs={}, **kwargs):
    if is_data(name, _dir):
        file_path = os.path.join(_dir, f"{name}.parquet")
        print(f"Loading existing data from {file_path}")
        #return pd.read_parquet(file_path)
        pf = ParquetFile(file_path, **pqargs)
        return pf.to_pandas(**kwargs)
    return False
        
def save_data(name, _dir=generated_data_output, df=None, **kwargs):
    if df is None:
            raise ValueError("No dataframe provided to save.")
    file_path = os.path.join(_dir, f"{name}.parquet")
    print(f"Saving data to {file_path}")
    os.makedirs(_dir, exist_ok=True)  # Ensure the directory exists
    df.to_parquet(file_path, **kwargs)
    return df      
    
def _shuffle(df):
    return df.sample(frac = 1)
    
shuffle_method = _shuffle

# To quietly stop cell execution
class StopExecution(Exception):
    def _render_traceback_(self):
        return []


import json
def get_post_counts():
    tweetNodeFilesParquet = list(filter(lambda fileName: twibot_node_identifier_str in fileName, 
                                        [child.name for child in Path(generated_data_output).iterdir()]))
    post_count_dict = defaultdict(int)
    debug_print(f"Called: get_post_counts", 5)
    for targetFile in tweetNodeFilesParquet:
        targetInput = Path(f"{generated_data_output}/{targetFile}")
        try:
            debug_print("Looking in " + targetInput.__str__(), 5)
            pf = ParquetFile(targetInput)
            df = pf.to_pandas(columns=['author_id'])
            for uid in df['author_id']:
                post_count_dict[uid] = post_count_dict[uid] + 1
            del pf, df
            gc.collect()
        except Exception as e:
            debug_print(f"Failed to load node parquet: {e}", 5)
            raise RuntimeError("Error processing Parquet files.")
    debug_print(f"Completed: get_post_counts", 5)
    return post_count_dict
    
def get_post_chunks(cols = '*', index="author_id", pqargs={}, pdkwargs={}, margs={}):
    # Result: Dataframe, index with one or more features.
    # Index: from users
    global NODE_FILE_LIST
    result_builder = None
    for targetFile in NODE_FILE_LIST:
        targetInput = Path(f"{generated_data_output}/{targetFile}")
        debug_print(f"Extracting from {targetInput.__str__()}...", 5)
        
        if cols != '*':
            # Set the columns to pull from the parquet, either through pqargs directly or here, through cols
            pdkwargs['columns'] = cols
            
        pdkwargs['index'] = index       
        try:
            pfinput = ParquetFile(targetInput, **pqargs)    
            process_group = pfinput.to_pandas(**pdkwargs)  
            result_builder = pd.concat([result_builder, process_group])         
        except Exception as e:
            debug_print(f"Failed to load node parquet: {e}", 5)
            raise RuntimeError("Error processing Parquet files.")
    return result_builder


# Lockstep, the full recipe, Part 2: Training

## Stage 1: Sample Selection, Constraint, Stratification

In [2]:
import numpy as np
from sklearn.utils import resample

def stratify_samples(df, label_col='label', target_col='sampled_post_count', strategy='oversample'):
    """
    Stratify samples to handle class imbalances.

    Parameters:
        df (pd.DataFrame): Input data with labels and target column.
        label_col (str): Column name for class labels.
        target_col (str): Column name for target data (if needed).
        strategy (str): 'oversample' or 'undersample'. Default is 'oversample'.

    Returns:
        pd.DataFrame: Stratified data.
        
    """

    bots = df[df[label_col] == 'bot']
    humans = df[df[label_col] == 'human']

    if strategy == 'oversample':
        # Oversample the minority class
        if len(bots) < len(humans):
            bots = resample(bots, replace=True, n_samples=len(humans), random_state=42)
        else:
            humans = resample(humans, replace=True, n_samples=len(bots), random_state=42)
    elif strategy == 'undersample':
        # Undersample the majority class
        if len(bots) < len(humans):
            humans = resample(humans, replace=False, n_samples=len(bots), random_state=42)
        else:
            bots = resample(bots, replace=False, n_samples=len(humans), random_state=42)
    else:
        raise ValueError("Strategy must be either 'oversample' or 'undersample'")

    # Combine and shuffle
    stratified_df = pd.concat([bots, humans]).sample(frac=1, random_state=42)
    stratified_df = stratified_df.loc[~stratified_df.index.duplicated()]
    return stratified_df

def limit_samples_per_label(df, label_col='label', sample_set_size_per_label=1000):
    """
    Limit the number of samples per label in the DataFrame.

    Parameters:
        df (pd.DataFrame): Input data with labels.
        label_col (str): Column name for class labels.
        sample_set_size_per_label (int): Maximum number of samples per label.

    Returns:
        pd.DataFrame: DataFrame with limited samples per label.
    """
    limited_dfs = []
    for label in df[label_col].unique():
        label_df = df[df[label_col] == label]
        limited_label_df = label_df.sample(n=min(sample_set_size_per_label, len(label_df)), random_state=42)
        limited_dfs.append(limited_label_df)

    limited_df = pd.concat(limited_dfs).sample(frac=1, random_state=42)
    return limited_df.loc[~limited_df.index.duplicated()]



def mask_selected_users_by_sampled_post_count(df, count):
    return pd.Series(user_detail_data['sampled_post_count'] >= count)


def _label_counts(df):
    return (df.loc[df['label'] == 'bot'].shape[0], df.loc[df['label'] == 'human'].shape[0])
    
debug_print(f"Loading data from parquet at {generated_data_output}",1)
user_detail_data = get_data("assembled_user_details", index='id')

debug_print(f"Loaded parquet. Column names: {user_detail_data.dtypes}",5)
user_detail_data['sampled_post_count'] = user_detail_data['sampled_post_count'].fillna(0).astype('int32')
max_sampled_posts = user_detail_data.loc[user_detail_data['sampled_post_count'] > 0, 'sampled_post_count'].max()
min_sampled_posts = user_detail_data.loc[user_detail_data['sampled_post_count'] > 0, 'sampled_post_count'].min()

debug_print(f"Min/Max sampled posts on a user: {min_sampled_posts}, ({max_sampled_posts})", 3)

display(user_detail_data.head(2))
# ls_selected_indices_output = rf"{generated_data_output}/selected.csv"
# If we have a saved file specifying the indices of our selected samples, we can just load those and skip the following sections e.g. we are continuing an experiment
# If not, continue and then depending on the sample_set_save_selected_indices flag, save the indices for the next run.

if os.path.exists(ls_selected_indices_output):
    debug_print(f"Found saved selected indices at {ls_selected_indices_output}. Loading...", 1)
    selected_indices = pd.read_csv(ls_selected_indices_output)
    selected_user_data = user_detail_data.loc[selected_indices['selected_indices']]
    debug_print(f"Loaded {len(selected_indices)} selected samples from saved file.", 5)
else:
    debug_print(f"No saved indices found. Proceeding with filtering and selection.", 1)
    
    # Randomization
    if sample_set_randomization:
        user_detail_data = _shuffle(user_detail_data)
        print("Sample rows randomized.")

    bot_users_count, human_users_count = _label_counts(user_detail_data)
    debug_print(f"{user_detail_data.shape[0]} samples available. {bot_users_count} are bots, and {human_users_count} are humans (labeled).", 5)

    # Filter out samples to users that have at least X number of sampled posts in the post data.
    constrained_user_data_mask = mask_selected_users_by_sampled_post_count(user_detail_data, sample_set_constrain_minimum_posts)
    selected_user_data = user_detail_data.loc[constrained_user_data_mask] if constrained_user_data_mask is not None else user_detail_data

    selected_bot_users_count = selected_user_data.loc[selected_user_data['label'] == 'bot'].shape[0]
    selected_human_users_count = selected_user_data.loc[selected_user_data['label'] == 'human'].shape[0]
    debug_print(f"Of these, {selected_user_data.shape[0]} meet sampled minimum post constraints. {selected_bot_users_count} bots meet this constraint, and {selected_human_users_count} humans meet this constraint.", 5)

    # Calculate balance
    balance = (selected_bot_users_count / float(selected_human_users_count)) - (selected_human_users_count / float(selected_bot_users_count))
    skew = abs(balance * 100)
    debug_print(f"The balance between human and bot users is skewed by about ~{skew}% on the side of {'humans' if balance < 0 else 'bots'}", 3)

    # Stratification
    if sample_set_stratification:
        selected_user_data = stratify_samples(
            selected_user_data,
            label_col='label',
            target_col='sampled_post_count',
            strategy='undersample'  # Change to 'oversample' if preferred
        )
        stratified_bot_count, stratified_human_count = _label_counts(selected_user_data)
        debug_print(f"Post-stratification: {selected_user_data.shape[0]} samples available. {stratified_bot_count} bots, {stratified_human_count} humans.", 5)
        balance_after = (stratified_bot_count / float(stratified_human_count)) - (stratified_human_count / float(stratified_bot_count))
        skew_after = abs(balance_after * 100)
        debug_print(f"Post-stratification skew is ~{skew_after}%.", 3)

    # Truncate data to limit samples per label
    bot_users_count, human_users_count = _label_counts(selected_user_data)
    sample_set_size_per_label = min(sample_set_size_per_label, min(bot_users_count, human_users_count))
    debug_print(f"Truncating data to sample_set_size_per_label: {sample_set_size_per_label}", 1)
    selected_user_data = limit_samples_per_label(
        selected_user_data,
        label_col='label',
        sample_set_size_per_label=sample_set_size_per_label
    )
    debug_print(f"After truncation: {selected_user_data.shape[0]} samples available.", 5)
    limited_bot_count, limited_human_count = _label_counts(selected_user_data)
    debug_print(f"Truncated data contains {limited_bot_count} bots and {limited_human_count} humans.", 5)

    # Save the selected indices if the flag is set
    if sample_set_save_selected_indices:
        debug_print(f"Saving selected indices to {ls_selected_indices_output}", 1)
        selected_indices = pd.DataFrame({'selected_indices': selected_user_data.index})
        selected_indices.to_csv(ls_selected_indices_output, index=False)
        debug_print(f"Saved {len(selected_user_data)} selected indices to {ls_selected_indices_output}.", 5)

Loading data from parquet at /dataset/twibot22/generated_data
Loading existing data from /dataset/twibot22/generated_data/assembled_user_details.parquet
Loaded parquet. Column names: created_at                           datetime64[ns, UTC]
description                                       object
location                                          object
name                                              object
url                                               object
username                                          object
label                                             object
followers_count                                    int64
following_count                                    int64
tweet_count                                        int64
listed_count                                       int64
url.urls                                          object
description.urls                                  object
description.mentions                              object
description.hashtag

Unnamed: 0_level_0,created_at,description,location,name,url,username,label,followers_count,following_count,tweet_count,...,tweet_urls_total,tweet_hashtags_total,avg_hashtags_in_tweet,avg_urls_in_tweet,tweet_urls_top_x,tweet_hashtags_top_x,tweet_has_hashtag_weekday_entropy,tweet_has_hashtag_hour_entropy,tweet_has_url_weekday_entropy,tweet_has_url_hour_entropy
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1217628182611927040,2020-01-16 02:02:55+00:00,Theoretical Computer Scientist. See also https...,"Cambridge, MA",Boaz Barak,https://t.co/BoMip9FF17,boazbaraktcs,human,7316,215,3098,...,39940.0,27598.0,24.208772,35.035088,"{'twitter.com': 110, 'horoscoponegro.com': 29,...","{'Aries': 323, 'ARIES': 286, 'aries': 3, 'Part...",,,,
2664730894,2014-07-02 17:56:46+00:00,creative _,🎈,olawale 💨,,wale_io,human,123,1090,1823,...,5154.0,4005.0,4.009009,5.159159,"{'twitter.com': 5, 'xkcd.com': 1, 'swag.github...",{},,,,


Found saved selected indices at /dataset/twibot22/generated_data/selected.csv. Loading...
Loaded 100000 selected samples from saved file.


## Step 2: Convert edge file into relationship dict

In [3]:
import pickle

ENTITY_USER = 0
ENTITY_GROUP = 1
ENTITY_POST = 2
entity_types = { 'u' : ENTITY_USER, 'l': ENTITY_GROUP, 't': ENTITY_POST, 'user': ENTITY_USER, 'group':ENTITY_GROUP, 'post': ENTITY_POST}

def strip_id(x):
    prefix_map = {'u': 'u', 'l': 'l', 't': 't'}
    return int(x[1:]) if x and x[0] in prefix_map else -1

def entity_type_from_id(any_str_id):
    return entity_types.get(any_str_id[0], -1)
    
REL_TYPE_DISC = 0
REL_TYPE_FOLLWD = 1
REL_TYPE_REPLD = 2
REL_TYPE_RTD = 3
REL_TYPE_POSTED = 4
REL_TYPE_PINNED = 5
REL_TYPE_MEMB = 6
REL_TYPE_QUOTED = 7
REL_TYPE_CONTAINS = 8
REL_TYPE_FOLLWG = 9
REL_TYPE_MENT = 10
REL_TYPE_FOLLWR = 11
REL_TYPE_LIKED = 12
REL_TYPE_OWN = 13

relationship_types = {
                      'discuss':REL_TYPE_DISC,
                      'followed':REL_TYPE_FOLLWD,
                      'replied_to':REL_TYPE_REPLD,
                      'retweeted':REL_TYPE_RTD,
                      'post':REL_TYPE_POSTED,
                      'pinned':REL_TYPE_PINNED,
                      'membership':REL_TYPE_MEMB,
                      'quoted':REL_TYPE_QUOTED,
                      'contain':REL_TYPE_CONTAINS,
                      'following':REL_TYPE_FOLLWG,
                      'mentioned':REL_TYPE_MENT,
                      'followers':REL_TYPE_FOLLWR,
                      'like':REL_TYPE_LIKED,
                      'own':REL_TYPE_OWN
                    }


def dd():
    return defaultdict(set)
    
def convert_edge_file_to_rel_map(relationship_map_file='relationship_map.pkl'):
    
    """
    Convert edge file to relationship_map
    If the relationship_map file exists, load it. Otherwise, parse the edge parquet and save the result.
    Args:
        relationship_map_file (str): Path to the file to save/load the relationship map.

    Returns:
        dict: The relationship map.
    """

    global selected_user_data
    # Check if the relationship map file exists
    if os.path.exists(relationship_map_file):
        print(f"Loading relationship map from {relationship_map_file}...")
        with open(relationship_map_file, 'rb') as f:
            relationship_map = pickle.load(f)
        print(f"Loaded relationship map with {len(relationship_map)} nodes.")
        return relationship_map

    print("Relationship map file not found. Parsing edge parquet to create new map.")
    
    relationship_map = defaultdict(dd)
    next_pass_targets = set('u'+str(x) for x in selected_user_data.index.values) 
    print(f"Total passes scheduled: {gnn_dataset_build_num_pass} on {len(next_pass_targets)} initial target entities.")
    relevant_data = get_data("edges")

    try:
        for i in range(gnn_dataset_build_num_pass):
            print(f"Pass {i}...")

            sub_pass = [ relevant_data.loc[relevant_data['id1'].isin(next_pass_targets)] ]
            next_pass_targets.clear()            
            
            # src, dest
            def graph_dir_pass(rw, k1, k2):                 
                relationship_map[rw[k1]][str(relationship_types[rw['relationship']])].add(rw[k2])
                next_pass_targets.add(rw[k2])
                
            sub_pass[0].apply(graph_dir_pass, args=('id1', 'id2'), axis=1)             
            print(f"Entities collected: {len(next_pass_targets)}")
            
        edge_counter = sum(len(v) for v in relationship_map.values())
        print("All passes completed. Nodes collected:", len(relationship_map), "Edges collected:", edge_counter)
        
    except Exception as e:
        print(f"Problem collecting relationship edges: {str(e)} Freeing resources.")
        del relevant_data, relationship_map
        gc.collect()
        raise StopExecution
        
    # Save the relationship map to file
    print(f"Saving relationship map to {relationship_map_file}...")
    with open(relationship_map_file, 'wb') as f:
        pickle.dump(relationship_map, f)
    print("Relationship map saved successfully.")
    del relevant_data
    gc.collect()
    return relationship_map


gnn_rel_map = convert_edge_file_to_rel_map(relationship_map_file=gnn_output_relmap_name)


print("Constrinaed Preview (Number of entities in relationship types limited to 10):")
def new_encoder(unknown_var):
    if type(unknown_var) is set:
        return str(list(unknown_var)[0:10])
    if type(unknown_var) is list:
        return str(list(unknown_var[0:10]))
    return str(unknown_var)
print(json.dumps(dict(list(gnn_rel_map.items())[0:2]),indent=1, default=new_encoder))


Loading relationship map from /dataset/twibot22/generated_data/gs_output/relationship_map.pkl...
Loaded relationship map with 10290934 nodes.
Constrinaed Preview (Number of entities in relationship types limited to 10):
{
 "u251372955": {
  "11": "['u51934480', 'u76966092', 'u1711571520', 'u218379543', 'u187044118', 'u56986684', 'u1008681429423837185', 'u354582942', 'u1244426959322599428', 'u96117173']",
  "9": "['u76966092', 'u51934480', 'u1008681429423837185', 'u56986684', 'u218379543', 'u386538440', 'u899480529296379904', 'u8062702', 'u34373064', 'u14777850']",
  "12": "['t1438602507018719232', 't1359932716863651844', 't1321741120989265920', 't1376208963243507723', 't1442699035178065922', 't1308750685865992203', 't1481429296501739520', 't1335726528253079553', 't1375154351975530498', 't1329492903283073024']",
  "4": "['t1278210024763011073', 't1272197193940770816', 't1271050880532520962', 't1280016344965152768', 't1254289206664454146', 't1273113212867788802', 't1290331795637055489', 

## Step 3: Normalize Features

In [4]:
from datetime import datetime, timezone
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler

def datetime_to_age_in_seconds(dt):
    if pd.api.types.is_datetime64_any_dtype(dt):
        dt = dt.date
    if not isinstance(dt, datetime):
        raise ValueError("The input must be a datetime object.")
    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=timezone.utc)  # Correct use of timezone.utc
    now = datetime.now(timezone.utc)  # Correct use of timezone.utc
    age_in_seconds = int((now - dt).total_seconds())
    return age_in_seconds



#--------------
# User Features
#--------------
debug_print("Normalizing user information...", 3)
target_user_feature_keys = [
    'followers_count',
    'following_count',
    'listed_count',
    'tweet_count',
    'tweet_following_ratio',
    'created_at',
    'profile_desc_len'
]

copied_user_details_for_graph = user_detail_data.loc[:, target_user_feature_keys].copy()
copied_user_details_for_graph.replace([np.inf, -np.inf], np.nan, inplace=True)
copied_user_details_for_graph.dropna(inplace=True)
copied_user_details_for_graph['account_age'] = copied_user_details_for_graph['created_at'].apply(datetime_to_age_in_seconds)
copied_user_details_for_graph.drop('created_at', axis=1, inplace=True)

scaler = preprocessing.MinMaxScaler()
transformed = scaler.fit_transform(copied_user_details_for_graph)
normalized_details_for_graph = pd.DataFrame(
    transformed, 
    index=copied_user_details_for_graph.index, 
    columns=copied_user_details_for_graph.columns
)

debug_print("Sample user features for graph node embedding (before normalization):", 5)
if verbosity >= 5: 
    display(copied_user_details_for_graph.head(2))

debug_print("Sample user features for graph nodes after normalization:", 5)
if verbosity >= 5: 
    display(normalized_details_for_graph.head(2))


#--------------
# Post Features
#--------------

debug_print("Getting post information...", 3)
relevant_data = get_post_chunks(cols=["id","quote_count","like_count","retweet_count","reply_count"], index="id")
display(relevant_data.head(5))

relevant_data.fillna({'like_count': 0, 'quote_count': 0, 'retweet_count': 0, 'reply_count': 0}, inplace=True)

debug_print("Finding relevant node IDs...", 3)
post_ids = [strip_id(node) for node in gnn_rel_map if entity_type_from_id(node) == ENTITY_POST]
valid_post_ids = list(set(relevant_data.index).intersection(post_ids))
valid_posts = relevant_data.loc[valid_post_ids]

print("Normalizing...")
scaler = preprocessing.MinMaxScaler()
transformed = scaler.fit_transform(valid_posts)
normalized_thread_feature_data = pd.DataFrame(
    transformed, 
    index=valid_posts.index, 
    columns=valid_posts.columns
)

# Convert to a dictionary for `thread_feature_data`
#thread_feature_data = normalized_thread_feature_data[['like_count', 'quote_count', 'retweet_count', 'reply_count']].to_dict(orient='index')

debug_print("Completed. Sample of normalized thread data:",3)
if verbosity >= 5: 
    display(normalized_thread_feature_data.head(10))


del relevant_data

Normalizing user information...
Sample user features for graph node embedding (before normalization):


Unnamed: 0_level_0,followers_count,following_count,listed_count,tweet_count,tweet_following_ratio,profile_desc_len,account_age
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1217628182611927040,7316,215,69,3098,14.409302,92,158635127
2664730894,123,1090,0,1823,1.672477,10,333451496


Sample user features for graph nodes after normalization:


Unnamed: 0_level_0,followers_count,following_count,listed_count,tweet_count,tweet_following_ratio,profile_desc_len,account_age
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1217628182611927040,5.601278e-05,5.1e-05,0.000128,6.1e-05,8.576703e-07,0.290221,0.041333
2664730894,9.417129e-07,0.000262,0.0,3.6e-05,9.954916e-08,0.031546,0.147461


Getting post information...
Extracting from /dataset/twibot22/generated_data/tweet_7.parquet...
Extracting from /dataset/twibot22/generated_data/tweet_8.parquet...
Extracting from /dataset/twibot22/generated_data/tweet_4.parquet...
Extracting from /dataset/twibot22/generated_data/tweet_1.parquet...
Extracting from /dataset/twibot22/generated_data/tweet_6.parquet...
Extracting from /dataset/twibot22/generated_data/tweet_2.parquet...
Extracting from /dataset/twibot22/generated_data/tweet_5.parquet...
Extracting from /dataset/twibot22/generated_data/tweet_0.parquet...
Extracting from /dataset/twibot22/generated_data/tweet_3.parquet...


Unnamed: 0_level_0,quote_count,like_count,retweet_count,reply_count
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1499547298828886016,0,4,0,0
1499544740148264962,0,1,0,0
1499542095224360962,0,2,0,0
1499539485092155402,0,1,0,0
1499535965089632262,0,2,0,0


Finding relevant node IDs...
Normalizing...
Completed. Sample of normalized thread data:


Unnamed: 0_level_0,quote_count,like_count,retweet_count,reply_count
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
35819845282955264,0.0,3.303699e-07,0.0,0.0
32800663024959489,0.0,0.0,0.0,0.0
26434597420007424,0.0,0.0,0.0,0.0
24628078877605888,0.0,0.0,0.0,0.0
2130691131179009,0.0,0.0,0.0,0.0
34938672109322240,0.0,0.0,0.0,0.0
727225741606913,0.0,0.0,0.0,0.0
32959274892132353,0.0,0.0,2.612459e-07,0.0
7940735991021568,0.0,0.0,0.0,0.0
20547389144170496,0.0,0.0,0.0,0.0


# Stage 2: Train Embedded Models 
### Part A: URL Scoring

Bots have a purpose, and that purpose generally means needing to spread external information. Sometimes that means using external URLS. If we put all those URLS people post in a bag, can we spot the patterns in how bots post URLs versus humans?

In [14]:
import os
import gc
from collections import Counter
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.ensemble import RandomForestClassifier
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import joblib  # For saving and loading models

class URLConfidenceModel:
    def __init__(self, exclude_domains=None, domain_bag_min_frequency=2, domain_bag_max_domains=20):
        self.exclude_domains = exclude_domains or {"twitter.com", "www.twitter.com", "t.co"}
        self.domain_bag_min_frequency = domain_bag_min_frequency
        self.domain_bag_max_domains = domain_bag_max_domains
        self.domain_index = None
        self.model = None

    def preprocess_data(self, data, one_hot_encode=True):
        if not one_hot_encode:
            raise NotImplementedError("Only one-hot encoding is supported in this version.")

        domain_counts = Counter(
            domain
            for row in data['tweet_urls_top_x'].apply(dict)
            for domain, freq in row.items()
        )

        filtered_domains = {
            domain for domain, count in domain_counts.items()
            if count >= self.domain_bag_min_frequency and domain not in self.exclude_domains
        }

        print(f"{len(filtered_domains)} domains selected for modeling.")

        def calculate_domain_frequencies(row):
            domain_data = {}
            idx = 0
            for domain, freq in row.items():
                if domain in filtered_domains:
                    domain_data[domain] = {'frequency': freq}
                    idx += 1
                    if idx >= self.domain_bag_max_domains:
                        break
            return domain_data

        X_one_hot_dict = data['tweet_urls_top_x'].apply(calculate_domain_frequencies)

        attributes = ['frequency']
        self.domain_index = {
            (domain, attr): idx
            for idx, (domain, attr) in enumerate(
                (domain, attr) for domain in filtered_domains for attr in attributes
            )
        }

        rows, cols, values = [], [], []
        for row_idx, domain_data in enumerate(X_one_hot_dict):
            for domain, features in domain_data.items():
                for attr, value in features.items():
                    col_idx = self.domain_index[(domain, attr)]
                    rows.append(row_idx)
                    cols.append(col_idx)
                    values.append(value)

        X_sparse = csr_matrix((values, (rows, cols)), shape=(len(X_one_hot_dict), len(self.domain_index)))
        return X_sparse

    def train_model(self, X, y, refit=False, param_distributions=None, refit_num_iter=20):
        base_params = {
            'bootstrap': False,
            'random_state': 42
        }

        if refit:
            param_distributions = param_distributions or {
                'criterion': ['gini', 'entropy'],
                'n_estimators': [300, 350],
                'max_depth': [30, 20],
                'min_samples_split': [5, None],
                'min_samples_leaf': [2, None],
                'max_features': ["sqrt", None],
                'class_weight': ['balanced', None],
            }

            clf = RandomForestClassifier(**base_params)
            random_search = RandomizedSearchCV(
                clf, param_distributions=param_distributions, n_iter=refit_num_iter,
                scoring='f1_weighted', n_jobs=3, cv=3
            )
            random_search.fit(X, y)
            self.model = random_search.best_estimator_
            print(f"Best parameters: {random_search.best_params_}")
        else:
            params = {
                'n_estimators': 300,
                'min_samples_split': 15,
                'min_samples_leaf': 2,
                'max_features': "sqrt",
                'max_depth': 30,
                'class_weight': 'balanced',
                'criterion': 'entropy'
            }
            self.model = RandomForestClassifier(random_state=42, **params)
            self.model.fit(X, y)

    def evaluate_model(self, X_test, y_test):
        y_pred = self.model.predict(X_test)
        y_prob = self.model.predict_proba(X_test)

        acc = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred, average='weighted')
        roc_auc = roc_auc_score(y_test, y_prob[:, 1])

        print(f"Accuracy: {acc:.4f}")
        print(f"F1 Score: {f1:.4f}")
        print(f"ROC AUC: {roc_auc:.4f}")
     
        return acc, f1, roc_auc

    def make_predictions(self, X, user_ids, output_file=None):
        """
        Make predictions using the trained model and return associated user IDs, binary predictions,
        and probability scores for class '1'.
    
        Args:
            X (csr_matrix): Preprocessed feature matrix.
            user_ids (list or array-like): List of user IDs corresponding to the rows in X.
    
        Returns:
            DataFrame: A DataFrame containing user IDs, binary predictions, and probability scores.
        """
        if self.model is None:
            raise ValueError("The model has not been trained or loaded. Please train or load a model before making predictions.")
        
        # Predict class probabilities and binary labels
        y_prob = self.model.predict_proba(X)[:, 1]  # Probability scores for class '1'
        y_pred = self.model.predict(X)  # Binary predictions
    
        # Create a DataFrame with results
        results = pd.DataFrame({
            'user_id': user_ids,
            'prediction': y_pred,
            'probability_class_1': y_prob
        })

        if output_file is not None:
            results.to_csv(output_file, index=False)
        return results
        
    def save_model(self, filepath):
        joblib.dump(self.model, filepath)
        print(f"Model saved to {filepath}")

    def load_model(self, filepath):
        self.model = joblib.load(filepath)
        print(f"Model loaded from {filepath}")

    def plot_feature_importances(self, top_n=25):
        importances = self.model.feature_importances_
        indices = np.argsort(importances)[::-1]
        feature_names = [f"{domain}_{attr}" for domain, attr in self.domain_index.keys()]

        top_indices = [idx for idx in indices if "frequency" in feature_names[idx]][:top_n]

        plt.figure(figsize=(10, 6))
        plt.title(f"Top {top_n} Feature Importances", fontsize=14)
        plt.barh(range(len(top_indices)), importances[top_indices][::-1], color="b", align="center")
        plt.yticks(range(len(top_indices)), [feature_names[i] for i in top_indices[::-1]])
        plt.xlabel("Weight Coefficient", fontsize=12)
        plt.ylabel("Feature", fontsize=12)
        plt.tight_layout()
        plt.show()

model = URLConfidenceModel()
load_url_model_from_file = False

X_lab = selected_user_data.index.array
X_sparse = model.preprocess_data(selected_user_data)
y = selected_user_data['label']
X_train, X_test, y_train, y_test = train_test_split(X_sparse, y, test_size=0.2, stratify=y)

if load_url_model_from_file:
    model.load_model(url_model_output)
else:   
    model.train_model(X_train, y_train)
    model.evaluate_model(X_test, y_test)
    model.save_model(url_model_output)

model.make_predictions(X_sparse,selected_user_data.index.array,output_file=url_model_prediction_output)


24123 domains selected for modeling.
Accuracy: 0.5719
F1 Score: 0.5651
ROC AUC: 0.6061
Model saved to /dataset/twibot22/generated_data/url_model.pkl


Unnamed: 0,user_id,prediction,probability_class_1
0,89612950,bot,0.478117
1,466835426,bot,0.495663
2,429208947,human,0.500569
3,1181476422327689216,human,0.518565
4,1381627326753038341,human,0.516660
...,...,...,...
99995,1486417203092148227,bot,0.489038
99996,372830746,human,0.552072
99997,44653376,human,0.515908
99998,2337682316,human,0.536984


# Stage 2 Part B: Relationship Modeling

There are patterns in how bots choose to reply to certain tweets, keywords, or users. If we see these interactions as graphed relationships, can we find patterns?

## Warnings: CUDA USE (if it exists)

## Step 1: Import torch, prepare constants, prepare cuda

In [3]:
import torch 
from cogdl.experiments import experiment, output_results
import torch
import os
from cogdl import experiment
from cogdl.datasets import build_dataset
from cogdl.models import build_model
from cogdl.options import get_default_args



dataset = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["PYTORCH_DEBUG"] = "1"

experiment_params =  {
                'epochs':200,
                'do_valid':False,
                'gnn_output_checkpoint_name':gnn_output_checkpoint_name,
                'n_trials':2,
                'hidden_size':[175],
                'batch_size':1024,
                #weight_decay=1e-4,
                #eval_step=5,
                'lr':0.01,
                'patience':20,
                'dropout':0.4,               # Reduced dropout for large graphs
                'sample_size':[15, 10],      # Increased neighbors sampled
                'num_layers':2,              # Deeper model
            }

graphsage_run_params = {
        'dropout':0.3,               # Reduced dropout for large graphs
        'sample_size':[15, 10],      # Increased neighbors sampled        
        'hidden_size':[175],    # Hidden feature dimension
        'num_layers':2,      # Number of layers
        'aggr':'mean'
}

print(f"Using device: {device}")

Using device: cuda


## Step 2: Convert intermediate format dictionary into CogDL appropriate Graph

In [4]:
import os
import torch
import json
import gc
from cogdl.data import Graph
from cogdl.datasets import NodeDataset, generate_random_graph
from cogdl.utils.graph_utils import to_undirected, remove_self_loops

def get_rel_map_graph(relationship_map, normalized_user_features, normalized_thread_features):
    """
    Create a graph representation from relationship map and features.
    
    Args:
        relationship_map (dict): Node relationships where keys are node IDs and values are dictionaries 
                                 with relationship types and their connected nodes (sets).
        normalized_user_features (DataFrame): Pre-normalized user feature data indexed by user ID.
        normalized_thread_features (DataFrame): Pre-normalized thread feature data indexed by thread ID.

    Returns:
        entity_index_map (dict): Mapping from original entity IDs to sequential node indices.
        Graph: A CogDL Graph object with features, labels, edges, and edge attributes.
    """

    print("Creating graph...")

    # Precompute sizes
    num_nodes = len(relationship_map)
    feature_size = max(normalized_user_features.shape[1], normalized_thread_features.shape[1]) + 1  # +1 for entity type
    num_edges = sum(len(rel_dict.get(rel, [])) for rel_dict in relationship_map.values() for rel in rel_dict if '_' not in rel)

    # Initialize tensors
    node_features = torch.full((num_nodes, feature_size), -1.0, dtype=torch.float)
    labels = torch.full((num_nodes,), -1, dtype=torch.long)
    edges = torch.empty((2, num_edges), dtype=torch.long)
    edge_features = torch.empty((num_edges,), dtype=torch.long)

    
    # Precompute entity index map. Our node IDs should be sequential ints.
    entity_index_map = {uid: idx for idx, uid in enumerate(relationship_map.keys())}

    # Populate tensors
    edge_idx = 0
    for node_idx, (uid1, rel_dict) in enumerate(relationship_map.items()):
        entity_type = entity_type_from_id(uid1)
        stripped_id = strip_id(uid1)

        # Assign features and labels based on entity type
        if entity_type == ENTITY_USER and stripped_id in normalized_user_features.index:
            feature_values = [entity_type] + normalized_user_features.loc[stripped_id].tolist()
        elif entity_type == ENTITY_POST and stripped_id in normalized_thread_features.index:
            feature_values = [entity_type] + normalized_thread_features.loc[stripped_id].tolist()
        else:
            feature_values = [-1] * feature_size  # Default values for missing data
        
        # Pad or truncate to match feature size
        feature_values = feature_values[:feature_size] + [-1] * (feature_size - len(feature_values))
        node_features[node_idx] = torch.tensor(feature_values, dtype=torch.float)
        labels[node_idx] = (
            1 if entity_type == ENTITY_USER and stripped_id in user_detail_data.index
                 and user_detail_data.loc[stripped_id, 'label'] == 'bot'
            else 0 if entity_type == ENTITY_USER and stripped_id in user_detail_data.index
                 and user_detail_data.loc[stripped_id, 'label'] == 'human'
            else -1
        )

        # Process edges
        for rel, uid_list in rel_dict.items():
            if '_' in rel:
                continue  # Skip backreferences
            valid_uids = [uid2 for uid2 in uid_list if uid2 in entity_index_map]
            for uid2 in valid_uids:
                edges[0, edge_idx] = node_idx
                edges[1, edge_idx] = entity_index_map[uid2]
                edge_features[edge_idx] = int(rel)
                edge_idx += 1

    # Trim unused edge entries
    edges = edges[:, :edge_idx]
    edge_features = edge_features[:edge_idx]

    print(f"Graph created. Nodes: {num_nodes}, Edges: {edge_idx}")
    return entity_index_map, Graph(
        x=node_features,
        y=labels,
        edge_index=edges,
        edge_attr=edge_features
    )

class LockstepRelDataset(NodeDataset):
    def __init__(
        self,
        relationship_map=None,
        normalized_user_features=None,
        normalized_thread_features=None,
        path=gnn_output_dataset_name,
        index_map_path="index_map.json",
        sample_limit=0,
        train_ratio=0.6,
        val_ratio=0.1,
        stratify=True,
        stratification_mode="undersample",  # "undersample" or "oversample"
    ):
        """
        A custom NodeDataset class for creating and managing graph datasets.

        Args:
            relationship_map (dict): A map defining the relationships between nodes.
            path (str): Path to save/load the graph data.
            index_map_path (str): Path to save/load the node index mapping.
            sample_limit (int): Limit the number of samples to process (0 for no limit).
            train_ratio (float): Proportion of nodes used for training (0 < train_ratio < 1).
            val_ratio (float): Proportion of nodes used for validation (0 < val_ratio < 1).
            stratify (bool): Whether to stratify the dataset based on class labels.
            stratification_mode (str): "undersample" to downsample the majority class, 
                                       "oversample" to upsample the minority class.
        """
        self.path = path
        self.index_map_path = index_map_path
        self.sample_limit = sample_limit
        self.relationship_map = relationship_map
        self.index_mapping = None
        self.data_cached = None
        self.normalized_thread_features = normalized_thread_features
        self.normalized_user_features = normalized_user_features
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self.stratify = stratify
        self.stratification_mode = stratification_mode

        if relationship_map is None and not os.path.exists(path):
            raise ValueError("Relationship map must be provided if data at the path does not exist.")

        if train_ratio + val_ratio >= 1.0:
            raise ValueError("The sum of train_ratio and val_ratio must be less than 1.0.")

        super(LockstepRelDataset, self).__init__(path, scale_feat=False, metric="accuracy")

    def stratify_indices(self, valid_indices, labels, ignore_labels=None):
        """
        Stratify the dataset by balancing class distributions, with an option to ignore specific labels.
    
        Args:
            valid_indices (Tensor): Indices of nodes with valid labels.
            labels (Tensor): Labels corresponding to the nodes.
            ignore_labels (list, optional): List of labels to ignore during stratification.
    
        Returns:
            Tensor: Stratified indices.
        """
        from collections import Counter
        from sklearn.utils import resample
    
        # Default to an empty list if no labels to ignore are specified
        if ignore_labels is None:
            ignore_labels = []
    
        # Filter out indices corresponding to ignored labels
        valid_mask = ~torch.isin(labels, torch.tensor(ignore_labels))
        valid_indices = valid_indices[valid_mask]
        labels = labels[valid_mask]
    
        # Group indices by class
        class_indices = {label.item(): valid_indices[labels == label] for label in torch.unique(labels)}
        min_class_size = min(len(indices) for indices in class_indices.values())
        max_class_size = max(len(indices) for indices in class_indices.values())
    
        # Apply stratification
        stratified_indices = []
        for label, indices in class_indices.items():
            if self.stratification_mode == "undersample":
                stratified_indices.append(indices[torch.randperm(len(indices))[:min_class_size]])
            elif self.stratification_mode == "oversample":
                extra_indices = indices[torch.randint(len(indices), (max_class_size - len(indices),))]
                stratified_indices.append(torch.cat([indices, extra_indices], dim=0))
            else:
                raise ValueError("Invalid stratification_mode. Choose 'undersample' or 'oversample'.")
    
        return torch.cat(stratified_indices, dim=0)
        
    def load_index_map(self):
        """
        Load the index mapping from the file specified in `index_map_path`.

        Returns:
            dict: The loaded index mapping.
        """
        if not os.path.exists(self.index_map_path):
            raise FileNotFoundError(f"Index map file not found at {self.index_map_path}")
        
        print(f"Loading index mapping from {self.index_map_path}...")
        with open(self.index_map_path, 'r') as f:
            self.index_mapping = json.load(f)
        print(f"Index mapping loaded successfully. Total entries: {len(self.index_mapping)}")

    def process(self):
        """
        Process the relationship map to generate graph data and index mapping.

        Returns:
            Graph: The processed graph object with node features, edges, and masks.
        """
        do_save = False

        # Load existing data if available
        if os.path.exists(self.path):
            print(f"Loading graph data from {self.path}")
            data = torch.load(self.path)
            self.load_index_map()
        else:
           
            # Create graph data from the relationship map
            if self.relationship_map is None:
                raise ValueError("Relationship map must be provided to generate graph data.")

            print("Processing relationship map to create graph...")
            do_save = True
            index_map, data = get_rel_map_graph(
                self.relationship_map,
                self.normalized_user_features,
                self.normalized_thread_features
            )
            self.index_mapping = index_map

            # Save the index mapping
            with open(self.index_map_path, 'w') as f:
                json.dump(self.index_mapping, f)
            print(f"Index mapping saved to {self.index_map_path}")

        # Cache data
        if self.data_cached is not None:
            print("Using cached data.")
            return self.data_cached

        print("Sampling graph...")
        # Masks for training, validation, and testing
        num_nodes = data.num_nodes
        labels = data.y
        valid_indices = (labels != -1).nonzero(as_tuple=True)[0]  # Nodes with valid labels

        if self.stratify:
            valid_indices = self.stratify_indices(valid_indices, labels[valid_indices], [-1])

        train_mask = torch.zeros(num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(num_nodes, dtype=torch.bool)

        # Random sampling of valid nodes
        valid_indices = valid_indices[torch.randperm(len(valid_indices))]  # Shuffle valid indices
        num_valid = len(valid_indices)

        train_size = int(self.train_ratio * num_valid)  # Training nodes
        val_size = int(self.val_ratio * num_valid)      # Validation nodes
        test_size = num_valid - train_size - val_size   # Test nodes

        train_mask[valid_indices[:train_size]] = True
        val_mask[valid_indices[train_size:train_size + val_size]] = True
        test_mask[valid_indices[train_size + val_size:]] = True

        # Assign masks to the graph data
        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask

        if do_save:
            print(f"Saving graph data to {self.path}")
            p = Path(self.path).parents[0]
            if not p.exists:
                os.mkdir(p)
            torch.save(data, self.path)

        self.data_cached = data
        gc.collect()
        return data


def print_graph_statistics(data):
    print("Graph Statistics:")
    print(f" - Number of nodes: {data.num_nodes}")
    num_edges = len(data.edge_index[0]) if isinstance(data.edge_index, tuple) else data.edge_index.shape[1]
    print(f" - Number of edges: {num_edges}")
    print(f" - Feature dimension: {data.x.shape[1]}")
    print(f" - Edge Attr dimension: {data.edge_attr.shape}")
    print(f" - Number of unique labels: {data.y.unique().numel()}")
    print(f" - Label distribution: {dict(zip(*torch.unique(data.y, return_counts=True)))}")
    print(f" - Training nodes: {data.train_mask.sum().item()}")
    print(f" - Validation nodes: {data.val_mask.sum().item()}")
    print(f" - Test nodes: {data.test_mask.sum().item()}")

def print_label_distribution(data):
    train_labels = data.y[data.train_mask]
    val_labels = data.y[data.val_mask]
    test_labels = data.y[data.test_mask]

    print("Train label distribution:", dict(zip(*torch.unique(train_labels, return_counts=True))))
    print("Validation label distribution:", dict(zip(*torch.unique(val_labels, return_counts=True))))
    print("Test label distribution:", dict(zip(*torch.unique(test_labels, return_counts=True))))

if gnn_data_load_from_file:
    print("Loading dataset from file...")
    try:
        dataset = LockstepRelDataset(path=gnn_output_dataset_name,
                                     index_map_path=gnn_output_map_name,
                                     train_ratio=0.6,
                                     val_ratio=0.2,
                                     stratify=True)
    except Exception as e:
        print(f"Failed to load from file...{e}")
else:
    dataset = LockstepRelDataset(relationship_map=gnn_rel_map,
                                 normalized_thread_features=normalized_thread_feature_data,
                                 normalized_user_features=normalized_details_for_graph,
                                 stratification_mode="oversample",
                                 train_ratio=0.6,
                                 val_ratio=0.2,
                                 stratify=True)

print_graph_statistics(dataset.data)
print_label_distribution(dataset.data)


Loading dataset from file...


  self.data = torch.load(path)


Graph Statistics:
 - Number of nodes: 10290934
 - Number of edges: 17037422
 - Feature dimension: 8
 - Edge Attr dimension: torch.Size([17037422])
 - Number of unique labels: 3
 - Label distribution: {tensor(-1): tensor(9747710), tensor(0): tensor(465201), tensor(1): tensor(78023)}
 - Training nodes: 355627
 - Validation nodes: 147787
 - Test nodes: 147873
Train label distribution: {tensor(0): tensor(279248), tensor(1): tensor(76379)}
Validation label distribution: {tensor(0): tensor(93002), tensor(1): tensor(54785)}
Test label distribution: {tensor(0): tensor(92951), tensor(1): tensor(54922)}


## STEP 6. Train GraphSAGE model

In [None]:
print(gnn_output_dataset_name)

In [None]:

if gnn_doing_experiment:
    try:
        # Run the experiment with GPU configuration
        results = experiment(
            dataset=dataset,
            model=gnn_model_type,
            **experiment_params,
            devices=[device]  # Pass device explicitly
        )

        # Save the model checkpoint to CPU for portability
        #model = results[1]  # Assuming the model is returned as the second item
        #torch.save(model.to("cpu").state_dict(), gnn_output_model_path)
        output_results(results)
        print("Experiment results:", results)
        
    except Exception as e:
        print("Problem encountered while training:")
        print(e)

In [6]:
import torch
from cogdl.models.nn.gcn import GCN
from cogdl.models.nn.dropedge_gcn import DropEdge_GCN
from cogdl.models.nn.graphsage import Graphsage  # Import GraphSAGE
import pandas as pd


torch.manual_seed(1)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1)

def strip_model_root(chkp):
    new_state_dict = {}
    # Clean up state_dict keys if needed
    for key, value in chkp.items():
        new_key = key.replace('model.', '')  # Remove the 'model.' prefix
        new_state_dict[new_key] = value
    return new_state_dict


# Model setup based on selection
if gnn_model_type == "gcn":   
    checkpoint = torch.load(gnn_output_checkpoint_name)
    print(checkpoint)
    
    
    new_state_dict = strip_model_root(checkpoint)

    model = GCN(in_feats=23, hidden_size=64, out_feats=2, num_layers=2, dropout=0.2)

    # Load model state
    model.load_state_dict(new_state_dict)

elif gnn_model_type == "dropedge_gcn":
    checkpoint = torch.load(gnn_output_checkpoint_name)

    model = DropEdge_GCN(
        nfeat=23,                # Input feature dimension
        nhid=24,                 # Hidden feature dimension
        nclass=2,                # Output feature dimension
        nhidlayer=1,             # Number of hidden blocks
        dropout=0.3,             # Dropout ratio
        baseblock="mutigcn",     # Baseblock type
        inputlayer="gcn",        # Input layer type
        outputlayer="gcn",       # Output layer type
        nbaselayer=1,            # Number of layers in one hidden block
        activation=torch.relu,   # Activation function
        withbn=False,            # Use batch normalization
        withloop=False,          # Use self-feature modeling
        aggrmethod="default"     # Aggregation function for baseblock
    )

    # Load model state
    model.load_state_dict(checkpoint)

elif gnn_model_type == "graphsage":

    if os.path.exists(gnn_output_checkpoint_name):
        checkpoint = torch.load(gnn_output_checkpoint_name)    
        checkpoint = strip_model_root(checkpoint)
    
    model = Graphsage(
        num_features=dataset.data.x.shape[1],       # Input feature dimension
        num_classes=2,   
        **graphsage_run_params
    )
    # Load model state
    model.load_state_dict(checkpoint)

device = torch.device("cpu")
print(f"Using device: {device}")

data = dataset.data.to(device)  # Ensure data is on the same device as the model
model.to(device)
model.eval()

ind_map = dataset.index_mapping

if ind_map is None:
    # needs process()
    dataset.process()
    ind_map = dataset.index_mapping
    if ind_map is None:
        # actual problem
        raise StopExecution
node_to_entity_mapping = {idx: entity_id for entity_id, idx in ind_map.items()}

# Forward pass through the model
with torch.no_grad():
    output = model(data)
    predicted_labels = output.argmax(dim=1)  # Get the class with the highest score for each node
    probabilities = torch.softmax(output, dim=1)
    
    # Filter out nodes with label -1 (thread nodes)
    valid_node_indices = (data.y != -1).nonzero(as_tuple=True)[0]
    predictions_df = pd.DataFrame({
        'Node ID': valid_node_indices.cpu().numpy(),
        'Entity ID': [node_to_entity_mapping[i.item()] for i in valid_node_indices],
        'Predicted Label': predicted_labels[valid_node_indices].cpu().numpy(),
        'Probability Class 0': probabilities[valid_node_indices, 0].cpu().numpy(),
        'Probability Class 1': probabilities[valid_node_indices, 1].cpu().numpy()
    })
    
    # Output to a CSV file
    predictions_df.to_csv(gnn_output_predictions_name, index=False)
    print(f"Filtered predictions saved to '{gnn_output_predictions_name}'")
 

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import numpy as np

true_labels = data.y

# Example: Calculate accuracy on the test set
correct = (predicted_labels[data.test_mask] == true_labels[data.test_mask]).sum().item()
total = data.test_mask.sum().item()
accuracy = correct / total if total != 0 else 0

# Print the accuracy on the test set
print(f"Test Accuracy: {accuracy:.4f}")


train_nodes = data.train_mask.nonzero(as_tuple=True)[0]  # Get indices of training nodes

# Example: print the first 5 training nodes with their features and labels
for idx in train_nodes[:5]:
    feature = data.x[idx]  # Features of the node
    label = data.y[idx]    # Label of the node
    print(f"Node ID: {idx.item()} | Feature: {feature} | Label: {label.item()}")


import numpy as np

# Convert features and labels to numpy arrays for correlation
features_np = data.x.cpu().numpy()  # Move features to CPU if on GPU
labels_np = data.y.cpu().numpy()    # Similarly move labels to CPU if needed

# Calculate correlation of each feature with the label
correlations = [np.corrcoef(features_np[:, i], labels_np)[0, 1] for i in range(features_np.shape[1])]

# Print correlations
for i, corr in enumerate(correlations):
    print(f"Feature {i} - Correlation with Label: {corr:.4f}")
    

  checkpoint = torch.load(gnn_output_checkpoint_name)
  data = torch.load(self.path)


Using device: cpu
Loading graph data from /dataset/twibot22/generated_data/gs_output/gnn_graph_data.pt
Loading index mapping from /dataset/twibot22/generated_data/gs_output/index_map.json...
Index mapping loaded successfully. Total entries: 10290934
Sampling graph...


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


Filtered predictions saved to '/dataset/twibot22/generated_data/gs_output/gs_predictions.csv'
Test Accuracy: 0.6745
Node ID: 0 | Feature: tensor([0.0000e+00, 6.7145e-06, 1.6582e-04, 1.4090e-04, 1.6425e-04, 7.1375e-07,
        4.7634e-01, 2.1228e-01]) | Label: 1
Node ID: 2 | Feature: tensor([0.0000e+00, 3.5892e-05, 2.0812e-04, 3.1887e-04, 4.2004e-05, 1.4548e-07,
        4.0063e-01, 2.3310e-01]) | Label: 0
Node ID: 3 | Feature: tensor([0.0000e+00, 1.2587e-05, 9.2285e-05, 7.9718e-05, 6.5434e-05, 5.1034e-07,
        2.0820e-01, 2.2603e-01]) | Label: 0
Node ID: 4 | Feature: tensor([0.0000e+00, 5.7062e-05, 3.1002e-05, 2.1505e-04, 5.5107e-06, 1.2729e-07,
        4.9211e-01, 3.6093e-02]) | Label: 1
Node ID: 6 | Feature: tensor([0.0000e+00, 1.1607e-05, 1.1632e-04, 4.0786e-05, 1.2865e-05, 7.9649e-08,
        2.9968e-01, 9.0938e-02]) | Label: 0
Feature 0 - Correlation with Label: -0.9366
Feature 1 - Correlation with Label: -0.0468
Feature 2 - Correlation with Label: -0.0462
Feature 3 - Correlatio

## End Training Part A.