TP Bensafi Sarra

In [102]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html
!pip install ogb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.data import DataLoader , Data
from torch_geometric.utils import to_networkx

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cpu.html
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [105]:
#Ces graphe qui representes des molécules on veut savoir si une molécule inhibe ou non la réplication du virus VIH

In [88]:
from torch_geometric.data import DataLoader
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
import torch
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder
from torch_geometric.nn import GCNConv
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from tqdm import tqdm
from torch_geometric.nn import GraphConv , SAGEConv , GATConv , ChebConv , TAGConv

class GCN(torch.nn.Module):

    def __init__(self, num_classes_end, emb_dim):
        super(GCN, self).__init__()
        
        #Pour encoder ces caractéristiques d'entrée brutes , pour intégrer des caractéristiques brutes d'atomes
        self.atom_encoder = AtomEncoder(emb_dim)
        self.conv1=GCNConv(emb_dim, emb_dim)
        self.batch_norm1=torch.nn.BatchNorm1d(emb_dim)

        self.conv2=TAGConv(emb_dim, emb_dim)
        self.batch_norm2=torch.nn.BatchNorm1d(emb_dim)
        
        '''self.conv3=GATConv(emb_dim, emb_dim)
        self.batch_norm3=torch.nn.BatchNorm1d(emb_dim)'''

        self.lin1=torch.nn.Linear(emb_dim, emb_dim)
        self.bns_lin1= torch.nn.BatchNorm1d(emb_dim)

        self.lin2=torch.nn.Linear(emb_dim, num_classes_end)
        self.dropout = 0.5

    def reset_parameters(self):
        
            self.conv1.reset_parameters()
            self.batch_norm1.reset_parameters()
            self.lin1.reset_parameters()
            self.bns_lin1.reset_parameters()

    def forward(self, batched_data):

        x, edge_index, batch = batched_data.x, batched_data.edge_index, batched_data.batch
        AE= self.atom_encoder(x)
        
        h = self.conv1(AE, edge_index)
        h = self.batch_norm1(h)         
        h = F.relu(h)
        h = F.dropout(h, self.dropout, training = self.training)

        h = self.conv2(AE, edge_index)
        h = self.batch_norm2(h)         
        h = F.relu(h)
        h = F.dropout(h, self.dropout, training = self.training)


        '''
        h=self.conv3(h,edge_index)
        h = self.batch_norm3(h)         
        h = F.relu(h)
        h = F.dropout(h, self.dropout, training = self.training)
        # ajout du dropout'''

        #j'effectue le pooling qui donne la valeurs moyenne des voisins
        x = global_mean_pool(h, batch)
        x = self.lin1(x)
        x = self.bns_lin1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x= self.lin2(x)
        
        return x



In [89]:
# permet d'obtenir une classification binaire
loss_fn = torch.nn.BCEWithLogitsLoss()#Binary cross-entropy with logits loss combines a Sigmoid layer and the BCELoss in one single class.

def train(model, loader, optimizer):
    model.train()

    for step, batch in enumerate(tqdm(loader)):
        batch = batch.to(device)
        pred = model(batch)
        optimizer.zero_grad()
        loss = loss_fn(pred.to(torch.float32), batch.y.to(torch.float32))
        loss_final = loss
        loss.backward()
        optimizer.step()     
    
    return loss.item()

In [90]:
def eval(model, loader, evaluator):
    model.eval()

    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader)):
        batch = batch.to(device)

        with torch.no_grad():
            pred = F.sigmoid(model(batch))
            
        y_true.append(batch.y.view(pred.shape).detach().cpu())
        y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()
    

    input_dict = {
        "y_true": y_true, 
        "y_pred": y_pred
        }

    return evaluator.eval(input_dict)["rocauc"]

In [91]:
#-------------Data
# In molhiv each graph represents a molecule, where nodes are atoms, and edges are chemical bonds
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dataset = PygGraphPropPredDataset(name = "ogbg-molhiv")
print(dataset[0])
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name = "ogbg-molhiv")



