In [1]:
import torch
from torch import nn
import os.path as osp
import GCL.losses as L
import GCL.augmentors as A
import torch.nn.functional as F
import torch_geometric.transforms as T

from tqdm import tqdm
from torch.optim import Adam
from GCL.eval import from_predefined_split, LREvaluator
from GCL.models import DualBranchContrast

In [2]:
from torch_geometric.nn import GCN

### import data

In [4]:
import scipy.sparse as sp
import numpy as np
import json

import torch
import torch.nn.functional as F

In [5]:
# load data
adj = sp.load_npz('../CSE881_data_2024/adj.npz')
features  = np.load('../CSE881_data_2024/features.npy')
labels = np.load('../CSE881_data_2024/labels.npy')
splits = json.load(open('../CSE881_data_2024/splits.json'))
idx_train, idx_test = splits['idx_train'], splits['idx_test']

In [6]:
# transfer adjacency matrix into edge index
from torch_geometric.utils import from_scipy_sparse_matrix

edge_index = from_scipy_sparse_matrix(adj)
print("There are", edge_index[0].size(1), "edges in total in the graph\n")

print(torch.unique(edge_index[1]))
print("These edges are not weighted.")

There are 10100 edges in total in the graph

tensor([1.])
These edges are not weighted.


In [7]:
print("There are", len(features), "nodes in the graph.")
num_classes = len(np.unique(labels))
print("Each node can be one of", num_classes, "classes.")
print("Training set size:", len(idx_train))
print("Test set size:", len(idx_test))

There are 2480 nodes in the graph.
Each node can be one of 7 classes.
Training set size: 496
Test set size: 1984


In [8]:
device = 'cuda'

In [9]:
features = torch.from_numpy(features).float()
num_features = len(features[0])
print("Number of features:", num_features)

Number of features: 1390


In [10]:
features = features.to(device)
edge_index = edge_index[0].long().to(device)
edge_weight = edge_index[1].float().to(device)

### Split Valiadation Set

In [12]:
# split 20% of training set as validation set
from sklearn.model_selection import train_test_split

idx_train_sub, idx_val, train_labels_sub, val_labels = train_test_split(
    idx_train, labels, test_size=0.2, random_state=123, stratify=labels)

print("Training subset size:", len(idx_train_sub))
print("Validation set size:", len(idx_val))

Training subset size: 396
Validation set size: 100


In [13]:
train_labels_sub = torch.from_numpy(train_labels_sub).long().to(device)
val_labels = torch.from_numpy(val_labels).long().to(device)
labels = torch.from_numpy(labels).long().to(device)

### main

