In [1]:
import os
import argparse
from ml_collections import ConfigDict
import yaml
import time

import copy
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch_geometric.transforms import Compose
from tqdm import tqdm
import wandb
import numpy as np
import random

from data.data_preprocess import HeteroAddLaplacianEigenvectorPE, SubSample
from data.dataset import LPDataset
from data.utils import args_set_bool, collate_fn_ip
from models.hetero_gnn import TripartiteHeteroGNN_
from trainer import Trainer

In [2]:
# of outer loop: ipm_steps
# of inner loop: num_conv_layers

var_dict = {
            "weight_decay": 0,
            "micro_batch": 4,         
            "batchsize": 128,         
            "hidden": 180, 
            "num_conv_layers": 2,     
            "num_pred_layers": 4, 
            "num_mlp_layers": 4, 
            "share_lin_weight": True, 
            "conv_sequence": 'cov', 
            "loss_weight_x": 1.0, 
            "loss_weight_obj": 3.43, 
            "loss_weight_cons": 5.8,    
            "losstype": 'l2',
            "runs": 3,
            "lappe": 0, 
            "conv": 'gcnconv', 
            "epoch": 1,  
            "ipm_alpha": 0.7,
            "ipm_steps": 16,   
            "dropout": 0,
            "share_conv_weight": True,        
            "use_norm": True,
            "use_res": True,  
            "lr": 1.e-5,  
            "weight_decay": 0
            }

In [3]:
wandb.init(project='xxx',
           config=var_dict,
           entity="xxx")
# use you own name for project='xxx' and entity="xxx"

In [4]:
if wandb.run is not None:
    print('wandb running')
else:
    print('wandb not running')

In [5]:
# preprocess the data for training or testing using the raw data instances in ./raw/raw folder
# use this 'train_ins', 'test_ins' to differentiate 

train_ins = 'train_b4' 
test_ins = 'test_b4'

ipm = 16
train_dataset = LPDataset('raw',
                    extra_path=f'{1}restarts_'
                                     f'{0}lap_'
                                     f'{ipm}steps'
                                     f'{"_upper_" + str(train_ins)}',
                    upper_bound=1,
                    rand_starts=1,
                    pre_transform=Compose([HeteroAddLaplacianEigenvectorPE(k=0),
                                                 SubSample(ipm)]))

test_dataset = LPDataset('raw',
                    extra_path=f'{1}restarts_'
                                     f'{0}lap_'
                                     f'{ipm}steps'
                                     f'{"_upper_" + str(test_ins)}',
                    upper_bound=1,
                    rand_starts=1,
                    pre_transform=Compose([HeteroAddLaplacianEigenvectorPE(k=0),
                                                 SubSample(ipm)]))

# train and test on different dataset

train_loader = DataLoader(train_dataset,
                          batch_size=var_dict['batchsize'],
                          shuffle=True,
                          num_workers=1,
                          collate_fn=collate_fn_ip)
val_loader = DataLoader(train_dataset[int(len(train_dataset) * 0.9):],
                        batch_size=var_dict['batchsize'],
                        shuffle=True,
                        num_workers=1,
                        collate_fn=collate_fn_ip)
test_loader = DataLoader(test_dataset,
                        batch_size=var_dict['batchsize'],
                        shuffle=False,
                        num_workers=1,
                        collate_fn=collate_fn_ip)

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
seed = 2026     # 2026, 2027, 2028
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

best_val_objgap_mean = []
best_val_consgap_mean = []
test_objgap_mean = []
test_consgap_mean = []
test_objgap_nocon_mean = []

