In [1]:
import json
import torch
import genova
from datetime import datetime
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from genova.utils.BasicClass import Residual_seq
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler
import torch.nn as nn
import torch.optim as optim

In [2]:
with open('genova/utils/dictionary') as f:
    dictionary = json.load(f)

In [3]:
cfg = OmegaConf.load('configs/genova_dda_light.yaml')
spec_header = pd.read_csv('/data/z37mao/genova/pretrain_data_sparse/genova_psm.csv',index_col='index')
spec_header = spec_header[np.logical_or(spec_header['Experiment Name']=='Cerebellum',spec_header['Experiment Name']=='HeLa')]
#spec_header = spec_header[spec_header['Node Number']<=512]
#spec_header = spec_header[spec_header['path_matrix_dense_num']<=1e6]
small_spec = spec_header[spec_header['Node Number']<=256]
dataset = genova.data.GenovaDataset(cfg,dictionary=dictionary,spec_header=small_spec,dataset_dir_path='/data/z37mao/genova/pretrain_data_sparse/')
collate_fn  = genova.data.GenovaCollator(cfg,mode='train')
dl = DataLoader(dataset,batch_size=4,collate_fn=collate_fn,num_workers=4,shuffle=True)

In [4]:
model = genova.models.Genova(cfg).cuda()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(),lr=1e-5)
scaler = GradScaler()
#checkpoint = torch.load('save/model.pt')
#model.load_state_dict(checkpoint['model_state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [5]:
def encoder_input_cuda(encoder_input):
    for section_key in encoder_input:
        for key in encoder_input[section_key]:
            if isinstance(encoder_input[section_key][key],torch.Tensor):
                encoder_input[section_key][key] = encoder_input[section_key][key].cuda()
    return encoder_input

def decoder_input_cuda(decoder_input):
    for key in decoder_input:
        if isinstance(decoder_input[key],torch.Tensor):
            decoder_input[key] = decoder_input[key].cuda()
    return decoder_input

In [None]:
loss_detect = 0
detect_period = 200
for epoch in range(5):
    for i, (encoder_input, decoder_input, labels) in enumerate(dl,start=1):
        encoder_input = encoder_input_cuda(encoder_input)
        decoder_input = decoder_input_cuda(decoder_input)
        labels = labels.cuda()
        optimizer.zero_grad()
        with autocast():
            output = model(encoder_input=encoder_input,decoder_input=decoder_input)
            loss = loss_fn(output[labels!=0],labels[labels!=0])
        loss_detect+=loss.item()
        if i%detect_period==0:
            print(loss_detect/detect_period)
            loss_detect = 0
        if i%10000==0:
            torch.save({'model_state_dict':model.state_dict(),
                        'optim_state_dict':optimizer.state_dict()},
                       'save/model_large_{}.pt'.format(datetime.strftime(datetime.now(),'%Y-%m-%d_%H-%M')))
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

2.945377013683319
2.756175673007965
2.5744961524009704
2.522887713909149
2.4932640194892883
2.469462194442749
2.474999210834503
2.4739758920669557
2.458248772621155
2.491481537818909
2.4509156107902528
2.462229964733124
2.455927391052246
2.451265037059784
2.4515697526931763
2.4565942072868348


In [None]:
model = genova.models.Genova(cfg).cuda()
checkpoint = torch.load('save/model.pt')
loss_fn = nn.CrossEntropyLoss()
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
loss_detect = 0
recall_detect = 0
for i, (encoder_input, decoder_input, labels) in enumerate(dl,start=1):
    encoder_input = encoder_input_cuda(encoder_input)
    decoder_input = decoder_input_cuda(decoder_input)
    labels = labels.cuda()
    with torch.no_grad() and autocast():
        output = model(encoder_input=encoder_input,decoder_input=decoder_input)
        loss = loss_fn(output[labels!=0],labels[labels!=0])
        recall = (torch.argmax(output[labels!=0],-1)==labels[labels!=0]).sum()/labels[labels!=0].shape[0]
    loss_detect+=loss.item()
    recall_detect+=recall.item()
    if i==2000: break

In [None]:
print(loss_detect/i,recall_detect/i)