In [1]:
import dataset
from models import LorentzNet
import torch
from torch import nn, optim
import json, time
import utils
import numpy as np
import torch.distributed as dist
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

In [2]:
log_interval = 100
epochs = 10
val_interval = 1
logdir = './logs'

In [3]:
def run(epoch, loader, partition):
    if partition == 'train':
        model.train()
    else:
        model.eval()

    res = {'time':0, 'correct':0, 'loss': 0, 'counter': 0, 'acc': 0,
           'loss_arr':[], 'correct_arr':[],'label':[],'score':[]}

    tik = time.time()
    loader_length = len(loader)

    for i, (label, p4s, nodes, atom_mask, edge_mask, edges) in enumerate(loader):
        if partition == 'train':
            optimizer.zero_grad()

        batch_size, n_nodes, _ = p4s.size()
        atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
        atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device)
        edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)
        nodes = nodes.view(batch_size * n_nodes, -1).to(device,dtype)
        edges = [a.to(device) for a in edges]
        label = label.to(device, dtype).long()

        pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                         edge_mask=edge_mask, n_nodes=n_nodes)
        
        predict = pred.max(1).indices
        correct = torch.sum(predict == label).item()
        loss = loss_fn(pred, label)
        
        if partition == 'train':
            loss.backward()
            optimizer.step()
        elif partition == 'test':
            # save labels and probilities for ROC / AUC
            score = torch.nn.functional.softmax(pred, dim = -1)
            res['label'].append(label)
            res['score'].append(score)

        res['time'] = time.time() - tik
        res['correct'] += correct
        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size
        res['loss_arr'].append(loss.item())
        res['correct_arr'].append(correct)

        if i != 0 and i % log_interval == 0:
            running_loss = sum(res['loss_arr'][-log_interval:])/len(res['loss_arr'][-log_interval:])
            running_acc = sum(res['correct_arr'][-log_interval:])/(len(res['correct_arr'][-log_interval:])*batch_size)
            avg_time = res['time']/res['counter'] * batch_size
            tmp_counter = res['counter']
            tmp_loss = res['loss'] / tmp_counter
            tmp_acc = res['correct'] / tmp_counter
            print(">> %s \t Epoch %d/%d \t Batch %d/%d \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
                  (partition, epoch + 1, epochs, i, loader_length, running_loss, running_acc, tmp_acc, avg_time))

    torch.cuda.empty_cache()
    if partition == 'test':
        res['label'] = torch.cat(res['label']).unsqueeze(-1)
        res['score'] = torch.cat(res['score'])
        res['score'] = torch.cat((res['label'],res['score']),dim=-1)
    res['loss'] = res['loss'] / res['counter']
    res['acc'] = res['correct'] / res['counter']
    return res

In [4]:
def train(res):
    ### training and validation
    for epoch in range(0, epochs):
        train_res = run(epoch, dataloaders['train'], partition='train')
        print("Time: train: %.2f \t Train loss %.4f \t Train acc: %.4f" % (train_res['time'],train_res['loss'],train_res['acc']))
        if epoch % val_interval == 0:
            torch.save(model.state_dict(), f"{logdir}/checkpoint-epoch-{epoch}.pt")
            with torch.no_grad():
                val_res = run(epoch, dataloaders['val'], partition='val')
            res['lr'].append(optimizer.param_groups[0]['lr'])
            res['train_time'].append(train_res['time'])
            res['val_time'].append(val_res['time'])
            res['train_loss'].append(train_res['loss'])
            res['train_acc'].append(train_res['acc'])
            res['val_loss'].append(val_res['loss'])
            res['val_acc'].append(val_res['acc'])
            res['epochs'].append(epoch)

                ## save best model
            if val_res['acc'] > res['best_val']:
                print("New best validation model, saving...")
                torch.save(model.state_dict(), f"{logdir}/best-val-model.pt")
                res['best_val'] = val_res['acc']
                res['best_epoch'] = epoch

            print("Epoch %d/%d finished." % (epoch, epochs))
            print("Train time: %.2f \t Val time %.2f" % (train_res['time'], val_res['time']))
            print("Train loss %.4f \t Train acc: %.4f" % (train_res['loss'], train_res['acc']))
            print("Val loss: %.4f \t Val acc: %.4f" % (val_res['loss'], val_res['acc']))
            print("Best val acc: %.4f at epoch %d." % (res['best_val'],  res['best_epoch']))

            

        ## adjust learning rate
        if (epoch < 31):
            lr_scheduler.step(metrics=val_res['acc'])
        else:
            for g in optimizer.param_groups:
                g['lr'] = g['lr']*0.5


