In [2]:
from dataloader import DataLoader

loader = DataLoader(hs_code=282520)

_initialize_data took 110.1829 seconds


In [3]:
df = loader.get_data()

In [35]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
from sklearn.preprocessing import StandardScaler
import hdbscan
import umap
import math

#####################################
# Helper Functions
#####################################
def filter_top_countries(df: pd.DataFrame, top_percent: float = 0.4):
    """
    Filter the dataset to only include countries that are in the top `top_percent`
    in terms of total trade volume.

    Parameters
    ----------
    df : pd.DataFrame
        The input dataframe with 'export_country', 'import_country', and 'v'.
    top_percent : float
        Fraction of countries to keep, based on total trade volume. 0.5 means top 50%.

    Returns
    -------
    df_filtered : pd.DataFrame
        The filtered dataframe containing only the top-percent countries.
    """
    exp_sum = df.groupby('export_country')['v'].sum()
    imp_sum = df.groupby('import_country')['v'].sum()

    total_trade = exp_sum.add(imp_sum, fill_value=0).sort_values(ascending=False)

    cutoff_index = int(len(total_trade)*top_percent)
    cutoff_countries = total_trade.index[:cutoff_index]

    df_filtered = df[df['export_country'].isin(cutoff_countries) & df['import_country'].isin(cutoff_countries)].copy()
    return df_filtered

#####################################
# Step 1: Data Preparation
#####################################
def build_snapshots(
    df: pd.DataFrame,
    node_features = ['gdpcap_o','gdpcap_d','wto_o','wto_d','eu_o','eu_d','pop_o','pop_d'],
    edge_features = [
        'v','comlang_off','comlang_ethno','comcol','col45',
        'comleg_pretrans','comleg_posttrans','col_dep_ever',
        'empire','sibling_ever','scaled_sci_2021','comrelig','distw_harmonic'
    ],
    device: torch.device = torch.device('cpu')
):
    """
    Convert df into a list of PyG Data objects, one per snapshot (year).
    Focuses on the specified node and edge features only, after filtering top countries.

    Parameters
    ----------
    df : pd.DataFrame
        The merged dataset containing trade, gravity, and additional attributes.
    node_features : list of str
        Node-level features to use (aggregated by country).
    edge_features : list of str
        Edge-level features to use directly on edges.
    device : torch.device, optional
        The computation device (CPU or GPU).

    Returns
    -------
    snapshots : list of Data
        A list of PyTorch Geometric Data objects, one per year.
    all_countries : np.ndarray
        Array of all unique country names.
    """
    df = filter_top_countries(df, top_percent=0.5)

    years = sorted(df['t'].unique())
    all_countries = np.union1d(df['export_country'].unique(), df['import_country'].unique())
    country_to_id = {c: i for i, c in enumerate(all_countries)}
    num_nodes = len(all_countries)
    
    for col in node_features+edge_features:
        df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)

    df['yearly_total_v'] = df.groupby('t')['v'].transform('sum')
    df['v'] = df['v'] / (df['yearly_total_v'] + 1e-9)

    node_df = df.groupby('export_country')['gdpcap_d'].mean().rename('avg_gdpcap_exp').to_frame()
    node_df['avg_gdpcap_imp'] = df.groupby('import_country')['gdpcap_o'].mean()
    node_df['wto_member'] = df[['export_country','wto_o']].drop_duplicates().groupby('export_country')['wto_o'].max()
    node_df['eu_member'] = df[['export_country','eu_o']].drop_duplicates().groupby('export_country')['eu_o'].max()
    node_df['avg_pop'] = df.groupby('export_country')['pop_o'].mean()
    node_df = node_df.fillna(0)
    
    for c in all_countries:
        if c not in node_df.index:
            node_df.loc[c] = [0,0,0,0,0]

    node_mat = node_df[['avg_gdpcap_exp','avg_gdpcap_imp','wto_member','eu_member','avg_pop']].values
    scaler_node = StandardScaler()
    node_mat = scaler_node.fit_transform(node_mat)

    snapshots = []
    scaler_edge = StandardScaler()

    edge_mat_all = df[edge_features].values
    edge_mat_all = scaler_edge.fit_transform(edge_mat_all)

    for year in years:
        sub = df[df['t']==year].copy()
        sub['src'] = sub['export_country'].map(country_to_id)
        sub['dst'] = sub['import_country'].map(country_to_id)

        sub_edge_mat = sub[edge_features].values
        sub_edge_mat = scaler_edge.transform(sub_edge_mat)

        edge_index = torch.tensor([sub['src'].values, sub['dst'].values], dtype=torch.long)
        edge_attr = torch.tensor(sub_edge_mat, dtype=torch.float)
        
        x = torch.tensor(node_mat, dtype=torch.float)

        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        data = data.to(device)
        snapshots.append(data)

    return snapshots, all_countries

