# Import libraries

In [1]:
%load_ext autoreload
%autoreload 2
from deepgd import *

# Config

In [2]:
cuda_idx = 0
canonicalization = CanonicalizationByStress()
config = StaticConfig({
    "name": 'GAN(gan=rgan,data=best(xing,random),canon=stress,conv=3,share=6,embed=10)',
    "uid": None,
    "link": None,
    "generator": {
        "params": {
            "num_blocks": 9,
            "normalize": canonicalization
        },
        "pretrained": {
            "name": None,
            "epoch": -1,
        },
        "optim": torch.optim.AdamW,
        "lr" : {
            "initial": 1e-3,
            "decay": 0.99,
        },
    },
    "discriminator": {
        "params": {
            "conv": [2, 16, 16, 16],
            "dense": [2],
            "shared_depth": 6,
            "enet_depth": 10,
            "enet_width": 64,
            "aggr": "add",
            "normalize": canonicalization
        },
        "pretrained": {
            "name": None,
            "epoch": -1,
        },
        "optim": torch.optim.AdamW,
        "lr" : {
            "initial": 1e-3,
            "decay": 0.99,
        },
        "noise": {
            "std": 0,
            "decay": 0.95,
        },
        "repeat": 1,
        "complete": True,
        "adaptive": True
    },
    "alternate": "epoch",
    "batchsize": 24,
    "epoch": {
        "start": -1,
        "end": None,
    },
    "log_interval": 1,
    "test": {
        "name": "test",
        "epoch": -1,
    },
    "gan_flavor": "rgan",
    "gp_weight": 0,
})
data_config = StaticConfig({
    "sparse": False,
    "pivot": None,
    "init": "pmds",
    "edge": {
        "index": "full_edge_index",
        "attr": "full_edge_attr",
    },
})
loss_fns = {
    Stress(): 1
}
ctrler_params = {
    "tau": 0.95,
    "beta": 1,
    "exploit_rate": 0.5,
    "warmup": 2,
}
paths = StaticConfig({
    "root": "artifacts",
    "checkpoints": lambda: f"{paths.root}/checkpoints/{config.name}",
    "gen_pretrain": lambda: f"{paths.root}/checkpoints/{config.generator.pretrained.name}",
    "dis_pretrain": lambda: f"{paths.root}/checkpoints/{config.discriminator.pretrained.name}",
    "tensorboard": lambda: f"{paths.root}/tensorboards/{config.name}",
    "visualization": lambda: f"{paths.root}/visualizations/{config.name}_{config.test.name}",
    "log": lambda: f"{paths.root}/logs/{config.name}.log",
    "metrics": lambda suffix: f"{paths.root}/metrics/{config.name}_{suffix}.pickle",
})

In [3]:
if " " in config.name:
    raise Exception("Space is not allowed in model name.")

# Prepare

## Get log command

In [4]:
print(f"cd {os.getcwd()} && tail -n1000 -f '{paths.log()}'")

cd /users/PAS0027/osu10203/deepgd && tail -n1000 -f 'artifacts/logs/GAN(gan=rgan,data=best(xing,random),canon=stress,conv=3,share=6,embed=10).log'


In [5]:
print(f"tensorboard dev upload --logdir '{paths.tensorboard()}'")

tensorboard dev upload --logdir 'artifacts/tensorboards/GAN(gan=rgan,data=best(xing,random),canon=stress,conv=3,share=6,embed=10)'


## Set globals

In [6]:
if cuda_idx is not None and torch.cuda.is_available():
    device = f'cuda:{cuda_idx}'
    pynvml.nvmlInit()
    cuda = pynvml.nvmlDeviceGetHandleByIndex(cuda_idx)
else:
    device = 'cpu'
    cuda =  None
np.set_printoptions(precision=2)
warnings.filterwarnings("ignore", category=RuntimeWarning)

## Load data

In [7]:
G_list = load_G_list(data_path='data/rome', index_file='data_index.txt', cache='G_list', cache_prefix='cache/')
data_list = generate_data_list(G_list, 
                               sparse=data_config.sparse, 
                               pivot_mode=data_config.pivot,
                               init_mode=data_config.init,
                               edge_index=data_config.edge.index,
                               edge_attr=data_config.edge.attr,
                               pmds_list=np.load('layouts/rome/pmds.npy', allow_pickle=True),
                               gviz_list=np.load('layouts/rome/gviz.npy', allow_pickle=True),
                               noisy_layout=True,
                               device='cpu', 
                               cache=True,
                               cache_prefix='cache/')
