In [1]:
# https://arxiv.org/abs/1610.02415

# https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html

import torch
print(torch.__version__)
import torch.nn.functional as F
import torch.nn as nn
import torch.distributed as dist

import torch_geometric
import torch_geometric.nn as gnn
#print(torch_geometric.__version__)
from torch_geometric.datasets import QM9
import GCL.augmentors
import GCL.augmentors as A

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, mean_squared_error
from sklearn.linear_model import RidgeClassifierCV, LogisticRegression, LinearRegression

1.13.1+cu117


In [2]:

parameters = {}
parameters['batch_size'] = 64

In [3]:
whole_dataset = QM9(root = 'data/')

#print(whole_dataset.get_summary())
#print(dir(whole_dataset))
#print(whole_dataset.len())

n = whole_dataset.len()
tr_n = 2000 # Number of QM9 to use as training data

all_inds = range(n)
tr_inds, val_inds = train_test_split(all_inds, train_size = tr_n)
print(len(tr_inds), len(val_inds))
print(type(tr_inds), type(tr_inds[0]))


train_sampler = torch.utils.data.SubsetRandomSampler(tr_inds)
val_sampler = torch.utils.data.SubsetRandomSampler(val_inds)

# We need to make a train and validation set since QM9 does not provide them
train_set = torch.utils.data.Subset(whole_dataset, tr_inds)
val_set = torch.utils.data.Subset(whole_dataset, val_inds)

train_loader = torch_geometric.loader.DataLoader(train_set, batch_size = parameters['batch_size'],
                                                shuffle = True, num_workers = 2,)
                                                #sampler = train_sampler)
big_train_loader = torch_geometric.loader.DataLoader(train_set, batch_size = int(1e9),
                                                shuffle = True, num_workers = 2,)

val_loader = torch_geometric.loader.DataLoader(val_set, batch_size=2048,
                                            shuffle=False, num_workers=2,)
                                              #sampler = val_sampler)


2000 128831
<class 'list'> <class 'int'>


In [4]:
qm9_index = {0: 'Dipole moment',
1: 'Isotropic polarizability',
2: 'Highest occupied molecular orbital energy',
3: 'Lowest unoccupied molecular orbital energy',
4: 'Gap between previous 2',
5: 'Electronic spatial extent',
6: 'Zero point vibrational energy',
7: 'Internal energy at 0K',
8: 'Internal energy at 298.15K',
9: 'Enthalpy at 298.15K',
10: 'Free energy at 298.15K',
11: 'Heat capavity at 298.15K',
12: 'Atomization energy at 0K',
13: 'Atomization energy at 298.15K',
14: 'Atomization enthalpy at 298.15K',
15: 'Atomization free energy at 298.15K',
16: 'Rotational constant A',
17: 'Rotational constant B',
18: 'Rotational constant C',}

print(qm9_index.items())

dict_items([(0, 'Dipole moment'), (1, 'Isotropic polarizability'), (2, 'Highest occupied molecular orbital energy'), (3, 'Lowest unoccupied molecular orbital energy'), (4, 'Gap between previous 2'), (5, 'Electronic spatial extent'), (6, 'Zero point vibrational energy'), (7, 'Internal energy at 0K'), (8, 'Internal energy at 298.15K'), (9, 'Enthalpy at 298.15K'), (10, 'Free energy at 298.15K'), (11, 'Heat capavity at 298.15K'), (12, 'Atomization energy at 0K'), (13, 'Atomization energy at 298.15K'), (14, 'Atomization enthalpy at 298.15K'), (15, 'Atomization free energy at 298.15K'), (16, 'Rotational constant A'), (17, 'Rotational constant B'), (18, 'Rotational constant C')])