'''
x (Tensor, optional) – Node feature matrix with shape [num_nodes, num_node_features]. 

edge_index (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges].

edge_attr (Tensor, optional) – Edge feature matrix with shape [num_edges, num_edge_features]. (default: None)

y (Tensor, optional) – Graph-level or node-level ground-truth labels with arbitrary shape. (default: None)
'''

Data(edge_index=[2, 40], edge_attr=[40, 3], x=[19, 9], y=[1, 1], num_nodes=19)


'\nx (Tensor, optional) – Node feature matrix with shape [num_nodes, num_node_features]. \n\nedge_index (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges].\n\nedge_attr (Tensor, optional) – Edge feature matrix with shape [num_edges, num_edge_features]. (default: None)\n\ny (Tensor, optional) – Graph-level or node-level ground-truth labels with arbitrary shape. (default: None)\n'

In [92]:
#-------------parameters
epochs = 10
emb_dim = 100
batch_size = 64



In [93]:
#-------------Loader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=batch_size, shuffle=False)



In [94]:
#-------------- Call model
model = GCN(num_classes_end=dataset.num_tasks,emb_dim=emb_dim).to(device)
model.reset_parameters()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [95]:
#-------------- Train and test 
allAccval=[]
allAcctest=[]
validAvg=0
testAvg=0
for epoch in range(1, 1 + epochs):
  print(f"[Epoch {epoch}]")
  # train
  loss = train(model, train_loader, optimizer)  
  
  # loss
  print(f'Loss : {loss:.2f}')    
 
  #val acc
  validAcc = eval(model, valid_loader, evaluator)
  print(f'Val Accurcy : {validAcc:2f}')

  #test acc
testAcc = eval(model, test_loader, evaluator)
print(f'Test acc : {testAcc:2f}')





[Epoch 1]


100%|██████████| 515/515 [00:45<00:00, 11.31it/s]


Loss : 0.07


100%|██████████| 65/65 [00:02<00:00, 26.43it/s]


Val Accurcy : 0.668516
[Epoch 2]


100%|██████████| 515/515 [00:28<00:00, 17.99it/s]


Loss : 0.04


100%|██████████| 65/65 [00:02<00:00, 28.18it/s]


Val Accurcy : 0.686070
[Epoch 3]


100%|██████████| 515/515 [00:27<00:00, 18.60it/s]


Loss : 0.64


100%|██████████| 65/65 [00:01<00:00, 36.19it/s]


Val Accurcy : 0.754542
[Epoch 4]


100%|██████████| 515/515 [00:26<00:00, 19.39it/s]


Loss : 0.54


100%|██████████| 65/65 [00:01<00:00, 33.02it/s]


Val Accurcy : 0.688226
[Epoch 5]


100%|██████████| 515/515 [00:24<00:00, 21.08it/s]


Loss : 0.03


100%|██████████| 65/65 [00:01<00:00, 37.48it/s]


Val Accurcy : 0.730682
[Epoch 6]


100%|██████████| 515/515 [00:24<00:00, 21.23it/s]


Loss : 0.03


100%|██████████| 65/65 [00:01<00:00, 37.67it/s]


Val Accurcy : 0.774578
[Epoch 7]


100%|██████████| 515/515 [00:24<00:00, 21.18it/s]


Loss : 0.03


100%|██████████| 65/65 [00:01<00:00, 35.54it/s]


Val Accurcy : 0.768073
[Epoch 8]


100%|██████████| 515/515 [00:26<00:00, 19.58it/s]


Loss : 0.03


100%|██████████| 65/65 [00:01<00:00, 35.36it/s]


Val Accurcy : 0.770830
[Epoch 9]


100%|██████████| 515/515 [00:26<00:00, 19.32it/s]


Loss : 0.01


100%|██████████| 65/65 [00:01<00:00, 35.26it/s]


Val Accurcy : 0.757459
[Epoch 10]


100%|██████████| 515/515 [00:24<00:00, 20.96it/s]


Loss : 0.52


100%|██████████| 65/65 [00:01<00:00, 34.18it/s]


Val Accurcy : 0.758629


100%|██████████| 65/65 [00:02<00:00, 30.63it/s]

Test acc : 0.747448





In [96]:
#dropout = 0.001 => 0.724 acc
#dropot  = 0.01   => 0.706 acc