train_loader = LazyDeviceMappingDataLoader(data_list[:10000], batch_size=config.batchsize, shuffle=True, device=device)
val_loader = LazyDeviceMappingDataLoader(data_list[11000:], batch_size=config.batchsize, shuffle=False, device=device)

Load from 'cache/G_list.pickle'
Load from 'cache/generate_data_list(list,sparse=False,pivot_mode=None,init_mode=pmds,edge_index=full_edge_index,edge_attr=full_edge_attr,pmds_list=ndarray,gviz_list=ndarray,noisy_layout=True,device=cpu).pickle'




In [8]:
def draw_layout(G, method, draw=True):
    if method == 'fa2':
        layout = get_fa2_layout(G)
    else:
        try:
            fn = getattr(nx.drawing.layout, f'{method}_layout')
            layout = fn(G)
        except:
            layout = nx.drawing.nx_agraph.graphviz_layout(G, prog=method)
    if draw:
        nx.draw(G, pos=layout)
    return layout

In [9]:
methods = ['neato', 'dot', 'fdp', 'sfdp', 'twopi', 'circo', 'shell', 'spring', 'circular', 'spectral', 'kamada_kawai', 'fa2', 'pmds']

In [10]:
from functools import lru_cache

@lru_cache(maxsize=None)
def load_pos(method):
    return np.load(f'layouts/rome/{method}.npy', allow_pickle=True)

In [11]:
# best_list = []
# best_layout_list = []
# for idx, (G, data) in enumerate(zip(tqdm(G_list), data_list)):
#     xing, stress, layout = {}, {}, {}
#     for m in methods:
#         batch = Batch.from_data_list([data])
#         pos = load_pos(m)
#         p = CanonicalizationByStress()(torch.tensor(pos[idx]).float(), batch)
#         x = Xing()(p, batch).item()
#         s = Stress()(p, batch).item()
#         xing[m] = x
#         stress[m] = s
#         layout[m] = p.numpy()
#         # plt.figure()
#         # graph_vis(G, pos[idx])
#         # plt.title(f'{m} stress={s:.2f} xing={x}')
#     best, *_ = sorted(methods, key=lambda m: (xing[m], stress[m]))
#     best_list.append(best)
#     best_layout_list.append(layout[best])
#     print(f'{best}, xing={xing[best]}, stress={stress[best]:.2f}')

In [12]:
# pickle.dump(best_list, open('layouts/rome/best[xing,stress].pkl', 'wb'))
# np.save('layouts/rome/best[xing,stress].npy', best_layout_list)

In [13]:
best_layout_list = np.load('layouts/rome/best[xing,random].npy', allow_pickle=True)
for data, layout in zip(tqdm(data_list), best_layout_list):
    data.gt_pos = torch.tensor(layout)

  0%|          | 0/11531 [00:00<?, ?it/s]

In [14]:
train_loader = LazyDeviceMappingDataLoader(data_list[:10000], batch_size=config.batchsize, shuffle=True, device=device)
val_loader = LazyDeviceMappingDataLoader(data_list[11000:], batch_size=config.batchsize, shuffle=False, device=device)

In [15]:
# for m in methods:
#     layouts = []
#     for G in tqdm(G_list):
#         layout = draw_layout(G, method=m, draw=False)
#         layouts.append(np.array(list(layout.values())))
#     np.save(f'layouts/rome/{m}.npy', layouts)

In [16]:
# from collections import Counter

# letter_counts = Counter(best_list)
# df = pd.DataFrame.from_dict(letter_counts, orient='index')
# df.plot(kind='bar', figsize=[12, 8])

## Create folders

In [17]:
mkdirs(paths.checkpoints(), paths.tensorboard(), paths.visualization(), f"{paths.root}/logs", f"{paths.root}/metrics")

## Load checkpoints