#####################################
# Dynamic GNN Autoencoder
#####################################
class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_dim=64, out_channels=32, dropout=0.1):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_channels)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

class DynamicGNNAutoencoder(nn.Module):
    def __init__(self, num_nodes, in_channels, hidden_dim=64, out_channels=32, edge_dim=13, dropout=0.4):
        super().__init__()
        self.num_nodes = num_nodes
        self.encoder = GCNEncoder(in_channels, hidden_dim, out_channels, dropout)
        self.rnn = nn.GRU(out_channels, out_channels, batch_first=False)
        self.out_dim = out_channels

        self.mlp = nn.Sequential(
            nn.Linear(2*out_channels, 64),
            nn.ReLU(),
            nn.Linear(64, edge_dim+1)
        )

        self.dropout = nn.Dropout(dropout)

    def decode(self, z, edge_index):
        u, v = edge_index
        z_uv = torch.cat([z[u], z[v]], dim=-1)
        out = self.mlp(z_uv)
        return out

    def forward(self, x, edge_index):
        z = self.encoder(x, edge_index)
        return z

    def forward_timestep(self, data, h):
        z_curr = self.encoder(data.x, data.edge_index)
        z_curr = z_curr.unsqueeze(0)
        if h is None:
            h = z_curr
        else:
            h, _ = self.rnn(z_curr, h)
        return h

    def compute_loss(self, z, data):
        z = z.squeeze(0)
        pos_edges = data.edge_index
        pos_out = self.decode(z, pos_edges)
        pos_labels = torch.cat([torch.ones(pos_out.size(0),1,device=z.device), data.edge_attr], dim=-1)

        neg_edge_index = negative_sampling(pos_edges, num_nodes=self.num_nodes, num_neg_samples=pos_edges.size(1), method='dense')
        neg_out = self.decode(z, neg_edge_index)
        neg_labels = torch.cat([torch.zeros(neg_out.size(0),1,device=z.device), 
                                torch.zeros((neg_out.size(0), data.edge_attr.size(1)),device=z.device)], dim=-1)

        out = torch.cat([pos_out, neg_out], dim=0)
        labels = torch.cat([pos_labels, neg_labels], dim=0)

        adj_pred = out[:,0]
        adj_true = labels[:,0]
        adj_loss = F.binary_cross_entropy_with_logits(adj_pred, adj_true)

        attr_pred = out[:,1:]
        attr_true = labels[:,1:]
        attr_loss = F.mse_loss(attr_pred, attr_true)

        total_loss = 0.2*adj_loss + 0.8*attr_loss
        return total_loss

#####################################
# Training the Model
#####################################
def train_dynamic_model(snapshots, num_nodes, in_channels, device=torch.device('cpu'), epochs=100, stability_weight=0.9):
    edge_dim = 13
    model = DynamicGNNAutoencoder(num_nodes, in_channels=in_channels, edge_dim=edge_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)

    for epoch in range(1, epochs+1):
        model.train()
        optimizer.zero_grad()
        loss_total = 0
        h = None
        prev_z = None
        for data in snapshots:
            h = model.forward_timestep(data, h)
            loss = model.compute_loss(h, data)

            if prev_z is not None:
                stability_loss = F.mse_loss(h, prev_z)
                loss = loss + stability_weight * stability_loss

            prev_z = h.clone().detach()
            loss_total += loss
        loss_total = loss_total / len(snapshots)
        loss_total.backward()
        optimizer.step()

        if epoch % 20 == 0:
            print(f"Epoch {epoch}, Loss: {loss_total.item():.4f}")
    return model

