In [1]:
import json
import torch
import genova
from datetime import datetime
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from collections import OrderedDict
from genova.utils.BasicClass import Residual_seq
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler
import torch.nn as nn
import torch.optim as optim

In [2]:
with open('genova/utils/dictionary') as f:
    dictionary = json.load(f)

In [None]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')
spec_header = pd.read_csv('/data/z37mao/genova/pretrain_data_sparse/genova_psm.csv',index_col='index')
spec_header = spec_header[np.logical_or(spec_header['Experiment Name']=='Cerebellum',spec_header['Experiment Name']=='HeLa')]
#spec_header = spec_header[spec_header['Node Number']<=512]
#spec_header = spec_header[spec_header['path_matrix_dense_num']<=1e6]
small_spec = spec_header[spec_header['Node Number']<=256]
dataset = genova.data.GenovaDataset(cfg,dictionary=dictionary,spec_header=small_spec,dataset_dir_path='/data/z37mao/genova/pretrain_data_sparse/')
collate_fn  = genova.data.GenovaCollator(cfg,mode='train')
dl = DataLoader(dataset,batch_size=16,collate_fn=collate_fn,num_workers=8,shuffle=True)

In [3]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')
spec_header = pd.read_csv('/data/z37mao/genova/pretrain_data_sparse/genova_psm.csv',index_col='index')
spec_header = spec_header[spec_header['Experiment Name']=='PXD008844']
small_spec = spec_header[spec_header['Node Number']<=256]
dataset = genova.data.GenovaDataset(cfg,dictionary=dictionary,spec_header=small_spec,dataset_dir_path='/data/z37mao/genova/pretrain_data_sparse/')
collate_fn  = genova.data.GenovaCollator(cfg,mode='train')
dl = DataLoader(dataset,batch_size=16,collate_fn=collate_fn,num_workers=8,shuffle=True)
model = genova.models.Genova(cfg).cuda()
#torch.save(model.state_dict(),'/data/z37mao/save/GenovaPrototype.pt')

In [4]:
model = model.half()

In [5]:
checkpoint = torch.load('/data/z37mao/save/Genova_model.pt')

In [6]:
model_state_dict = OrderedDict([(k[7:],v) for k,v in checkpoint['model_state_dict'].items()])

In [7]:
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [None]:
checkpoint_temp = []
for key, v in checkpoint.items():
    if key.split('.')[0] != 'output_ffn': checkpoint_temp.append(('encoder.'+key, v))
for key, v in model.state_dict().items():
    if key.split('.')[0] != 'encoder': checkpoint_temp.append((key, v))
checkpoint = OrderedDict(checkpoint_temp)
model.load_state_dict(checkpoint)

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False
optimizer = optim.AdamW(model.parameters(),lr=1e-5)
scaler = GradScaler()

In [8]:
def encoder_input_cuda(encoder_input):
    for section_key in encoder_input:
        for key in encoder_input[section_key]:
            if isinstance(encoder_input[section_key][key],torch.Tensor):
                if encoder_input[section_key][key].dtype == torch.float32:
                    encoder_input[section_key][key] = encoder_input[section_key][key].cuda().half()
                else:
                    encoder_input[section_key][key] = encoder_input[section_key][key].cuda()
    return encoder_input

def decoder_input_cuda(decoder_input):
    for key in decoder_input:
        if isinstance(decoder_input[key],torch.Tensor):
            decoder_input[key] = decoder_input[key].cuda()
    return decoder_input

In [None]:
loss_detect = 0
detect_period = 200
for epoch in range(5):
    for i, (encoder_input, decoder_input, labels) in enumerate(dl,start=1):
        encoder_input = encoder_input_cuda(encoder_input)
        decoder_input = decoder_input_cuda(decoder_input)
        labels = labels.cuda()
        optimizer.zero_grad()
        with autocast():
            output = model(encoder_input=encoder_input,decoder_input=decoder_input)
            loss = loss_fn(output[labels!=0],labels[labels!=0])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        #torch.cuda.empty_cache()

In [None]:
torch.save({'model_state_dict':model.state_dict(),
            'optim_state_dict':optimizer.state_dict()},
            'save/model_large_{}.pt'.format(datetime.strftime(datetime.now(),'%Y-%m-%d_%H-%M')))