In [18]:
class EdgeFeatureDiscriminator(nn.Module):
    def __init__(self, 
                 conv, 
                 dense,
                 shared_depth,
                 enet_depth,
                 enet_width,
                 aggr='add', 
                 root_weight=True,
                 normalize=None):
        super().__init__()
        self.enet = nn.Sequential(*[
            DenseLayer(in_dim=in_d,
                       out_dim=out_d,
                       skip=nonlin,
                       bn=nonlin,
                       act=nonlin,
                       dp=None)
            for in_d, out_d, nonlin 
            in zip([self._get_feature_dim()] + [enet_width] * (shared_depth-1),
                   [enet_width] * shared_depth,
                   [True] * (shared_depth-1) + [False])     
        ])
        self.blocks = nn.ModuleList([
            GNNLayer(nfeat_dims=(in_d, out_d),
                     efeat_dim=enet_width,
                     edge_net=EdgeNet(nfeat_dims=(in_d, out_d), 
                                      efeat_dim=enet_width, 
                                      depth=enet_depth, 
                                      width=enet_width),
                     aggr=aggr,
                     dense=False,
                     skip=nonlin,
                     bn=nonlin,
                     act=nonlin,
                     root_weight=root_weight) 
            for in_d, out_d, nonlin 
            in zip(conv[:-1], conv[1:], [True] * (len(conv)-2) + [False])     
        ])
        self.pool = (gnn.global_mean_pool if aggr == 'mean' 
                     else gnn.global_add_pool if aggr == 'add' 
                     else None)
        self.dense = nn.Identity() if not dense else nn.Sequential(*[
            DenseLayer(in_dim=in_d,
                       out_dim=out_d,
                       skip=nonlin,
                       bn=nonlin,
                       act=nonlin,
                       dp=None)
            for in_d, out_d, nonlin 
            in zip(conv[-1:] + dense[:-1], dense, [True] * (len(dense)-1) + [False])     
        ])
        self.normalize = normalize or IdentityTransformation()
    
    def forward(self, batch):
        x = torch.ones_like(batch.pos)
        e = self.enet(self._get_features(self._get_edge_info(batch, layout='pos')))
        for block in self.blocks:
            x = block(x, e, batch)
        x = self.pool(x, batch.batch)
        x = self.dense(x)
        return x
    
    def _get_edge_info(self, batch, layout='gt_pos'):
        pos = self.normalize(batch[layout].float(), batch)
        src, dst = get_edges(pos, batch)
        v, u = l2_normalize(dst - src, return_norm=True)
        d = batch.edge_attr[:, :1]
        return {
            "src": src,
            "dst": dst,
            "v": v,
            "u": u,
            "d": d,
        }

    def _get_features(self, edges):
        return torch.cat([edges['src'], edges['dst'], edges['d']], dim=1)
        
    def _get_feature_dim(self):
        return self._get_features({
            "src": torch.zeros(1, 2),
            "dst": torch.zeros(1, 2),
            "v": torch.zeros(1, 2),
            "u": torch.zeros(1, 1),
            "d": torch.zeros(1, 1),
        }).shape[-1]

In [19]:
class StressDiscriminator(nn.Module):
    def __init__(self, normalize=CanonicalizationByStress(), **kwargs):
        super().__init__()
        self.dummy = nn.Parameter(torch.zeros(1))
        self.normalize = normalize
        self.stress = Stress(reduce=None)

    def forward(self, batch):
        return -self.stress(self.normalize(batch.pos, batch), batch)

In [20]:
def get_ckpt_epoch(folder, epoch):
    if not os.path.isdir(folder):
        os.mkdir(folder)
    if epoch >= 0:
        return epoch
    ckpt_files = os.listdir(folder)
    last_epoch = 0
    if ckpt_files:
        last_epoch = sorted(list(map(lambda x: int(re.search('(?<=epoch_)(\d+)(?=\.)', x).group(1)), ckpt_files)))[-1]
    return last_epoch + epoch + 1

def start_epoch():
    return get_ckpt_epoch(paths.checkpoints(), config.epoch.start)

In [21]:
generator = Generator(**config.generator.params[...]).to(device)
generator_optimizer = config.generator.optim(generator.parameters(), lr=config.generator.lr.initial * config.generator.lr.decay ** start_epoch())
generator_scheduler = torch.optim.lr_scheduler.ExponentialLR(generator_optimizer, gamma=config.generator.lr.decay)
if start_epoch() != 0:
    gen_ckpt_epoch = start_epoch()
elif config.generator.pretrained.name is not None and config.generator.pretrained.epoch != 0:
    gen_pretrained_epoch = get_ckpt_epoch(paths.gen_pretrain(), config.generator.pretrained.epoch)
    gen_ckpt_epoch = gen_pretrained_epoch 
else:
    gen_ckpt_epoch = None
