# Masked auto-encoder

version = 1


In [1]:
import numpy as np
import pandas as pd
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
from datetime import datetime
import json
import torch.nn.functional as F  
import matplotlib
matplotlib.use('Agg') 
import gc

In [None]:
protein_dict = pd.read_pickle('../B4PPI-main/data/medium_set/embeddings/embeddings_merged.pkl')

In [None]:
class ProteinDataset(Dataset):
    def __init__(self, protein_dict, max_len=1502):
        self.keys = list(protein_dict.keys())
        self.protein_dict = protein_dict
        self.max_len = max_len
        self.scale_factor = 20 
        
    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]
        seq = self.protein_dict[key]  # shape: (seq_len, 960)
        seq_len = seq.shape[0]

        need_cleanup = False
        # padding to max_len
        if seq_len < self.max_len:
            pad_size = (self.max_len - seq_len, 960)
            padding = torch.zeros(pad_size, dtype=seq.dtype)
            seq = torch.cat([seq, padding], dim=0)
            need_cleanup = True
        else:
            seq = seq[:self.max_len]  # truncate if needed
        #print(f"输入数据统计 - 均值: {seq.mean():.4f} 方差: {seq.var():.4f}")
        #均值为0，方差0.0002-0.001之间；说明特征之间相似度过高
        seq = seq * self.scale_factor
        if need_cleanup:
            del padding
        return seq.clone()  # shape: (max_len, 960)

### mask 75%

without taking into account the influence of padding

