In [1]:
import hydra
import wandb
import torch
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
import genova
import numpy as np
import pandas as pd
from omegaconf import OmegaConf, open_dict
from genova.utils.BasicClass import Residual_seq
from torch.utils.data import DataLoader

In [2]:
hydra.initialize('configs')
cfg = hydra.compose('config.yaml')
with open_dict(cfg):
    cfg.task = 'optimum_path'

In [3]:
spec_header = pd.read_csv('/home/z37mao/genova_dataset_index.csv',low_memory=False,index_col='Spec Index')
#spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']
#spec_header = spec_header[spec_header['Node Number']<=512]

In [4]:
spec_header = spec_header[spec_header['MSGP File Name']=='1_3.msgp']

In [None]:
spec_header = spec_header[spec_header['Node Number']<=128]

In [5]:
model = genova.models.Genova(cfg).to(3)
ds = genova.data.GenovaDataset(cfg,spec_header=spec_header,dataset_dir_path='/home/z37mao/')
sampler = genova.data.GenovaBatchSampler(cfg,torch.device(3),0.9,spec_header,[0,128], model)
collate_fn = genova.data.GenovaCollator(cfg)
dl = DataLoader(ds, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=True, num_workers=2)
dl = genova.data.DataPrefetcher(dl,torch.device(3))
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters(),lr=2e-4)
scaler = GradScaler()

In [6]:
epochs = 5
loss_cum = 0
for epoch in range(epochs):
    print('new epoch')
    try: print(loss.item())
    except: pass
    for i, batch in enumerate(dl, start=1):
        encoder_input, decoder_input, graph_probability, label, label_mask = batch
        optimizer.zero_grad()
        with autocast():
            output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
            output = output.log_softmax(-1)
            loss = loss_fn(output[label_mask],label[label_mask])
        #break
        #print(loss)
        loss_cum += loss.item()
        if i%100==0:
            print(loss_cum/100)
            #wandb.log({'loss':loss_cum})
            loss_cum = 0
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

new epoch
3.345377576351166
2.3930571484565735
2.1479326736927034
2.0068275344371798
new epoch
1.9844505786895752
2.2607165849208832
1.942398247718811
1.9117928123474122
1.874899628162384
new epoch
1.9498313665390015
2.2033796596527098
1.833137058019638
1.8242606961727141
1.8842390859127045
new epoch
2.1033236980438232
2.088594515323639
1.817513861656189
1.7799456572532655
1.818541531562805
new epoch
2.072784423828125
2.052454866170883
1.7916455519199372
1.7810451555252076
1.7796060848236084


In [7]:
epochs = 5
loss_cum = 0
for epoch in range(epochs):
    print('new epoch')
    try: print(loss.item())
    except: pass
    for i, batch in enumerate(dl, start=1):
        encoder_input, decoder_input, graph_probability, label, label_mask = batch
        optimizer.zero_grad()
        with autocast():
            output = model(encoder_input=encoder_input, decoder_input=decoder_input, graph_probability=graph_probability)
            output = output.log_softmax(-1)
            loss = loss_fn(output[label_mask],label[label_mask])
        #break
        #print(loss)
        loss_cum += loss.item()
        if i%100==0:
            print(loss_cum/100)
            #wandb.log({'loss':loss_cum})
            loss_cum = 0
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

new epoch
1.6874823570251465
1.7391903650760652
1.7110785686969756
1.762934627532959
1.7356312024593352
new epoch
1.9749747514724731
2.050465053319931
1.7123709964752196
1.713335702419281
1.7026486313343048
new epoch
1.7378637790679932
1.8952631175518035
1.6893843185901642
1.7154736602306366
1.7092554116249083
new epoch
1.9898345470428467


KeyboardInterrupt: 

In [None]:
i = 0
while True:
    print('new epoch')
    if i == 2: break
    for index in sampler:
        print(len(index))
    i += 1