if gen_ckpt_epoch is not None:
    # Load generator
    gen_ckpt_file = f"{paths.checkpoints()}/gen_epoch_{gen_ckpt_epoch}.pt"
    print(f"Loading from {gen_ckpt_file}...")
    generator.load_state_dict(torch.load(gen_ckpt_file, map_location=torch.device(device)))
    # Load generator optimizer
    gen_optim_ckpt_file = f"{paths.checkpoints()}/gen_optim_epoch_{gen_ckpt_epoch}.pt"
    print(f"Loading from {gen_optim_ckpt_file}...")
    generator_optimizer.load_state_dict(torch.load(gen_optim_ckpt_file, map_location=torch.device(device)))

Loading from artifacts/checkpoints/GAN(gan=rgan,data=best(xing,random),canon=stress,conv=3,share=6,embed=10)/gen_epoch_663.pt...
Loading from artifacts/checkpoints/GAN(gan=rgan,data=best(xing,random),canon=stress,conv=3,share=6,embed=10)/gen_optim_epoch_663.pt...


In [22]:
discriminator = EdgeFeatureDiscriminator(**config.discriminator.params[...]).to(device)
discriminator_optimizer = config.discriminator.optim(discriminator.parameters(), lr=config.discriminator.lr.initial * config.discriminator.lr.decay ** start_epoch())
discriminator_scheduler = torch.optim.lr_scheduler.ExponentialLR(discriminator_optimizer, gamma=config.discriminator.lr.decay)
if start_epoch() != 0:
    dis_ckpt_epoch = start_epoch()
elif config.discriminator.pretrained.name is not None and config.discriminator.pretrained.epoch != 0:
    dis_pretrained_epoch = get_ckpt_epoch(paths.dis_pretrain(), config.discriminator.pretrained.epoch)
    dis_ckpt_epoch = dis_pretrained_epoch # f"{paths.dis_pretrain()}/dis_epoch_{dis_pretrained_epoch}.pt"
else:
    dis_ckpt_epoch = None
if dis_ckpt_epoch is not None:
    # Load discriminator
    dis_ckpt_file = f"{paths.checkpoints()}/dis_epoch_{dis_ckpt_epoch}.pt"
    print(f"Loading from {dis_ckpt_file}...")
    discriminator.load_state_dict(torch.load(dis_ckpt_file, map_location=torch.device(device)))
    # Load discriminator optimizer
    dis_optim_ckpt_file = f"{paths.checkpoints()}/dis_optim_epoch_{dis_ckpt_epoch}.pt"
    print(f"Loading from {dis_optim_ckpt_file}...")
    discriminator_optimizer.load_state_dict(torch.load(dis_optim_ckpt_file, map_location=torch.device(device)))

Loading from artifacts/checkpoints/GAN(gan=rgan,data=best(xing,random),canon=stress,conv=3,share=6,embed=10)/dis_epoch_663.pt...
Loading from artifacts/checkpoints/GAN(gan=rgan,data=best(xing,random),canon=stress,conv=3,share=6,embed=10)/dis_optim_epoch_663.pt...


# Train

In [23]:
stress_criterion = StressDiscriminator().to(device)
val_criterion = Stress(reduce=None)
xing_criterion = Xing(reduce=None)
dis_convert = DiscriminatorDataConverter(complete_graph=config.discriminator.complete, normalize=config.discriminator.params.normalize)
tensorboard = SummaryWriter(log_dir=paths.tensorboard())
epoch = start_epoch() + 1

In [24]:
def gradient_penalty(interpolated, discriminator, weight=10):
    interpolated.pos.requires_grad_()
    prob_interpolated = discriminator(interpolated)
    gradients = autograd.grad(outputs=prob_interpolated, 
                              inputs=interpolated.pos,
                              grad_outputs=torch.ones_like(prob_interpolated),
                              create_graph=True, 
                              retain_graph=True, 
                              allow_unused=True)[0]
    gradients_norm = torch.sqrt(gnn.global_add_pool(gradients.square().sum(dim=1), batch.batch) + 1e-8)
    return weight * ((gradients_norm - 1) ** 2).mean()

In [25]:
def get_gp_loss(batch, fake_pos, weight):
    if weight > 0:
        interp = dis_convert(batch, fake_pos, random.random())
        return gradient_penalty(interp, discriminator, weight).mean()
    return 0

In [26]:
def get_sgan_loss(batch, fake_pos, mode='discriminator'):
    real = dis_convert(batch)
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(merge_batch(real, fake)).view(2, -1).T
    if mode == 'discriminator':
        label = torch.zeros(pred.shape[0]).long()
    elif mode == 'generator':
        label = torch.ones(pred.shape[0]).long()
    else:
        raise Exception
    return nn.CrossEntropyLoss()(pred, label)

