In [17]:
import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch.optim as optim
import time
import pickle
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset ,ConcatDataset
from MyDataSet import MultiMaskTimeSeriesDataset

### model

In [18]:
class GraphSAGE_Mean(nn.Module):
    def __init__(self, input_dim, output_dim,  dropout=0., act=F.relu):
        '''
        :param name: name of this layer
        :param input_dim: input dimension
        :param output_dim: output dimension
        :param adj: adjacency matrix
        :param dropout: dropout rate
        
        :param act: activation function 
        input: [batch_size, num_nodes, input_dim]   
        output: [batch_size, num_nodes, output_dim]
        
        '''
        super(GraphSAGE_Mean, self).__init__()
        self.dropout = dropout
        self.act = act
        
        self.w1 = nn.Parameter(torch.Tensor(input_dim, output_dim))
        self.w2 = nn.Parameter(torch.Tensor(input_dim, output_dim))
        nn.init.xavier_uniform_(self.w1)
        nn.init.xavier_uniform_(self.w2)
    
    def forward(self, inputs, adj):
        inputs = F.dropout(inputs, p=self.dropout)
        _self = torch.matmul(inputs, self.w1)
        _neg = torch.matmul(torch.bmm(adj, inputs), self.w2)
        
        _concat = self.act(torch.cat([_self, _neg], dim=2))
        return _concat
    
    
class Linear(nn.Module):
    def __init__(self, input_dim, output_dim, use_bias=True, dropout=0., act=nn.ReLU()):
        super(Linear, self).__init__()
        self.dropout = dropout
        self.act = act
        self.w = nn.Parameter(torch.Tensor(input_dim, output_dim))
        nn.init.xavier_uniform_(self.w)
        self.use_bias = use_bias
        if self.use_bias:
            self.b = nn.Parameter(torch.Tensor(output_dim))
            nn.init.zeros_(self.b)

    def forward(self, inputs):
        x = F.dropout(inputs, p=self.dropout, training=self.training)
        x = torch.matmul(x, self.w)
        if self.use_bias:
            x = x + self.b
        x = self.act(x)
        return x

### Generator & Discriminator

In [19]:
class GraphSAGE(nn.Module):
    def __init__(self, input_dim,out_dim,dropout=0.,act=nn.ReLU()):
        super(GraphSAGE, self).__init__()
        self.layer1 = GraphSAGE_Mean(input_dim, out_dim, dropout=dropout,act=act)
        self.layer2 = GraphSAGE_Mean(out_dim*2, out_dim, dropout=dropout,act=act)  
    
    def forward(self, inputs ,adj):
        x = self.layer1(inputs,adj)
        x = self.layer2(x,adj)
        return x
    
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, linear_hidden, dropout=0., act=nn.ReLU()):
        super(Generator, self).__init__()
        # 检测linear_hidden[0]是否为out_dim*4
        assert linear_hidden[0] == output_dim*4, 'The first hidden layer should be out_dim*4'
        
        self.layer1 = GraphSAGE(input_dim, output_dim, dropout=dropout, act=act)
        self.layer2 = GraphSAGE(input_dim, output_dim, dropout=dropout, act=act)
        
        self.h1 = Linear(linear_hidden[0],linear_hidden[1])
        self.h2 = Linear(linear_hidden[1],linear_hidden[2])
        self.h3 = Linear(linear_hidden[2],linear_hidden[-1],act=nn.Sigmoid())
        
    def forward(self, inputs,adj,adj_ori):
        concat_features_1 = self.layer1(inputs,adj)
        concat_features_2 = self.layer2(inputs,adj_ori)
        concat_features =  torch.cat([concat_features_1,concat_features_2], dim = -1)
        x = self.h1(concat_features)
        x = self.h2(x)
        x = self.h3(x)
    
        return x

class Discriminator(nn.Module):
    def __init__(self, input_dim,linear_hidden, dropout=0., act=nn.ReLU()):
        super(Discriminator, self).__init__()
        self.h1=Linear(input_dim,linear_hidden[0],dropout=dropout,act=act)
        self.h2=Linear(linear_hidden[0],linear_hidden[1],dropout=dropout,act=act)
        self.h3=Linear(linear_hidden[1],linear_hidden[2],dropout=dropout,act=lambda x:x)  

    def forward(self, inputs):
        x = self.h1(inputs)
        x = self.h2(x)
        x = self.h3(x)

        return x

In [None]:

