In [2]:
#import GCL
#from GCL.examples import InfoGraph

In [54]:
import torch
import os.path as osp
import GCL.losses as L
import GCL.augmentors as A
import torch_geometric.transforms as T
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from torch.optim import Adam
from GCL.eval import get_split, LREvaluator
from GCL.models.contrast_model import WithinEmbedContrast
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import QM9
# from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR
from torch_geometric.loader import DataLoader

from sklearn.linear_model import LinearRegression
from sklearn.metrics import f1_score

device = 'cuda'

class GConv(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GConv, self).__init__()
        self.act = torch.nn.PReLU()
        self.bn = torch.nn.BatchNorm1d(2 * hidden_dim, momentum=0.01)
        self.conv1 = GCNConv(input_dim, 2 * hidden_dim, cached=False)
        self.conv2 = GCNConv(2 * hidden_dim, hidden_dim, cached=False)

    def forward(self, x, edge_index, edge_weight=None):
        
        print('into gconv', x.shape, edge_index.shape)
        z = self.conv1(x, edge_index, edge_weight)
        z = self.bn(z)
        z = self.act(z)
        
        print('out1', z.shape)
        z = self.conv2(z, edge_index, edge_weight)
        print('out GCONV', z.shape)
        return z

class GCN(torch.nn.Module):
    def __init__(self, n_features):
        super().__init__()
        
        self.rep_dim = 128
        #self.emb_dim = 64
        
        self.conv1 = GCNConv(n_features, self.rep_dim // 2)
        self.bn1 = nn.BatchNorm1d(self.rep_dim // 2)
        self.a1 = nn.LeakyReLU(0.02)
        
        self.conv2 = GCNConv(self.rep_dim // 2, self.rep_dim) # To Rep Space
        self.bn2 = nn.BatchNorm1d(self.rep_dim)
        self.a2 = nn.LeakyReLU(0.02)
        
        self.conv3 = GCNConv(self.rep_dim, self.rep_dim * 2) # To Emb Space
        self.bn3 = nn.BatchNorm1d(self.rep_dim * 2)
        
        self.fc1 = nn.Linear(self.rep_dim * 2, 999) # Linear to rep?
        
    def forward(self, x, edge_index, donut):
        #x = x.float().to(device)
        #edge_index = data[1].to(device)
        
        #print(x.dtype)
        #print(edge_index.dtype)
        #x, edge_index = data.x.float(), data.edge_index
        
        x = self.conv1(x, edge_index)
        x = self.a1(self.bn1(x))
        x = F.dropout(x, training=self.training)
        
        x = self.conv2(x, edge_index)
        #x = self.a2(self.bn2(x))
        #x = F.dropout(x, training=self.training)
        x_rep = self.bn2(x)
        x_emb = self.conv3(x_rep, edge_index)

        # Can have the -> rep and -> emb layers be linear layers on the graph conv output
        x_fc1 = self.fc1(x_emb)
        #print('from conv3 to linear output', x_fc1.shape)
        
        return x_rep, x_emb

class Encoder(torch.nn.Module):
    # Encoder is not itself a model, it holds the augs and forward function of the GConv model
    def __init__(self, encoder, augmentor):
        super(Encoder, self).__init__()
        self.encoder = encoder
        self.augmentor = augmentor

    def forward(self, x, edge_index, edge_weight=None):
        aug1, aug2 = self.augmentor # unpack two augmentations
        x1, edge_index1, edge_weight1 = aug1(x, edge_index, edge_weight)
        x2, edge_index2, edge_weight2 = aug2(x, edge_index, edge_weight)
        
        #print('x, x1, x2', x.shape, x1.shape, x2.shape)
        
        # Encoder passes GConv over each of untransformed x, aug1(x), and aug2(x)
        z, _ = self.encoder(x, edge_index, edge_weight)
        _, z1 = self.encoder(x1, edge_index1, edge_weight1)
        _, z2 = self.encoder(x2, edge_index2, edge_weight2)
        
        #print('z, z1, z2', z.shape, z1.shape, z2.shape)
        return z, z1, z2


def train(encoder_model, contrast_model, batch, optimizer):
    encoder_model.train()
    optimizer.zero_grad()
    
    print('in training, ', batch.x.shape, batch.edge_index.shape, batch.y.shape)
    _, z1, z2 = encoder_model(batch.x.to(device), batch.edge_index.to(device), batch.edge_attr.to(device))
    loss = contrast_model(z1, z2)
    loss.backward()
    optimizer.step()
    return loss.item()


def validation(gcn, loader):
    gcn.eval()
    score = 0
    with torch.no_grad():
        for bind, batch in enumerate(loader):
            
            print('encoding', batch.x.shape, batch.edge_index.shape)
            rep, _ = gcn(batch.x.to(device), batch.edge_index.to(device), batch.edge_attr.to(device))
            
            print('out of enc', rep.shape)
            rep = rep.cpu()
            
            print(rep.shape, batch.y.shape)
            lm = LinearRegression().fit(rep, batch.y)
            score = score + f1_score(y, lm.predict(rep))
                
    return score / (bind+1)


def main():
    device = torch.device('cuda')
    path = osp.join(osp.expanduser('~'), 'datasets', 'QM9')

    train_dataset = QM9(root = 'datasets/', transform=T.NormalizeFeatures()) # subset false -> 250k graphs

    train_loader = DataLoader(train_dataset, batch_size=99, shuffle=True)
    
    aug1 = A.Compose([A.EdgeRemoving(pe=0.5), A.FeatureMasking(pf=0.1)])
    aug2 = A.Compose([A.EdgeRemoving(pe=0.5), A.FeatureMasking(pf=0.1)])
    
    #print(train_dataset.num_node_features)
    #gconv = GConv(input_dim=train_dataset.num_node_features, hidden_dim=256).to(device)
    gconv = GCN(n_features = train_dataset.num_node_features)
    encoder_model = Encoder(encoder=gconv, augmentor=(aug1, aug2)).to(device)
    contrast_model = WithinEmbedContrast(loss=L.BarlowTwins()).to(device)

    optimizer = Adam(encoder_model.parameters(), lr=5e-4)
#     scheduler = LinearWarmupCosineAnnealingLR(
#         optimizer=optimizer,
#         warmup_epochs=400,
#         max_epochs=4000)

    for epoch in range(1, 2):
        for bind, batch in enumerate(train_loader):
            
            print(batch.y.shape)
            loss = train(encoder_model, contrast_model, batch, optimizer)
            optimizer.step()
            optimizer.zero_grad()
            print(loss)
            
            if bind > 10:
                break
            
    score = validation(gconv, train_loader)
    print(f'score {score}')


if __name__ == '__main__':
    main()

torch.Size([99, 19])
in training,  torch.Size([1786, 11]) torch.Size([2, 3720]) torch.Size([99, 19])
219.05809020996094
torch.Size([99, 19])
in training,  torch.Size([1770, 11]) torch.Size([2, 3686]) torch.Size([99, 19])
207.98971557617188
torch.Size([99, 19])
in training,  torch.Size([1788, 11]) torch.Size([2, 3740]) torch.Size([99, 19])




202.4779510498047
torch.Size([99, 19])
in training,  torch.Size([1806, 11]) torch.Size([2, 3742]) torch.Size([99, 19])
223.43893432617188
torch.Size([99, 19])
in training,  torch.Size([1759, 11]) torch.Size([2, 3694]) torch.Size([99, 19])
220.12379455566406
torch.Size([99, 19])
in training,  torch.Size([1760, 11]) torch.Size([2, 3646]) torch.Size([99, 19])
180.06825256347656
torch.Size([99, 19])
in training,  torch.Size([1772, 11]) torch.Size([2, 3690]) torch.Size([99, 19])
198.85830688476562
torch.Size([99, 19])
in training,  torch.Size([1763, 11]) torch.Size([2, 3626]) torch.Size([99, 19])
165.28704833984375
torch.Size([99, 19])
in training,  torch.Size([1756, 11]) torch.Size([2, 3620]) torch.Size([99, 19])
193.99253845214844
torch.Size([99, 19])
in training,  torch.Size([1805, 11]) torch.Size([2, 3738]) torch.Size([99, 19])
160.2373046875
torch.Size([99, 19])
in training,  torch.Size([1824, 11]) torch.Size([2, 3762]) torch.Size([99, 19])
173.57655334472656
torch.Size([99, 19])
in tr

ValueError: Found input variables with inconsistent numbers of samples: [1744, 99]

    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | Target | Property                         | Description                                                                       | Unit                                        |
    | 0      | :math:`\mu`                      | Dipole moment                                                                     | :math:`\textrm{D}`                          |
    | 1      | :math:`\alpha`                   | Isotropic polarizability                                                          | :math:`{a_0}^3`                             |
    | 2      | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy                                         | :math:`\textrm{eV}`                         |
    | 3      | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy                                        | :math:`\textrm{eV}`                         |
    | 4      | :math:`\Delta \epsilon`          | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}`                         |
    
    | 5      | :math:`\langle R^2 \rangle`      | Electronic spatial extent                                                         | :math:`{a_0}^2`                             |
    | 6      | :math:`\textrm{ZPVE}`            | Zero point vibrational energy                                                     | :math:`\textrm{eV}`                         |
    | 7      | :math:`U_0`                      | Internal energy at 0K                                                             | :math:`\textrm{eV}`                         |
    | 8      | :math:`U`                        | Internal energy at 298.15K                                                        | :math:`\textrm{eV}`                         |
    | 9      | :math:`H`                        | Enthalpy at 298.15K                                                               | :math:`\textrm{eV}`                         |
    | 10     | :math:`G`                        | Free energy at 298.15K                                                            | :math:`\textrm{eV}`                         |
    | 11     | :math:`c_{\textrm{v}}`           | Heat capavity at 298.15K                                                          | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
    | 12     | :math:`U_0^{\textrm{ATOM}}`      | Atomization energy at 0K                                                          | :math:`\textrm{eV}`                         |
    | 13     | :math:`U^{\textrm{ATOM}}`        | Atomization energy at 298.15K                                                     | :math:`\textrm{eV}`                         |
    | 14     | :math:`H^{\textrm{ATOM}}`        | Atomization enthalpy at 298.15K                                                   | :math:`\textrm{eV}`                         |
    | 15     | :math:`G^{\textrm{ATOM}}`        | Atomization free energy at 298.15K                                                | :math:`\textrm{eV}`                         |
    | 16     | :math:`A`                        | Rotational constant                                                               | :math:`\textrm{GHz}`                        |
    | 17     | :math:`B`                        | Rotational constant                                                               | :math:`\textrm{GHz}`                        |
    | 18     | :math:`C`                        | Rotational constant    