In [1]:
import hydra
import wandb
import torch
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf, open_dict
from genova.utils.BasicClass import Residual_seq
from torch.utils.data import DataLoader

In [2]:
hydra.initialize('configs')
cfg = hydra.compose('config.yaml')
with open_dict(cfg):
    cfg.task = 'optimum_path'

In [3]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv',low_memory=False,index_col='Spec Index')
spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']
#spec_header = spec_header[spec_header['Node Number']<=512]

In [4]:
model = genova.models.Genova(cfg).to('cuda')
ds = genova.data.GenovaDataset(cfg,spec_header=spec_header,dataset_dir_path='/home/z37mao/')
sampler = genova.data.GenovaBatchSampler(cfg,'cuda',0.95,spec_header,[0,128,256,512], model)
collate_fn = genova.data.GenovaCollator(cfg)
dl = DataLoader(ds, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=True, num_workers=4, prefetch_factor=4)
dl = genova.data.DataPrefetcher(dl,'cuda')
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters(),lr=2e-4)
scaler = GradScaler()

In [5]:
def train(dl,loss_fn,optimizer,scaler,model):
    total_step = 1
    for epoch in range(0, 40):
        print('new epoch')
        for encoder_input, decoder_input, graph_probability, label, label_mask in dl:
            if total_step%100 == 1: loss_cum = 0
            elif total_step%100 == 0 and total_step != 0: yield loss_cum/100, total_step
            optimizer.zero_grad()
            with autocast():
                output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
                output = output.log_softmax(-1)
                loss = loss_fn(output[label_mask],label[label_mask])
            assert loss.item()!=float('nan')
            loss_cum += loss.item()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_step += 1

In [6]:
a=train(dl,loss_fn,optimizer,scaler,model)

In [None]:
for loss, total_step in a:
    #loss, total_step = next(a)
    print(loss, total_step)

new epoch
3.5507740998268127 100
2.5701814818382265 200
2.451424291133881 300
2.3095486307144166 400
2.058634560108185 500
new epoch
2.1135887134075166 600
2.1388607692718504 700
2.0987807536125183 800
2.11384706735611 900
2.1082447481155397 1000
2.0386118972301484 1100
new epoch
2.049446624517441 1200
2.047368302345276 1300
2.002882659435272 1400
2.041422873735428 1500
2.0184627091884613 1600
new epoch
1.9236434638500213 1700
1.9879311096668244 1800
1.9574641633033751 1900
2.068376977443695 2000
1.913192287683487 2100
1.899888870716095 2200
new epoch
2.026212674379349 2300
1.8480445039272309 2400
1.9620160698890685 2500
1.995361157655716 2600
1.9354383969306945 2700
1.980180379152298 2800
new epoch
1.8846852684020996 2900
1.870929878950119 3000
1.9276260113716126 3100
1.8680173552036285 3200
1.9122110176086426 3300
new epoch
1.9589819991588593 3400
1.8494812154769897 3500
1.8488487565517426 3600
1.9164234125614166 3700
1.9303339445590972 3800
1.9521107399463653 3900
new epoch
1.837466