In [None]:
#  Masked Autoencoder with Transformer
class TransformerMAE(nn.Module):
    def __init__(self,
                 input_dim=960,
                 embed_dim=512,
                 mask_ratio=0.75,
                 num_layers=4,
                 nhead=16,
                 ff_dim=2048,
                 max_len=1502):
        super().__init__()
        self.mask_ratio = mask_ratio

        # ---- embed & mask token & pos embed ----
        self.embed = nn.Linear(input_dim, embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim))

        # ---- Transformer encoder ----
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=nhead,
            dim_feedforward=ff_dim,
            batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # ---- decoder head (MLP) ----
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, input_dim)
        )

        # ---- (embed_dim -> input_dim) ----
        self.compress_head = nn.Linear(embed_dim, input_dim)

    def forward(self, x):
        """
        x: Tensor (B, L, 960)
        return:
          - recon: Tensor (B, L, 960)   # 重建整个序列
          - compressed: Tensor (B, 960) # 池化后的压缩向量
          - mask_idx: LongTensor (B, num_mask)
        """
        B, L, _ = x.shape
        x = self.embed(x)  # (B, L, E)

        num_mask = int(L * self.mask_ratio)
        noise = torch.rand(B, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        mask_idx = ids_shuffle[:, :num_mask]

        for i in range(B):
            x[i, mask_idx[i]] = self.mask_token

        x = x + self.pos_embed  # (1,L,E) broadcast to (B,L,E)
        enc_out = self.encoder(x)  # (B, L, E)
        recon = self.decoder(enc_out)  # (B, L, 960)
        pooled = enc_out.mean(dim=1)       # (B, E)
        compressed = self.compress_head(pooled)  # (B, 960)

        return recon, compressed, mask_idx


In [None]:


def mae_loss(recon, orig, mask_idx):

    pred = recon.gather(1, mask_idx.unsqueeze(-1).expand(-1, -1, recon.size(-1)))
    target = orig.gather(1, mask_idx.unsqueeze(-1).expand(-1, -1, orig.size(-1)))
    scale_factor = 20
    loss = F.huber_loss(pred * scale_factor, target * scale_factor, delta=0.5)
    return loss

def plot_reconstruction(orig, recon, mask_idx, epoch, batch_idx, ts):
    plt.figure(figsize=(12, 4))
    orig_np = orig[0, :, 0].cpu().numpy().copy() 
    recon_np = recon[0, :, 0].cpu().numpy().copy()
    plt.plot(orig_np, label='Original', alpha=0.7, linewidth=2)
    plt.plot(recon_np,'--', label='Reconstructed', linewidth=1.5)
    mask_pos = mask_idx[0].numpy()
    plt.scatter(mask_pos, orig[0, mask_pos, 0].numpy(),
                color='red', marker='x', s=50, label='Masked Positions')
    
    plt.legend()
    plt.title(f"Epoch {epoch} Batch {batch_idx} - Reconstruction")
    plt.xlabel('Sequence Position')
    plt.ylabel('Feature Value')
    plt.grid(True, alpha=0.3)

    os.makedirs(f'logs/recon_plots_{ts}', exist_ok=True)
    plt.savefig(f'logs/recon_plots_{ts}/epoch{epoch}_batch{batch_idx}.png')
    plt.close()
    del orig_np, recon_np

In [None]:

def train(protein_dict):
    dataset = ProteinDataset(protein_dict)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2,persistent_workers=True,  # for cpu limitation
    pin_memory=False)   

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = TransformerMAE().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1,eta_min=1e-5)

    os.makedirs('logs', exist_ok=True)
    os.makedirs('models', exist_ok=True)
    ts = datetime.now().strftime('%Y%m%d-%H%M%S')
    log_path = f'logs/mae_train_{ts}.json'
    best_path = f'models/mae_best_{ts}.pth'

    best_loss = float('inf')
    history = []

    epochs = 80
    all_epoch_losses=[]
    for epoch in range(1, epochs+1):
        model.train()
        losses = []
        for batch_idx, batch in enumerate(dataloader):
            batch = batch.to(device).float()  # (B, L, 960)
            with torch.cuda.amp.autocast(enabled=False):  
                recon, compressed, mask_idx = model(batch)
            loss = mae_loss(recon, batch, mask_idx)
            del recon, compressed

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)  
            losses.append(loss.item())

            if batch_idx % 50 == 0:
                print(f'Epoch {epoch}/{epochs}  Batch {batch_idx}  Loss {loss.item():.4f}')
            if batch_idx == 200 : 
                with torch.no_grad():
                    sample = batch[:1].to(device)  
                    recon, _, mask_idx = model(sample)
                    
                    plot_reconstruction(
                        sample.cpu(), 
                        recon.cpu(), 
                        mask_idx.cpu(),
                        epoch=epoch,
                        batch_idx=batch_idx,
                        ts=ts
                    )

        epoch_loss = np.mean(losses)
        history.append(epoch_loss)

        # save
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss
            }, best_path)
            print(f'>>> 新最佳模型，Epoch={epoch}  Loss={best_loss:.4f}')

        with open(log_path, 'a') as f:
            log = {
                'epoch': epoch,
                'loss': epoch_loss,
                'best_loss': best_loss,
                'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            }
            f.write(json.dumps(log) + '\n')

        print(f'--- Epoch {epoch} 完成，Avg Loss={epoch_loss:.4f}，Best Loss={best_loss:.4f} ---')
        all_epoch_losses.append(epoch_loss)
        gc.collect()        
        torch.cuda.empty_cache()  
        

    plt.figure(figsize=(10, 6))
    plt.plot(all_epoch_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Training Loss Curve (Current: {epoch_loss:.4f})')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'logs/loss_curve_{ts}.png') 
    plt.close()  

    final_path = f'models/mae_final_{ts}.pth'
    torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': history[-1]
    }, final_path)
    print(f'训练结束，最终模型已保存到 {final_path}')

    return history


In [8]:
history = train(protein_dict)

  with torch.cuda.amp.autocast(enabled=False):  # 禁用混合精度（CPU无用）


