In [2]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive/276/proj

Mounted at /content/drive
/content/drive/My Drive/276/proj


In [3]:
!pip install dgl dglgo -f https://data.dgl.ai/wheels/repo.html -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.5/63.5 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.3/45.3 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.7/61.7 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.4/116.4 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m24.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━

In [4]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
import random

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [5]:
# Load the graphs and the labels
graphs, label_dict = dgl.load_graphs('/content/drive/My Drive/276/proj/graphs/dgl_graphs.bin')
labels = label_dict['labels']

In [8]:
graphs[0]

Graph(num_nodes=367, num_edges=2077,
      ndata_schemes={'b_factor': Scheme(shape=(), dtype=torch.float64)}
      edata_schemes={})

In [17]:
combined_dataset = list(zip(graphs, labels))
random.shuffle(combined_dataset)
split_index = int(len(combined_dataset) * 0.8)

trainset = combined_dataset[:split_index]
testset = combined_dataset[split_index:]

In [13]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    labels = torch.tensor(labels)
    return batched_graph, labels

In [15]:
from dgl.dataloading import GraphDataLoader

In [18]:
# Create the train and test dataloaders
train_dataloader = GraphDataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate)
test_dataloader = GraphDataLoader(testset, batch_size=32, shuffle=False, collate_fn=collate)

In [10]:
# Prepare dataset
combined_dataset = list(zip(graphs, labels))
random.shuffle(combined_dataset)
split_index = int(len(combined_dataset) * 0.8)
trainset = combined_dataset[:split_index]
testset = combined_dataset[split_index:]

# Define the GraphSAGE model
class GraphSAGENet(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GraphSAGENet, self).__init__()
        self.layer1 = SAGEConv(in_feats, hidden_size, 'mean')
        self.layer2 = SAGEConv(hidden_size, num_classes, 'mean')
    def forward(self, g, features):
        h = F.relu(self.layer1(g, features))
        h = self.layer2(g, h)
        return h

# Initialize the model
# Replace `num_features` and `num_classes` with actual numbers
num_features = graphs[0].ndata['b_factor'].shape[0]
num_classes = len(set(labels))
model = GraphSAGENet(num_features, 64, num_classes)

In [19]:
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(2):
    for batched_graph, labels in train_dataloader:
        features = batched_graph.ndata['b_factor']
        logits = model(batched_graph, features)
        loss = F.cross_entropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate the model
model.eval()
with torch.no_grad():
    for batched_graph, labels in test_dataloader:
        features = batched_graph.ndata['b_factor']
        logits = model(batched_graph, features)
        # Add evaluation code here (e.g., calculating accuracy)

RuntimeError: ignored