In [5]:
def test(res):
    ### test on best model
    best_model = torch.load(f"{logdir}/checkpoint-epoch-9.pt", map_location=device)
    model.load_state_dict(best_model)
    with torch.no_grad():
        test_res = run(0, dataloaders['test'], partition='test')

    print("Test: Loss %.4f \t Acc %.4f"
          % (test_res['loss'], test_res['acc']))

In [6]:
### initialize cuda
# dist.init_process_group(backend='nccl')
device = torch.device("cuda:{}".format(0))
dtype = torch.float32

    ### load data
dataloaders = dataset.retrieve_dataloaders(
    200,
    num_data=1000_000,
    num_workers=4)

In [7]:
### create parallel model
model = LorentzNet(n_scalar = 2, n_hidden = 72, n_class = 2,
                   dropout = 0.2, n_layers = 6,
                   c_weight = 0.001)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.to(device)

### print model and dataset information

pytorch_total_params = sum(p.numel() for p in model.parameters())
print("Network Size:", pytorch_total_params)
for (split, dataloader) in dataloaders.items():
    print(f" {split} samples: {len(dataloader.dataset)}")

Network Size: 224072
 train samples: 800000
 val samples: 100000
 test samples: 100000


In [8]:
### optimizer
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

### lr scheduler
base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
lr_scheduler = utils.GradualWarmupScheduler(optimizer, multiplier=1,
                                            warmup_epoch=2,
                                            after_scheduler=base_scheduler) ## warmup

### loss function
loss_fn = nn.CrossEntropyLoss()