In [27]:
def get_rgan_loss(batch, fake_pos, mode='discriminator'):
    real = dis_convert(batch)
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(merge_batch(real, fake)).view(2, -1).T
    real_pred, fake_pred = pred[:,0], pred[:,1]
    if mode == 'discriminator':
        losses = - F.logsigmoid(real_pred - fake_pred)
    elif mode == 'generator':
        losses = - F.logsigmoid(fake_pred - real_pred)
    else:
        raise Exception
    return losses.mean()

In [28]:
def get_wgan_loss(batch, fake_pos, mode='discriminator'):
    real = dis_convert(batch)
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(merge_batch(real, fake)).view(2, -1).T
    real_pred, fake_pred = pred[:,0], pred[:,1]
    if mode == 'discriminator':
        losses = fake_pred - real_pred 
    elif mode == 'generator':
        losses = real_pred - fake_pred
    else:
        raise Exception
    return losses.mean()

In [29]:
def get_ragan_loss(batch, fake_pos, mode='discriminator'):
    real = dis_convert(batch)
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(merge_batch(real, fake)).view(2, -1).T
    real_pred, fake_pred = pred[:,0], pred[:,1]
    if mode == 'discriminator':
        losses = - F.logsigmoid(real_pred - fake_pred.mean()) - F.logsigmoid(real_pred.mean() - fake_pred)
    elif mode == 'generator':
        losses = - F.logsigmoid(fake_pred - real_pred.mean()) - F.logsigmoid(fake_pred.mean() - real_pred)
    else:
        raise Exception
    return losses.mean()

In [30]:
def get_dgdv2_loss(batch, fake_pos, mode='discriminator'):
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(fake)
    if mode == 'discriminator':
        gt = get_gt(batch, fake_pos)
        loss = criterion(pred, gt)
    elif mode == 'generator':
        losses = pred.sum(dim=0)
#         losses = torch.tensor(config.importance).to(device) * losses #/ losses.detach()
        loss = losses.sum()
    else:
        raise Exception
    return -loss

In [31]:
def get_gan_loss(batch, fake_pos, mode='discriminator'):
    return {"sgan": get_sgan_loss,
            "wgan": get_wgan_loss,
            "rgan": get_rgan_loss,
            "ragan": get_ragan_loss,
            "dgdv2": get_dgdv2_loss}[config.gan_flavor](batch, fake_pos, mode)

In [32]:
def train_dis(batch, epoch):
    generator.requires_grad_(False)
    discriminator.zero_grad()
    generator_output = generator(batch)
    if config.discriminator.noise.std > 0:
        generator_output = generator_output + torch.randn_like(generator_output) * config.discriminator.noise.std * config.discriminator.noise.decay ** epoch
    discriminator_loss = get_gan_loss(batch, generator_output, mode='discriminator')
    
    # train discriminator
    discriminator_loss.backward()
    discriminator_optimizer.step()

    # gradient penalty
    if config.gp_weight > 0:
        discriminator.zero_grad()
        gp_loss = get_gp_loss(batch, generator_output, config.gp_weight)
        gp_loss.backward()
        discriminator_optimizer.step()

    hud['dis_loss'] = format(discriminator_loss.item(), '.2e')
    pbar().update()

def train_gen(batch, epoch):
    generator.requires_grad_(True)
    generator.zero_grad()
    discriminator.zero_grad()
    generator_output = generator(batch)
    if config.discriminator.noise.std > 0:
        generator_output = generator_output + torch.randn_like(generator_output) * config.discriminator.noise.std * config.discriminator.noise.decay ** epoch
    generator_loss = get_gan_loss(batch, generator_output, mode='generator') 
    
    #train generator
    generator_loss.backward()
    generator_optimizer.step()

    with torch.no_grad():
        dis_batch = dis_convert(batch, generator_output)
        stress = stress_criterion(dis_batch).mean()
        critic = discriminator(dis_batch).mean()
    hud.append({'gen_loss': format(generator_loss.item(), '.2e'),
                'stress': format(stress.item(), '.2e'),
                'critic': format(critic.item(), '.2e')})
    pbar().update()

