In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
from torch_geometric.loader import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool,global_max_pool
from torch.utils.data import random_split
from tqdm import trange

In [None]:
raw_data = torch.load("/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_input.pt")

In [None]:
class GNN_COS(nn.Module):
    
    def __init__(self):
        super().__init__()
        torch.manual_seed(42)
       
        self.conv1 = GCNConv(45, hidden_channels)
        self.conv2 = GCNConv(hidden_channels,hidden_channels)
        self.conv3 = GCNConv(hidden_channels,hidden_channels)
        self.conv4 = GCNConv(hidden_channels,hidden_channels)
        self.conv5 = GCNConv(hidden_channels,hidden_channels)
        
        self.lin1 = nn.Linear(2*hidden_channels,hidden_channels)
        self.lin2 = nn.Linear(hidden_channels,hidden_channels)
        self.out = nn.Linear(hidden_channels, 50)
        
        self.loss_function1 = nn.CosineEmbeddingLoss(margin=0)
        self.loss_function2 = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, weight_decay=5e-4)
        self.counter=0
        self.progress = []
        
    def forward(self, x, edge_index, batch):
        
        x = x.to(device)
        edge_index = edge_index.to(device)
        batch = batch.to(device)
        
        x=self.conv1(x, edge_index)
        x = x.relu()
        x=self.conv2(x, edge_index)
        x = x.relu()
        x=self.conv3(x, edge_index)
        x = x.relu()
        x=self.conv4(x, edge_index)
        x = x.relu()
        x=self.conv5(x, edge_index)
        x = x.relu()
        
        x = torch.cat((global_mean_pool(x,batch),global_max_pool(x,batch)),dim=1)
        x = self.lin1(x)
        x = x.relu()
        x = self.lin2(x)
        x = x.relu()
        out = self.out(x)
        return out
    
    def train(self,data):
        
        self.optimizer.zero_grad()
        
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        y = data.y.to(device)
        batch_size = data.batch_size
        target = torch.ones(batch_size)
        target = target.to(device)
        #loss = cos_ratio*self.loss_function1(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32),target)+mse_ratio*self.loss_function2(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32))
        loss = self.loss_function1(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32),target)
        self.progress.append(loss.item())        
        loss.backward()
        self.optimizer.step()
        
        return loss
        
    #def test(self,data):
        #outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        #target = torch.ones(len(data))
        #target = target.to(device)
        #y = data.y.to(device)
        #acc = cos_ratio*self.loss_function1(outputs.to(torch.float32),y.view(len(test_data),50).to(torch.float32),target)+mse_ratio*self.loss_function2(outputs.to(torch.float32),y.view(len(test_data),50).to(torch.float32))
        #acc = self.loss_function1(outputs.to(torch.float32),y.view(len(data),50).to(torch.float32),target)
        #return acc
    
    def pred(self,data):
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        return outputs



class GNN_MSE(nn.Module):
    
    def __init__(self):
        super().__init__()
        torch.manual_seed(42)
       
        self.conv1 = GCNConv(45, hidden_channels)
        self.conv2 = GCNConv(hidden_channels,hidden_channels)
        self.conv3 = GCNConv(hidden_channels,hidden_channels)
        self.conv4 = GCNConv(hidden_channels,hidden_channels)
        self.conv5 = GCNConv(hidden_channels,hidden_channels)
        
        self.lin1 = nn.Linear(2*hidden_channels,hidden_channels)
        self.lin2 = nn.Linear(hidden_channels,hidden_channels)
        self.out = nn.Linear(hidden_channels, 50)
        
        self.loss_function1 = nn.CosineEmbeddingLoss(margin=0)
        self.loss_function2 = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, weight_decay=5e-4)
        self.counter=0
        self.progress = []
        
    def forward(self, x, edge_index, batch):
        
        x = x.to(device)
        edge_index = edge_index.to(device)
        batch = batch.to(device)
        
        x=self.conv1(x, edge_index)
        x = x.relu()
        x=self.conv2(x, edge_index)
        x = x.relu()
        x=self.conv3(x, edge_index)
        x = x.relu()
        x=self.conv4(x, edge_index)
        x = x.relu()
        x=self.conv5(x, edge_index)
        x = x.relu()
        
        x = torch.cat((global_mean_pool(x,batch),global_max_pool(x,batch)),dim=1)
        x = self.lin1(x)
        x = x.relu()
        x = self.lin2(x)
        x = x.relu()
        out = self.out(x)
        return out
    
    def train(self,data):
        
        self.optimizer.zero_grad()
        
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        y = data.y.to(device)
        batch_size = data.batch_size
        target = torch.ones(batch_size)
        target = target.to(device)
        #loss = cos_ratio*self.loss_function1(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32),target)+mse_ratio*self.loss_function2(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32))
        loss = self.loss_function2(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32))
        self.progress.append(loss.item())        
        loss.backward()
        self.optimizer.step()
        
        return loss
        
    #def test(self,data):
        #outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        #target = torch.ones(len(test_data))
        #target = target.to(device)
        #y = data.y.to(device)
        #acc = cos_ratio*self.loss_function1(outputs.to(torch.float32),y.view(len(test_data),50).to(torch.float32),target)+mse_ratio*self.loss_function2(outputs.to(torch.float32),y.view(len(test_data),50).to(torch.float32))
        #acc = self.loss_function2(outputs.to(torch.float32),y.view(batch_size,50).to(torch.float32))
        #return acc
    
    def pred(self,data):
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        return outputs

In [None]:
hidden_channels=512
model_cos = GNN_COS()
model_mse = GNN_MSE()
model_mse.load_state_dict(torch.load("/home/chengc/workspace/cc/0815/model/qm_9_mse_model.pth"))
model_cos.load_state_dict(torch.load("/home/chengc/workspace/cc/0815/model/qm_9_cos_model.pth"))
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model_mse.to(device)
model_cos.to(device)

In [None]:
data_loader = DataLoader(raw_data,batch_size = 1)

In [None]:
d_data = []
for loader in data_loader:
    d_data.append(loader)

In [None]:
spec_100 = []
real_spec = []
for i in trange(len(d_data)):
    predict_mse_spec = model_mse.pred(d_data[i]).detach().cpu()[0].tolist()
    predict_cos_spec = model_cos.pred(d_data[i]).detach().cpu()[0].tolist()
    a = predict_mse_spec + predict_cos_spec
    spec_100.append(a)
    real_spec.append(d_data[i].y)

mseth = 0
costh = 0
for i in trange(len(spec_100)):
    mse_max_num = max(spec_100[i][:50])
    cos_max_num = max(spec_100[i][50:])
    if mse_max_num > mseth :
        mseth = mse_max_num
    if cos_max_num > costh :
        costh = cos_max_num
### mseth = 63
### costh = 0.0068

d_data[1]
#DataBatch(x=[1, 45], edge_index=[2, 0], edge_attr=[0, 10], y=[50], smi=[1], batch=[1], ptr=[2])

In [None]:
from torch_geometric.data import Data
def twoin1(i):
    x = torch.tensor(spec_100[i])
    x[50:] = x[50:]*10000
    smi = d_data[i].smi
    y = real_spec[i]
    x = torch.tensor(x)
    #y = torch.tensor(y)
    return Data(x = x, y = y, smi = smi)

In [None]:
twoin1_data = []
for i in trange(len(spec_100)):
    point = twoin1(i)
    twoin1_data.append(point)

In [None]:
torch.save(twoin1_data,"/home/chengc/workspace/cc/0815/qm_9_2in1.pt")