In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
from torch_geometric.loader import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv,EdgeConv
from torch_geometric.nn import global_mean_pool,global_max_pool
from torch.utils.data import random_split
import torch.nn.functional as F
from tqdm import trange

In [None]:
raw_data = torch.load("/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_input.pt")

In [None]:
num_train = int(len(raw_data)*0.8)
num_val = int(len(raw_data)-num_train)
train_data,test_data = random_split(dataset = raw_data,lengths = [num_train,num_val],generator = torch.Generator().manual_seed(42))
batch_size = 64
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_data, batch_size = len(test_data))
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [None]:
hidden_channels = 512
cos_ratio = 0
mse_ratio = 1
lr = 0.0001
class GNN(nn.Module):
    
    def __init__(self):
        super().__init__()
        torch.manual_seed(42)
       
        self.conv1 = GCNConv(45, hidden_channels)
        self.conv2 = GCNConv(hidden_channels,hidden_channels)
        self.conv3 = GCNConv(hidden_channels,hidden_channels)
        self.conv4 = GCNConv(hidden_channels,hidden_channels)
        self.conv5 = GCNConv(hidden_channels,hidden_channels)
        
        self.lin1 = nn.Linear(2*hidden_channels,hidden_channels)
        self.lin2 = nn.Linear(hidden_channels,hidden_channels)
        self.out = nn.Linear(hidden_channels, 50)
        
        self.loss_function1 = nn.CosineEmbeddingLoss(margin=0)
        self.loss_function2 = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4)
        self.counter=0
        self.progress = []
        
    def forward(self, x, edge_index, batch):
        
        x = x.to(device)
        edge_index = edge_index.to(device)
        batch = batch.to(device)
        
        x=self.conv1(x, edge_index)
        x = x.relu()
        x=self.conv2(x, edge_index)
        x = x.relu()
        x=self.conv3(x, edge_index)
        x = x.relu()
        x=self.conv4(x, edge_index)
        x = x.relu()
        x=self.conv5(x, edge_index)
        x = x.relu()
        
        x = torch.cat((global_mean_pool(x,batch),global_max_pool(x,batch)),dim=1)
        x = self.lin1(x)
        x = x.relu()
        x = self.lin2(x)
        x = x.relu()
        out = self.out(x)
        return out
    
    def train(self,data):
        
        self.optimizer.zero_grad()
        
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        y = data.y.to(device)
        batch_size = data.batch_size
        target = torch.ones(batch_size)
        target = target.to(device)
        loss = cos_ratio*self.loss_function1(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32),target)+mse_ratio*self.loss_function2(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32))
        
        self.progress.append(loss.item())        
        loss.backward()
        self.optimizer.step()
        
        return loss
        
    def test(self,data):
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        target = torch.ones(len(test_data))
        target = target.to(device)
        y = data.y.to(device)
        acc = cos_ratio*self.loss_function1(outputs.to(torch.float32),y.view(len(test_data),50).to(torch.float32),target)+mse_ratio*self.loss_function2(outputs.to(torch.float32),y.view(len(test_data),50).to(torch.float32))
        return acc
    
    def pred(self,data):
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        return outputs
    

In [None]:
model = GNN()
model.to(device)

In [None]:
loss_min = 100
model_path = '/home/chengc/workspace/cc/0815/model/qm_9_mse_model.pth'
model_path_ckp = '/home/chengc/workspace/cc/0815/model/qm_9_mse_model_ckp.pth'
for epoch in trange(5000):
    for data in train_loader:
        loss = model.train(data)
    #print(f"Epoch {epoch} | Training loss {loss}")
    for datat in test_loader:
        val_l = []
        acc = model.test(datat)
        val_l.append(acc)
    val_loss = torch.stack(val_l,dim=0).mean(dim=0)
    #print(f"Epoch {epoch} | Validation loss {val_loss}")
    if val_loss < loss_min:
        print(f"Epoch {epoch} | Training loss {loss}| Validation loss {val_loss}")
        torch.save(model.state_dict(),model_path)
        print("save model")
        loss_min = val_loss
    if epoch % 100 == 0:
        print(f"_Epoch {epoch} | Training loss {loss}| Validation loss {val_loss}")
        torch.save(model.state_dict(),model_path_ckp)

In [None]:
#Epoch 692 | Training loss 2.353151559829712| Validation loss 4.63758659362793
#save model