In [None]:
! pip install torch==2.1.0  torchvision==0.16.0 torchtext==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
! pip install torch_geometric==2.4
! pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
! pip install sentence-transformers
! pip install torcheval
! pip install matplotlib
! pip install pandas
! pip install tensorboard

In [None]:
from torch_geometric.data import HeteroData
import pandas as pd
import numpy as np 
from sklearn.metrics import f1_score
import torch 

    
from torch_geometric.data import HeteroData
# load data
import torch 
train_graph = torch.load('/kaggle/input/twibot22-pyggraph/TwiBot22_Graph_with_degreecounts_train_with_y.pt')


In [None]:
# Define GNN Model
from torch_geometric.nn import HGTConv, Linear

from torch.nn import functional as F

class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers, node_types, data_metadata):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data_metadata,
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict['user'])
    
    
    
model = HGT(hidden_channels=256, out_channels=1, num_heads=8, num_layers=2, node_types=train_graph.node_types, data_metadata=train_graph.metadata())


In [None]:
# get weights after training
import gc 
def init(model, optimizer, state_dict_path, data):
    with torch.no_grad():
        model.eval()
        for node_type in data.node_types:
            print(node_type)

            loader = HGTLoader(
                    data.cuda(),
                    # Sample 512 nodes per type and per iteration for 4 iterations
                    num_samples=num_neighbors,
                    batch_size=64, #96 or 32 nodes
                    input_nodes=node_type,
                    num_workers=0,
                    pin_memory=True,
                    prefetch_factor=None,
                )
            minibatch = next(iter(loader))

            model(minibatch.x_dict, minibatch.edge_index_dict)
            model_and_optimizer = torch.load(state_dict_path)
            model.load_state_dict(model_and_optimizer['model_state_dict'])
            del loader
            gc.collect()
            
            
            optimizer.load_state_dict(model_and_optimizer['optimizer_state_dict'])
   




In [None]:
# create minibatch loader
from torch_geometric.loader import HGTLoader
batch_size = 32
num_node_types = len(train_graph.node_types)
one_hop_neighbors = (20 * batch_size)//num_node_types # per relationship type
two_hop_neighbors = (20 * 8 * batch_size)//num_node_types # per relationship type
#three_hop_neighbors = (20 * 8 * 3 * batch_size)//num_node_types # per relationship type
num_neighbors = [one_hop_neighbors, two_hop_neighbors]

for node_type in train_graph.node_types:
    train_graph[node_type].x = train_graph[node_type].x.float()
    



In [None]:
# train model 
from tqdm.auto import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
criterion = torch.nn.BCEWithLogitsLoss()  # more numerically stable than standard BCE because of log sum exp trick https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_folder = Path(f'runs/{timestamp}')
writer = SummaryWriter(run_folder)

print("a")
model.train()

gc.collect()
print('b')
model.cuda()
samples_seen = 7175872

init(model,optimizer, f'/kaggle/input/twibot22-hgt-models/model_samplesseen{samples_seen}.pt', train_graph)
print('c')
train_graph=train_graph.cuda()
print('d')
loader = HGTLoader(
        train_graph,
        num_samples=num_neighbors,
        batch_size=32,
        input_nodes=('user', torch.arange(0, len(train_graph['user'].y))),
        num_workers=4,
        pin_memory=True,
        prefetch_factor=2,
        shuffle=True
    )
for epoch in range(1):
    for i,minibatch in enumerate(loader):
        optimizer.zero_grad()
        out = model(minibatch.x_dict, minibatch.edge_index_dict)
        loss = criterion(out, minibatch['user'].y.unsqueeze(1).float())
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0 or i == len(loader) - 1:
            writer.add_scalar('Loss/train', loss.item(), samples_seen+(i+1)*32)
        
        if i % 1000 == 0 or i == len(loader) - 1:
            torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    }, run_folder/f'model_samplesseen{samples_seen+(i+1)*32}.pt')