def cuda_memsafe_map(fn, *iterables, summary=False):
    total, failed = 0, 0
    iterator = zip(*iterables)
    items = None
    while True:
        try:
            items = next(iterator)
            yield fn(*items)
        except StopIteration:
            if summary:
                print(f'Iteration finished. {failed} out of {total} failed!')
            break
        except RuntimeError:
            print('CUDA memory overflow! Skip batch...')
            del items
            failed += 1
        torch.cuda.empty_cache()
        total += 1
    
def validate(model, data_loader, criterion=val_criterion):
    def val_one_batch(batch):
        batch = preprocess_batch(model, batch)
        pred = CanonicalizationByStress()(model(batch), batch)
        gt = CanonicalizationByStress()(batch.gt_pos, batch)
        loss = criterion(pred, batch)
        gt_loss = criterion(gt, batch)
        spc = (loss - gt_loss) / torch.maximum(torch.maximum(loss, gt_loss), torch.ones_like(loss)*1e-5)
        return loss.mean().item(), spc.mean().item()
    loss_all, spc_all = zip(*cuda_memsafe_map(val_one_batch, data_loader))
    return np.mean(loss_all), np.mean(spc_all)

def log(msg):
    msg = f"[{epoch:03}] {msg}"
    print(msg, file=open(paths.log(), "a"))
    with log_out: 
        print(msg)

print(f"{'='*10} {config.link} {'='*10}", file=open(paths.log(), "a"))
hud = Hud()
pbar = Wrapper(tqdm, total=len(train_loader)*2, smoothing=0)
plot_out = Output()
log_out = Output()
tabs = {"status": VBox([pbar, hud]), 
        "plot": HBox([plot_out], layout=Layout(height='500px', overflow_y='auto')),
        "log": HBox([log_out], layout=Layout(height='500px', overflow_y='auto'))}
tab_bar = Tab(children=list(tabs.values()))
[tab_bar.set_title(i, name) for i, name in enumerate(tabs)]
display(tab_bar)
while True:
    if epoch % config.log_interval == 0:
        generator.eval()
        with torch.no_grad():
            val_stress, val_stress_spc = validate(model=generator, data_loader=val_loader)
            val_xing, val_xing_spc = validate(model=generator, data_loader=val_loader, criterion=xing_criterion)
            with plot_out:
                fig = plt.figure()
                graph_vis(G_list[11100], generator(make_batch(data_list[11100]).to(device)).cpu())
                plt.show()
        # tensorboard.add_scalars('loss', {'train': train_loss, 
        #                                  'validation': val_loss}, epoch)
        # for i, fn in enumerate(loss_fns):
        #     tensorboard.add_scalars(type(fn).__name__, {'train': train_loss_comp[i].item(), 
        #                                           'validation': val_loss_comp[i].item()}, epoch)
        hud.append({
            'val_stress': format(val_stress, '.2f'),
            'val_stress_spc': format(val_stress_spc, '.2%'),
            'val_xing': format(val_xing, '.2f'),
            'val_xing_spc': format(val_xing_spc, '.2%'),
        })
        log(f"stress={hud.data['val_stress']}({hud.data['val_stress_spc']}) xing={hud.data['val_xing']}({hud.data['val_xing_spc']})")
        
    # handle.update(tab_bar)
    pbar().reset()
    pbar().set_description(desc=f"[epoch {epoch}/{config.epoch.end}]")
    hud(title=f"epoch {epoch}")
    generator.train()
    discriminator.train()
    # proper: proper layout
    for _ in range(config.discriminator.repeat):
        for batch in train_loader:
            train_dis(batch, epoch)
            if config.alternate == 'iteration':
                train_gen(batch, epoch)

    if config.alternate == 'epoch':
        for batch in train_loader:
            train_gen(batch, epoch)

    discriminator_scheduler.step()
    generator_scheduler.step()
    
    if epoch % config.log_interval == 0:
        torch.save(generator.state_dict(), f"{paths.checkpoints()}/gen_epoch_{epoch}.pt")
        torch.save(generator_optimizer.state_dict(), f"{paths.checkpoints()}/gen_optim_epoch_{epoch}.pt")
        torch.save(discriminator.state_dict(), f"{paths.checkpoints()}/dis_epoch_{epoch}.pt")
        torch.save(discriminator_optimizer.state_dict(), f"{paths.checkpoints()}/dis_optim_epoch_{epoch}.pt")

    if epoch == config.epoch.end:
        break
    epoch += 1

