In [15]:
import seaborn as sns
from sklearn.preprocessing import StandardScaler
import spacy
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import GATConv, GATv2Conv, HeteroConv
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import GloVe
import torch.nn as nn
import torch.nn.functional as F
import re
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# PARAMETERS
#-------------------------
DATA = r'CSV_Files\sydneysiege.csv'
TEXT = 'text'
PLOT_COLOUR = 'label'
NUM_POINTS = 3000
EPOCHS = 200
INCLUDE_STOPWORDS = True
REMOVE_ATS = False
EMBEDDING_DIM = 100
LR = 0.0001
WEIGHT_DECAY = 5e-4
EMBEDDING_METHOD = "Glove" # Options: "Glove", "BERT", "BERT_Sentence"
#-------------------------

# Load Venv in terminal: source .venv/Scripts/activate
Data_Part = pd.read_csv(DATA)
df = Data_Part.copy()


# Convert 'Created_At' column to datetime
df['created_at'] = pd.to_datetime(df['created_at'], errors='coerce')  # Convert to datetime, handle any conversion errors
# Create a new column for quarterly periods
df['Quarter'] = df['created_at'].dt.to_period('Q')  # Converts to quarterly periods (e.g., 2015Q1, 2015Q2)

df.head()


`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884


Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.


Converting to PeriodArray/Index representation will drop timezone information.



Unnamed: 0,tweet_id,user_id,user_screen_name,user_followers_count,user_friends_count,text,created_at,retweet_count,favorite_count,lang,geo,hashtags,mentions,urls,label,reply_to_tweet_id,Quarter
0,544267656597995521,258627226,KristyMayr7,1672,731,BREAKING: Hostages are being forced to hold an...,2014-12-14 23:08:15+00:00,445,54,en,,7NEWS,,,1,,2014Q4
1,544269152198721536,443573208,ottomanscribe,1192,607,@KristyMayr7 that is not an IS flag specifically.,2014-12-14 23:14:12+00:00,1,1,en,,,KristyMayr7,,1,5.442677e+17,2014Q4
2,544269478406529024,2700238212,tlcrosemiller11,42,234,@KristyMayr7 omg I'm watching it on @morningsh...,2014-12-14 23:15:30+00:00,0,0,en,,,"KristyMayr7,morningshowon7",,1,5.442677e+17,2014Q4
3,544270016770633728,760246262,tihrigby,147,846,"@KristyMayr7 @DeepPolitics dudes, not the flag...",2014-12-14 23:17:38+00:00,0,1,en,,,"KristyMayr7,DeepPolitics",http://www.pri.org/sites/default/files/ISISfla...,1,5.442677e+17,2014Q4
4,544270351619067904,2670053628,SloaneSW7,299,614,@KristyMayr7 @bluebuzzbird How pathetic! These...,2014-12-14 23:18:58+00:00,0,0,en,,,"KristyMayr7,bluebuzzbird",,1,5.442677e+17,2014Q4


In [16]:
class TextPreprocessor:
    def __init__(self, target_column=TEXT, contraction_dict=None):
        self.target_column = target_column
        self.tokenizer = get_tokenizer("basic_english")
        
        # Initialize spaCy model
        self.nlp = spacy.load('en_core_web_sm')
        default_stop_words = self.nlp.Defaults.stop_words
        
        important_words = {'not', 'no', 'never', 'none', 'nobody', 'nothing', 'nowhere', 'neither', 'nor', 'cannot', 'without', 'hardly', 'barely'}
        self.stop_words = default_stop_words - important_words

        # Dictionary has been taken for UWA NLP lab 5
        self.contraction_dict = contraction_dict or  {
            "ain't": "is not", "aren't": "are not", "can't": "cannot", "'cause": "because",
            "could've": "could have", "couldn't": "could not", "didn't": "did not", "doesn't": "does not",
            "don't": "do not", "hadn't": "had not", "hasn't": "has not", "haven't": "have not",
            "he'd": "he would", "he'll": "he will", "he's": "he is", "how'd": "how did",
            "how'd'y": "how do you", "how'll": "how will", "how's": "how is", "I'd": "I would",
            "I'd've": "I would have", "I'll": "I will", "I'll've": "I will have", "I'm": "I am",
            "I've": "I have", "i'd": "i would", "i'd've": "i would have", "i'll": "i will",
            "i'll've": "i will have", "i'm": "i am", "i've": "i have", "isn't": "is not",
            "it'd": "it would", "it'd've": "it would have", "it'll": "it will", "it'll've": "it will have",
            "it's": "it is", "let's": "let us", "ma'am": "madam", "mayn't": "may not",
            "might've": "might have", "mightn't": "might not", "mightn't've": "might not have",
            "must've": "must have", "mustn't": "must not", "mustn't've": "must not have",
            "needn't": "need not", "needn't've": "need not have", "o'clock": "of the clock",
            "oughtn't": "ought not", "oughtn't've": "ought not have", "shan't": "shall not",
            "sha'n't": "shall not", "shan't've": "shall not have", "she'd": "she would",
            "she'd've": "she would have", "she'll": "she will", "she'll've": "she will have",
            "she's": "she is", "should've": "should have", "shouldn't": "should not",
            "shouldn't've": "should not have", "so've": "so have", "so's": "so as",
            "this's": "this is", "that'd": "that would", "that'd've": "that would have",
            "that's": "that is", "there'd": "there would", "there'd've": "there would have",
            "there's": "there is", "here's": "here is", "they'd": "they would",
            "they'd've": "they would have", "they'll": "they will", "they'll've": "they will have",
            "they're": "they are", "they've": "they have", "to've": "to have", "wasn't": "was not",
            "we'd": "we would", "we'd've": "we would have", "we'll": "we will", "we'll've": "we will have",
            "we're": "we are", "we've": "we have", "weren't": "were not", "what'll": "what will",
            "what'll've": "what will have", "what're": "what are", "what's": "what is",
            "what've": "what have", "when's": "when is", "when've": "when have", "where'd": "where did",
            "where's": "where is", "where've": "where have", "who'll": "who will",
            "who'll've": "who will have", "who's": "who is", "who've": "who have", "why's": "why is",
            "why've": "why have", "will've": "will have", "won't": "will not", "won't've": "will not have",
            "would've": "would have", "wouldn't": "would not", "wouldn't've": "would not have",
            "y'all": "you all", "y'all'd": "you all would", "y'all'd've": "you all would have",
            "y'all're": "you all are", "y'all've": "you all have", "you'd": "you would",
            "you'd've": "you would have", "you'll": "you will", "you'll've": "you will have",
            "you're": "you are", "you've": "you have"
        }
        
        self.re_pattern = re.compile(r'[^\w\s]')  # Compile the regex once.
        self.at_pattern = re.compile(r'@\S+')


    def clean_text(self, df):
        # Regex Pattern and Lowercase
        df[self.target_column] = df[self.target_column].str.lower()
        if not REMOVE_ATS:
            df[self.target_column] = df[self.target_column].str.replace(self.at_pattern, '', regex=True)

        if self.contraction_dict:
            for word, new_word in self.contraction_dict.items():
                df[self.target_column] = df[self.target_column].str.replace(word, new_word, regex=False)
        
        # Regex for Punctuation Removal
        df[self.target_column] = df[self.target_column].str.replace(self.re_pattern, '', regex=True)
        
        # Remove stopwords using spaCy
        if INCLUDE_STOPWORDS:
            df['cleaned_text'] = df[self.target_column].apply(lambda x: ' '.join(word for word in x.split() if word not in self.stop_words))
        else:
            df['cleaned_text'] = df[self.target_column].apply(lambda x: ' '.join(word for word in x.split()))
        
        # Tokenize cleaned texts
        df['tokenized_text'] = df['cleaned_text'].apply(lambda text: self.tokenizer(text))
        return df['tokenized_text']

text_preprocessor = TextPreprocessor()
df[TEXT] = text_preprocessor.clean_text(df.copy())
df = df[df[TEXT].apply(lambda tokens: all(len(word) <= 100 for word in tokens))]
df.reset_index(drop=True, inplace=True)

In [17]:
from gensim.models import Word2Vec
from torchtext.vocab import GloVe
import matplotlib.pyplot as plt
import nbformat

def plot_updated_embeddings(df_embeddings, color_column='label', n=1000, name='gat'):
    import plotly.graph_objs as go
    import plotly.io as pio
    from sklearn.manifold import TSNE
    import matplotlib.pyplot as plt
    import numpy as np
    import time

    pio.renderers.default = 'browser'

    start_time = time.time()

    # Sample 'n' observations
    df_sampled = df_embeddings.sample(n=n, random_state=42)

    # Extract the embeddings and convert to a NumPy array
    embedding_matrix = np.array(df_sampled['embedding'].tolist())

    # Reduce the dimensionality to 3D using t-SNE
    tsne = TSNE(n_components=3, perplexity=30, n_iter=300, random_state=42)
    reduced_embeddings_tsne = tsne.fit_transform(embedding_matrix)


    #---------------------------------------------------------------

    # Compute the mean and standard deviation of the reduced embeddings
    mean = np.mean(reduced_embeddings_tsne, axis=0)
    std = np.std(reduced_embeddings_tsne, axis=0)

    # Compute the z-score for each point
    z_scores = np.abs((reduced_embeddings_tsne - mean) / std)

    # Set a threshold for outliers (e.g., z-score > 3)
    threshold = 3  # adjust as needed
    mask = (z_scores < threshold).all(axis=1)

    # Filter the embeddings and other arrays
    filtered_embeddings = reduced_embeddings_tsne[mask]
    df_sampled_filtered = df_sampled[mask].reset_index(drop=True)

    #---------------------------------------------------------------


    # Prepare colors based on the specified column
    unique_values = df_sampled_filtered[color_column].unique()
    color_map = plt.cm.get_cmap('tab10', len(unique_values))

    # Create a mapping from unique values to colors
    color_dict = {value: color_map(i) for i, value in enumerate(unique_values)}

    # Apply the mapping to the DataFrame column
    colors = df_sampled_filtered[color_column].map(color_dict)

    # Convert colors to a list
    colors_list = colors.tolist()

    # Create hover text for each point
    hover_text = df_sampled_filtered.apply(
        lambda row: f"{color_column}: {row[color_column]}, Text: {' '.join(row[TEXT])}",
        axis=1
    )

    # Create a trace for the 3D scatter plot
    trace = go.Scatter3d(
        x=filtered_embeddings[:, 0],
        y=filtered_embeddings[:, 1],
        z=filtered_embeddings[:, 2],
        mode='markers',
        marker=dict(
            size=3,
            color=colors_list,  # Color based on the specified column
            opacity=0.7
        ),
        text=hover_text.tolist()  # Hover text
    )

    # Set up the layout
    layout = go.Layout(
        title=f't-SNE - {name} Embeddings colored by {color_column}',
        scene=dict(
            xaxis_title='Component 1',
            yaxis_title='Component 2',
            zaxis_title='Component 3'
        ),
        margin=dict(l=0, r=0, b=0, t=40),
    )

    # Create the figure
    fig = go.Figure(data=[trace], layout=layout)

    # Show the plot
    pio.show(fig)

    # Save as HTML file
    fig.write_html(f"interactive_plot_{name}_{color_column}.html")

    # End Timer with note after
    end_time = time.time()
    print(f"Time taken for {n} points: {end_time - start_time:.2f} seconds")

In [18]:
# Load pre-trained GloVe embeddings
if EMBEDDING_METHOD == "Glove":
    glove = GloVe(name='6B', dim=EMBEDDING_DIM, cache=r'C:\Users\matth\OneDrive\Desktop\1. DATA SCIENCE MASTER\Research_CITS5014\Sentence Embeddings\.vector_cache')

def get_average_embedding(tokens, glove, word2vec_model=None):
    embeddings = []
    for token in tokens:
        if token in glove.stoi:
            embeddings.append(glove[token].numpy())  # Use GloVe if available
        elif word2vec_model and token in word2vec_model.wv:
            embeddings.append(word2vec_model.wv[token])  # Use Word2Vec as a fallback
    if len(embeddings) > 0:
        return np.mean(embeddings, axis=0)
    else:
        return np.zeros(glove.dim)  # Return a zero vector if no tokens are found
    

def get_bert_embedding(tokens):
    # Join tokens back to text
    text = ' '.join(tokens)
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=EMBEDDING_DIM)
    outputs = bert_model(**inputs)
    # Use the [CLS] token embedding
    cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().detach().numpy()
    return cls_embedding

def get_bert_sentence_embedding(tokens):
    # Join tokens back to text
    text = ' '.join(tokens)
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=EMBEDDING_DIM)
    outputs = bert_model(**inputs)
    # Use the [CLS] token embedding
    cls_embedding = outputs.last_hidden_state.squeeze().detach().numpy()
    return cls_embedding

## Generate Text Embeddings

In [19]:
# rename text column to 'tokenized_text'

if EMBEDDING_METHOD == "BERT":
    df['embedding'] = df['text'].apply(get_bert_embedding)
elif EMBEDDING_METHOD == "Glove":
    df['embedding'] = df['text'].apply(lambda tokens: get_average_embedding(tokens, glove))
elif EMBEDDING_METHOD == "BERT_Sentence":
    df['embedding'] = df['text'].apply(get_bert_sentence_embedding)


df.shape

(23996, 18)

## Filter Out Tweets with 'Unknown' Labels

In [20]:

# Convert label column to integer type
df['label'] = df['label'].astype(int)
print(df['label'].value_counts())
plot_updated_embeddings(df, color_column=PLOT_COLOUR, n=NUM_POINTS, name='before_gat')

label
0    15320
1     8676
Name: count, dtype: int64



'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.



Time taken for 3000 points: 3.31 seconds


## Map Tweet IDs and User IDs to Node Indices

In [21]:
# Create a mapping from tweet IDs to indices
tweet_id_to_idx = {tweet_id: idx for idx, tweet_id in enumerate(df['tweet_id'].unique())}
df['tweet_idx'] = df['tweet_id'].map(tweet_id_to_idx)

# Create a mapping from user IDs to indices
user_id_to_idx = {user_id: idx for idx, user_id in enumerate(df['user_id'].unique())}
df['user_idx'] = df['user_id'].map(user_id_to_idx)


## Build the Heterogeneous Graph

In [22]:
# Initialize a HeteroData object
data = HeteroData()

# Add tweet nodes with their embeddings
data['tweet'].x = torch.tensor(np.stack(df['embedding'].values), dtype=torch.float)

# Add user nodes (initialize with zero features)
num_users = len(user_id_to_idx)
data['user'].num_nodes = num_users
data['user'].x = torch.zeros(num_users, data['tweet'].x.size(1))

# Add 'writes' edges from users to tweets
data['user', 'writes', 'tweet'].edge_index = torch.tensor([
    df['user_idx'].values,
    df['tweet_idx'].values
], dtype=torch.long)

# Prepare the 'reply_to' edges
reply_df = df[df['reply_to_tweet_id'].notnull()].copy()
reply_df['reply_to_tweet_idx'] = reply_df['reply_to_tweet_id'].map(tweet_id_to_idx)
reply_df = reply_df[reply_df['reply_to_tweet_idx'].notnull()].astype({'reply_to_tweet_idx': int})

# Add 'reply_to' edges from tweets to tweets
data['tweet', 'reply_to', 'tweet'].edge_index = torch.tensor([
    reply_df['tweet_idx'].values,
    reply_df['reply_to_tweet_idx'].values
], dtype=torch.long)


## Plotly for heterogeneous graph

In [23]:
import networkx as nx
import plotly.graph_objs as go
from plotly.offline import plot

def visualize_original_graph_interactive(data, edge_type=('user', 'writes', 'tweet'), node_type='tweet'):
    """
    Visualize the original graph interactively using Plotly with edge thickness and node sizes.

    Args:
        data (HeteroData): The heterogeneous graph data.
        edge_type (tuple): The edge type to visualize.
        node_type (str): The node type to set sizes.
    """
    # Create a NetworkX graph
    G = nx.DiGraph()

    # Add nodes
    tweet_indices = data['tweet'].num_nodes
    user_indices = data['user'].num_nodes
    # Assign unique identifiers for nodes to distinguish user and tweet nodes
    G.add_nodes_from([f"user_{i}" for i in range(user_indices)], type='user')
    G.add_nodes_from([f"tweet_{i}" for i in range(tweet_indices)], type='tweet')

    # Add edges
    src, dst = data[edge_type].edge_index
    edges = [(f"user_{s.item()}", f"tweet_{d.item()}") for s, d in zip(src, dst)]
    G.add_edges_from(edges)

    # Compute edge widths (uniform if no weights are available)
    edge_widths = [1 for _ in G.edges()]

    # Compute node sizes based on degree (as a proxy for importance)
    degree_dict = dict(G.degree())
    node_sizes = [degree_dict[node] * 10 for node in G.nodes()]  # Adjust scaling as needed

    # Define layout using spring layout for better visualization
    pos = nx.spring_layout(G, k=0.15, iterations=20, seed=42)

    # Extract node positions
    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

    # Extract edge positions
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    # Create edge trace
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines'
    )

    # Create node trace
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            reversescale=True,
            color=[],  # To be filled with node degrees
            size=node_sizes,
            colorbar=dict(
                thickness=15,
                title='Node Degree',
                xanchor='left',
                titleside='right'
            ),
            line_width=2
        )
    )

    # Assign colors based on node degree
    node_adjacencies = [degree_dict[node] for node in G.nodes()]
    node_trace.marker.color = node_adjacencies

    # Create hover text
    node_text = []
    for node in G.nodes():
        node_type_str, node_id = node.split('_')
        text = f"Type: {node_type_str}<br>ID: {node_id}<br>Degree: {degree_dict[node]}"
        node_text.append(text)
    node_trace.text = node_text

    # Create the figure
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='<br>Original Graph',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        annotations=[ dict(
                            text="Interactive Network Graph",
                            showarrow=False,
                            xref="paper", yref="paper") ],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                   )

    # Optionally, save the plot as an HTML file
    plot(fig, filename='original_graph_interactive.html')
    # Display the figure
    fig.show()

visualize_original_graph_interactive(data)

KeyboardInterrupt: 

## Assign Labels and Create Data Splits

In [24]:
# Assign labels to tweet nodes
data['tweet'].y = torch.tensor(df['label'].values, dtype=torch.long)

# Split indices for training, validation, and testing
train_idx, test_idx = train_test_split(
    df.index, test_size=0.2, random_state=42, stratify=df['label']
)
train_idx, val_idx = train_test_split(
    train_idx, test_size=0.25, random_state=42, stratify=df.loc[train_idx, 'label']
)

# Create boolean masks for the splits
def create_mask(idx, size):
    mask = torch.zeros(size, dtype=torch.bool)
    mask[idx] = True
    return mask

data['tweet'].train_mask = create_mask(train_idx, len(df))
data['tweet'].val_mask = create_mask(val_idx, len(df))
data['tweet'].test_mask = create_mask(test_idx, len(df))


## Define the Heterogeneous GAT Model

In [33]:
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GATv2Conv, HeteroConv
import torch.nn as nn

class HeteroGATWithAttention(nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, heads=2, conv_type='GAT'):
        """
        Initializes the HeteroGAT model with attention weight storage.

        Args:
            metadata: Metadata for the heterogeneous graph.
            hidden_channels (int): Number of hidden units.
            out_channels (int): Number of output units.
            heads (int, optional): Number of attention heads. Defaults to 2.
            conv_type (str, optional): Type of GAT convolution ('GAT' or 'GATv2'). Defaults to 'GAT'.
        """
        super(HeteroGATWithAttention, self).__init__()
        self.hidden_channels = hidden_channels
        self.heads = heads
        self.conv_type = conv_type

        # Select the convolution class based on conv_type
        if conv_type == 'GAT':
            ConvLayer = GATConv
        elif conv_type == 'GATv2':
            ConvLayer = GATv2Conv
        else:
            raise ValueError(f"Unsupported conv_type '{conv_type}'. Use 'GAT' or 'GATv2'.")

        # Define GAT layers for each relation type using the selected ConvLayer
        self.conv1 = HeteroConv({
            ('tweet', 'reply_to', 'tweet'): ConvLayer((-1, -1), hidden_channels, heads=heads, add_self_loops=True),
            ('user', 'writes', 'tweet'): ConvLayer((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('tweet', 'reply_to', 'tweet'): ConvLayer((-1, -1), hidden_channels, heads=heads, add_self_loops=True),
            ('user', 'writes', 'tweet'): ConvLayer((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
        }, aggr='sum')

        # Linear layer for output (adjusted input dimension)
        self.lin = nn.Linear(hidden_channels * heads, out_channels)

        # To store attention weights
        self.attention_weights = {}

    def forward(self, x_dict, edge_index_dict):
        """
        Forward pass of the HeteroGAT model. Stores attention weights.

        Args:
            x_dict (dict): Node feature dictionary.
            edge_index_dict (dict): Edge index dictionary.

        Returns:
            torch.Tensor: Output for 'tweet' nodes.
        """
        # First GAT layer
        x_dict_updated, attn_weights_1 = self.conv1(x_dict, edge_index_dict)
        # Store attention weights
        self.attention_weights['layer1'] = attn_weights_1

        # Merge updated features with original features
        x_dict = {
            key: F.elu(x_dict_updated[key][0].view(x_dict_updated[key][0].size(0), -1)) if key in x_dict_updated else x_dict[key]
            for key in x_dict.keys()
        }

        # Second GAT layer
        x_dict_updated, attn_weights_2 = self.conv2(x_dict, edge_index_dict)
        # Store attention weights
        self.attention_weights['layer2'] = attn_weights_2

        x_dict = {
            key: F.elu(x_dict_updated[key][0].view(x_dict_updated[key][0].size(0), -1)) if key in x_dict_updated else x_dict[key]
            for key in x_dict.keys()
        }

        # Output layer for tweet nodes
        out = self.lin(x_dict['tweet'])
        return out


## Initialize the Model, Optimizer, and Loss Function

In [34]:
#GAT
#---------------------------------------------------------------
# Get metadata from the data object
metadata = data.metadata()

# Instantiate the model
model = HeteroGATWithAttention(metadata, hidden_channels=64, out_channels=2, conv_type='GAT')

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

# Define the loss function
criterion = nn.CrossEntropyLoss()
#---------------------------------------------------------------

#GATV2
#---------------------------------------------------------------
# Get metadata from the data object
metadata_V2 = data.metadata()

# Instantiate the model
model_V2 = HeteroGATWithAttention(metadata, hidden_channels=64, out_channels=2, conv_type='GATv2')

# Define the optimizer
optimizer_V2 = torch.optim.Adam(model_V2.parameters(),  lr=LR, weight_decay=WEIGHT_DECAY)

# Define the loss function
criterion_V2 = nn.CrossEntropyLoss()
#---------------------------------------------------------------


There exist node types ({'user'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behavior.



## Define Training and Evaluation Functions 

In [35]:
# GAT
#--------------------------------------------------------------- 
# Training function
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    out = out[data['tweet'].train_mask]
    y = data['tweet'].y[data['tweet'].train_mask]
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()
    return loss.item()

# Evaluation function
@torch.no_grad()
def evaluate(mask):
    model.eval()
    out = model(data.x_dict, data.edge_index_dict)
    pred = out.argmax(dim=1)
    correct = pred[mask] == data['tweet'].y[mask]
    accuracy = correct.sum().item() / mask.sum().item()
    return accuracy
#---------------------------------------------------------------

# GATV2
#---------------------------------------------------------------
# Training function
def train_V2():
    model_V2.train()
    optimizer_V2.zero_grad()
    out = model_V2(data.x_dict, data.edge_index_dict)
    out = out[data['tweet'].train_mask]
    y = data['tweet'].y[data['tweet'].train_mask]
    loss = criterion_V2(out, y)
    loss.backward()
    optimizer_V2.step()
    return loss.item()

# Evaluation function
@torch.no_grad()
def evaluate_V2(mask):
    model_V2.eval()
    out = model_V2(data.x_dict, data.edge_index_dict)
    pred = out.argmax(dim=1)
    correct = pred[mask] == data['tweet'].y[mask]
    accuracy = correct.sum().item() / mask.sum().item()
    return accuracy
#---------------------------------------------------------------


## Train the Model

In [36]:
import matplotlib.pyplot as plt

# Training and Plotting

# ---------------------------------------------------------------
# Define the number of epochs
epochs = EPOCHS

# Initialize lists to store accuracy values for GAT
train_acc_list = []
val_acc_list = []

# Initialize lists to store accuracy values for GATv2
train_acc_list_V2 = []
val_acc_list_V2 = []

# Training Loop for GAT
print("Training GAT Model:")
for epoch in range(1, epochs + 1):
    loss = train()
    train_acc = evaluate(data['tweet'].train_mask)
    val_acc = evaluate(data['tweet'].val_mask)
    
    # Append accuracies for plotting later
    train_acc_list.append(train_acc)
    val_acc_list.append(val_acc)
    
    print(f'GAT Epoch {epoch:02d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')

# Training Loop for GATv2
print("\nTraining GATv2 Model:")
for epoch in range(1, epochs + 1):
    loss = train_V2()
    train_acc = evaluate_V2(data['tweet'].train_mask)
    val_acc = evaluate_V2(data['tweet'].val_mask)
    
    # Append accuracies for plotting later
    train_acc_list_V2.append(train_acc)
    val_acc_list_V2.append(val_acc)
    
    print(f'GATv2 Epoch {epoch:02d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')

# ---------------------------------------------------------------
# Plotting the Accuracies Side by Side
# ---------------------------------------------------------------
# Create a figure with two subplots side by side
fig, axes = plt.subplots(1, 2, figsize=(20, 6), sharey=True)

# -----------------------
# Plot GAT Accuracies
# -----------------------
axes[0].plot(range(1, epochs + 1), train_acc_list, label='Train Accuracy', color='blue')
axes[0].plot(range(1, epochs + 1), val_acc_list, label='Validation Accuracy', linestyle='--', color='orange')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('GAT: Train and Validation Accuracy over Epochs')
axes[0].legend()
axes[0].grid(True)

# -----------------------
# Plot GATv2 Accuracies
# -----------------------
axes[1].plot(range(1, epochs + 1), train_acc_list_V2, label='Train Accuracy', color='green')
axes[1].plot(range(1, epochs + 1), val_acc_list_V2, label='Validation Accuracy', linestyle='--', color='red')
axes[1].set_xlabel('Epoch')
# Y-label is shared; no need to set for the second subplot
axes[1].set_title('GATv2: Train and Validation Accuracy over Epochs')
axes[1].legend()
axes[1].grid(True)

# Adjust layout for better spacing
plt.tight_layout()

# Display the plots
plt.show()

Training GAT Model:


ValueError: not enough values to unpack (expected 2, got 1)

## Retrieve embeddings

In [141]:
#GAT
#---------------------------------------------------------------
@torch.no_grad()
def get_tweet_embeddings(model, data):
    model.eval()
    x_dict = data.x_dict
    edge_index_dict = data.edge_index_dict

    # First GAT layer
    x_dict_updated = model.conv1(x_dict, edge_index_dict)
    x_dict = {
        key: F.elu(x_dict_updated[key]) if key in x_dict_updated else x_dict[key]
        for key in x_dict.keys()
    }

    # Second GAT layer
    x_dict_updated = model.conv2(x_dict, edge_index_dict)
    x_dict = {
        key: F.elu(x_dict_updated[key]) if key in x_dict_updated else x_dict[key]
        for key in x_dict.keys()
    }

    embeddings = x_dict['tweet']
    return embeddings


# Step 1: Extract embeddings
embeddings = get_tweet_embeddings(model, data)
embeddings_np = embeddings.cpu().numpy()

# Step 2: Prepare DataFrame with embeddings
df_embeddings = df.copy()
df_embeddings['embedding'] = list(embeddings_np)

# Step 3: Plot the embeddings
plot_updated_embeddings(df_embeddings, color_column=PLOT_COLOUR, n=NUM_POINTS,name='after_gat')
#---------------------------------------------------------------

#GATV2
#---------------------------------------------------------------
@torch.no_grad()
def get_tweet_embeddings_V2(model, data):
    model.eval()
    x_dict = data.x_dict
    edge_index_dict = data.edge_index_dict

    # First GAT layer
    x_dict_updated = model.conv1(x_dict, edge_index_dict)
    x_dict = {
        key: F.elu(x_dict_updated[key]) if key in x_dict_updated else x_dict[key]
        for key in x_dict.keys()
    }

    # Second GAT layer
    x_dict_updated = model.conv2(x_dict, edge_index_dict)
    x_dict = {
        key: F.elu(x_dict_updated[key]) if key in x_dict_updated else x_dict[key]
        for key in x_dict.keys()
    }

    embeddings = x_dict['tweet']
    return embeddings


# Step 1: Extract embeddings
embeddings_V2 = get_tweet_embeddings_V2(model_V2, data)
embeddings_np_V2 = embeddings_V2.cpu().numpy()

# Step 2: Prepare DataFrame with embeddings
df_embeddings_V2 = df.copy()
df_embeddings_V2['embedding'] = list(embeddings_np_V2)

# Step 3: Plot the embeddings
plot_updated_embeddings(df_embeddings_V2, color_column=PLOT_COLOUR, n=NUM_POINTS,name='after_gatv2')
#---------------------------------------------------------------




'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.



Time taken for 800 points: 0.82 seconds



'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.



Time taken for 800 points: 1.00 seconds


In [143]:
visualize_original_graph_interactive(data)

In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GATv2Conv, HeteroConv
import torch.nn as nn
from torch_geometric.utils import remove_self_loops, to_networkx
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import plotly.graph_objs as go
from plotly.offline import plot
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict

# ============================
# Your Existing Model Code
# ============================

# Define GATConvWithAttention
class GATConvWithAttention(GATConv):
    def __init__(self, in_channels, out_channels, heads=1, concat=True, **kwargs):
        super(GATConvWithAttention, self).__init__(
            in_channels, out_channels, heads=heads, concat=concat, **kwargs
        )
        self.attentions = []
        self.dst_nodes = []

    def message(self, **kwargs):
        # Extract necessary components from kwargs
        alpha = kwargs.get('alpha')  # Attention coefficients
        edge_index_i = kwargs.get('edge_index_i')  # Destination node indices

        if alpha is not None and edge_index_i is not None:
            self.attentions.append(alpha)
            self.dst_nodes.append(edge_index_i)

        # Call the parent class's message method
        return super(GATConvWithAttention, self).message(**kwargs)

# Define HeteroGATWithAttention
class HeteroGATWithAttention(nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, heads=2, conv_type='GAT'):
        super(HeteroGATWithAttention, self).__init__()
        self.hidden_channels = hidden_channels
        self.heads = heads
        self.conv_type = conv_type

        # Select the convolution class based on conv_type
        if conv_type == 'GAT':
            ConvLayer = GATConvWithAttention  # Use the custom GATConvWithAttention
        elif conv_type == 'GATv2':
            ConvLayer = GATv2Conv  # Placeholder: Implement a similar custom class if needed
        else:
            raise ValueError(f"Unsupported conv_type '{conv_type}'. Use 'GAT' or 'GATv2'.")

        # Define GAT layers for each relation type using the selected ConvLayer
        self.conv1 = HeteroConv({
            ('tweet', 'reply_to', 'tweet'): ConvLayer(
                (-1, -1), hidden_channels, heads=heads, add_self_loops=True
            ),
            ('user', 'writes', 'tweet'): ConvLayer(
                (-1, -1), hidden_channels, heads=heads, add_self_loops=False
            ),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('tweet', 'reply_to', 'tweet'): ConvLayer(
                (-1, -1), hidden_channels, heads=heads, add_self_loops=True
            ),
            ('user', 'writes', 'tweet'): ConvLayer(
                (-1, -1), hidden_channels, heads=heads, add_self_loops=False
            ),
        }, aggr='sum')

        # Linear layer for output (adjusted input dimension)
        self.lin = nn.Linear(hidden_channels * heads, out_channels)

    def forward(self, x_dict, edge_index_dict):
        # Reset attention weights before forward pass
        for conv in self.conv1.convs.values():
            if isinstance(conv, GATConvWithAttention):
                conv.attentions = []
                conv.dst_nodes = []
        for conv in self.conv2.convs.values():
            if isinstance(conv, GATConvWithAttention):
                conv.attentions = []
                conv.dst_nodes = []

        # First GAT layer
        x_dict_updated = self.conv1(x_dict, edge_index_dict)
        x_dict = {
            key: F.elu(x_dict_updated[key]) if key in x_dict_updated else x_dict[key]
            for key in x_dict.keys()
        }

        # Second GAT layer
        x_dict_updated = self.conv2(x_dict, edge_index_dict)
        x_dict = {
            key: F.elu(x_dict_updated[key]) if key in x_dict_updated else x_dict[key]
            for key in x_dict.keys()
        }

        # Output layer for tweet nodes
        out = self.lin(x_dict['tweet'])
        return out

    def get_attention_weights(self):
        attention_weights = {}
        # Collect attention weights from conv1
        for edge_type, conv in self.conv1.convs.items():
            if isinstance(conv, GATConvWithAttention):
                attention_weights.setdefault('layer1', {})
                attention_weights['layer1'][edge_type] = (torch.cat(conv.attentions, dim=0), torch.cat(conv.dst_nodes, dim=0))
        # Collect attention weights from conv2
        for edge_type, conv in self.conv2.convs.items():
            if isinstance(conv, GATConvWithAttention):
                attention_weights.setdefault('layer2', {})
                attention_weights['layer2'][edge_type] = (torch.cat(conv.attentions, dim=0), torch.cat(conv.dst_nodes, dim=0))
        return attention_weights

# Initialize the HeteroGATWithAttention model
# Note: Ensure that 'metadata', 'LR', 'WEIGHT_DECAY', 'EPOCHS', and 'data' are defined appropriately
model_attn = HeteroGATWithAttention(metadata, hidden_channels=64, out_channels=2, conv_type='GAT')

# Define the optimizer and loss function
optimizer_attn = torch.optim.Adam(model_attn.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion_attn = nn.CrossEntropyLoss()

# Define training and evaluation functions
def train_attn():
    model_attn.train()
    optimizer_attn.zero_grad()
    out = model_attn(data.x_dict, data.edge_index_dict)
    out = out[data['tweet'].train_mask]
    y = data['tweet'].y[data['tweet'].train_mask]
    loss = criterion_attn(out, y)
    loss.backward()
    optimizer_attn.step()
    return loss.item()

@torch.no_grad()
def evaluate_attn(mask):
    model_attn.eval()
    out = model_attn(data.x_dict, data.edge_index_dict)
    pred = out.argmax(dim=1)
    correct = pred[mask] == data['tweet'].y[mask]
    accuracy = correct.sum().item() / mask.sum().item()
    return accuracy

# Training Loop
epochs = EPOCHS
train_acc_list_attn = []
val_acc_list_attn = []

print("Training GAT Model with Attention Extraction:")
for epoch in range(1, epochs + 1):
    loss = train_attn()
    train_acc = evaluate_attn(data['tweet'].train_mask)
    val_acc = evaluate_attn(data['tweet'].val_mask)
    
    # Append accuracies for plotting later
    train_acc_list_attn.append(train_acc)
    val_acc_list_attn.append(val_acc)
    
    # Print progress every 100 epochs and the first epoch
    if epoch % 100 == 0 or epoch == 1:
        print(f'GAT Attn Epoch {epoch:02d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')

# ============================
# Visualization Code
# ============================

# Ensure the model is in evaluation mode
model_attn.eval()

# Extract attention weights
attention_weights = model_attn.get_attention_weights()

# Convert PyTorch Geometric data to NetworkX graph
# We'll manually construct the graph to include attention weights
G = nx.Graph()

# Add all nodes from the 'tweet' and 'user' node types
# Assuming node indices are unique across types or appropriately mapped
# Adjust this part if node indices overlap or need specific handling
for node_type in data.x_dict:
    num_nodes = data.x_dict[node_type].size(0)
    G.add_nodes_from([f"{node_type}_{i}" for i in range(num_nodes)], node_type=node_type)

# Initialize a dictionary to hold edge weights
edge_weights = defaultdict(list)

# Iterate through each layer and edge type to collect attention weights
for layer, edge_types in attention_weights.items():
    for edge_type, (alphas, dst_nodes_attn) in edge_types.items():
        src_node_type, relation, dst_node_type = edge_type
        edge_index = data.edge_index_dict[edge_type].cpu().numpy()  # Shape: [2, num_edges]
        src_nodes = edge_index[0]
        dst_nodes = edge_index[1]
        alphas = alphas.cpu().numpy()
        dst_nodes_attn = dst_nodes_attn.cpu().numpy()

        # Assuming the order of alphas corresponds to the order of edges in edge_index
        for src, dst, alpha in zip(src_nodes, dst_nodes, alphas):
            # Create unique node identifiers
            src_node = f"{src_node_type}_{src}"
            dst_node = f"{dst_node_type}_{dst}"
            # Use a tuple of node identifiers as the edge key
            edge_key = (src_node, dst_node)
            edge_weights[edge_key].append(alpha)

# Assign average attention weights to edges
for edge, alphas in edge_weights.items():
    avg_alpha = np.mean(alphas)
    G.add_edge(edge[0], edge[1], weight=avg_alpha)

# Normalize edge weights for visualization (thickness)
weights = [G[u][v]['weight'] for u, v in G.edges()]
if weights:
    max_weight = max(weights)
    min_weight = min(weights)
else:
    max_weight = 1
    min_weight = 0

# Define edge thickness range
min_thickness = 1
max_thickness = 10

# Normalize weights to thickness
normalized_weights = [
    min_thickness + (w - min_weight) / (max_weight - min_weight) * (max_thickness - min_thickness)
    if max_weight != min_weight else min_thickness
    for w in weights
]

# Assign normalized thickness to edges
for (edge, thickness) in zip(G.edges(), normalized_weights):
    G.edges[edge]['thickness'] = thickness

# Determine node sizes based on the sum of incoming attention weights
node_sizes = {}
for node in G.nodes():
    incoming_weights = [G[u][node]['weight'] for u in G.predecessors(node)] if G.is_directed() else [G[u][node]['weight'] for u in G.neighbors(node)]
    total_weight = sum(incoming_weights) if incoming_weights else 1
    node_sizes[node] = total_weight

# Normalize node sizes for visualization
sizes = list(node_sizes.values())
if sizes:
    max_size = max(sizes)
    min_size = min(sizes)
else:
    max_size = 1
    min_size = 0

# Define node size range
min_node_size = 10
max_node_size = 50

normalized_node_sizes = [
    min_node_size + (s - min_size) / (max_size - min_size) * (max_node_size - min_node_size)
    if max_size != min_size else min_node_size
    for s in sizes
]

# Assign normalized sizes to nodes
for node, size in zip(G.nodes(), normalized_node_sizes):
    G.nodes[node]['size'] = size

# Create positions for all nodes using a layout algorithm
pos = nx.spring_layout(G, k=0.15, iterations=20, seed=42)  # Seed for reproducibility

# Prepare edge traces grouped by thickness to minimize the number of Plotly traces
thickness_to_edges = defaultdict(list)
for edge in G.edges(data=True):
    thickness = edge[2]['thickness']
    thickness_to_edges[thickness].append(edge)

# Create Plotly edge traces
edge_traces = []
for thickness, edges in thickness_to_edges.items():
    edge_x = []
    edge_y = []
    for edge in edges:
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]
    trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        line=dict(width=thickness, color='#888'),
        hoverinfo='none',
        mode='lines'
    )
    edge_traces.append(trace)

# Extract node information
node_x = []
node_y = []
node_size = []
node_color = []
node_text = []

for node in G.nodes(data=True):
    x, y = pos[node[0]]
    node_x.append(x)
    node_y.append(y)
    node_size.append(node[1]['size'])
    node_color.append(node[1]['size'])
    node_text.append(f'Node: {node[0]}<br>Size: {node[1]["size"]:.2f}')

# Create node trace
node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode='markers',
    hoverinfo='text',
    marker=dict(
        showscale=True,
        colorscale='YlGnBu',
        color=node_color,
        size=node_size,
        colorbar=dict(
            thickness=15,
            title='Node Size',
            xanchor='left',
            titleside='right'
        ),
        line_width=2
    ),
    text=node_text
)

# Create the figure
fig = go.Figure()

# Add all edge traces
for trace in edge_traces:
    fig.add_trace(trace)

# Add node trace
fig.add_trace(node_trace)

# Update layout
fig.update_layout(
    title='Network Graph with Attention-based Edge Thickness and Node Sizes',
    titlefont_size=16,
    showlegend=False,
    hovermode='closest',
    margin=dict(b=20, l=5, r=5, t=40),
    annotations=[
        dict(
            text="",
            showarrow=False,
            xref="paper", yref="paper"
        )
    ],
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
)

# Display the plot
fig.show()

# Optional: Save the plot as an HTML file
# plot(fig, filename='network_graph.html')



There exist node types ({'user'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behavior.



Training GAT Model with Attention Extraction:
GAT Attn Epoch 01, Loss: 0.6816, Train Acc: 0.5800, Val Acc: 0.5749
GAT Attn Epoch 100, Loss: 0.3938, Train Acc: 0.8500, Val Acc: 0.7305
GAT Attn Epoch 200, Loss: 0.1283, Train Acc: 0.9800, Val Acc: 0.7126


RuntimeError: torch.cat(): expected a non-empty list of Tensors

## Test the model

In [112]:
#GAT
#---------------------------------------------------------------
# Evaluate the model on the test set
test_acc = evaluate(data['tweet'].test_mask)
print(f'Test Accuracy for base GAT: {test_acc:.4f}')
#------------------------------------------------

#GATV2
#---------------------------------------------------------------
# Evaluate the model on the test set
test_acc_V2 = evaluate_V2(data['tweet'].test_mask)
print(f'Test Accuracy for GATv2: {test_acc_V2:.4f}')
#---------------------------------------------------------------

Test Accuracy for base GAT: 0.7305
Test Accuracy for GATv2: 0.6826
