In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model import LayerNormNet
from data_utils import transfer_dataset
from losses import criterion
import os
from torch.utils.tensorboard import SummaryWriter  # 引入 TensorBoard

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load dataset
embed_path='../data/DMS_substitutionsesm_embed'
labels_path='../data/dataset_cor_random.csv'
dataset = transfer_dataset(embed_path, labels_path)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=dataset.collate_fn)

# Initialize model, optimizer, and loss function
hidden_dim = 512
out_dim = 128
model = LayerNormNet(hidden_dim, out_dim, device=device, dtype=torch.float).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-1)
log_dir = './runs'
writer = SummaryWriter(log_dir)
# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_sr = 0
    for batch in dataloader:
        qurry, data, label = [x.to(device) for x in batch]
        
        # Forward pass
        output1 = model(qurry)
        output2 = model(data)
        
        # Compute loss
        sr,loss = criterion(output1, output2, label)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        total_loss += loss.item()
        total_sr+=sr
    writer.add_scalar('Train/Loss', total_loss/len(dataloader), epoch)
    print(f'Epoch [{epoch+1}/{num_epochs}], spearman: {total_sr/len(dataloader):.4f}, Loss: {total_loss/len(dataloader):.4f}')

  embed_list = [torch.load(os.path.join(embed_path,embed)).cpu().numpy() for embed in embeds]


Epoch [1/1000], spearman: -0.0031, Loss: 204.9690
Epoch [2/1000], spearman: -0.0017, Loss: 204.9641
Epoch [3/1000], spearman: 0.0039, Loss: 204.9639
Epoch [4/1000], spearman: 0.0028, Loss: 204.9640
Epoch [5/1000], spearman: -0.0054, Loss: 204.9644
Epoch [6/1000], spearman: -0.0054, Loss: 204.9645
Epoch [7/1000], spearman: -0.0024, Loss: 204.9641
Epoch [8/1000], spearman: -0.0081, Loss: 204.9640
Epoch [9/1000], spearman: 0.0084, Loss: 204.9633
Epoch [10/1000], spearman: -0.0003, Loss: 204.9640
Epoch [11/1000], spearman: 0.0052, Loss: 204.9633
Epoch [12/1000], spearman: -0.0006, Loss: 204.9644
Epoch [13/1000], spearman: 0.0013, Loss: 204.9630
Epoch [14/1000], spearman: 0.0052, Loss: 204.9640
Epoch [15/1000], spearman: -0.0050, Loss: 204.9639
Epoch [16/1000], spearman: -0.0028, Loss: 204.9638
Epoch [17/1000], spearman: -0.0012, Loss: 204.9632
Epoch [18/1000], spearman: -0.0039, Loss: 204.9632
Epoch [19/1000], spearman: -0.0019, Loss: 204.9634
Epoch [20/1000], spearman: -0.0002, Loss: 204.

KeyboardInterrupt: 