In [None]:
model = genova.models.Genova(cfg).cuda()
checkpoint = torch.load('save/model_large_2022-02-22_16-51.pt')
loss_fn = nn.CrossEntropyLoss()
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
spec_header = pd.read_csv('/data/z37mao/genova/pretrain_data_sparse/genova_psm.csv',index_col='index')
spec_header = spec_header[spec_header['Experiment Name']=='PXD008844']
small_spec = spec_header[spec_header['Node Number']<=256]
dataset = genova.data.GenovaDataset(cfg,dictionary=dictionary,spec_header=small_spec,dataset_dir_path='/data/z37mao/genova/pretrain_data_sparse/')
collate_fn  = genova.data.GenovaCollator(cfg,mode='train')
dl = DataLoader(dataset,batch_size=1,collate_fn=collate_fn,shuffle=True)

In [None]:
torch.cuda.empty_cache()

In [None]:
loss_detect = 0
recall_detect = 0
recall_pep_detect = 0
predict = []
label = []
for i, (encoder_input, decoder_input, labels) in enumerate(dl,start=1):
    encoder_input = encoder_input_cuda(encoder_input)
    decoder_input = decoder_input_cuda(decoder_input)
    labels = labels.cuda()
    with torch.no_grad():
        with autocast():
            output = model(encoder_input=encoder_input,decoder_input=decoder_input)
            predict.append([reverse_dict[i.item()] for i in torch.argmax(output,-1)[0,:-1].cpu()])
            label.append([reverse_dict[i.item()] for i in labels[0,:-1].cpu()])
    #loss_detect+=loss.item()
    #recall_detect+=recall.item()
    #recall_pep_detect+=recall_pep.item()
    if i==2000: break
print(loss_detect/i,recall_detect/i,recall_pep_detect/i)

In [None]:
label

In [None]:
predict

In [None]:
reverse_dict = {dictionary[aa]:aa for aa in dictionary}

In [None]:
[reverse_dict[i.item()] for i in torch.argmax(output,-1)[0,:-1].cpu()]

In [None]:
torch.argmax(output,-1)==labels

In [None]:
labels

In [None]:
loss_detect = 0
recall_detect = 0
recall_pep_detect = 0
for i, (encoder_input, decoder_input, labels) in enumerate(dl,start=1):
    encoder_input = encoder_input_cuda(encoder_input)
    decoder_input = decoder_input_cuda(decoder_input)
    labels = labels.cuda()
    with torch.no_grad():
        with autocast():
            output = model(encoder_input=encoder_input,decoder_input=decoder_input)
            loss = loss_fn(output[labels!=0],labels[labels!=0])
            recall = (torch.argmax(output[labels!=0],-1)==labels[labels!=0]).sum()/labels[labels!=0].shape[0]
            recall_pep = torch.all(torch.argmax(output,-1)==labels).sum()
    loss_detect+=loss.item()
    recall_detect+=recall.item()
    recall_pep_detect+=recall_pep.item()
    if i==2000: break
print(loss_detect/i,recall_detect/i,recall_pep_detect/i)

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [9]:
loss_detect = 0
recall_detect = 0
for i, (encoder_input, decoder_input, labels) in enumerate(dl,start=1):
    encoder_input = encoder_input_cuda(encoder_input)
    decoder_input = decoder_input_cuda(decoder_input)
    labels = labels.cuda()
    with torch.no_grad():
        output = model(encoder_input=encoder_input,decoder_input=decoder_input)
        loss = loss_fn(output[labels!=0],labels[labels!=0])
        recall = (torch.argmax(output[labels!=0],-1)==labels[labels!=0]).sum()/labels[labels!=0].shape[0]
    break
    loss_detect+=loss.item()
    recall_detect+=recall.item()
    if i==2000: break
print(loss_detect/i,recall_detect/i)

RuntimeError: CUDA out of memory. Tried to allocate 3.23 GiB (GPU 0; 15.78 GiB total capacity; 2.49 GiB already allocated; 2.67 GiB free; 2.55 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
for i, (encoder_input, decoder_input, labels) in enumerate(dl,start=1):
    encoder_input = encoder_input_cuda(encoder_input)
    decoder_input = decoder_input_cuda(decoder_input)
    labels = labels.cuda()
    with torch.no_grad():
        with autocast():
            output = model(encoder_input=encoder_input,decoder_input=decoder_input)
            #loss = loss_fn(output[labels!=0],labels[labels!=0])
            #recall = (torch.argmax(output[labels!=0],-1)==labels[labels!=0]).sum()/labels[labels!=0].shape[0]
    print(torch.argmax(output,-1).cpu().numpy())
    print(labels.cpu().numpy())
    print()
    if i==10: break

In [None]:
print(torch.argmax(output,-1).cpu().numpy())
print(labels.cpu().numpy())

In [None]:
dictionary