In [10]:
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.rep_dim = 128
        self.emb_dim = 256
        
        # Data under graph
        self.conv1 = GCNConv(whole_dataset.num_node_features, self.rep_dim // 2)
        self.bn1 = nn.BatchNorm1d(self.rep_dim // 2)
        self.a1 = nn.LeakyReLU(0.02)
        
        self.conv2 = GCNConv(self.rep_dim // 2, self.rep_dim) # To Rep Space
        self.bn2 = nn.BatchNorm1d(self.rep_dim)
        
        # Projection to representation
        self.mpool1 = gnn.global_mean_pool
        self.fc1 = nn.Linear(self.rep_dim, self.rep_dim)
        
        # Graph 2
        self.conv3 = GCNConv(self.rep_dim, self.rep_dim * 2) # To Emb Space
        self.bn3 = nn.BatchNorm1d(self.rep_dim * 2)
        
        # Projection to embedding
        self.mpool2 = gnn.global_mean_pool
        self.fc2 = nn.Linear(self.emb_dim, self.emb_dim) # Linear to rep?
        
    def forward(self, data, binds):
        x = data[0].float().to(device)
        edge_index = data[1].to(device)
        
        # Input graph to GConv
        x = self.conv1(x, edge_index)
        x = self.a1(self.bn1(x))
        x = F.dropout(x, training=self.training)
        
        x = self.bn2(self.conv2(x, edge_index))
        
        # GConv outputs projected to representation space
        #print('before pool: ', x.shape)
        x_rep = self.mpool1(x, binds)
        #print('pooled: ', x_rep.shape)
        
        x_rep = self.fc1(x_rep)
        #print('projected: ', x_rep.shape, 'gconv', x.shape)
        
        x_emb = self.bn3(self.conv3(x, edge_index))
        #print('x emb after conv3', x_emb.shape)
        x_emb = self.mpool2(x_emb, binds)
        #print('after pool', x_emb.shape)
        x_emb = self.fc2(x_emb)
        #print('after fc2', x_emb.shape)
        
        return x_rep, x_emb

device = 'cuda'

model = GCN().to(device)

sim_coeff = 25
std_coeff = 25
cov_coeff = 1

aug = A.RandomChoice([#A.RWSampling(num_seeds=1000, walk_length=10),
                      A.NodeDropping(pn=0.1),
                      A.FeatureMasking(pf=0.1),
                      A.EdgeRemoving(pe=0.1)],
                     num_choices=1)
val_aug = A.RandomChoice([], num_choices = 0)

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def VicRegLoss(x, y):
    # https://github.com/facebookresearch/vicreg/blob/4e12602fd495af83efd1631fbe82523e6db092e0/main_vicreg.py#L184
    # x, y are output of projector(backbone(x and y))
    repr_loss = F.mse_loss(x, y)

    x = x - x.mean(dim=0)
    y = y - y.mean(dim=0)

    std_x = torch.sqrt(x.var(dim=0) + 0.0001)
    std_y = torch.sqrt(y.var(dim=0) + 0.0001)
    std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

    cov_x = (x.T @ x) / (parameters['batch_size'] - 1)
    cov_y = (y.T @ y) / (parameters['batch_size'] - 1)
    cov_loss = off_diagonal(cov_x).pow_(2).sum().div(
        x.shape[1]
    ) + off_diagonal(cov_y).pow_(2).sum().div(x.shape[1])
    
    # self.num_features -> rep_dim?
    loss = (
        sim_coeff * repr_loss
        + std_coeff * std_loss
        + cov_coeff * cov_loss
    )
    return loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.002, weight_decay=5e-4)

n_epochs = 100
for epoch in range(0,n_epochs+1):
    epoch_losses = []
    for batch in train_loader:
        optimizer.zero_grad()

        batch_inds = batch.batch.to(device)

        # batch of graphs has edge attribs, node attribs - (n_nodes, n_features+1) -> concat (n_nodes, attrib1)

        batch.x = batch.x.float()#.to(device)
        #batch.edge_index = batch.edge_index.to(device)

        # Barlow - get 2 random views of batch
        b1 = aug(batch.x, batch.edge_index, batch.edge_attr)
        b2 = aug(batch.x, batch.edge_index, batch.edge_attr)

        # Embed each batch (ignoring representations)
        r1, e1 = model(b1, batch_inds)
        r2, e2 = model(b2, batch_inds)

        loss = VicRegLoss(e1, e2)
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.data.item())
        
    print('epoch train loss', sum(epoch_losses) / len(epoch_losses))

    if epoch % 10 == 0:
        
        # Downstream supervised loss
        for batch in big_train_loader: # take entire train set
            with torch.no_grad():
                # Embed training set under model
                rep_tr, _ = model(val_aug(batch.x, batch.edge_index, batch.edge_attr), batch.batch.to(device))
                
                
                for val_batch in val_loader:
                    # Embed validation set under model
                    rep_val, _ = model(val_aug(val_batch.x, val_batch.edge_index, val_batch.edge_attr), val_batch.batch.to(device))
                    
                    # For each task in QM9
                    for tar_ind in range(batch.y.shape[1]):
                        # Fit a model on model representation of train set

                        #print(rep_tr.shape, batch.y[tar_ind].shap)
                        lm = LinearRegression().fit(rep_tr.cpu(), batch.y[:,tar_ind])
                        # Test the model on model repersentation of val set
                        tar_yhat = lm.predict(rep_val.cpu())
                        score = mean_squared_error(val_batch.y[:,tar_ind], tar_yhat)
                        print(qm9_index[tar_ind], score)
        
        # VicReg Validation Loss
        for batch in val_loader:
            with torch.no_grad():
                # VicReg validation loss
                b1 = aug(batch.x, batch.edge_index, batch.edge_attr)
                b2 = aug(batch.x, batch.edge_index, batch.edge_attr)
                r1, e1 = model(b1, batch.batch.to(device))
                r2, e2 = model(b2, batch.batch.to(device))
                
                val_loss = VicRegLoss(e1, e2)
                print('val loss', val_loss)




epoch train loss 20.78702187538147
Dipole moment 59.56189
Isotropic polarizability 614.3458
Highest occupied molecular orbital energy 2.4603133
Lowest unoccupied molecular orbital energy 0.93259454
Gap between previous 2 3.211514
Electronic spatial extent 123137.516
Zero point vibrational energy 0.1655652
Internal energy at 0K 109755190.0
Internal energy at 298.15K 109758296.0
Enthalpy at 298.15K 109736510.0
Free energy at 298.15K 109767160.0
Heat capavity at 298.15K 10.847432
Atomization energy at 0K 75.61366
Atomization energy at 298.15K 76.07681
Atomization enthalpy at 298.15K 77.113625
Atomization free energy at 298.15K 61.381744
Rotational constant A 2.4902701
Rotational constant B 55.046738
Rotational constant C 12.99593
Dipole moment 76.1948
Isotropic polarizability 764.5828
Highest occupied molecular orbital energy 2.862842
Lowest unoccupied molecular orbital energy 0.9801656
Gap between previous 2 3.685186
Electronic spatial extent 136473.34
Zero point vibrational energy 0.197

Dipole moment 53.69593
Isotropic polarizability 547.1973
Highest occupied molecular orbital energy 2.179466
Lowest unoccupied molecular orbital energy 0.8889209
Gap between previous 2 2.8045912
Electronic spatial extent 118771.52
Zero point vibrational energy 0.16585997
Internal energy at 0K 95443500.0
Internal energy at 298.15K 95448664.0
Enthalpy at 298.15K 95438630.0
Free energy at 298.15K 95449190.0
Heat capavity at 298.15K 11.044163
Atomization energy at 0K 67.73904
Atomization energy at 298.15K 68.18904
Atomization enthalpy at 298.15K 69.11359
Atomization free energy at 298.15K 55.128983
Rotational constant A 2.5729475
Rotational constant B 47.509827
Rotational constant C 11.221441
Dipole moment 98.7581
Isotropic polarizability 993.67236
Highest occupied molecular orbital energy 3.7478144
Lowest unoccupied molecular orbital energy 0.96186733
Gap between previous 2 4.3454924
Electronic spatial extent 165513.94
Zero point vibrational energy 0.2040233
Internal energy at 0K 177787100

Dipole moment 83.81416
Isotropic polarizability 812.81714
Highest occupied molecular orbital energy 3.1276157
Lowest unoccupied molecular orbital energy 0.9620706
Gap between previous 2 3.84688
Electronic spatial extent 134954.95
Zero point vibrational energy 0.17947525
Internal energy at 0K 147014220.0
Internal energy at 298.15K 147019780.0
Enthalpy at 298.15K 147005570.0
Free energy at 298.15K 147008700.0
Heat capavity at 298.15K 11.718472
Atomization energy at 0K 88.48368
Atomization energy at 298.15K 88.937225
Atomization enthalpy at 298.15K 90.169785
Atomization free energy at 298.15K 71.30876
Rotational constant A 2.362928
Rotational constant B 74.021255
Rotational constant C 17.522005
Dipole moment 44.309513
Isotropic polarizability 454.29395
Highest occupied molecular orbital energy 1.8287215
Lowest unoccupied molecular orbital energy 0.95256615
Gap between previous 2 2.5455956
Electronic spatial extent 108987.74
Zero point vibrational energy 0.14862919
Internal energy at 0K 79

Dipole moment 76.71602
Isotropic polarizability 778.2931
Highest occupied molecular orbital energy 3.014328
Lowest unoccupied molecular orbital energy 0.9108783
Gap between previous 2 3.7228804
Electronic spatial extent 138257.2
Zero point vibrational energy 0.17317021
Internal energy at 0K 139589340.0
Internal energy at 298.15K 139599380.0
Enthalpy at 298.15K 139581570.0
Free energy at 298.15K 139596720.0
Heat capavity at 298.15K 12.407175
Atomization energy at 0K 87.1593
Atomization energy at 298.15K 87.63763
Atomization enthalpy at 298.15K 88.83921
Atomization free energy at 298.15K 70.42886
Rotational constant A 2.3788986
Rotational constant B 69.43685
Rotational constant C 16.442312
Dipole moment 48.260994
Isotropic polarizability 480.9295
Highest occupied molecular orbital energy 1.9752243
Lowest unoccupied molecular orbital energy 0.96933043
Gap between previous 2 2.6566224
Electronic spatial extent 105154.266
Zero point vibrational energy 0.16468853
Internal energy at 0K 867464

Dipole moment 70.67614
Isotropic polarizability 713.5532
Highest occupied molecular orbital energy 2.768471
Lowest unoccupied molecular orbital energy 0.9017178
Gap between previous 2 3.3654764
Electronic spatial extent 136127.84
Zero point vibrational energy 0.16963738
Internal energy at 0K 129350550.0
Internal energy at 298.15K 129354400.0
Enthalpy at 298.15K 129344730.0
Free energy at 298.15K 129352640.0
Heat capavity at 298.15K 11.468906
Atomization energy at 0K 82.46213
Atomization energy at 298.15K 82.926895
Atomization enthalpy at 298.15K 84.06401
Atomization free energy at 298.15K 66.66803
Rotational constant A 16.660826
Rotational constant B 64.791824
Rotational constant C 15.278155
Dipole moment 56.95991
Isotropic polarizability 567.3563
Highest occupied molecular orbital energy 2.209167
Lowest unoccupied molecular orbital energy 0.9419839
Gap between previous 2 3.0250578
Electronic spatial extent 116789.734
Zero point vibrational energy 0.16010046
Internal energy at 0K 10189

Dipole moment 67.83937
Isotropic polarizability 671.5127
Highest occupied molecular orbital energy 2.6570373
Lowest unoccupied molecular orbital energy 0.8833792
Gap between previous 2 3.2872028
Electronic spatial extent 123929.664
Zero point vibrational energy 0.1594261
Internal energy at 0K 122417830.0
Internal energy at 298.15K 122423200.0
Enthalpy at 298.15K 122412070.0
Free energy at 298.15K 122412640.0
Heat capavity at 298.15K 11.128297
Atomization energy at 0K 79.06416
Atomization energy at 298.15K 79.50961
Atomization enthalpy at 298.15K 80.59884
Atomization free energy at 298.15K 63.97081
Rotational constant A 2.705501
Rotational constant B 61.038445
Rotational constant C 14.453486
Dipole moment 73.86493
Isotropic polarizability 742.0419
Highest occupied molecular orbital energy 2.8426785
Lowest unoccupied molecular orbital energy 0.9397851
Gap between previous 2 3.517953
Electronic spatial extent 130812.47
Zero point vibrational energy 0.18039595
Internal energy at 0K 1336381



val loss tensor(2522.8528, device='cuda:0')
val loss tensor(1471.1406, device='cuda:0')
val loss tensor(860.6014, device='cuda:0')
val loss tensor(2362.8931, device='cuda:0')
val loss tensor(1337.2581, device='cuda:0')
val loss tensor(1378.4548, device='cuda:0')
val loss tensor(2069.4468, device='cuda:0')
val loss tensor(3427.5303, device='cuda:0')
val loss tensor(1448.7499, device='cuda:0')
val loss tensor(3362.5557, device='cuda:0')
val loss tensor(1519.8036, device='cuda:0')
val loss tensor(1362.6510, device='cuda:0')
val loss tensor(2302.8291, device='cuda:0')
val loss tensor(2399.4634, device='cuda:0')
val loss tensor(1488.8633, device='cuda:0')
val loss tensor(2045.1171, device='cuda:0')
val loss tensor(1898.0886, device='cuda:0')
val loss tensor(2367.9731, device='cuda:0')
val loss tensor(1876.4165, device='cuda:0')
val loss tensor(2093.4548, device='cuda:0')
val loss tensor(2462.3916, device='cuda:0')
val loss tensor(2309.0117, device='cuda:0')
val loss tensor(2043.6755, device

KeyboardInterrupt: 

In [None]:
print(breaker)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = GCN().to(device)
#data = train_dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


aug = A.RandomChoice([#A.RWSampling(num_seeds=1000, walk_length=10),
                      A.NodeDropping(pn=0.1),
                      A.FeatureMasking(pf=0.1),
                      A.EdgeRemoving(pe=0.1)],
                     num_choices=1)

val_aug = A.RandomChoice([], num_choices = 0)


def barlow(batch):
    # Return two random views of input batch
    return aug(batch[0], batch[1]), aug(batch[0], batch[1])

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

class FullGatherLayer(torch.autograd.Function):
    """
    Gather tensors from all process and support backward propagation
    for the gradients across processes.
    """

    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        return all_gradients[dist.get_rank()]
    
def VicRegLoss(x, y):
    # https://github.com/facebookresearch/vicreg/blob/4e12602fd495af83efd1631fbe82523e6db092e0/main_vicreg.py#L184
    # x, y are output of projector(backbone(x and y))
    repr_loss = F.mse_loss(x, y)

    x = x - x.mean(dim=0)
    y = y - y.mean(dim=0)

    std_x = torch.sqrt(x.var(dim=0) + 0.0001)
    std_y = torch.sqrt(y.var(dim=0) + 0.0001)
    std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

    cov_x = (x.T @ x) / (parameters['batch_size'] - 1)
    cov_y = (y.T @ y) / (parameters['batch_size'] - 1)
    cov_loss = off_diagonal(cov_x).pow_(2).sum().div(
        x.shape[1]
    ) + off_diagonal(cov_y).pow_(2).sum().div(x.shape[1])
    
    # self.num_features -> rep_dim?
    loss = (
        sim_coeff * repr_loss
        + std_coeff * std_loss
        + cov_coeff * cov_loss
    )
    return loss

sim_coeff = 25
std_coeff = 25
cov_coeff = 1

model.train()
for epoch in range(5):
    
    epo_losses = []
    for batch in train_loader:
        #batch = batch.to(device)
        batch.x = batch.x.float()#.to(device)
        #batch.edge_index = batch.edge_index.to(device)

        optimizer.zero_grad()
        
        # Barlow - get 2 random views of batch
        b1 = aug(batch.x, batch.edge_index, batch.edge_attr)
        b2 = aug(batch.x, batch.edge_index, batch.edge_attr)
        
                
        # Embed each batch (ignoring representations)
        [r1, e1], [r2, e2] = model.pair_emb_rep(b1, b2)

        # VicReg loss on projections
        loss = VicRegLoss(e1, e2)
        
        loss.backward()
        optimizer.step()
        
        epo_losses.append(loss.data.item())
        
    print(sum(epo_losses) / len(epo_losses))
    
    ############################
    ## Per-epoch validation step:
    
    GCL.eval


    # Embed Training Samples:
    train_batch = next(iter(train_big_subset))
    #print('train batch', train_batch)
    train_batch = val_aug(train_batch.x, train_batch.edge_index, train_batch.edge_attr) # val_aug is an empty augmentation
    #print('train_batch augd', train_batch)

    with torch.no_grad():
        tr_rep, _ = model.forward(train_batch)
    #print(tr_rep.shape)

    # Train linear model on embedded samples:
    ridge_mod = RidgeClassifierCV(cv = 4).fit(tr_rep, y_train)
    linear_mod = LogisticRegression(penalty = None).fit(tr_rep, y_train)

    # Embed validation samples:
    val_batch = next(iter(val_loader))
    #print('val batch', val_batch)
    val_batch = val_aug(val_batch.x, val_batch.edge_index, val_batch.edge_attr) # val_aug is an empty augmentation
    #print('val_batch augd', val_batch)
    
    with torch.no_grad():
        val_rep, _ = model.forward(val_batch)
    #print(val_rep.shape)

    # Test linear model on embedded samples:
    ridge_score = f1_score(ridge_mod.predict(val_rep), y_val)
    linear_score = f1_score(linear_mod.predict(val_rep), y_val)
    
    print(f'Classifier Scores at Epoch {epoch}:', round(linear_score, 3), round(ridge_score, 3))

In [None]:
if False: # Update for some downstream? Keep in mind this idea of graph masking
    # Evaluate
    model.eval()
    pred = model(data).argmax(dim=1)
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    print(f'Accuracy: {acc:.4f}')