adj = torch.randn(3,5,5)
adj_ori = torch.randn(3,5,5)
input_dim = 2
output_dim = 2
linear_hidden = [output_dim*4, 16, 1]
dropout = 0.5
act = nn.ReLU()
gen = Generator(input_dim, output_dim, linear_hidden,  dropout, act)
dis = Discriminator(output_dim*2, linear_hidden, dropout, act)
print(gen(torch.randn(3,5,2), adj, adj_ori).shape)
dis(torch.randn(3,5,4)).shape
                    


### Trainer

In [21]:
class Trainer():
    def __init__(self,Generator,  Discriminator,  dataloader , Gen_lr, Dis_lr, alpha ,p_hint,device):
        '''
        Generator: Generator model
        Discriminator: Discriminator model
        dataloader: DataLoader object
        Gen_lr: Learning rate for the generator
        Dis_lr: Learning rate for the discriminator
        x_dim: Dimension of the input data
        device: 'cuda' or 'cpu'
        alpha: Hyperparameter for the generator loss
        p_hint: Probability of hint vector
        '''
        super(Trainer, self).__init__()
        self.Generator = Generator
        self.Discriminator = Discriminator
        self.dataloader = dataloader
        self.device = device
        self.Gen_lr = Gen_lr
        self.Dis_lr = Dis_lr
        self.alpha = alpha
        self.p_hint = p_hint
        
        self.optimizer_G = optim.RMSprop(Generator.parameters(), lr=Gen_lr)
        self.optimizer_D = optim.RMSprop(Discriminator.parameters(), lr=Dis_lr)
        
        self.graph = {'iter':[],'G_loss': [], 'D_loss': [], 'MSE_loss': [], 'MSE_test_loss': []}
        self.total_iter = 0
        
        # self.noise_z = self.sample_noise(x_dim)
        
    def sample_noise(self, dim):
        return torch.rand(dim).to(self.device)
    
    # hint Vector Generation
    def sample_hint(self, dim, p):
        A=torch.rand(dim,dtype=torch.float16).to(self.device)
        A[A>p]=1.0
        A[A<=p]=0
        return A
    
    def discriminator_loss(self,mask,x_mask ,x_raw ,adj ,adj_ori):
        '''
        mask: Masked data 1:raw, 0:gen
        x_raw: Real data
        Hint: Hint vector
        '''
        D_logit_real = self.Discriminator(x_raw)
        G_samples = self.Generator(x_mask ,adj,adj_ori)
        D_logit_fake = self.Discriminator(G_samples)
        
        d_loss_real = -torch.mean(D_logit_real)
        d_loss_fake = torch.mean(D_logit_fake)
        
        mse_loss = torch.mean(( x_raw - G_samples)**2)
        d_loss = d_loss_real + d_loss_fake 
        g_loss = -d_loss_fake + self.alpha * mse_loss
        mse_test_loss = torch.mean(((1 - mask) * x_raw - (1 - mask) * G_samples)**2)/torch.mean(1 - mask)
            
        return d_loss, g_loss, mse_loss, mse_test_loss
    

    
    def test(self , test_dataloader):
        self.Generator.eval()
        self.Discriminator.eval()
        result = {'x_raw':[],'mask': [], 'G_loss': [], 'MSE_loss': [], 'MSE_test_loss': []}
        
        with torch.no_grad():
            for x_raw, mask,adj_ori,adj in test_dataloader:
                x_raw = x_raw.float().to(self.device)
                mask = mask.float().to(self.device)
                adj_ori = adj_ori.float().to(self.device)
                adj = adj.float().to(self.device)
                x_raw = x_raw.permute(0,2,1)
                mask = mask.permute(0,2,1)
                
                # print(x_raw.shape, mask.shape)
                x_mask_noise = mask * x_raw 
                D_loss, G_loss, MSE_loss, MSE_test_loss = self.discriminator_loss(mask, x_mask_noise, x_raw, adj, adj_ori)
                result['x_raw'].append(x_raw)
                result['mask'].append(mask)
                result['G_loss'].append(G_loss.item())
                result['MSE_loss'].append(MSE_loss.item())
                result['MSE_test_loss'].append(MSE_test_loss.item())


                # print('Generator Loss: {:.4f}, MSE Train Loss: {:.4f}, MSE Test Loss: {:.4f}'.format(G_loss.item(), MSE_train_loss.item(), MSE_test_loss.item()))
        
        return result
        
    def train(self, epochs,save_path=None):
        self.Generator.train()
        self.Discriminator.train()
        
        print('Training...')
        for it in tqdm(range(epochs)):
            for x_raw, mask, adj_ori, adj in self.dataloader:
                self.total_iter = self.total_iter + 1
                x_raw = x_raw.float().to(self.device)
                mask = mask.float().to(self.device)
                adj_ori = adj_ori.float().to(self.device)
                adj = adj.float().to(self.device)
                
                # 将x_raw和mask transform成【batch_size, num_nodes, time_dim】
                x_raw = x_raw.permute(0,2,1)
                mask = mask.permute(0,2,1)
                
                # print(x_raw.shape, mask.shape)
                x_mask_noise = mask * x_raw 
                
                # Train Discriminator
                self.optimizer_D.zero_grad()
                self.optimizer_G.zero_grad()
                
                D_loss, G_loss, MSE_loss,MSE_test_loss = self.discriminator_loss(mask, x_mask_noise, x_raw, adj, adj_ori)
                D_loss.backward(retain_graph=True)
                
    
                # Train Generator
                G_loss.backward()
                self.optimizer_D.step()
                self.optimizer_G.step()
                
                if self.total_iter % 256 ==0:
                    self.graph['iter'].append(self.total_iter)
                    self.graph['G_loss'].append(G_loss.item())
                    self.graph['D_loss'].append(D_loss.item())
                    self.graph['MSE_loss'].append(MSE_loss.item())
                    self.graph['MSE_test_loss'].append(MSE_test_loss.item())

                
                # print('MSE_train_loss:',MSE_train_loss.item())
                
                # break
            
            if it%2==0:
                print('Epoch: {}, Generator Loss: {:.4f}, Discriminator Loss: {:.4f}, \
                      MSE Test Loss: {:.4f}'.format(it, G_loss.item(), 
                                                      D_loss.item(), MSE_test_loss.item()))
                
            if save_path is not None and (it+1)%20==0:
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                self.save_checkpoint(os.path.join(save_path,'checkpoint_{}.pth'.format(it)))
        
        if save_path is not None:
            self.save_graph(os.path.join(save_path,'Train_record.pkl'))
        
        print('Training finished!')
    
    def save_checkpoint(self,path):
        torch.save({'Generator_state_dict': self.Generator.state_dict(),
                    'Discriminator_state_dict': self.Discriminator.state_dict(),
                    'optimizer_G_state_dict': self.optimizer_G.state_dict(),
                    'optimizer_D_state_dict': self.optimizer_D.state_dict()}, path)
        
    def save_graph(self,path):
        with open(path, 'wb') as f:
            pickle.dump(self.graph, f)
    
    def load_checkpoint(self,path):
        checkpoint = torch.load(path)
        self.Generator.load_state_dict(checkpoint['Generator_state_dict'])
        self.Discriminator.load_state_dict(checkpoint['Discriminator_state_dict'])
        self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        self.optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
    
    def plot_graph(self,graph):
        # 分图绘制
        plt.figure(figsize=(16, 9))
        plt.subplot(2, 2, 1)
        plt.plot(graph['iter'], graph['G_loss'])
        plt.xlabel('Iteration')
        plt.ylabel('Generator Loss')
        plt.title('Generator Loss')
        
        plt.subplot(2, 2, 2)
        plt.plot(graph['iter'], graph['D_loss'])
        plt.xlabel('Iteration')
        plt.ylabel('Discriminator Loss')
        plt.title('Discriminator Loss')
        
        plt.subplot(2, 2, 3)
        plt.plot(graph['iter'], graph['MSE_loss'])
        plt.xlabel('Iteration')
        plt.ylabel('MSE  Loss')
        plt.title('MSE Loss')
        
        plt.subplot(2, 2, 4)
        plt.plot(graph['iter'], graph['MSE_test_loss'])
        plt.xlabel('Iteration')
        plt.ylabel('MSE Test Loss')
        plt.title('MSE Test Loss')

        plt.show()
        
    

### Training

In [22]:
source_train_path = r'D:\heyulinWorkPath\Models\GA-GAN\Data\source_train'
source_train_files = os.listdir(source_train_path)
source_train_files = [os.path.join(source_train_path, file) for file in source_train_files]
source_train_data = []

for file_path in source_train_files:
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    source_train_data.append(data)

data = ConcatDataset(source_train_data)
data_loader = DataLoader(data, batch_size=32, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [23]:
input_dim = 4*12 

Discriminator = Discriminator(input_dim=input_dim,linear_hidden= [256, 64,1],
                              dropout=0.0, act=nn.ReLU()).to(device)
Generator = Generator(input_dim=input_dim, output_dim=64, 
                      linear_hidden=[64*4, 64*2, input_dim], dropout=0, act=nn.ReLU()).to(device)


In [None]:
trainer = Trainer(Generator, Discriminator, data_loader, Gen_lr=0.01, Dis_lr=0.01, alpha=100, p_hint=0.3, device=device)
trainer.train(epochs=2, save_path=r'D:\heyulinWorkPath\Models\GA-GAN\Checkpoints')