#####################################
# Extract Final Embeddings & Cluster
#####################################
def get_final_embeddings(model, snapshots):
    model.eval()
    h = None
    with torch.no_grad():
        for data in snapshots:
            h = model.forward_timestep(data, h)
    return h.squeeze(0).cpu().numpy()

def cluster_embeddings(emb, min_cluster_size=10):
    reducer = umap.UMAP(random_state=42)
    emb_2d = reducer.fit_transform(emb)
    clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size)
    labels = clusterer.fit_predict(emb_2d)
    return labels, emb_2d

#####################################
# FULL PIPELINE
#####################################
def run_temporal_community_detection(df: pd.DataFrame, device: torch.device = torch.device('cpu')):
    snapshots, all_countries = build_snapshots(df, device=device)
    in_channels = snapshots[0].x.size(1)
    num_nodes = snapshots[0].x.size(0)

    model = train_dynamic_model(snapshots, num_nodes, in_channels, device=device, epochs=200, stability_weight=0.9)
    emb = get_final_embeddings(model, snapshots)
    labels, emb_2d = cluster_embeddings(emb, min_cluster_size=5)

    int_to_node = {i: c for i, c in enumerate(all_countries)}
    communities = {}
    for i, lbl in enumerate(labels):
        if lbl not in communities:
            communities[lbl] = []
        communities[lbl].append(int_to_node[i])

    return communities, labels, emb_2d

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
communities, labels, emb_2d = run_temporal_community_detection(df.to_pandas(), device=device)
for c_id, members in communities.items():
    print(f"Community {c_id}: {members}")

total_trade = df['q'].sum()

country_to_community = {}
for c_id, members in communities.items():
    for m in members:
        country_to_community[m] = c_id

community_trade = {c_id: 0.0 for c_id in communities.keys()}

for idx, row in df.to_pandas().iterrows():
    exp_country = row['export_country']
    imp_country = row['import_country']
    trade_val = row['q']

    if exp_country in country_to_community and imp_country in country_to_community:
        c_id_exp = country_to_community[exp_country]
        c_id_imp = country_to_community[imp_country]
        if c_id_exp == c_id_imp:
            community_trade[c_id_exp] += trade_val

for c_id, total_c_trade in community_trade.items():
    share = (total_c_trade / total_trade) * 100 if total_trade > 0 else 0
    print(f"Community {c_id}: {share:.2f}% of total trade volume")

Epoch 20, Loss: 0.4760
Epoch 40, Loss: 0.4660
Epoch 60, Loss: 0.4438
Epoch 80, Loss: 0.4574
Epoch 100, Loss: 0.4386
Epoch 120, Loss: 0.4370
Epoch 140, Loss: 0.4239
Epoch 160, Loss: 0.4130
Epoch 180, Loss: 0.4010
Epoch 200, Loss: 0.4050
Community 4: ['Afghanistan', 'Azerbaijan', "Côte d'Ivoire", 'Guatemala', 'Iran', 'Jordan', 'Oman', 'Philippines', 'Syria', 'Trinidad and Tobago', 'Tunisia', 'United Arab Emirates', 'United Rep. of Tanzania', 'Uzbekistan', 'Venezuela', 'Viet Nam', 'Zimbabwe']
Community -1: ['Algeria', 'Argentina', 'Austria', 'China, Hong Kong SAR', 'Croatia', 'Czechia', 'Kenya', 'Malaysia', 'Mexico', 'Nepal', 'New Zealand', 'Other Asia, nes', 'Pakistan', 'Sweden']
Community 1: ['Australia', 'Belarus', 'Belgium', 'Brazil', 'Bulgaria', 'Denmark', 'Egypt', 'France', 'Germany', 'Greece', 'Hungary', 'Israel', 'Italy', 'Lithuania', 'Morocco', 'Netherlands', 'Poland', 'Portugal', 'Romania', 'Russian Federation', 'Slovenia', 'Spain', 'Switzerland', 'Türkiye', 'Ukraine']
Community

  warn(


Community 4: 0.08% of total trade volume
Community -1: 0.00% of total trade volume
Community 1: 14.06% of total trade volume
Community 2: 0.00% of total trade volume
Community 3: 0.05% of total trade volume
Community 0: 58.31% of total trade volume