Epoch 1/10  Batch 0  Loss 0.1677
>>> 新最佳模型，Epoch=1  Loss=0.1197
--- Epoch 1 完成，Avg Loss=0.1197，Best Loss=0.1197 ---
Epoch 2/10  Batch 0  Loss 0.1188
>>> 新最佳模型，Epoch=2  Loss=0.1087
--- Epoch 2 完成，Avg Loss=0.1087，Best Loss=0.1087 ---
Epoch 3/10  Batch 0  Loss 0.0997
>>> 新最佳模型，Epoch=3  Loss=0.1072
--- Epoch 3 完成，Avg Loss=0.1072，Best Loss=0.1072 ---
Epoch 4/10  Batch 0  Loss 0.1199
>>> 新最佳模型，Epoch=4  Loss=0.1059
--- Epoch 4 完成，Avg Loss=0.1059，Best Loss=0.1059 ---
Epoch 5/10  Batch 0  Loss 0.0986
--- Epoch 5 完成，Avg Loss=0.1060，Best Loss=0.1059 ---
Epoch 6/10  Batch 0  Loss 0.0986
--- Epoch 6 完成，Avg Loss=0.1062，Best Loss=0.1059 ---
Epoch 7/10  Batch 0  Loss 0.1033
--- Epoch 7 完成，Avg Loss=0.1060，Best Loss=0.1059 ---
Epoch 8/10  Batch 0  Loss 0.0932
--- Epoch 8 完成，Avg Loss=0.1068，Best Loss=0.1059 ---
Epoch 9/10  Batch 0  Loss 0.1212
>>> 新最佳模型，Epoch=9  Loss=0.1054
--- Epoch 9 完成，Avg Loss=0.1054，Best Loss=0.1054 ---
Epoch 10/10  Batch 0  Loss 0.0927
>>> 新最佳模型，Epoch=10  Loss=0.1042
--- Epoch 10 完

## test

In [None]:
import torch
import pickle

def load_model(model, optimizer, model_path):
    checkpoint = torch.load(model_path, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss


model_path = 'models/mae_best_20250524-003857.pth' 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransformerMAE().to(device).float()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
model, optimizer, epoch, loss = load_model(model, optimizer, model_path)

print(f"加载模型：{model_path} (Epoch {epoch}, Loss {loss:.4f})")
model.eval()


加载模型：models/mae_best_20250524-003857.pth (Epoch 60, Loss 1.0562)


TransformerMAE(
  (embed): Linear(in_features=960, out_features=512, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Sequential(
    (0): Linear(in_features=512, out_features=2048, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=2048, out_features=960, bias=True)
  )
  (compress_head): Linea

In [None]:
max_len = 1502      
scale_factor = 20   
batch_size = 32     

new_dict = {}
keys = list(protein_dict.keys())

for start_idx in range(0, len(keys), batch_size):
    batch_keys = keys[start_idx:start_idx + batch_size]
    batch_seqs = [protein_dict[key] for key in batch_keys]
    
    processed_batch = []
    for seq in batch_seqs:
        seq_len = seq.shape[0]
        
        if isinstance(seq, torch.Tensor):
            seq = seq.float()  # 转换为float32
        else:
            seq = torch.tensor(seq, dtype=torch.float32)  # numpy数组转tensor
        
        if seq_len < max_len:
            pad = torch.zeros((max_len - seq_len, 960), dtype=torch.float32)
            processed_seq = torch.cat([seq, pad], dim=0)
        else:
            processed_seq = seq[:max_len]
        
        processed_batch.append(processed_seq * scale_factor)
    
    input_tensor = torch.stack(processed_batch).to(device).to(torch.float32)  # 显式指定类型
    
    with torch.no_grad():
        _, compressed, _ = model(input_tensor)
    
    compressed_np = compressed.cpu().numpy()
    for key, vec in zip(batch_keys, compressed_np):
        new_dict[key] = vec  # vec形状为(960,)

    del input_tensor, compressed
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

In [None]:
with open('compressed_protein_features_fit01.pkl', 'wb') as f:
    pickle.dump(new_dict, f)