In [15]:
class Encoder(torch.nn.Module):
    def __init__(self, encoder, augmentor, hidden_dim, proj_dim):
        super(Encoder, self).__init__()
        self.encoder = encoder
        self.augmentor = augmentor

        self.fc1 = torch.nn.Linear(hidden_dim, proj_dim)
        self.fc2 = torch.nn.Linear(proj_dim, hidden_dim)

    def forward(self, x, edge_index, edge_weight=None):
        aug1, aug2 = self.augmentor
        x1, edge_index1, edge_weight1 = aug1(x, edge_index, edge_weight)
        x2, edge_index2, edge_weight2 = aug2(x, edge_index, edge_weight)
        z = self.encoder(x, edge_index, edge_weight)
        z1 = self.encoder(x1, edge_index1, edge_weight1)
        z2 = self.encoder(x2, edge_index2, edge_weight2)
        return z, z1, z2

    def project(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        return self.fc2(z)

In [16]:
num_nodes = 2480

# create extra_pos_mask
# initialize a 2480 x 2480 matrix of False
extra_pos_mask = torch.zeros((num_nodes, num_nodes), dtype=torch.bool).to(device)

# create a temporary full label tensor initialized with a dummy label and place the known labels
full_labels = torch.full((num_nodes,), -1, dtype=train_labels_sub.dtype).to(device)
full_labels[idx_train_sub] = train_labels_sub

# iterate through each known label and update the label_matrix to True by finding nodes with the same label
for i, label in zip(idx_train_sub, train_labels_sub):
    same_label_indices = torch.where(full_labels == label)[0]
    extra_pos_mask[i, same_label_indices] = True
    extra_pos_mask[same_label_indices, i] = True
extra_pos_mask.fill_diagonal_(False)

# pos_mask: [N, 2N] for both inter-view and intra-view samples
extra_pos_mask = torch.cat([extra_pos_mask, extra_pos_mask], dim=1).to(device)
# fill inter-view positives only; pos_mask for intra-view samples should have False in diagonal
extra_pos_mask.fill_diagonal_(True)
    
# create extra_neg_mask
# initialize a 2480 x 2480 matrix of True
extra_neg_mask = torch.ones((num_nodes, num_nodes), dtype=torch.bool).to(device)

# iterate through each known label and update the label_matrix to False by finding nodes with the same label
for i, label in zip(idx_train_sub, train_labels_sub):
    same_label_indices = torch.where(full_labels == label)[0]
    extra_neg_mask[i, same_label_indices] = False
    extra_neg_mask[same_label_indices, i] = False

# set the diagonal to False since a sample cannot be a negative of itself
extra_neg_mask.fill_diagonal_(False)

# neg_mask: [N, 2N] for both inter-view and intra-view samples
extra_neg_mask = torch.cat([extra_neg_mask, extra_neg_mask], dim=1).to(device)

In [17]:
def train(encoder_model, contrast_model, optimizer, features=features, 
         edge_index=edge_index, edge_weight=edge_weight, labels=labels,
         extra_pos_mask=extra_pos_mask, extra_neg_mask=extra_neg_mask):
    
    encoder_model.train()
    optimizer.zero_grad()
    z, z1, z2 = encoder_model(features, edge_index, edge_weight)
    h1, h2 = [encoder_model.project(x) for x in [z1, z2]]
    
    loss = contrast_model(h1=h1, h2=h2, extra_pos_mask=extra_pos_mask, extra_neg_mask=extra_neg_mask)
    loss.backward()
    optimizer.step()
    return loss.item()    

In [18]:
class LogisticRegression(nn.Module):
    def __init__(self, num_features, num_classes):
        super(LogisticRegression, self).__init__()
        self.fc = nn.Linear(num_features, num_classes)
        torch.nn.init.xavier_uniform_(self.fc.weight.data)


    def forward(self, x):
        x = self.fc(x)
        return x

def test(encoder_model, input_dim, num_class, features=features, edge_index=edge_index, 
         labels=labels, edge_weight=edge_weight):
    
    encoder_model.eval()
    with torch.no_grad():
        z, _, _ = encoder_model(features, edge_index, edge_weight)
        z = z.detach()
    
    classifier = LogisticRegression(input_dim, num_class).to(device)
    optimizer = Adam(classifier.parameters(), lr=0.02)
    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in range(400):
        classifier.train()
        optimizer.zero_grad()
        logits = classifier(z)  
        loss = criterion(logits[idx_train_sub], train_labels_sub)
        loss.backward()
        optimizer.step()
        
        if epoch % 40 == 0:
            classifier.eval()
            with torch.no_grad():
                logits_val = classifier(z)
                loss_val = criterion(logits_val[idx_val], val_labels)
                preds = logits_val.argmax(dim=1)
                correct = preds[idx_val].eq(val_labels).sum().item()
                accuracy = correct / len(val_labels)
                print(f'Epoch: {epoch}, Train Loss: {loss:.4f}, Val Loss: {loss_val:.4f}, Val Accuracy: {accuracy:.4f}')
        
    classifier.eval()
    with torch.no_grad():
        logits = classifier(z)
        preds = logits.argmax(dim=1)
        correct = preds[idx_val].eq(val_labels).sum().item()
        accuracy = correct / len(val_labels)
        print(f'Final Val Accuracy: {accuracy:.4f}')

    return accuracy

In [19]:
aug1 = A.Compose([A.EdgeRemoving(pe=0.1), A.FeatureMasking(pf=0.1), A.FeatureDropout(pf=0.1), A.NodeDropping(pn=0.2)])
aug2 = A.Compose([A.EdgeRemoving(pe=0.1), A.FeatureMasking(pf=0.1), A.FeatureDropout(pf=0.1), A.NodeDropping(pn=0.2)])

In [20]:
gconv = GCN(in_channels=num_features, hidden_channels=512, 
               out_channels=32, num_layers=3, dropout=0.5).to(device)

In [21]:
encoder_model = Encoder(encoder=gconv, augmentor=(aug1, aug2), hidden_dim=32, proj_dim=32).to(device)
contrast_model = DualBranchContrast(loss=L.InfoNCE(tau=0.2), mode='L2L', intraview_negs=True).to(device)

In [22]:
optimizer = Adam(encoder_model.parameters(), lr=0.001)

In [23]:
with tqdm(total=700, desc='(T)') as pbar:
    for epoch in range(1, 701):
        loss = train(encoder_model, contrast_model, optimizer=optimizer)
        pbar.set_postfix({'loss': loss})
        pbar.update()
        if epoch % 100 == 0:
            test(encoder_model, input_dim=32, num_class=7)

(T):  14%|█▍        | 100/700 [00:04<00:23, 25.76it/s, loss=8.38]

Epoch: 0, Train Loss: 3.3788, Val Loss: 2.7310, Val Accuracy: 0.1400
Epoch: 40, Train Loss: 1.5374, Val Loss: 1.5558, Val Accuracy: 0.4300
Epoch: 80, Train Loss: 1.3435, Val Loss: 1.3526, Val Accuracy: 0.4800
Epoch: 120, Train Loss: 1.2484, Val Loss: 1.2484, Val Accuracy: 0.5300
Epoch: 160, Train Loss: 1.1823, Val Loss: 1.1733, Val Accuracy: 0.5600
Epoch: 200, Train Loss: 1.1309, Val Loss: 1.1146, Val Accuracy: 0.5800
Epoch: 240, Train Loss: 1.0889, Val Loss: 1.0675, Val Accuracy: 0.6000
Epoch: 280, Train Loss: 1.0537, Val Loss: 1.0292, Val Accuracy: 0.6300


(T):  14%|█▍        | 101/700 [00:04<00:23, 25.76it/s, loss=8.4] 

Epoch: 320, Train Loss: 1.0235, Val Loss: 0.9978, Val Accuracy: 0.6300
Epoch: 360, Train Loss: 0.9974, Val Loss: 0.9717, Val Accuracy: 0.6300
Final Val Accuracy: 0.6500


(T):  28%|██▊       | 199/700 [00:08<00:20, 24.57it/s, loss=7.96]

Epoch: 0, Train Loss: 3.1012, Val Loss: 2.5677, Val Accuracy: 0.2300
Epoch: 40, Train Loss: 0.7733, Val Loss: 0.7211, Val Accuracy: 0.7300
Epoch: 80, Train Loss: 0.7338, Val Loss: 0.6836, Val Accuracy: 0.7600
Epoch: 120, Train Loss: 0.7149, Val Loss: 0.6684, Val Accuracy: 0.7400
Epoch: 160, Train Loss: 0.7011, Val Loss: 0.6574, Val Accuracy: 0.7400
Epoch: 200, Train Loss: 0.6903, Val Loss: 0.6491, Val Accuracy: 0.7400
Epoch: 240, Train Loss: 0.6814, Val Loss: 0.6424, Val Accuracy: 0.7500
Epoch: 280, Train Loss: 0.6740, Val Loss: 0.6367, Val Accuracy: 0.7500
Epoch: 320, Train Loss: 0.6677, Val Loss: 0.6319, Val Accuracy: 0.7500


(T):  29%|██▉       | 202/700 [00:09<00:46, 10.61it/s, loss=7.99]

Epoch: 360, Train Loss: 0.6621, Val Loss: 0.6278, Val Accuracy: 0.7600
Final Val Accuracy: 0.7700


(T):  43%|████▎     | 299/700 [00:13<00:16, 24.00it/s, loss=7.83]

Epoch: 0, Train Loss: 2.9851, Val Loss: 2.7800, Val Accuracy: 0.2600
Epoch: 40, Train Loss: 0.6941, Val Loss: 0.6427, Val Accuracy: 0.7700
Epoch: 80, Train Loss: 0.6549, Val Loss: 0.5981, Val Accuracy: 0.7900
Epoch: 120, Train Loss: 0.6386, Val Loss: 0.5809, Val Accuracy: 0.7900
Epoch: 160, Train Loss: 0.6272, Val Loss: 0.5652, Val Accuracy: 0.7900
Epoch: 200, Train Loss: 0.6180, Val Loss: 0.5511, Val Accuracy: 0.7900
Epoch: 240, Train Loss: 0.6101, Val Loss: 0.5396, Val Accuracy: 0.7900
Epoch: 280, Train Loss: 0.6033, Val Loss: 0.5302, Val Accuracy: 0.8000
Epoch: 320, Train Loss: 0.5971, Val Loss: 0.5226, Val Accuracy: 0.8000


(T):  43%|████▎     | 302/700 [00:14<00:37, 10.49it/s, loss=7.88]

Epoch: 360, Train Loss: 0.5914, Val Loss: 0.5163, Val Accuracy: 0.8100
Final Val Accuracy: 0.8100


(T):  57%|█████▋    | 400/700 [00:18<00:11, 25.10it/s, loss=7.8] 

Epoch: 0, Train Loss: 2.0088, Val Loss: 1.6415, Val Accuracy: 0.5500
Epoch: 40, Train Loss: 0.6295, Val Loss: 0.5471, Val Accuracy: 0.8300
Epoch: 80, Train Loss: 0.5920, Val Loss: 0.5433, Val Accuracy: 0.8200
Epoch: 120, Train Loss: 0.5733, Val Loss: 0.5357, Val Accuracy: 0.8100
Epoch: 160, Train Loss: 0.5596, Val Loss: 0.5377, Val Accuracy: 0.8200
Epoch: 200, Train Loss: 0.5484, Val Loss: 0.5461, Val Accuracy: 0.8200
Epoch: 240, Train Loss: 0.5389, Val Loss: 0.5572, Val Accuracy: 0.8100
Epoch: 280, Train Loss: 0.5305, Val Loss: 0.5658, Val Accuracy: 0.8200
Epoch: 320, Train Loss: 0.5228, Val Loss: 0.5711, Val Accuracy: 0.8200
Epoch: 360, Train Loss: 0.5158, Val Loss: 0.5738, Val Accuracy: 0.8200


(T):  58%|█████▊    | 404/700 [00:18<00:27, 10.65it/s, loss=7.84]

Final Val Accuracy: 0.8200


(T):  71%|███████▏  | 499/700 [00:22<00:07, 25.30it/s, loss=7.82]

Epoch: 0, Train Loss: 3.3592, Val Loss: 2.6670, Val Accuracy: 0.3400
Epoch: 40, Train Loss: 0.6372, Val Loss: 0.6036, Val Accuracy: 0.8200
Epoch: 80, Train Loss: 0.5756, Val Loss: 0.5560, Val Accuracy: 0.8300
Epoch: 120, Train Loss: 0.5487, Val Loss: 0.5363, Val Accuracy: 0.8400
Epoch: 160, Train Loss: 0.5316, Val Loss: 0.5273, Val Accuracy: 0.8400
Epoch: 200, Train Loss: 0.5193, Val Loss: 0.5243, Val Accuracy: 0.8400
Epoch: 240, Train Loss: 0.5098, Val Loss: 0.5231, Val Accuracy: 0.8500
Epoch: 280, Train Loss: 0.5018, Val Loss: 0.5226, Val Accuracy: 0.8400


(T):  72%|███████▏  | 502/700 [00:23<00:18, 10.56it/s, loss=7.74]

Epoch: 320, Train Loss: 0.4950, Val Loss: 0.5224, Val Accuracy: 0.8400
Epoch: 360, Train Loss: 0.4890, Val Loss: 0.5223, Val Accuracy: 0.8400
Final Val Accuracy: 0.8500


(T):  86%|████████▌ | 599/700 [00:27<00:04, 24.49it/s, loss=7.71]

Epoch: 0, Train Loss: 5.3429, Val Loss: 4.3637, Val Accuracy: 0.1600
Epoch: 40, Train Loss: 0.6309, Val Loss: 0.5988, Val Accuracy: 0.8300
Epoch: 80, Train Loss: 0.5522, Val Loss: 0.5610, Val Accuracy: 0.8500
Epoch: 120, Train Loss: 0.5273, Val Loss: 0.5400, Val Accuracy: 0.8400
Epoch: 160, Train Loss: 0.5119, Val Loss: 0.5379, Val Accuracy: 0.8600
Epoch: 200, Train Loss: 0.5007, Val Loss: 0.5404, Val Accuracy: 0.8500
Epoch: 240, Train Loss: 0.4921, Val Loss: 0.5438, Val Accuracy: 0.8500
Epoch: 280, Train Loss: 0.4850, Val Loss: 0.5465, Val Accuracy: 0.8500


(T):  86%|████████▌ | 601/700 [00:27<00:10,  9.79it/s, loss=7.68]

Epoch: 320, Train Loss: 0.4791, Val Loss: 0.5483, Val Accuracy: 0.8500
Epoch: 360, Train Loss: 0.4739, Val Loss: 0.5493, Val Accuracy: 0.8500
Final Val Accuracy: 0.8500


(T): 100%|██████████| 700/700 [00:31<00:00, 25.47it/s, loss=7.73]

Epoch: 0, Train Loss: 4.4184, Val Loss: 4.0197, Val Accuracy: 0.1700
Epoch: 40, Train Loss: 0.5771, Val Loss: 0.5871, Val Accuracy: 0.8400
Epoch: 80, Train Loss: 0.5168, Val Loss: 0.5675, Val Accuracy: 0.8600
Epoch: 120, Train Loss: 0.4960, Val Loss: 0.5539, Val Accuracy: 0.8500
Epoch: 160, Train Loss: 0.4832, Val Loss: 0.5512, Val Accuracy: 0.8500
Epoch: 200, Train Loss: 0.4737, Val Loss: 0.5497, Val Accuracy: 0.8500
Epoch: 240, Train Loss: 0.4660, Val Loss: 0.5489, Val Accuracy: 0.8500
Epoch: 280, Train Loss: 0.4594, Val Loss: 0.5483, Val Accuracy: 0.8500


(T): 100%|██████████| 700/700 [00:32<00:00, 21.60it/s, loss=7.73]

Epoch: 320, Train Loss: 0.4535, Val Loss: 0.5474, Val Accuracy: 0.8500
Epoch: 360, Train Loss: 0.4482, Val Loss: 0.5462, Val Accuracy: 0.8500
Final Val Accuracy: 0.8500



