In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from models.gat_encoder import GATEncoder
from Dataset.dataset import GraphDataset

In [7]:
# Initialize the dataset and dataloader
data_path = "Dataset\processed.jsonl"  # Replace with your dataset path
dataset = GraphDataset(data_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [8]:
# Initialize the GAT model
gat_model = GATEncoder(in_channels=768, hidden_channels=64, out_channels=32, heads=4)
gat_model.train()  # Set model to training mode

GATEncoder(
  (gat1): GATConv(768, 64, heads=4)
  (gat2): GATConv(256, 32, heads=1)
)

In [9]:
# Define a dummy objective (MSE loss)
criterion = nn.MSELoss()

# Optimizer
optimizer = optim.Adam(gat_model.parameters(), lr=0.001)

In [12]:
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        # Extract data from the batch
        node_features = batch["node_features"].squeeze(0)  # [num_nodes, in_channels]
        edge_list = batch["edge_list"].squeeze(0)  # [2, num_edges]
        
        # Forward pass
        output_embeddings = gat_model(node_features, edge_list)  # [num_nodes, out_channels]
        
        # Generate a dummy target (same shape as output)
        target_embeddings = torch.rand_like(output_embeddings)  # Random target embeddings
        
        # Compute the loss
        loss = criterion(output_embeddings, target_embeddings)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")

Epoch 1/100, Loss: 0.09261189848184585
Epoch 2/100, Loss: 0.09108399301767349
Epoch 3/100, Loss: 0.09358458667993545
Epoch 4/100, Loss: 0.09253397211432457
Epoch 5/100, Loss: 0.09454121440649033
Epoch 6/100, Loss: 0.0928214319050312
Epoch 7/100, Loss: 0.09186646267771721
Epoch 8/100, Loss: 0.0905759021639824
Epoch 9/100, Loss: 0.09223682284355164
Epoch 10/100, Loss: 0.09323505610227585
Epoch 11/100, Loss: 0.0933595433831215
Epoch 12/100, Loss: 0.08988827094435692
Epoch 13/100, Loss: 0.09441400915384293
Epoch 14/100, Loss: 0.08874301686882972
Epoch 15/100, Loss: 0.08900878131389618
Epoch 16/100, Loss: 0.09275438413023948
Epoch 17/100, Loss: 0.08944416120648384
Epoch 18/100, Loss: 0.09182324558496475
Epoch 19/100, Loss: 0.09244096651673317
Epoch 20/100, Loss: 0.09296470880508423
Epoch 21/100, Loss: 0.0898807942867279
Epoch 22/100, Loss: 0.08978739380836487
Epoch 23/100, Loss: 0.08881771266460418
Epoch 24/100, Loss: 0.08743634670972825
Epoch 25/100, Loss: 0.08699331060051918
Epoch 26/100,