In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, save_path, name, patience=50, verbose=False, delta=0):
        """
        Args:
            save_path : Save the model
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.save_path = save_path
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.name = name

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            #print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        path = os.path.join(self.save_path, self.name)
        torch.save(model.state_dict(), path)	# 这里会存储迄今最优模型的参数
        self.val_loss_min = val_loss

In [None]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadCrossAttention, self).__init__()
        self.num_heads = num_heads
        self.input_dim = input_dim
        self.head_dim = input_dim // num_heads

        # Define linear layers for query, key, and value projections
        self.query_linear = nn.Linear(self.input_dim, self.input_dim)
        self.key_linear = nn.Linear(self.input_dim, self.input_dim)
        self.value_linear = nn.Linear(self.input_dim, self.input_dim)

    def forward(self, array1, array2):
        # Project input arrays into query, key, and value spaces
        query = self.query_linear(array1)
        key = self.key_linear(array2)
        value = self.value_linear(array2)

        # Reshape tensors for multi-head attention
        query = query.view(array1.size()[0], self.num_heads, self.head_dim)
        key = key.view(array1.size()[0], self.num_heads, self.head_dim)
        value = value.view(array1.size()[0], self.num_heads, self.head_dim)

        # Calculate attention scores
        attention_scores = torch.matmul(query, key.transpose(-2, -1))
        attention_scores = attention_scores / (self.head_dim ** 0.5)
        attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)

        # Weighted sum of values using attention weights
        weighted_sum = torch.matmul(attention_weights, value)

        # Reshape and concatenate multi-head results
        weighted_sum = weighted_sum.view(array1.size()[0], self.input_dim)

        return weighted_sum

class MLP(nn.Module):
    def __init__(self, hidden_size, last_activation=False):
        super(MLP, self).__init__()
        q = []
        for i in range(len(hidden_size)-1):
            in_dim = hidden_size[i]
            out_dim = hidden_size[i+1]
            q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim)))
            if (i < len(hidden_size)-2) or ((i == len(hidden_size) - 2) and (last_activation)):
                q.append(("BN_%d" % i, nn.BatchNorm1d(out_dim)))
                q.append(("ReLU_%d" % i, nn.ReLU(inplace=True)))
                q.append(("Dropout_%d" % i, nn.Dropout(p=0.1)))

            self.mlp = nn.Sequential(OrderedDict(q))

    def forward(self, x):
        return self.mlp(x)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class CrossTransformerBlock(nn.Module):
    def __init__(self, hidden_d, n_heads):
        super(CrossTransformerBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(self.hidden_d)

        self.mhsa = MultiHeadCrossAttention(num_heads=self.n_heads, input_dim=self.hidden_d)
        self.ff = FeedForward(d_model=self.hidden_d, d_ff=self.hidden_d)

        self.norm2 = nn.LayerNorm(hidden_d)

    def forward(self, x, y):
        out = self.norm1(x + self.mhsa(x, y))
        out = self.norm2(out + self.ff(out))

        return out

class SelfTransformerBlock(nn.Module):
    def __init__(self, hidden_d, n_heads):
        super(SelfTransformerBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(self.hidden_d)

        self.mhsa = MultiHeadCrossAttention(num_heads=self.n_heads, input_dim=self.hidden_d)
        self.ff = FeedForward(d_model=self.hidden_d, d_ff=self.hidden_d)

        self.norm2 = nn.LayerNorm(hidden_d)

    def forward(self, x):
        out = self.norm1(x + self.mhsa(x, x))
        out = self.norm2(out + self.ff(out))

        return out

In [None]:
class VAE(nn.Module):
    def __init__(self, input_size, latent_dim):
        super(VAE, self).__init__()
        
        self.input_embedding = MLP([input_size, latent_dim])

        self.encoder_self_attentions = nn.ModuleList([SelfTransformerBlock(latent_dim, 8) for _ in range(1)])
        
        self.calc_mean = MLP([latent_dim, latent_dim])
        self.calc_logvar = MLP([latent_dim, latent_dim])

        self.decoder_self_attentions = nn.ModuleList([SelfTransformerBlock(latent_dim, 8) for _ in range(12)])
        self.recon_input = MLP([latent_dim, input_size])

    def reparameterize(self, mu, logvar):
        # Get standard deviation
        std = torch.exp(logvar*0.5)
        # Returns random numbers from a normal distribution
        eps = torch.randn_like(std)
        # Return sampled values
        return eps.mul(std).add_(mu)

    def forward(self, curve):
        x = self.input_embedding(curve)
        
        for encoder_self_attention in self.encoder_self_attentions:
            x = encoder_self_attention(x)
                    
        mu = self.calc_mean(x)
        logvar = self.calc_logvar(x)
        z = self.reparameterize(mu, logvar)
        
        for decoder_self_attention in self.decoder_self_attentions:
            z = decoder_self_attention(z)
            
        recon_input = self.recon_input(z)
        
        return recon_input, mu, logvar



In [None]:
class VAE_pure(nn.Module):
    def __init__(self, input_size, latent_dim):
        super(VAE_pure, self).__init__()
        
        self.calc_mean = MLP([input_size, latent_dim])
        self.calc_logvar = MLP([input_size, latent_dim])

        self.recon_input = MLP([latent_dim, input_size])

    def reparameterize(self, mu, logvar):
        # Get standard deviation
        std = torch.exp(logvar*0.5)
        # Returns random numbers from a normal distribution
        eps = torch.randn_like(std)
        # Return sampled values
        return eps.mul(std).add_(mu)

    def forward(self, curve):
        
        mu = self.calc_mean(curve)
        logvar = self.calc_logvar(curve)
        z = self.reparameterize(mu, logvar)
            
        recon_input = self.recon_input(z)
        
        return recon_input, mu, logvar



In [None]:
class MLP_mapper(nn.Module):
    def __init__(self, hidden_size, last_activation=False):
        super(MLP_mapper, self).__init__()
        q = []
        for i in range(len(hidden_size)-1):
            in_dim = hidden_size[i]
            out_dim = hidden_size[i+1]
            q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim)))
            if (i < len(hidden_size)-2) or ((i == len(hidden_size) - 2) and (last_activation)):
                q.append(("BN_%d" % i, nn.BatchNorm1d(out_dim)))
                q.append(("ReLU_%d" % i, nn.ReLU(inplace=True)))
                q.append(("Dropout_%d" % i, nn.Dropout(p=0)))

            self.mlp = nn.Sequential(OrderedDict(q))

    def forward(self, x):
        return self.mlp(x)
    
class LatentMapper(nn.Module):
    def __init__(self, latent_dim, joint_size):
        super(LatentMapper, self).__init__()
        
        # self.self_attentions = nn.ModuleList([SelfTransformerBlock(2048, 8) for _ in range(1)])
        # self.latent_mapper = MLP([latent_dim, 2048])
        self.joint_mapper = MLP_mapper([latent_dim, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, joint_size])
        
    def forward(self, x):     
        # x = self.latent_mapper(x)
                
        # for self_attention in self.self_attentions:
        #     x = self_attention(x)
        
        return self.joint_mapper(x)