In [1]:
import pickle

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import GCNConv, GINConv, GATConv, GATv2Conv
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.nn import MessagePassing


In [3]:
file_path = 'reservoir_data/data_list_vae.pkl'
with open(file_path,'rb') as f:
    data = pickle.load(f)

In [4]:
import random
from sklearn.model_selection import train_test_split

train_set = train_test_split(data, test_size = 0.1)

In [7]:
train_loader = DataLoader(train_set, batch_size=16, num_workers=3)

In [6]:
class Encoder(torch.nn.Module):
    def __init__(self, n_features, dim_h, z_dim):
        super(GATv2, self).__init__()
        self.conv1 = GATv2Conv(n_features, dim_h, edge_dim = 1)
        self.conv2 = GATv2Conv(dim_h, dim_h, edge_dim = 1)
        self.conv3 = GATv2Conv(dim_h, dim_h, edge_dim = 1)
        self.lin = Linear(dim_h, z_dim*2)
    
    def forward(self, x, edge_index, edge_attr, batch):
        # Node embeddings 
        h = self.conv1(x, edge_index, edge_attr= edge_attr)
        h = h.relu()
        h = self.conv2(h, edge_index, edge_attr= edge_attr)
        h = h.relu()
        h = self.conv3(h, edge_index, edge_attr= edge_attr)

        # Graph-level readout
        hG = global_mean_pool(h, batch)
        h = self.lin(h)
        
        mu = h[:z_dim]
        logvar = h[z_dim:]
        
        return mu, logvar

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, z_dim, out_dim):
        super(GATv2, self).__init__()
        self.lin1 = Linear(z_dim, hidden_dim)
        self.lin2 = Linear(hidden_dim, hidden_dim)
        self.lin3 = Linear(hidden_dim, out_dim)

    def forward(self, z):
        # Node embeddings 
        h = self.lin1(z)
        h = h.relu()
        h = self.lin2(h)
        h = h.relu()
        h = self.lin3(h)
        
        return h

In [None]:
class Model(nn.Module):
    def __init__(self, Encoder, Decoder, device):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        self.device = device
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(self.device)        # sampling epsilon        
        z = mean + var * epsilon                          # reparameterization trick
        return z
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var

In [21]:
for epoch in range(num_epochs):
    train_loss = 0
    for idx, data in enumerate(train_loader):
        
        optimizer.zero_grad()
        
        reconstruction, mu, log_var = vae(data[0])
        
        loss = loss_function(recon, x, mu, logvar)
        
        train_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
def train(model, criterion, epochs, loader, val_loader, save_path):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    
    losses = []
    
    best_val_loss = float('inf')
    
    iteration = 0
    
    update_iter = 500
    
    model.train()
    for epoch in range(epochs+1):
        running_reconstruction_loss = 0
        running_KLD_loss = 0
        total_loss = 0
        

        # Train on batches
        for data in loader:
            optimizer.zero_grad()
            
            reconstruction, mu, log_var = model(data.x, data.edge_index, data.edge_attr, data.batch)
            
            reconstruction_loss, KLD_loss = criterion(reconstruction, data.y, mu, logvar)
            
            running_reconstruction_loss += reconstruction_loss.item()
            running_KLD_loss += KLD_loss.item() 
            
            total_loss += running_reconstruction_loss + running_KLD_loss
            
            loss.backward()
            optimizer.step()

            if iteration % update_iter == 0:
                # Validation
                val_loss = test(model, criterion, val_loader)
                # Print metrics every 20 epochs
                print(f'Epoch {epoch:>3} | Train Loss: {total_loss/update_iter:} | Reconstruction Loss: {running_reconstruction_loss/update_iter} | KLD loss: {running_KLD_loss/update_iter}' )
                losses.append[total_loss/update_iter, running_reconstruction_loss/update_iter, running_KLD_loss/update_iter ]
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_model_state = model.state_dict()
                    torch.save(best_model_state, save_path) 
                    
                running_reconstruction_loss = 0
                running_KLD_loss = 0
                total_loss = 0
                
            iteration += 1
            print(iteration, end='\r')
            
    return model, train_losses, val_losses

@torch.no_grad()
def test(model, criterion, loader):
    model.eval()
    loss = 0
    for data in loader:
        rec,mu,logvar = model(data.x, data.edge_index, data.edge_attr, data.batch)
        rec_loss, kld_loss = criterion(rec, data.y, mu, logvar) 
        loss += rec_los.item() + kld_loss.item()

    return loss / len(loader)
                    
            
    return model, losses


In [None]:
n_features = 125
hidden_dim = 32
z_dim = 16

edge_dim = 1
n_nodes = 18


encoder = Encoder(n_features,hidden_dim,z_dim)
decoder = Decoder(z_dim, n_features*n_nodes + n_nodes**2 + (n_nodes**2)*edge_dim)

model = Model(Encoder=encoder,Decoder=decoder)

In [22]:
device = 'cpu'
epochs = 10

save_path = 'reservoir_data/res16_best_model.pth'

gcn = gcn.to(device)

gcn, train_losses, val_losses = train(gcn, epochs, train_loader, test_loader, save_path)


#test_loss = test(gcn, test_loader)
#print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}%')
#print()


Epoch   0 | Train Loss: 0.00023566046729683876 | Val Loss: 0.00046019235742278397
Epoch   0 | Train Loss: 0.0003077992587350309 | Val Loss: 0.0005208745133131742
Epoch   0 | Train Loss: 0.0003799987316597253 | Val Loss: 0.0004448906984180212
Epoch   0 | Train Loss: 0.0004534148029051721 | Val Loss: 0.00045804408728145063
Epoch   0 | Train Loss: 0.0005272722919471562 | Val Loss: 0.0004370369715616107
Epoch   0 | Train Loss: 0.000599189312197268 | Val Loss: 0.00045981016592122614
Epoch   1 | Train Loss: 1.632763451198116e-05 | Val Loss: 0.0004718992277048528
Epoch   1 | Train Loss: 8.670402894495055e-05 | Val Loss: 0.0004062704392708838
Epoch   1 | Train Loss: 0.00014990601630415767 | Val Loss: 0.00044218855327926576
Epoch   1 | Train Loss: 0.00021457116235978901 | Val Loss: 0.00042080131242983043
Epoch   1 | Train Loss: 0.00028681097319349647 | Val Loss: 0.0004002690257038921
Epoch   1 | Train Loss: 0.0003533469862304628 | Val Loss: 0.0003725403221324086
Epoch   1 | Train Loss: 0.000417