for run in range(1):
    
    if not os.path.isdir('logs'):
        os.mkdir('logs')
    exist_runs = [d for d in os.listdir('logs') if d.startswith('exp')]
    log_folder_name = f'logs/exp{len(exist_runs)}'
    os.mkdir(log_folder_name)
    with open(os.path.join(log_folder_name, 'config.yaml'), 'w') as outfile:
        yaml.dump(var_dict, outfile, default_flow_style=False)
            
    os.mkdir(os.path.join(log_folder_name, f'run{run}'))

    model = TripartiteHeteroGNN_(ipm_steps=var_dict['ipm_steps'],
                               conv=var_dict['conv'],
                               in_shape=2,
                               pe_dim=var_dict['lappe'],
                               hid_dim=var_dict['hidden'],
                               num_conv_layers=var_dict['num_conv_layers'],
                               num_pred_layers=var_dict['num_pred_layers'],
                               num_mlp_layers=var_dict['num_mlp_layers'],
                               dropout=var_dict['dropout'],
                               share_conv_weight=var_dict['share_conv_weight'],
                               share_lin_weight=var_dict['share_lin_weight'],
                               use_norm=var_dict['use_norm'],
                               use_res=var_dict['use_res'],
                               conv_sequence=var_dict['conv_sequence']).to(device)
    
    best_model = copy.deepcopy(model.state_dict())

    optimizer = optim.Adam(model.parameters(), lr=var_dict['lr'], weight_decay=var_dict['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1.e-6)

    trainer = Trainer(device,
                      'primal+objgap+constraint',
                      var_dict['losstype'],
                      var_dict['micro_batch'],
                      var_dict['ipm_steps'],
                      var_dict['ipm_alpha'],
                      loss_weight={'primal': var_dict['loss_weight_x'],
                                   'objgap': var_dict['loss_weight_obj'],
                                   'constraint': var_dict['loss_weight_cons']})

    pbar = tqdm(range(var_dict['epoch']))
    curr = time.time()
    for epoch in pbar:
        train_loss, primal_loss, obj_loss, cons_loss = trainer.train_(train_loader, model, optimizer)
        with torch.no_grad():

            val_gaps, val_constraint_gap, val_gaps_nocon = trainer.eval_metrics_(val_loader, model)

            # metric to cache the best model
            cur_mean_gap = val_gaps[:, -1].mean().item()
            cur_cons_gap_mean = val_constraint_gap[:, -1].mean().item()
            if scheduler is not None:
                scheduler.step(cur_mean_gap)
                
            torch.save(model.state_dict(), os.path.join(log_folder_name, f'run{run}', str(epoch)+'_model.pt'))
            
            if trainer.best_val_objgap > cur_mean_gap:
                trainer.patience = 0
                trainer.best_val_objgap = cur_mean_gap
                trainer.best_val_consgap = cur_cons_gap_mean
                best_model = copy.deepcopy(model.state_dict())
  
                torch.save(model.state_dict(), os.path.join(log_folder_name, f'run{run}', str(epoch)+'_best_model.pt'))


        pbar.set_postfix({'train_loss': train_loss,
                          'primal_loss': primal_loss,
                          'obj_loss': obj_loss,
                          'cons_loss': cons_loss,
                          'val_obj': cur_mean_gap,
                          'val_cons': cur_cons_gap_mean,
                          'lr': scheduler.optimizer.param_groups[0]["lr"]})
        log_dict = {'train_loss': train_loss,
                    'primal_loss': primal_loss,
                    'obj_loss': obj_loss,
                    'cons_loss': cons_loss,
                    'val_obj_gap_last_mean': cur_mean_gap,
                    'val_cons_gap_last_mean': cur_cons_gap_mean,
                   'lr': scheduler.optimizer.param_groups[0]["lr"]}

        wandb.log(log_dict)
    print('time:', time.time()-curr)

    best_val_objgap_mean.append(trainer.best_val_objgap)
    best_val_consgap_mean.append(trainer.best_val_consgap)

    model.load_state_dict(best_model)
    
    with torch.no_grad():

        test_gaps, test_cons_gap, test_gaps_nocon = trainer.eval_metrics_(test_loader, model)

    
    test_objgap_mean.append(test_gaps[:, -1].mean().item())
    test_consgap_mean.append(test_cons_gap[:, -1].mean().item())
    test_objgap_nocon_mean.append(test_gaps_nocon[:, -1].mean().item())

    wandb.log({'test_objgap': test_objgap_mean[-1]})
    wandb.log({'test_consgap': test_consgap_mean[-1]})
    wandb.log({'test_objgap_nocon': test_objgap_nocon_mean[-1]})


wandb.log({
    'best_val_objgap': np.mean(best_val_objgap_mean),
    'test_objgap_mean': np.mean(test_objgap_mean),
    'test_objgap_std': np.std(test_objgap_mean),
    'test_consgap_mean': np.mean(test_consgap_mean),
    'test_consgap_std': np.std(test_consgap_mean),
    'test_hybrid_gap': np.mean(test_objgap_mean) + np.mean(test_consgap_mean),  # for the sweep
})
wandb.finish()