### initialize logs
res = {'epochs': [], 'lr' : [],
       'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],
       'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

In [9]:
train(res)

>> train 	 Epoch 1/10 	 Batch 100/4000 	 Loss 0.4916 	 Running Acc 0.778 	 Total Acc 0.776 	 Avg Batch Time 0.2052
>> train 	 Epoch 1/10 	 Batch 200/4000 	 Loss 0.4627 	 Running Acc 0.787 	 Total Acc 0.781 	 Avg Batch Time 0.2003
>> train 	 Epoch 1/10 	 Batch 300/4000 	 Loss 0.4592 	 Running Acc 0.791 	 Total Acc 0.785 	 Avg Batch Time 0.1995
>> train 	 Epoch 1/10 	 Batch 400/4000 	 Loss 0.4626 	 Running Acc 0.789 	 Total Acc 0.786 	 Avg Batch Time 0.1993
>> train 	 Epoch 1/10 	 Batch 500/4000 	 Loss 0.4596 	 Running Acc 0.788 	 Total Acc 0.786 	 Avg Batch Time 0.1995
>> train 	 Epoch 1/10 	 Batch 600/4000 	 Loss 0.4475 	 Running Acc 0.794 	 Total Acc 0.788 	 Avg Batch Time 0.1997
>> train 	 Epoch 1/10 	 Batch 700/4000 	 Loss 0.4554 	 Running Acc 0.790 	 Total Acc 0.788 	 Avg Batch Time 0.1999
>> train 	 Epoch 1/10 	 Batch 800/4000 	 Loss 0.4509 	 Running Acc 0.793 	 Total Acc 0.789 	 Avg Batch Time 0.2005
>> train 	 Epoch 1/10 	 Batch 900/4000 	 Loss 0.4415 	 Running Acc 0.798 	 Total

>> train 	 Epoch 2/10 	 Batch 2700/4000 	 Loss 0.4294 	 Running Acc 0.805 	 Total Acc 0.804 	 Avg Batch Time 0.2111
>> train 	 Epoch 2/10 	 Batch 2800/4000 	 Loss 0.4282 	 Running Acc 0.806 	 Total Acc 0.804 	 Avg Batch Time 0.2111
>> train 	 Epoch 2/10 	 Batch 2900/4000 	 Loss 0.4289 	 Running Acc 0.808 	 Total Acc 0.804 	 Avg Batch Time 0.2111
>> train 	 Epoch 2/10 	 Batch 3000/4000 	 Loss 0.4311 	 Running Acc 0.806 	 Total Acc 0.804 	 Avg Batch Time 0.2111
>> train 	 Epoch 2/10 	 Batch 3100/4000 	 Loss 0.4313 	 Running Acc 0.806 	 Total Acc 0.804 	 Avg Batch Time 0.2111
>> train 	 Epoch 2/10 	 Batch 3200/4000 	 Loss 0.4303 	 Running Acc 0.804 	 Total Acc 0.804 	 Avg Batch Time 0.2111
>> train 	 Epoch 2/10 	 Batch 3300/4000 	 Loss 0.4336 	 Running Acc 0.806 	 Total Acc 0.804 	 Avg Batch Time 0.2110
>> train 	 Epoch 2/10 	 Batch 3400/4000 	 Loss 0.4207 	 Running Acc 0.812 	 Total Acc 0.804 	 Avg Batch Time 0.2110
>> train 	 Epoch 2/10 	 Batch 3500/4000 	 Loss 0.4342 	 Running Acc 0.80

>> train 	 Epoch 4/10 	 Batch 800/4000 	 Loss 0.4044 	 Running Acc 0.824 	 Total Acc 0.824 	 Avg Batch Time 0.2111
>> train 	 Epoch 4/10 	 Batch 900/4000 	 Loss 0.3963 	 Running Acc 0.830 	 Total Acc 0.824 	 Avg Batch Time 0.2109
>> train 	 Epoch 4/10 	 Batch 1000/4000 	 Loss 0.4031 	 Running Acc 0.820 	 Total Acc 0.824 	 Avg Batch Time 0.2109
>> train 	 Epoch 4/10 	 Batch 1100/4000 	 Loss 0.4028 	 Running Acc 0.825 	 Total Acc 0.824 	 Avg Batch Time 0.2110
>> train 	 Epoch 4/10 	 Batch 1200/4000 	 Loss 0.4011 	 Running Acc 0.821 	 Total Acc 0.824 	 Avg Batch Time 0.2111
>> train 	 Epoch 4/10 	 Batch 1300/4000 	 Loss 0.3995 	 Running Acc 0.825 	 Total Acc 0.824 	 Avg Batch Time 0.2110
>> train 	 Epoch 4/10 	 Batch 1400/4000 	 Loss 0.4005 	 Running Acc 0.825 	 Total Acc 0.824 	 Avg Batch Time 0.2109
>> train 	 Epoch 4/10 	 Batch 1500/4000 	 Loss 0.3961 	 Running Acc 0.830 	 Total Acc 0.824 	 Avg Batch Time 0.2108
>> train 	 Epoch 4/10 	 Batch 1600/4000 	 Loss 0.4042 	 Running Acc 0.824 

>> train 	 Epoch 5/10 	 Batch 3400/4000 	 Loss 0.3836 	 Running Acc 0.832 	 Total Acc 0.831 	 Avg Batch Time 0.2105
>> train 	 Epoch 5/10 	 Batch 3500/4000 	 Loss 0.3986 	 Running Acc 0.826 	 Total Acc 0.831 	 Avg Batch Time 0.2105
>> train 	 Epoch 5/10 	 Batch 3600/4000 	 Loss 0.3963 	 Running Acc 0.829 	 Total Acc 0.831 	 Avg Batch Time 0.2105
>> train 	 Epoch 5/10 	 Batch 3700/4000 	 Loss 0.3932 	 Running Acc 0.829 	 Total Acc 0.831 	 Avg Batch Time 0.2106
>> train 	 Epoch 5/10 	 Batch 3800/4000 	 Loss 0.3873 	 Running Acc 0.833 	 Total Acc 0.831 	 Avg Batch Time 0.2105
>> train 	 Epoch 5/10 	 Batch 3900/4000 	 Loss 0.3930 	 Running Acc 0.831 	 Total Acc 0.831 	 Avg Batch Time 0.2106
Time: train: 842.30 	 Train loss 0.3917 	 Train acc: 0.8307
>> val 	 Epoch 5/10 	 Batch 100/500 	 Loss 0.3850 	 Running Acc 0.834 	 Total Acc 0.833 	 Avg Batch Time 0.0869
>> val 	 Epoch 5/10 	 Batch 200/500 	 Loss 0.3842 	 Running Acc 0.834 	 Total Acc 0.834 	 Avg Batch Time 0.0867
>> val 	 Epoch 5/10 

>> train 	 Epoch 7/10 	 Batch 1500/4000 	 Loss 0.3895 	 Running Acc 0.834 	 Total Acc 0.828 	 Avg Batch Time 0.2109
>> train 	 Epoch 7/10 	 Batch 1600/4000 	 Loss 0.3990 	 Running Acc 0.828 	 Total Acc 0.828 	 Avg Batch Time 0.2110
>> train 	 Epoch 7/10 	 Batch 1700/4000 	 Loss 0.3971 	 Running Acc 0.830 	 Total Acc 0.828 	 Avg Batch Time 0.2108
>> train 	 Epoch 7/10 	 Batch 1800/4000 	 Loss 0.3938 	 Running Acc 0.830 	 Total Acc 0.828 	 Avg Batch Time 0.2108
>> train 	 Epoch 7/10 	 Batch 1900/4000 	 Loss 0.3960 	 Running Acc 0.830 	 Total Acc 0.828 	 Avg Batch Time 0.2108
>> train 	 Epoch 7/10 	 Batch 2000/4000 	 Loss 0.3912 	 Running Acc 0.833 	 Total Acc 0.828 	 Avg Batch Time 0.2108
>> train 	 Epoch 7/10 	 Batch 2100/4000 	 Loss 0.3995 	 Running Acc 0.829 	 Total Acc 0.828 	 Avg Batch Time 0.2108
>> train 	 Epoch 7/10 	 Batch 2200/4000 	 Loss 0.4032 	 Running Acc 0.826 	 Total Acc 0.828 	 Avg Batch Time 0.2108
>> train 	 Epoch 7/10 	 Batch 2300/4000 	 Loss 0.3902 	 Running Acc 0.83

>> val 	 Epoch 8/10 	 Batch 200/500 	 Loss 0.3853 	 Running Acc 0.834 	 Total Acc 0.835 	 Avg Batch Time 0.0865
>> val 	 Epoch 8/10 	 Batch 300/500 	 Loss 0.3984 	 Running Acc 0.827 	 Total Acc 0.832 	 Avg Batch Time 0.0863
>> val 	 Epoch 8/10 	 Batch 400/500 	 Loss 0.3922 	 Running Acc 0.828 	 Total Acc 0.831 	 Avg Batch Time 0.0862
Epoch 7/10 finished.
Train time: 843.74 	 Val time 43.01
Train loss 0.3911 	 Train acc: 0.8316
Val loss: 0.3902 	 Val acc: 0.8310
Best val acc: 0.8311 at epoch 4.
>> train 	 Epoch 9/10 	 Batch 100/4000 	 Loss 0.3870 	 Running Acc 0.832 	 Total Acc 0.832 	 Avg Batch Time 0.2114
>> train 	 Epoch 9/10 	 Batch 200/4000 	 Loss 0.3844 	 Running Acc 0.834 	 Total Acc 0.833 	 Avg Batch Time 0.2108
>> train 	 Epoch 9/10 	 Batch 300/4000 	 Loss 0.3910 	 Running Acc 0.832 	 Total Acc 0.833 	 Avg Batch Time 0.2109
>> train 	 Epoch 9/10 	 Batch 400/4000 	 Loss 0.3952 	 Running Acc 0.830 	 Total Acc 0.832 	 Avg Batch Time 0.2108
>> train 	 Epoch 9/10 	 Batch 500/4000 	 

>> train 	 Epoch 10/10 	 Batch 2300/4000 	 Loss 0.3776 	 Running Acc 0.835 	 Total Acc 0.837 	 Avg Batch Time 0.2108
>> train 	 Epoch 10/10 	 Batch 2400/4000 	 Loss 0.3729 	 Running Acc 0.841 	 Total Acc 0.837 	 Avg Batch Time 0.2107
>> train 	 Epoch 10/10 	 Batch 2500/4000 	 Loss 0.3769 	 Running Acc 0.841 	 Total Acc 0.837 	 Avg Batch Time 0.2107
>> train 	 Epoch 10/10 	 Batch 2600/4000 	 Loss 0.3820 	 Running Acc 0.836 	 Total Acc 0.837 	 Avg Batch Time 0.2107
>> train 	 Epoch 10/10 	 Batch 2700/4000 	 Loss 0.3856 	 Running Acc 0.835 	 Total Acc 0.837 	 Avg Batch Time 0.2107
>> train 	 Epoch 10/10 	 Batch 2800/4000 	 Loss 0.3771 	 Running Acc 0.837 	 Total Acc 0.837 	 Avg Batch Time 0.2107
>> train 	 Epoch 10/10 	 Batch 2900/4000 	 Loss 0.3835 	 Running Acc 0.837 	 Total Acc 0.837 	 Avg Batch Time 0.2107
>> train 	 Epoch 10/10 	 Batch 3000/4000 	 Loss 0.3802 	 Running Acc 0.838 	 Total Acc 0.837 	 Avg Batch Time 0.2107
>> train 	 Epoch 10/10 	 Batch 3100/4000 	 Loss 0.3837 	 Running

In [10]:
test(res)

>> test 	 Epoch 1/10 	 Batch 100/500 	 Loss 0.3854 	 Running Acc 0.834 	 Total Acc 0.834 	 Avg Batch Time 0.0900
>> test 	 Epoch 1/10 	 Batch 200/500 	 Loss 0.3760 	 Running Acc 0.841 	 Total Acc 0.837 	 Avg Batch Time 0.0879
>> test 	 Epoch 1/10 	 Batch 300/500 	 Loss 0.3910 	 Running Acc 0.833 	 Total Acc 0.836 	 Avg Batch Time 0.0870
>> test 	 Epoch 1/10 	 Batch 400/500 	 Loss 0.3903 	 Running Acc 0.834 	 Total Acc 0.835 	 Avg Batch Time 0.0868
Test: Loss 0.3858 	 Acc 0.8355
