In [1]:
import hydra
import wandb
import torch
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf, open_dict
from genova.utils.BasicClass import Residual_seq
from torch.utils.data import DataLoader

In [2]:
hydra.initialize('configs')
cfg = hydra.compose('config.yaml')
with open_dict(cfg):
    cfg.task = 'optimum_path'

In [None]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv',low_memory=False,index_col='Spec Index')
#spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']
#spec_header = spec_header[spec_header['Node Number']<=512]

In [None]:
spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']

In [None]:
model = genova.models.Genova(cfg).to(3)
ds = genova.data.GenovaDataset(cfg,spec_header=spec_header,dataset_dir_path='/home/z37mao/')
sampler = genova.data.GenovaBatchSampler(cfg,torch.device(3),0.9,spec_header,[0,128,256,512], model)
collate_fn = genova.data.GenovaCollator(cfg)
dl = DataLoader(ds, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=True, num_workers=2)
dl = genova.data.DataPrefetcher(dl,torch.device(3))
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters(),lr=1e-5)
scaler = GradScaler()

In [None]:
epochs = 5
loss_cum = 0
for epoch in range(epochs):
    print('new epoch')
    for i, batch in enumerate(dl, start=1):
        encoder_input, decoder_input, graph_probability, label, label_mask = batch
        optimizer.zero_grad()
        with autocast():
            output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
            output = output.log_softmax(-1)
            loss = loss_fn(output[label_mask],label[label_mask])
        #break
        loss_cum += loss.item()
        if i%100==0:
            print(loss_cum/100)
            #wandb.log({'loss':loss_cum})
            loss_cum = 0
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()