In [1]:
import torch as pt


mols_test = pt.load('./data/mine/test_11499.pt')
print(len(mols_test))
mols_val = pt.load('./data/mine/val_11825.pt')
print(len(mols_val))
mols_all = pt.load('./data/mine/mols_all.pt')
print(len(mols_all))

11499
11825
2253216


In [None]:
import numpy as np

mass_all = np.array([float(mol.metadata['nominal_mass']) for mol in mols_all])
mass_test = np.array([float(mol.metadata['nominal_mass']) for mol in mols_test])
mass_val = np.array([float(mol.metadata['nominal_mass']) for mol in mols_val])

In [2]:
from torch.utils.data import DataLoader
from utils.data import SpecDataset, SpecDataset_finetune, collate_fun_emb, collate_fun_finetune


dataset_lib = SpecDataset(mols_all)
loader_lib = DataLoader(dataset_lib, batch_size=2048, shuffle=False,
                        num_workers=8, collate_fn=collate_fun_emb)
dataset_val = SpecDataset(mols_val)
loader_val = DataLoader(dataset_val, batch_size=2048, shuffle=False,
                        num_workers=8, collate_fn=collate_fun_emb)
dataset_test = SpecDataset(mols_test)
loader_test = DataLoader(dataset_test, batch_size=2048, shuffle=False,
                        num_workers=8, collate_fn=collate_fun_emb)
dataset_finetune = SpecDataset_finetune((mols_val, mols_all))

In [14]:
import torch as pt
import torch.optim as optim
from utils.model import Spec2Emb
from tqdm import tqdm
from utils.tools import gen_embeddings, build_idx, evaluate, find_nearest_hit_nhit, save_model


gpu = 6
model = Spec2Emb().to(gpu)
model.load_state_dict(pt.load('./model/base_peak0.01_epoch4.pth', map_location='cpu'))
epochs = 3
batch_size = 32
optimizer = optim.Adam(model.parameters(), lr=0.001)
f = open('ft_p0.4_mass.txt', 'w') # ft：finetune，base指未更改模型结构
model_name = 'ft_p0.4_mass'
max_metrics = {'expanded': [0, 0], 'insilico': [0, 0], 'expanded_mass': [0, 0], 'insilico_mass': [0, 0]}

for epoch in range(epochs):  
    print(f'==================================Finetune_epoch{epoch+1}======================================')
    f.write('\nFinetune_epoch%d\n' % (epoch+1))
    embeddings_lib = gen_embeddings(model, loader_lib, gpu, power=0.4)
    embeddings_val = gen_embeddings(model, loader_val, gpu, power=0.4)
    embeddings_lib[:, -1] = mass_all
    embeddings_val[:, -1] = mass_val
    I, _ = build_idx(embeddings_lib, embeddings_val, gpu, topk=200) # 内置清缓存
    top1_val, top10_val = evaluate(mols_val, I, mols_all, f, 'Validation')
    vals, hits, nhits = find_nearest_hit_nhit(I, mols_val, mols_all)
    dataset_ft = SpecDataset_finetune(dataset_finetune, mapping=(vals, hits, nhits))
    loader_ft = DataLoader(dataset_ft, batch_size, shuffle=True, num_workers=8, collate_fn=collate_fun_finetune)
    model.train()
    for j in range(5):
        finetune_loss = []
        for i, Data in enumerate(tqdm(loader_ft, unit='batch')):
            Data = [d.to(gpu) for data in Data for d in data]
            optimizer.zero_grad()
            loss = model((Data[:3], Data[3:6], Data[6:9]), mode='finetune', power=0.4)
            finetune_loss.append(loss.item())
            loss.backward()
            optimizer.step()
            if (i+1) %300 ==0:
                loss = np.mean(finetune_loss)
                print(f'Total Loss: {loss}')
                finetune_loss = []

    print(f'===================================Test_epoch{epoch+1}======================================')
    f.write('\n\nTest_epoch%d\n' % (epoch+1))
    embeddings_lib = gen_embeddings(model, loader_lib, gpu, power=0.4) 
    embeddings_test = gen_embeddings(model, loader_test, gpu, power=0.4)
    I_expand, _ = build_idx(embeddings_lib, embeddings_test, gpu, topk=200)
    top1_expand, top10_expand = evaluate(mols_test, I_expand, mols_all, f, 'expanded')
    if top1_expand > max_metrics['expand'][0] and top10_expand > max_metrics['expanded'][1]:
        max_metrics['expand'] = [top1_expand, top10_expand]
        save_model(model, model_name, epoch)
    I_insilico, _ = build_idx(embeddings_lib[:2146690], embeddings_test, gpu, topk=200)
    top1_insilico, top10_insilico = evaluate(mols_test, I_insilico, mols_all, f, 'insilico')
    if top1_insilico > max_metrics['insilico'][0] and top10_insilico > max_metrics['insilico'][1]:
        max_metrics['insilico'] = [top1_insilico, top10_insilico]
        save_model(model, model_name, epoch)
    print(f'\nWith Mass:')
    f.write('With Mass:\n')
    embeddings_lib[:, -1] = mass_all
    embeddings_test[:, -1] = mass_test
    I_expand, _ = build_idx(embeddings_lib, embeddings_test, gpu, topk=200)
    top1_expand, top10_expand = evaluate(mols_test, I_expand, mols_all, f, 'expanded')
    if top1_expand > max_metrics['expanded_mass'][0] and top10_expand > max_metrics['expanded_mass'][1]:
        max_metrics['expanded_mass'] = [top1_expand, top10_expand]
        save_model(model, model_name, epoch)
    I_insilico, _ = build_idx(embeddings_lib[:2146690], embeddings_test, gpu, topk=200)
    top1_insilico, top10_insilico = evaluate(mols_test, I_insilico, mols_all, f, 'insilico')
    if top1_insilico > max_metrics['insilico_mass'][0] and top10_insilico > max_metrics['insilico_mass'][1]:
        max_metrics['insilico_mass'] = [top1_insilico, top10_insilico]
        save_model(model, model_name, epoch)
    print(f'================================================================================================')
