In [1]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
from torch_geometric.loader import DataLoader
import torch.nn as nn
from torch.utils.data import random_split
from tqdm import trange

In [2]:
hidden_channels = 256
class twoinone(nn.Module):
    
    def __init__(self):
        super().__init__()
        torch.manual_seed(42)
       
        self.nn1 = nn.Linear(100, hidden_channels*2)
        self.nn2 = nn.Linear(hidden_channels*2,hidden_channels*2)
        self.nn3 = nn.Linear(hidden_channels*2,hidden_channels*2)
        self.nn4 = nn.Linear(hidden_channels*2,hidden_channels*2)
        self.out = nn.Linear(hidden_channels*2, 50)
        
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, weight_decay=5e-4)
        self.counter=0
        self.progress = []
        
    def forward(self,x):
        
        x = x.to(device)
        x=self.nn1(x)
        x = x.relu()
        x=self.nn2(x)
        x = x.relu()
        x=self.nn3(x)
        x = x.relu()
        x=self.nn4(x)
        x = x.relu()
        out = self.out(x)
        
        return out
    
    def train(self,dataloader):
        losses=[]
        for i in range(len(dataloader)):
            data = dataloader[i]
            outputs = self.forward(data.x.float())
            y = data.y.to(device)
            y = y.float()
            loss_t = self.loss_function(outputs.to(torch.float32),y.to(torch.float32))
            losses.append(loss_t)
        loss = torch.stack(losses,dim=0).mean(dim=0)        
        #outputs = self.forward(data.x.float())
        #y = data.y.to(device)
        #y = y.float()
        #loss = self.loss_function(outputs.to(torch.float32),y.to(torch.float32))
       
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss
    
    def val(self,dataloader):
        losses=[]
        for i in range(len(dataloader)):
            data = dataloader[i]
            outputs = self.forward(data.x.float())
            y = data.y.to(device)
            y = y.float()
            loss_t = self.loss_function(outputs.to(torch.float32),y.to(torch.float32))
            losses.append(loss_t)
        loss = torch.stack(losses,dim=0).mean(dim=0)
        return loss
    
    def pred(self,data):
        output = self.forward(data.x.float())
        return output

In [3]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = twoinone()
model.to(device)

twoinone(
  (nn1): Linear(in_features=100, out_features=512, bias=True)
  (nn2): Linear(in_features=512, out_features=512, bias=True)
  (nn3): Linear(in_features=512, out_features=512, bias=True)
  (nn4): Linear(in_features=512, out_features=512, bias=True)
  (out): Linear(in_features=512, out_features=50, bias=True)
  (loss_function): MSELoss()
)

In [4]:
data = torch.load("/home/chengc/workspace/cc/0815/qm_9_2in1.pt")
train_data,val_data,test_data = random_split(dataset = data,lengths = [100000,20000,7468],generator = torch.Generator().manual_seed(42))
train_loader = DataLoader(train_data, batch_size = 64, shuffle = True)
val_loader = DataLoader(val_data, batch_size = 64, shuffle = True)

In [15]:
thresh = 500
model_path = "/home/chengc/workspace/cc/0815/model/qm_9_2in1_model.pth"
model_path_ckp = "/home/chengc/workspace/cc/0815/model/qm_9_2in1_model_ckp.pth"
for epoch in trange(2000):
    losses=[]
    for loader in train_loader:
        loss = model.train(loader)
        losses.append(loss)
    lossmean = torch.stack(losses,dim=0).mean(dim=0)
    #print(lossmean)
    vlosses=[]
    for Vloader in val_loader:
        vloss = model.val(Vloader)
        vlosses.append(vloss)
    vlossmean = torch.stack(vlosses,dim=0).mean(dim=0)
    #print(vlossmean)
    if vlossmean < thresh:
    #and lossmean < 0.1:
        print(f"Save Model : Epoch {epoch} | Trainloss {lossmean}| Valiloss {vlossmean}")
        torch.save(model.state_dict(),model_path)
        thresh = vlossmean
    if (epoch % 100 == 0):
        print(f"Epoch {epoch} | Trainloss {lossmean}| Valiloss {vlossmean}")
        torch.save(model.state_dict(),model_path_ckp)

  0%|                                                                                                                                                | 1/2000 [04:04<136:01:06, 244.96s/it]

Save Model : Epoch 0 | Trainloss 4.895115375518799| Valiloss 4.592116832733154
Epoch 0 | Trainloss 4.895115375518799| Valiloss 4.592116832733154
Save Model : Epoch 1 | Trainloss 2.7575457096099854| Valiloss 4.522714614868164


  0%|▏                                                                                                                                               | 2/2000 [08:07<135:11:10, 243.58s/it]

