In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GAE
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch.optim.lr_scheduler import StepLR
import optuna
from tqdm import tqdm
import numpy as np
import pandas as pd
import polars as pl
from sklearn.preprocessing import StandardScaler
import gc

  from .autonotebook import tqdm as notebook_tqdm


cpu


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

## Prepare Input Data
### Node Features

In [None]:
# Load node2vec embeddings and node features
node2vec_embeddings = np.load("./output/node_embeddings.npy")
aid_features = pl.read_parquet("./data/aid_features.parquet").to_numpy()
# pathing up kernel dead problem
# aid_features_agg = pl.read_parquet("./data/aid_features_agg.parquet")
# aid_features_all = aid_features.join(aid_features_agg, on="aid", how="inner").drop("aid").to_numpy()

# Scaling node2vec embeddings and node features separately
scaler_node2vec = StandardScaler()
scaled_node2vec_embeddings = scaler_node2vec.fit_transform(node2vec_embeddings)

scaler_features = StandardScaler()
scaled_aid_features_all = scaler_features.fit_transform(aid_features)

# Concatenate node2vec embeddings and node features
features_and_embeddings = np.concatenate((scaled_node2vec_embeddings, scaled_aid_features_all), axis=1)

In [None]:
del node2vec_embeddings,aid_features#,aid_features_all, aid_features_agg
gc.collect()

### edge_index

In [None]:
data = pd.read_parquet('./data/train.parquet')
edge_weights = data.groupby(['session', 'aid']).size().reset_index(name='weight')

edge_list = edge_weights[['session', 'aid']].values.tolist()
edge_weights_list = edge_weights['weight'].values.tolist()

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_weights_list, dtype=torch.float).view(-1, 1)

In [None]:
graph_data = Data(x=torch.tensor(features_and_embeddings), edge_index=edge_index, edge_attr=edge_attr)
graph_data.n_id = torch.arange(graph_data.num_nodes)

## Neighbor Loader

In [None]:
gSAGE_loader = NeighborLoader(
    graph_data,
    num_neighbors=[10,10],
    batch_size=512)

## GraphSAGE

In [None]:
# use attr GraphSAGE
# emsemble GAE

class WeightedSAGEConv(SAGEConv):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(WeightedSAGEConv, self).__init__(in_channels, out_channels, **kwargs)

    def forward(self, x, edge_index, edge_weight=None):
        x = (x[0], x[1])
        return super(WeightedSAGEConv, self).forward(x, edge_index, edge_weight)

    
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super(GraphSAGE, self).__init__()

        self.num_layers = num_layers
        self.convs = torch.nn.ModuleList()

        # Input layer
        self.convs.append(WeightedSAGEConv(in_channels, hidden_channels))

        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(WeightedSAGEConv(hidden_channels, hidden_channels))

        # Output layer
        self.convs.append(WeightedSAGEConv(hidden_channels, out_channels))

    def forward(self, x, adjs):
        for i, (edge_index, edge_attr, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target node features
            x = self.convs[i]((x, x_target), edge_index, edge_attr)

            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)

        return x.log_softmax(dim=-1)

#         # 이 추론 코드는 깃에서 뜯어와 함친 것. 그러므로 3개 라벨 분류 추론에 적합하지 않은 형태.
#         # 두개 비교해 더 낫게 수정할 필요가 있음. 
#         # chat에게 하나씩 보여준 다음 두개의 장점을 결합해서 디벨롭 
#     def inference(self, x_all, subgraph_loader, device):
#         pbar = tqdm(total=x_all.size(0) * self.num_layers)
#         pbar.set_description('Evaluating')

#         for i in range(self.num_layers):
#             xs = []
#             for batch_size, n_id, adj in subgraph_loader:
#                 edge_index, edge_attr, size = adj.to(device)
#                 x = x_all[n_id].to(device)
#                 x_target = x[:size[1]]
#                 x = self.convs[i]((x, x_target), edge_index, edge_attr)
#                 if i != self.num_layers - 1:
#                     x = F.relu(x)
#                 xs.append(x.cpu())

#                 pbar.update(batch_size)

#             x_all = torch.cat(xs, dim=0)

#         pbar.close()
#         return x_all

In [None]:
out_channels = 16
num_features = graph_data.x.shape[1]
hidden_channels = 32
num_layers = 2
model = GAE(GraphSAGE(num_features, hidden_channels, out_channels, num_layers)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005,weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=2, gamma=0.1)

In [None]:
def train(loader):
    total_loss = 0
    for subgraph in tqdm(loader):
        optimizer.zero_grad()
        z = model.encode(subgraph.x.float().to(device),subgraph.edge_index.to(device))
        loss = model.recon_loss(z, pos_edge_index=subgraph.edge_index.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader), model

In [None]:
for epoch in range(0,10):
    
    loss,model = train(gSAGE_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
torch.save(model,"graphSage_model")

In [None]:
# kernel dead problem