f.close()

Searching time:  0:00:01.613016
Validation library
Top1 hit rate: 43.89%
Top10 hit rate: 82.76%


 89%|████████▉ | 312/349 [00:05<00:00, 91.33batch/s]

Total Loss: 0.9904110113779704


100%|██████████| 349/349 [00:06<00:00, 49.97batch/s]
 89%|████████▉ | 312/349 [00:06<00:00, 86.51batch/s]

Total Loss: 0.9859908823172251


100%|██████████| 349/349 [00:07<00:00, 45.48batch/s]
 90%|█████████ | 315/349 [00:07<00:00, 87.88batch/s]

Total Loss: 0.9802273492018382


100%|██████████| 349/349 [00:08<00:00, 39.27batch/s]
 91%|█████████ | 316/349 [00:07<00:00, 86.34batch/s]

Total Loss: 0.9670295816659927


100%|██████████| 349/349 [00:08<00:00, 40.06batch/s]
 89%|████████▊ | 309/349 [00:07<00:00, 85.25batch/s]

Total Loss: 0.9505515831708908


100%|██████████| 349/349 [00:08<00:00, 39.27batch/s]






Searching time:  0:00:01.574095
Expanded library
Top1 hit rate: 42.83%
Top10 hit rate: 83.90%
Searching time:  0:00:01.495802
In-silico library
Top1 hit rate: 43.11%
Top10 hit rate: 84.33%
With Mass:
Searching time:  0:00:01.567990
Expanded library
Top1 hit rate: 50.80%
Top10 hit rate: 90.93%
Searching time:  0:00:01.491236
In-silico library
Top1 hit rate: 51.15%
Top10 hit rate: 91.20%
Searching time:  0:00:01.608956
Validation library
Top1 hit rate: 53.82%
Top10 hit rate: 91.70%


 85%|████████▍ | 311/366 [00:06<00:00, 89.74batch/s]

Total Loss: 0.976724262436231


100%|██████████| 366/366 [00:08<00:00, 44.85batch/s]
 86%|████████▌ | 313/366 [00:07<00:00, 87.21batch/s]

Total Loss: 0.9719607601563136


100%|██████████| 366/366 [00:08<00:00, 41.08batch/s]
 84%|████████▍ | 308/366 [00:07<00:00, 84.14batch/s]

Total Loss: 0.9674709280331929


100%|██████████| 366/366 [00:08<00:00, 41.54batch/s]
 86%|████████▌ | 314/366 [00:07<00:00, 83.48batch/s]

Total Loss: 0.965538561741511


100%|██████████| 366/366 [00:09<00:00, 40.18batch/s]
 85%|████████▍ | 310/366 [00:07<00:00, 80.20batch/s]

Total Loss: 0.9619528927405675


100%|██████████| 366/366 [00:09<00:00, 39.25batch/s]






Searching time:  0:00:01.582163
Expanded library
Top1 hit rate: 43.31%
Top10 hit rate: 84.16%
Searching time:  0:00:01.503104
In-silico library
Top1 hit rate: 43.58%
Top10 hit rate: 84.52%
With Mass:
Searching time:  0:00:01.568346
Expanded library
Top1 hit rate: 51.52%
Top10 hit rate: 91.59%
Searching time:  0:00:01.491743
In-silico library
Top1 hit rate: 51.87%
Top10 hit rate: 91.95%
Searching time:  0:00:01.630666
Validation library
Top1 hit rate: 58.49%
Top10 hit rate: 93.29%


 85%|████████▌ | 313/368 [00:05<00:00, 88.22batch/s]

Total Loss: 0.9660742004712423


100%|██████████| 368/368 [00:07<00:00, 47.16batch/s]
 86%|████████▌ | 316/368 [00:07<00:00, 92.68batch/s]

Total Loss: 0.9614268527428309


100%|██████████| 368/368 [00:08<00:00, 41.46batch/s]
 86%|████████▌ | 317/368 [00:07<00:00, 86.32batch/s]

Total Loss: 0.9588480953375499


100%|██████████| 368/368 [00:09<00:00, 39.06batch/s]
 86%|████████▌ | 316/368 [00:07<00:00, 88.67batch/s]

Total Loss: 0.9558197156588236


100%|██████████| 368/368 [00:09<00:00, 39.47batch/s]
 85%|████████▍ | 312/368 [00:07<00:00, 85.54batch/s]

Total Loss: 0.9522421328226726


100%|██████████| 368/368 [00:09<00:00, 39.13batch/s]






Searching time:  0:00:01.556401
Expanded library
Top1 hit rate: 42.18%
Top10 hit rate: 83.54%
Searching time:  0:00:01.506676
In-silico library
Top1 hit rate: 42.45%
Top10 hit rate: 83.92%
With Mass:
Searching time:  0:00:01.573775
Expanded library
Top1 hit rate: 50.44%
Top10 hit rate: 91.38%
Searching time:  0:00:01.501219
In-silico library
Top1 hit rate: 50.91%
Top10 hit rate: 91.76%


In [12]:
pt.cuda.empty_cache()