Save Model : Epoch 2 | Trainloss 2.667605400085449| Valiloss 4.478421211242676


  0%|▏                                                                                                                                               | 3/2000 [12:20<137:32:14, 247.94s/it]

Save Model : Epoch 3 | Trainloss 2.6187219619750977| Valiloss 4.466000080108643


  0%|▎                                                                                                                                               | 4/2000 [16:36<139:09:39, 250.99s/it]

Save Model : Epoch 4 | Trainloss 2.5899758338928223| Valiloss 4.429870128631592


  0%|▍                                                                                                                                               | 6/2000 [25:01<139:50:13, 252.46s/it]

Save Model : Epoch 6 | Trainloss 2.547231674194336| Valiloss 4.425424575805664


  0%|▋                                                                                                                                               | 9/2000 [38:06<142:56:13, 258.45s/it]

Save Model : Epoch 9 | Trainloss 2.5078227519989014| Valiloss 4.413899898529053


  1%|█                                                                                                                                              | 14/2000 [59:57<144:12:10, 261.39s/it]

Save Model : Epoch 14 | Trainloss 2.4621458053588867| Valiloss 4.399220943450928


  1%|█                                                                                                                                            | 15/2000 [1:04:15<143:33:28, 260.36s/it]

Save Model : Epoch 15 | Trainloss 2.452589273452759| Valiloss 4.394375801086426


  1%|█▎                                                                                                                                           | 18/2000 [1:17:19<143:54:59, 261.40s/it]

Save Model : Epoch 18 | Trainloss 2.431213855743408| Valiloss 4.383485317230225


  1%|█▎                                                                                                                                           | 19/2000 [1:21:56<146:22:50, 266.01s/it]

Save Model : Epoch 19 | Trainloss 2.424161434173584| Valiloss 4.382511615753174


  1%|█▋                                                                                                                                           | 24/2000 [1:43:49<143:58:28, 262.30s/it]

Save Model : Epoch 24 | Trainloss 2.3949224948883057| Valiloss 4.380887508392334


  1%|██                                                                                                                                           | 29/2000 [2:06:09<146:08:38, 266.93s/it]

Save Model : Epoch 29 | Trainloss 2.367717981338501| Valiloss 4.372003078460693


  2%|██▎                                                                                                                                          | 32/2000 [2:19:33<146:15:57, 267.56s/it]

Save Model : Epoch 32 | Trainloss 2.3531835079193115| Valiloss 4.369251251220703


  2%|██▌                                                                                                                                          | 37/2000 [2:41:32<144:33:01, 265.10s/it]

Save Model : Epoch 37 | Trainloss 2.3296430110931396| Valiloss 4.362818241119385


  2%|███▍                                                                                                                                         | 48/2000 [3:30:57<147:24:58, 271.87s/it]

Save Model : Epoch 48 | Trainloss 2.287713050842285| Valiloss 4.337092399597168


  5%|███████                                                                                                                                     | 101/2000 [7:30:39<142:01:37, 269.25s/it]

Epoch 100 | Trainloss 2.110114574432373| Valiloss 4.369997978210449


 10%|█████████████▉                                                                                                                             | 201/2000 [13:58:55<113:09:48, 226.45s/it]

Epoch 200 | Trainloss 1.738579273223877| Valiloss 4.530508518218994


 15%|████████████████████▉                                                                                                                      | 301/2000 [20:36:09<108:35:10, 230.08s/it]

Epoch 300 | Trainloss 1.4397449493408203| Valiloss 4.675731182098389


 20%|███████████████████████████▊                                                                                                               | 401/2000 [27:32:50<117:02:17, 263.50s/it]

Epoch 400 | Trainloss 1.248345136642456| Valiloss 4.8093342781066895


 25%|██████████████████████████████████▊                                                                                                        | 501/2000 [35:12:52<110:53:23, 266.31s/it]

Epoch 500 | Trainloss 1.1258894205093384| Valiloss 4.886943340301514


 30%|█████████████████████████████████████████▊                                                                                                 | 601/2000 [42:32:44<103:18:54, 265.86s/it]

Epoch 600 | Trainloss 1.0490895509719849| Valiloss 5.0810346603393555


 32%|████████████████████████████████████████████▍                                                                                               | 634/2000 [45:04:08<97:06:17, 255.91s/it]


KeyboardInterrupt: 

In [None]:
#Save Model : Epoch 48 | Trainloss 2.287713050842285| Valiloss 4.337092399597168