Tab(children=(VBox(children=(Wrapper(), Hud())), HBox(children=(Output(),), layout=Layout(height='500px', over…

KeyboardInterrupt: 

# Test

In [33]:
test_epoch = -1

test_generator = Generator(**config.generator.params[...]).to(device)
test_ckpt_epoch = get_ckpt_epoch(paths.checkpoints(), test_epoch)
test_ckpt_file = f"{paths.checkpoints()}/gen_epoch_{test_ckpt_epoch}.pt"
print(f"Loading from {test_ckpt_file}...")
test_generator.load_state_dict(torch.load(test_ckpt_file, map_location=torch.device(device)))

Loading from artifacts/checkpoints/GAN(gan=rgan,data=best(xing,random),canon=stress,conv=3,share=6,embed=10)/gen_epoch_665.pt...


<All keys matched successfully>

In [34]:
rotate = RotateByPrincipalComponents()
def test_callback(*, idx, pred, metrics):
    # graph_vis(G_list[idx], pred, file_name=f"{paths.visualization()}/{idx}_{metrics['stress']:.2f}_{metrics['resolution_score']:.2f}.png")
    pred = rotate(torch.tensor(pred), data_list[idx])
    graph_vis(G_list[idx], pred)
    plt.title(f"[pred] idx: {idx}, stress: {metrics['stress']:.2f}({metrics['stress_spc']:.2%}), xing: {metrics['xing']:.2f}({metrics['xing_spc']:.2%})")
    plt.show()
    gt_pos = rotate(data_list[idx].gt_pos, data_list[idx])
    graph_vis(G_list[idx], gt_pos, node_color='orange')
    plt.title(f"[gt] idx: {idx}, stress: {metrics['gt_stress']:.2f}, xing: {metrics['gt_xing']:.2f}")
    plt.show()
    
test_metrics = test(model=test_generator, 
                    criteria_list=[], 
                    dataset=data_list, 
                    idx_range=range(10000, 11000), 
#                     callback=test_callback,
                    gt_pos=None)
pickle.dump(test_metrics, open(paths.metrics("test"), "wb"))

  0%|          | 0/1000 [00:00<?, ?it/s]

In [35]:
metrics = test_metrics
print('stress:', metrics['stress'].mean())
print('stress_spc:', metrics['stress_spc'].mean())
print('xing:', metrics['xing'].mean())
print('xing_spc:', metrics['xing_spc'].mean())
print('l1_angle:', metrics['l1_angle'].mean())
print('l1_angle_spc:', metrics['l1_angle_spc'].mean())
print('edge:', metrics['edge'].mean())
print('edge_spc:', metrics['edge_spc'].mean())
print('ring:', metrics['ring'].mean())
print('ring_spc:', metrics['ring_spc'].mean())
print('tsne:', metrics['tsne'].mean())
print('tsne_spc:', metrics['tsne_spc'].mean())
print('reso_score:', metrics['resolution_score'].mean())
print('min_angle:', metrics['min_angle'].mean())
columns = [
    'stress',
    'stress_spc',
    'xing',
    'xing_spc',
    'l1_angle',
    'l1_angle_spc',
    'edge',
    'edge_spc',
    'ring',
    'ring_spc',
    'tsne',
    'tsne_spc',
    'reso_score',
    'min_angle'
]
df = pd.DataFrame(map(lambda m: metrics[m].mean().item(), list(metrics.keys())[:-1])).set_axis(columns).T
df.style.format({c: "{:.2f}" for c in columns if 'spc' not in c} | {c: "{:.2%}" for c in columns if 'spc' in c})

stress: tensor(350.2226)
stress_spc: tensor(0.0429, dtype=torch.float64)
xing: tensor(27.9300)
xing_spc: tensor(0.2003, dtype=torch.float64)
l1_angle: tensor(79.6410)
l1_angle_spc: tensor(0.0465, dtype=torch.float64)
edge: tensor(0.2071)
edge_spc: tensor(0.1398, dtype=torch.float64)
ring: tensor(260.3856)
ring_spc: tensor(0.0066, dtype=torch.float64)
tsne: tensor(0.2077)
tsne_spc: tensor(0.0495, dtype=torch.float64)
reso_score: tensor(0.5955)
min_angle: tensor(4.5453)


Unnamed: 0,stress,stress_spc,xing,xing_spc,l1_angle,l1_angle_spc,edge,edge_spc,ring,ring_spc,tsne,tsne_spc,reso_score,min_angle
0,350.22,4.29%,27.93,20.03%,79.64,4.65%,0.21,13.98%,260.39,0.66%,0.21,4.95%,0.6,4.55


In [36]:
metrics = {}

In [37]:
methods = ['fa2']

In [38]:
for m in methods:
    metrics[m] = test(model=test_generator, 
                      criteria_list=[], 
                      dataset=data_list, 
                      idx_range=range(10000, 11000), 
                      callback=None,
                      gt_pos=load_pos(m))

  0%|          | 0/1000 [00:00<?, ?it/s]

In [39]:
mean_metrics = {key : list(map(lambda m: metrics[key][m].mean().item(), list(metrics[key].keys())[:-1])) for key in metrics}

In [40]:
columns = [
    'stress',
    'stress_spc',
    'xing',
    'xing_spc',
    'l1_angle',
    'l1_angle_spc',
    'edge',
    'edge_spc',
    'ring',
    'ring_spc',
    'tsne',
    'tsne_spc',
    'reso_score',
    'min_angle'
]
df = pd.DataFrame(mean_metrics).set_axis(columns).T
df.style.format({c: "{:.2f}" for c in columns if 'spc' not in c} | {c: "{:.2%}" for c in columns if 'spc' in c})

Unnamed: 0,stress,stress_spc,xing,xing_spc,l1_angle,l1_angle_spc,edge,edge_spc,ring,ring_spc,tsne,tsne_spc,reso_score,min_angle
fa2,350.23,-0.21%,27.94,9.04%,79.63,8.74%,0.21,-0.23%,260.39,2.39%,0.21,7.52%,0.6,4.53


# Large Graph

In [None]:
scalability = pd.read_csv(f"/__artifacts__/data/scalability.csv", index_col="index")
scalability

In [None]:
rescale = CanonicalizationByStress()
stressfn = Stress()
rotate = RotateByPrincipalComponents()

In [None]:
stress_list = []
spc_list = []
pmds_list = np.load("layouts/new_large_graph/pmds.npy", allow_pickle=True)
gviz_list = np.load("layouts/new_large_graph/gviz.npy", allow_pickle=True)
with torch.no_grad():
    for idx, col in tqdm(scalability.iterrows(), total=len(scalability)):
        # if idx not in [406, 516]: continue
        torch.cuda.empty_cache()
        G = load_mtx(col['file'])
        G.remove_edges_from(nx.selfloop_edges(G))
        data = generate_data_list(G, 
                                sparse=data_config.sparse, 
                                pivot_mode=data_config.pivot,
                                init_mode=data_config.init,
                                edge_index=data_config.edge.index,
                                edge_attr=data_config.edge.attr,
                                pmds_list=pmds_list[idx],
                                gviz_list=gviz_list[idx],
                                device=device)
        batch = Batch.from_data_list([data]).to(device)
        # generator.train()
        # generator(batch)
        generator.eval()
        pred = generator(batch)
        pos = rotate(rescale(pred, batch), batch)
        gt = rotate(rescale(batch.gt_pos, batch), batch)
        stress = stressfn(pos, batch).item()
        gt_stress = stressfn(gt, batch).item()
        spc = (stress - gt_stress) / np.maximum(stress, gt_stress)
        stress_list.append(stress)
        spc_list.append(spc)

        np.save(f"/__artifacts__/gan_result/data/scalability_{idx}.npy", pos.cpu().numpy())
        graph_attr = dict(node_size=1, 
                        with_labels=False, 
                        labels=dict(zip(list(G.nodes), map(lambda n: n if type(n) is int else n[1:], list(G.nodes)))),
                        font_color="white", 
                        font_weight="bold",
                        font_size=12,
                        width=0.1)

        # gt_pos = pickle.load(open(f"/__artifacts__/data/scalability_{idx}_gt.pkl", "rb"))

        plt.figure(figsize=[12, 9])
        nx.draw(G, pos=gt.cpu().numpy(), node_color='orange', **graph_attr)
        plt.title(f"neato: large_{idx}")
        plt.axis("equal")
        plt.savefig(f"/__artifacts__/gan_result/output/{idx}_{col['name']}_{col['n']}_{spc}_nx.png", dpi=300)
        plt.show()

        plt.figure(figsize=[12, 9])
        graph_vis(G, pos.cpu().numpy(), **graph_attr)
        plt.title(f"dgd: large_{idx}, spc={spc:.2%}")
        plt.axis("equal")
        plt.savefig(f"/__artifacts__/gan_result/output/{idx}_{col['name']}_{col['n']}_{spc}_dgd.png", dpi=300)
        plt.show()