# Import libraries

In [1]:
%load_ext autoreload
%autoreload 2
from deepgd import *

# Config

In [2]:
cuda_idx = 0
canonicalization = IdentityTransformation()
config = StaticConfig({
    "name": 'deepgd(conv=gat_simple(self_loop=false,heads=8),lr=1e-2)',
    "uid": None,
    "link": None,
    "generator": {
        "params": {
            "num_blocks": 9,
            "normalize": canonicalization
        },
        "pretrained": {
            "name": None,
            "epoch": -1,
        },
        "optim": torch.optim.AdamW,
        "lr" : {
            "initial": 1e-2,
            "decay": 0.99,
        },
    },
    "discriminator": {
        "params": {
            "conv": [2, 16, 16, 16],
            "dense": [2],
            "shared_depth": 10,
            "enet_depth": 6,
            "enet_width": 64,
            "aggr": "add",
            "normalize": canonicalization
        },
        "pretrained": {
            "name": None,
            "epoch": -1,
        },
        "optim": torch.optim.AdamW,
        "lr" : {
            "initial": 1e-3,
            "decay": 0.99,
        },
        "noise": {
            "std": 0,
            "decay": 0.95,
        },
        "repeat": 1,
        "complete": True,
        "adaptive": True
    },
    "alternate": "epoch",
    "batchsize": 64,
    "epoch": {
        "start": -1,
        "end": None,
    },
    "log_interval": 1,
    "test": {
        "name": "test",
        "epoch": -1,
    },
    "gan_flavor": "dgdv2",
    "gp_weight": 0,
})
data_config = StaticConfig({
    "sparse": False,
    "pivot": None,
    "init": "pmds",
    "edge": {
        "index": "full_edge_index",
        "attr": "full_edge_attr",
    },
})
loss_fns = {
    Stress(): 1
}
ctrler_params = {
    "tau": 0.95,
    "beta": 1,
    "exploit_rate": 0.5,
    "warmup": 2,
}
paths = StaticConfig({
    "root": "artifacts",
    "checkpoints": lambda: f"{paths.root}/checkpoints/{config.name}",
    "gen_pretrain": lambda: f"{paths.root}/checkpoints/{config.generator.pretrained.name}",
    "dis_pretrain": lambda: f"{paths.root}/checkpoints/{config.discriminator.pretrained.name}",
    "tensorboard": lambda: f"{paths.root}/tensorboards/{config.name}",
    "visualization": lambda: f"{paths.root}/visualizations/{config.name}_{config.test.name}",
    "log": lambda: f"{paths.root}/logs/{config.name}.log",
    "metrics": lambda suffix: f"{paths.root}/metrics/{config.name}_{suffix}.pickle",
})

In [3]:
if " " in config.name:
    raise Exception("Space is not allowed in model name.")

# Prepare

## Get log command

In [4]:
print(f"cd {os.getcwd()} && tail -n1000 -f '{paths.log()}'")

cd /users/PAS0027/osu10203/deepgd && tail -n1000 -f 'artifacts/logs/deepgd(conv=gat_simple(self_loop=false,heads=8),lr=1e-2).log'


In [5]:
print(f"tensorboard dev upload --logdir '{paths.tensorboard()}'")

tensorboard dev upload --logdir 'artifacts/tensorboards/deepgd(conv=gat_simple(self_loop=false,heads=8),lr=1e-2)'


## Set globals

In [6]:
if cuda_idx is not None and torch.cuda.is_available():
    device = f'cuda:{cuda_idx}'
    pynvml.nvmlInit()
    cuda = pynvml.nvmlDeviceGetHandleByIndex(cuda_idx)
else:
    device = 'cpu'
    cuda =  None
np.set_printoptions(precision=2)
warnings.filterwarnings("ignore", category=RuntimeWarning)

## Load data

In [7]:
G_list = load_G_list(data_path='data/rome', index_file='data_index.txt', cache='G_list', cache_prefix='cache/')
data_list = generate_data_list(G_list, 
                               sparse=data_config.sparse, 
                               pivot_mode=data_config.pivot,
                               init_mode=data_config.init,
                               edge_index=data_config.edge.index,
                               edge_attr=data_config.edge.attr,
                               pmds_list=np.load('layouts/rome/pmds.npy', allow_pickle=True),
                               gviz_list=np.load('layouts/rome/gviz.npy', allow_pickle=True),
                               noisy_layout=True,
                               device='cpu', 
                               cache=True,
                               cache_prefix='cache/')
train_loader = LazyDeviceMappingDataLoader(data_list[:10000], batch_size=config.batchsize, shuffle=True, device=device)
val_loader = LazyDeviceMappingDataLoader(data_list[11000:], batch_size=config.batchsize, shuffle=False, device=device)

Load from 'cache/G_list.pickle'
Load from 'cache/generate_data_list(list,sparse=False,pivot_mode=None,init_mode=pmds,edge_index=full_edge_index,edge_attr=full_edge_attr,pmds_list=ndarray,gviz_list=ndarray,noisy_layout=True,device=cpu).pickle'




## Create folders

In [8]:
mkdirs(paths.checkpoints(), paths.tensorboard(), paths.visualization(), f"{paths.root}/logs", f"{paths.root}/metrics")

## Load checkpoints

In [9]:
class GATConvWithEFeat(gnn.GATConv):
    def __init__(self,
                 in_channels,
                 out_channels, 
                 efeat_dim,
                 heads: int = 8,
                 concat: bool = False,
                 negative_slope: float = 0.2,
                 dropout: float = 0.0,
                 add_self_loops: bool = False,
                 bias: bool = True,
                 **kwargs):
        super().__init__(in_channels=in_channels,
                         out_channels=out_channels, 
                         heads=heads,
                         concat=concat,
                         negative_slope=negative_slope,
                         dropout=dropout,
                         add_self_loops=add_self_loops,
                         bias=bias, **kwargs)
        self.att_efeat = nn.Parameter(torch.Tensor(1, heads, efeat_dim))
        gnn.inits.glorot(self.att_efeat)

    def forward(self, x, edge_index, e, size=None, return_attention_weights=None):
        H, C = self.heads, self.out_channels

        # We first transform the input node features. If a tuple is passed, we
        # transform source and target node features via separate weights:
        if isinstance(x, torch.Tensor):
            assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
            x_src = x_dst = self.lin_src(x).view(-1, H, C)
        else:  # Tuple of source and target node features:
            x_src, x_dst = x
            assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
            x_src = self.lin_src(x_src).view(-1, H, C)
            if x_dst is not None:
                x_dst = self.lin_dst(x_dst).view(-1, H, C)

        x = (x_src, x_dst)

        # Next, we compute node-level attention coefficients, both for source
        # and target nodes (if present):
        alpha_src = (x_src * self.att_src).sum(dim=-1)
        alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)
        alpha = (alpha_src, alpha_dst)

        if self.add_self_loops:
            if isinstance(edge_index, torch.Tensor):
                # We only want to add self-loops for nodes that appear both as
                # source and target nodes:
                num_nodes = x_src.size(0)
                if x_dst is not None:
                    num_nodes = min(num_nodes, x_dst.size(0))
                num_nodes = min(size) if size is not None else num_nodes
                edge_index, _ = pyg.utils.remove_self_loops(edge_index)
                edge_index, _ = pyg.utils.add_self_loops(edge_index, num_nodes=num_nodes)
            elif isinstance(edge_index, torch_sparse.SparseTensor):
                edge_index = set_diag(edge_index)

        # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
        out = self.propagate(edge_index, x=x, e=e, alpha=alpha, size=size)

        alpha = self._alpha
        assert alpha is not None
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out += self.bias

        if isinstance(return_attention_weights, bool):
            if isinstance(edge_index, torch.Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, torch_sparse.SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out


    def message(self, x_j, e, alpha_j, alpha_i, index, ptr, size_i):
        # Given egel-level attention coefficients for source and target nodes,
        # we simply need to sum them up to "emulate" concatenation:
        alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
#         print(e.shape, self.att_efeat.shape)
        alpha += (e[:, None, :].repeat(1, self.heads, 1) * self.att_efeat).sum(dim=-1)

        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = pyg.utils.softmax(alpha, index, ptr, size_i)
        self._alpha = alpha  # Save for later use.
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha.unsqueeze(-1)

In [10]:
class GNNLayer(nn.Module):
    def __init__(self,
                 nfeat_dims,
                 efeat_dim,
                 aggr,
                 edge_net=None, 
                 dense=False,
                 bn=True, 
                 act=True, 
                 dp=None,
                 root_weight=True,
                 skip=True):
        super().__init__()
        try:
            in_dim = nfeat_dims[0]
            out_dim = nfeat_dims[1]
        except:
            in_dim = nfeat_dims
            out_dim = nfeat_dims
        self.conv = GATConvWithEFeat(in_dim, out_dim, efeat_dim, aggr=aggr)
        self.dense = nn.Linear(out_dim, out_dim) if dense else nn.Identity()
        self.bn = gnn.BatchNorm(out_dim) if bn else nn.Identity()
        self.act = nn.LeakyReLU() if act else nn.Identity()
        self.dp = dp and nn.Dropout(dp) or nn.Identity()
        self.skip = skip
        self.proj = nn.Linear(in_dim, out_dim, bias=False) if in_dim != out_dim else nn.Identity()
        
    def forward(self, v, e, data):
        v_ = v
        v = self.conv(v, data.edge_index, e)
        v = self.dense(v)
        v = self.bn(v)
        v = self.act(v)
        v = self.dp(v)
        return v + self.proj(v_) if self.skip else v

In [11]:
class GNNBlock(nn.Module):
    def __init__(self, 
                 feat_dims, 
                 efeat_hid_dims=[], 
                 efeat_hid_act=nn.LeakyReLU,
                 efeat_out_act=nn.Tanh,
                 bn=False,
                 act=True,
                 dp=None,
                 aggr='mean',
                 root_weight=True,
                 static_efeats=1,
                 dynamic_efeats='skip',
                 rich_efeats=False,
                 euclidian=False,
                 direction=False,
                 n_weights=0,
                 residual=False):
        '''
        dynamic_efeats: {
            skip: block input to each layer, 
            first: block input to first layer, 
            prev: previous layer output to next layer, 
            orig: original node feature to each layer
        }
        '''
        super().__init__()
        self.static_efeats = static_efeats
        self.dynamic_efeats = dynamic_efeats
        self.rich_efeats = rich_efeats
        self.euclidian = euclidian
        self.direction = direction
        self.n_weights = n_weights
        self.residual = residual
        self.gnn = nn.ModuleList()
        self.n_layers = len(feat_dims) - 1

        for idx, (in_feat, out_feat) in enumerate(zip(feat_dims[:-1], feat_dims[1:])):
            direction_dim = (feat_dims[idx] if self.dynamic_efeats == 'prev'
                             else 2 if self.dynamic_efeats == 'orig'
                             else feat_dims[0])
            in_efeat_dim = self.static_efeats
            if self.dynamic_efeats != 'first': 
                in_efeat_dim += self.euclidian + self.direction * direction_dim + self.n_weights + 3 * self.rich_efeats
            edge_net = nn.Sequential(*chain.from_iterable(
                [nn.Linear(idim, odim),
                 nn.BatchNorm1d(odim),
                 act()]
                for idim, odim, act in zip([in_efeat_dim] + efeat_hid_dims,
                                           efeat_hid_dims + [in_feat * out_feat],
                                           [efeat_hid_act] * len(efeat_hid_dims) + [efeat_out_act])
            ))
            self.gnn.append(GNNLayer(nfeat_dims=(in_feat, out_feat), 
                                     efeat_dim=in_efeat_dim, 
                                     edge_net=edge_net,
                                     bn=bn, 
                                     act=act, 
                                     dp=dp,
                                     aggr=aggr,
                                     root_weight=root_weight,
                                     skip=False))
        
    def _get_edge_feat(self, pos, data, rich_efeats=False, euclidian=False, direction=False, weights=None):
        e = data.edge_attr[:, :self.static_efeats]
        if euclidian or direction:
            start_pos, end_pos = get_edges(pos, data)
            v, u = l2_normalize(end_pos - start_pos, return_norm=True)
            if euclidian:
                e = torch.cat([e, u], dim=1)
            if direction:
                e = torch.cat([e, v], dim=1)
            if rich_efeats:
                d = e[:, :1]
                d2 = d ** 2
#                 d_inv = 1 / d
#                 d2_inv = 1 / d2
                u2 = u ** 2
                ud = u * d
#                 u_inv = 1 / u
#                 u2_inv = 1 / u2
                e = torch.cat([e, d2, u2, ud], dim=1)
        if weights is not None:
            w = weights.repeat(len(e), 1)
            e = torch.cat([e, w], dim=1)
        return e
    
    def _get_dynamic_edge_feat(self, pos, data, rich_efeats=False, euclidian=False, direction=False, weights=None):
        if euclidian or direction:
            start_pos, end_pos = get_edges(pos, data)
            d, u = l2_normalize(end_pos - start_pos, return_norm=True)
            if euclidian and direction:
                e = torch.cat([u, d], dim=1)
            else:
                if euclidian:
                    e = u
                if direction:
                    e = d
        if weights is not None:
            w = weights.repeat(len(e), 1)
            e = torch.cat([e, w], dim=1)
        return e
        
    def forward(self, v, data, weights=None):
        vres = v
        for layer in range(self.n_layers):
            vsrc = (v if self.dynamic_efeats == 'prev' 
                    else data.pos if self.dynamic_efeats == 'orig' 
                    else vres)
            get_extra = not (self.dynamic_efeats == 'first' and layer != 0)
            efeat_fn = self._get_dynamic_edge_feat if self.static_efeats == 0 else self._get_edge_feat

            e = efeat_fn(vsrc, data,
                         rich_efeats=self.rich_efeats and get_extra,
                         euclidian=self.euclidian and get_extra, 
                         direction=self.direction and get_extra,
                         weights=weights if get_extra and self.n_weights > 0 else None)
            v = self.gnn[layer](v, e, data)
        return v + vres if self.residual else v

In [12]:
class Generator(nn.Module):
    def __init__(self, 
                 num_blocks=9, 
                 num_layers=3,
                 num_enet_layers=2,
                 layer_dims=None,
                 n_weights=0, 
                 dynamic_efeats='skip',
                 euclidian=True,
                 direction=True,
                 residual=True,
                 normalize=None):
        super().__init__()

        self.in_blocks = nn.ModuleList([
            GNNBlock(feat_dims=[2, 8, 8 if layer_dims is None else layer_dims[0]], bn=True, dp=0.2, static_efeats=2)
        ])
        self.hid_blocks = nn.ModuleList([
            GNNBlock(feat_dims=layer_dims or ([8] + [8] * num_layers), 
                     efeat_hid_dims=[16] * (num_enet_layers - 1),
                     bn=True, 
                     act=True,
                     dp=0.2, 
                     static_efeats=2,
                     dynamic_efeats=dynamic_efeats,
                     euclidian=euclidian,
                     direction=direction,
                     n_weights=n_weights,
                     residual=residual)
            for _ in range(num_blocks)
        ])
        self.out_blocks = nn.ModuleList([
            GNNBlock(feat_dims=[8 if layer_dims is None else layer_dims[-1], 8], bn=True, static_efeats=2),
            GNNBlock(feat_dims=[8, 2], act=False, static_efeats=2)
        ])
        self.normalize = normalize

    def forward(self, data, weights=None, output_hidden=False, numpy=False):
        v = data.pos if data.pos is not None else generate_rand_pos(len(data.x)).to(data.x.device)
        if self.normalize is not None:
            v = self.normalize(v, data)
        
        hidden = []
        for block in chain(self.in_blocks, 
                           self.hid_blocks, 
                           self.out_blocks):
            v = block(v, data, weights)
            if output_hidden:
                hidden.append(v.detach().cpu().numpy() if numpy else v)
        if not output_hidden:
            vout = v.detach().cpu().numpy() if numpy else v
            if self.normalize is not None:
                vout = self.normalize(vout, data)
        
        return hidden if output_hidden else vout

In [13]:
class StressDiscriminator(nn.Module):
    def __init__(self, normalize=CanonicalizationByStress(), **kwargs):
        super().__init__()
        self.dummy = nn.Parameter(torch.zeros(1))
        self.normalize = normalize
        self.stress = Stress(reduce=None)

    def forward(self, batch):
        return -self.stress(self.normalize(batch.pos, batch), batch)

In [14]:
def get_ckpt_epoch(folder, epoch):
    if not os.path.isdir(folder):
        os.mkdir(folder)
    if epoch >= 0:
        return epoch
    ckpt_files = os.listdir(folder)
    last_epoch = 0
    if ckpt_files:
        last_epoch = sorted(list(map(lambda x: int(re.search('(?<=epoch_)(\d+)(?=\.)', x).group(1)), ckpt_files)))[-1]
    return last_epoch + epoch + 1

def start_epoch():
    return get_ckpt_epoch(paths.checkpoints(), config.epoch.start)

In [15]:
generator = Generator(**config.generator.params[...]).to(device)
generator_optimizer = config.generator.optim(generator.parameters(), lr=config.generator.lr.initial * config.generator.lr.decay ** start_epoch())
generator_scheduler = torch.optim.lr_scheduler.ExponentialLR(generator_optimizer, gamma=config.generator.lr.decay)
if start_epoch() != 0:
    gen_ckpt_epoch = start_epoch()
elif config.generator.pretrained.name is not None and config.generator.pretrained.epoch != 0:
    gen_pretrained_epoch = get_ckpt_epoch(paths.gen_pretrain(), config.generator.pretrained.epoch)
    gen_ckpt_epoch = gen_pretrained_epoch 
else:
    gen_ckpt_epoch = None
if gen_ckpt_epoch is not None:
    # Load generator
    gen_ckpt_file = f"{paths.checkpoints()}/gen_epoch_{gen_ckpt_epoch}.pt"
    print(f"Loading from {gen_ckpt_file}...")
    generator.load_state_dict(torch.load(gen_ckpt_file, map_location=torch.device(device)))
    # Load generator optimizer
    gen_optim_ckpt_file = f"{paths.checkpoints()}/gen_optim_epoch_{gen_ckpt_epoch}.pt"
    print(f"Loading from {gen_optim_ckpt_file}...")
    generator_optimizer.load_state_dict(torch.load(gen_optim_ckpt_file, map_location=torch.device(device)))

Loading from artifacts/checkpoints/deepgd(conv=gat_simple(self_loop=false,heads=8),lr=1e-2)/gen_epoch_1208.pt...
Loading from artifacts/checkpoints/deepgd(conv=gat_simple(self_loop=false,heads=8),lr=1e-2)/gen_optim_epoch_1208.pt...


In [16]:
discriminator = StressDiscriminator(**config.discriminator.params[...]).to(device)
# discriminator_optimizer = config.discriminator.optim(discriminator.parameters(), lr=config.discriminator.lr.initial * config.discriminator.lr.decay ** start_epoch())
# discriminator_scheduler = torch.optim.lr_scheduler.ExponentialLR(discriminator_optimizer, gamma=config.discriminator.lr.decay)
# if start_epoch() != 0:
#     dis_ckpt_epoch = start_epoch()
# elif config.discriminator.pretrained.name is not None and config.discriminator.pretrained.epoch != 0:
#     dis_pretrained_epoch = get_ckpt_epoch(paths.dis_pretrain(), config.discriminator.pretrained.epoch)
#     dis_ckpt_epoch = dis_pretrained_epoch # f"{paths.dis_pretrain()}/dis_epoch_{dis_pretrained_epoch}.pt"
# else:
#     dis_ckpt_epoch = None
# if dis_ckpt_epoch is not None:
#     # Load discriminator
#     dis_ckpt_file = f"{paths.checkpoints()}/dis_epoch_{dis_ckpt_epoch}.pt"
#     print(f"Loading from {dis_ckpt_file}...")
#     discriminator.load_state_dict(torch.load(dis_ckpt_file, map_location=torch.device(device)))
#     # Load discriminator optimizer
#     dis_optim_ckpt_file = f"{paths.checkpoints()}/dis_optim_epoch_{dis_ckpt_epoch}.pt"
#     print(f"Loading from {dis_optim_ckpt_file}...")
#     discriminator_optimizer.load_state_dict(torch.load(dis_optim_ckpt_file, map_location=torch.device(device)))

# Train

In [17]:
stress_criterion = StressDiscriminator().to(device)
val_criterion = Stress(reduce=None)
xing_criterion = Xing(reduce=None)
dis_convert = DiscriminatorDataConverter(complete_graph=config.discriminator.complete, normalize=config.discriminator.params.normalize)
tensorboard = SummaryWriter(log_dir=paths.tensorboard())
epoch = start_epoch() + 1

In [18]:
def gradient_penalty(interpolated, discriminator, weight=10):
    interpolated.pos.requires_grad_()
    prob_interpolated = discriminator(interpolated)
    gradients = autograd.grad(outputs=prob_interpolated, 
                              inputs=interpolated.pos,
                              grad_outputs=torch.ones_like(prob_interpolated),
                              create_graph=True, 
                              retain_graph=True, 
                              allow_unused=True)[0]
    gradients_norm = torch.sqrt(gnn.global_add_pool(gradients.square().sum(dim=1), batch.batch) + 1e-8)
    return weight * ((gradients_norm - 1) ** 2).mean()

In [19]:
def get_gp_loss(batch, fake_pos, weight):
    if weight > 0:
        interp = dis_convert(batch, fake_pos, random.random())
        return gradient_penalty(interp, discriminator, weight).mean()
    return 0

In [20]:
def get_sgan_loss(batch, fake_pos, mode='discriminator'):
    real = dis_convert(batch)
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(merge_batch(real, fake)).view(2, -1).T
    if mode == 'discriminator':
        label = torch.zeros(pred.shape[0]).long()
    elif mode == 'generator':
        label = torch.ones(pred.shape[0]).long()
    else:
        raise Exception
    return nn.CrossEntropyLoss()(pred, label)

In [21]:
def get_rgan_loss(batch, fake_pos, mode='discriminator'):
    real = dis_convert(batch)
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(merge_batch(real, fake)).view(2, -1).T
    real_pred, fake_pred = pred[:,0], pred[:,1]
    if mode == 'discriminator':
        losses = - F.logsigmoid(real_pred - fake_pred)
    elif mode == 'generator':
        losses = - F.logsigmoid(fake_pred - real_pred)
    else:
        raise Exception
    return losses.mean()

In [22]:
def get_wgan_loss(batch, fake_pos, mode='discriminator'):
    real = dis_convert(batch)
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(merge_batch(real, fake)).view(2, -1).T
    real_pred, fake_pred = pred[:,0], pred[:,1]
    if mode == 'discriminator':
        losses = fake_pred - real_pred 
    elif mode == 'generator':
        losses = real_pred - fake_pred
    else:
        raise Exception
    return losses.mean()

In [23]:
def get_ragan_loss(batch, fake_pos, mode='discriminator'):
    real = dis_convert(batch)
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(merge_batch(real, fake)).view(2, -1).T
    real_pred, fake_pred = pred[:,0], pred[:,1]
    if mode == 'discriminator':
        losses = - F.logsigmoid(real_pred - fake_pred.mean()) - F.logsigmoid(real_pred.mean() - fake_pred)
    elif mode == 'generator':
        losses = - F.logsigmoid(fake_pred - real_pred.mean()) - F.logsigmoid(fake_pred.mean() - real_pred)
    else:
        raise Exception
    return losses.mean()

In [24]:
def get_dgdv2_loss(batch, fake_pos, mode='discriminator'):
    fake = dis_convert(batch, fake_pos)
    pred = discriminator(fake)
    if mode == 'discriminator':
        gt = get_gt(batch, fake_pos)
        loss = criterion(pred, gt)
    elif mode == 'generator':
        losses = pred.sum(dim=0)
#         losses = torch.tensor(config.importance).to(device) * losses #/ losses.detach()
        loss = losses.sum()
    else:
        raise Exception
    return -loss

In [25]:
def get_gan_loss(batch, fake_pos, mode='discriminator'):
    return {"sgan": get_sgan_loss,
            "wgan": get_wgan_loss,
            "rgan": get_rgan_loss,
            "ragan": get_ragan_loss,
            "dgdv2": get_dgdv2_loss}[config.gan_flavor](batch, fake_pos, mode)

In [26]:
# def train_dis(batch, epoch):
#     generator.requires_grad_(False)
#     discriminator.zero_grad()
#     generator_output = generator(batch)
#     if config.discriminator.noise.std > 0:
#         generator_output = generator_output + torch.randn_like(generator_output) * config.discriminator.noise.std * config.discriminator.noise.decay ** epoch
#     discriminator_loss = get_gan_loss(batch, generator_output, mode='discriminator')
    
#     # train discriminator
#     discriminator_loss.backward()
#     discriminator_optimizer.step()

#     # gradient penalty
#     if config.gp_weight > 0:
#         discriminator.zero_grad()
#         gp_loss = get_gp_loss(batch, generator_output, config.gp_weight)
#         gp_loss.backward()
#         discriminator_optimizer.step()

#     hud['dis_loss'] = format(discriminator_loss.item(), '.2e')
#     pbar().update()

def train_gen(batch, epoch):
    generator.requires_grad_(True)
    generator.zero_grad()
    discriminator.zero_grad()
    generator_output = generator(batch)
    if config.discriminator.noise.std > 0:
        generator_output = generator_output + torch.randn_like(generator_output) * config.discriminator.noise.std * config.discriminator.noise.decay ** epoch
    generator_loss = get_gan_loss(batch, generator_output, mode='generator') 
    
    #train generator
    generator_loss.backward()
    generator_optimizer.step()

    with torch.no_grad():
        dis_batch = dis_convert(batch, generator_output)
        stress = stress_criterion(dis_batch).mean()
        critic = discriminator(dis_batch).mean()
    hud.append({'gen_loss': format(generator_loss.item(), '.2e'),
                'stress': format(stress.item(), '.2e'),
                'critic': format(critic.item(), '.2e')})
    pbar().update()

def cuda_memsafe_map(fn, *iterables, summary=False):
    total, failed = 0, 0
    iterator = zip(*iterables)
    items = None
    while True:
        try:
            items = next(iterator)
            yield fn(*items)
        except StopIteration:
            if summary:
                print(f'Iteration finished. {failed} out of {total} failed!')
            break
#         except RuntimeError:
#             print('CUDA memory overflow! Skip batch...')
#             del items
#             failed += 1
        torch.cuda.empty_cache()
        total += 1
    
def validate(model, data_loader, criterion=val_criterion):
    def val_one_batch(batch):
        batch = preprocess_batch(model, batch)
        pred = CanonicalizationByStress()(model(batch), batch)
        gt = CanonicalizationByStress()(batch.gt_pos, batch)
        loss = criterion(pred, batch)
        gt_loss = criterion(gt, batch)
        spc = (loss - gt_loss) / torch.maximum(torch.maximum(loss, gt_loss), torch.ones_like(loss)*1e-5)
        return loss.mean().item(), spc.mean().item()
    loss_all, spc_all = zip(*cuda_memsafe_map(val_one_batch, data_loader))
    return np.mean(loss_all), np.mean(spc_all)

def log(msg):
    msg = f"[{epoch:03}] {msg}"
    print(msg, file=open(paths.log(), "a"))
    with log_out: 
        print(msg)

print(f"{'='*10} {config.link} {'='*10}", file=open(paths.log(), "a"))
hud = Hud()
pbar = Wrapper(tqdm, total=len(train_loader), smoothing=0)
plot_out = Output()
log_out = Output()
tabs = {"status": VBox([pbar, hud]), 
        "plot": HBox([plot_out], layout=Layout(height='500px', overflow_y='auto')),
        "log": HBox([log_out], layout=Layout(height='500px', overflow_y='auto'))}
tab_bar = Tab(children=list(tabs.values()))
[tab_bar.set_title(i, name) for i, name in enumerate(tabs)]
display(tab_bar)
while True:
    if epoch % config.log_interval == 0:
        torch.save(generator.state_dict(), f"{paths.checkpoints()}/gen_epoch_{epoch}.pt")
        torch.save(generator_optimizer.state_dict(), f"{paths.checkpoints()}/gen_optim_epoch_{epoch}.pt")
#         torch.save(discriminator.state_dict(), f"{paths.checkpoints()}/dis_epoch_{epoch}.pt")
#         torch.save(discriminator_optimizer.state_dict(), f"{paths.checkpoints()}/dis_optim_epoch_{epoch}.pt")
        generator.eval()
        with torch.no_grad():
            val_stress, val_stress_spc = validate(model=generator, data_loader=val_loader)
            val_xing, val_xing_spc = validate(model=generator, data_loader=val_loader, criterion=xing_criterion)
            with plot_out:
                fig = plt.figure()
                graph_vis(G_list[11100], generator(make_batch(data_list[11100]).to(device)).cpu())
                plt.show()
#         for i, fn in enumerate(loss_fns):
#             tensorboard.add_scalars(type(fn).__name__, {'train': train_loss_comp[i].item(), 
#                                                   'validation': val_loss_comp[i].item()}, epoch)
        hud.append({
            'val_stress': format(val_stress, '.2f'),
            'val_stress_spc': format(val_stress_spc, '.2%'),
            'val_xing': format(val_xing, '.2f'),
            'val_xing_spc': format(val_xing_spc, '.2%'),
        })
        log(msg := f"stress={hud.data['val_stress']}({hud.data['val_stress_spc']}) xing={hud.data['val_xing']}({hud.data['val_xing_spc']})")
    
        tensorboard.add_scalars('Stress', {'Value': val_stress, 'SPC': val_stress_spc}, epoch)
        tensorboard.add_scalars('Xing', {'Value': val_xing, 'SPC': val_xing_spc}, epoch)
        tensorboard.add_figure('Rome/11100', fig, epoch)
        tensorboard.add_text('Log', msg, epoch)
        
    # handle.update(tab_bar)
    pbar().reset()
    pbar().set_description(desc=f"[epoch {epoch}/{config.epoch.end}]")
    hud(title=f"epoch {epoch}")
    generator.train()
#     discriminator.train()
    # proper: proper layout
#     for _ in range(config.discriminator.repeat):
#         for batch in train_loader:
#             train_dis(batch, epoch)
#             if config.alternate == 'iteration':
#                 train_gen(batch, epoch)

    if config.alternate == 'epoch':
        for batch in train_loader:
            train_gen(batch, epoch)

#     discriminator_scheduler.step()
    generator_scheduler.step()

    if epoch == config.epoch.end:
        break
    epoch += 1

Tab(children=(VBox(children=(Wrapper(), Hud())), HBox(children=(Output(),), layout=Layout(height='500px', over…

KeyboardInterrupt: 

In [None]:
data = Data()

In [None]:
data.n = torch.tensor(1)

In [None]:
Batch.from_data_list([data]).to_data_list()

In [None]:
import torch_geometric as pyg

In [None]:
isinstance(pyg.data.Batch.from_data_list([pyg.data.Data()]), pyg.data.Batch)

In [None]:
type(pyg.data.Batch.from_data_list([pyg.data.Data()])) is pyg.data.Batch

In [None]:
preprocess_batch(generator, data_list[0])

# Testdata_list

In [None]:
test_epoch = -1

test_generator = Generator(**config.generator.params[...]).to(device)
test_ckpt_epoch = get_ckpt_epoch(paths.checkpoints(), test_epoch)
test_ckpt_file = f"{paths.checkpoints()}/gen_epoch_{test_ckpt_epoch}.pt"
print(f"Loading from {test_ckpt_file}...")
test_generator.load_state_dict(torch.load(test_ckpt_file, map_location=torch.device(device)))

In [None]:
rotate = RotateByPrincipalComponents()
def test_callback(*, idx, pred, metrics):
    # graph_vis(G_list[idx], pred, file_name=f"{paths.visualization()}/{idx}_{metrics['stress']:.2f}_{metrics['resolution_score']:.2f}.png")
    pred = rotate(torch.tensor(pred), data_list[idx])
    graph_vis(G_list[idx], pred)
    plt.title(f"[pred] idx: {idx}, stress: {metrics['stress']:.2f}({metrics['stress_spc']:.2%}), xing: {metrics['xing']:.2f}({metrics['xing_spc']:.2%})")
    plt.show()
    gt_pos = rotate(data_list[idx].gt_pos, data_list[idx])
    graph_vis(G_list[idx], gt_pos, node_color='orange')
    plt.title(f"[gt] idx: {idx}, stress: {load_ground_truth(idx, 'stress')}, xing: {load_ground_truth(idx, 'xing')}")
    plt.show()
    
test_metrics = test(model=test_generator, 
                    criteria_list=[], 
                    dataset=data_list, 
                    idx_range=range(10000, 11000), 
                    callback=test_callback,
                    gt_file='gt.csv')
pickle.dump(test_metrics, open(paths.metrics("test"), "wb"))

In [None]:
metrics = test_metrics
print('stress:', metrics['stress'].mean())
print('stress_spc:', metrics['stress_spc'].mean())
print('xing:', metrics['xing'].mean())
print('xing_spc:', metrics['xing_spc'].mean())
print('l1_angle:', metrics['l1_angle'].mean())
print('l1_angle_spc:', metrics['l1_angle_spc'].mean())
print('edge:', metrics['edge'].mean())
print('edge_spc:', metrics['edge_spc'].mean())
print('ring:', metrics['ring'].mean())
print('ring_spc:', metrics['ring_spc'].mean())
print('tsne:', metrics['tsne'].mean())
print('tsne_spc:', metrics['tsne_spc'].mean())
print('reso_score:', metrics['resolution_score'].mean())
print('min_angle:', metrics['min_angle'].mean())
pd.DataFrame(map(lambda m: f"{metrics[m].mean().item():.4f}", list(metrics.keys())[:-1])).T

# Large Graph

In [None]:
scalability = pd.read_csv(f"/__artifacts__/data/scalability.csv", index_col="index")
scalability

In [None]:
rescale = CanonicalizationByStress()
stressfn = Stress()
rotate = RotateByPrincipalComponents()

In [None]:
stress_list = []
spc_list = []
pmds_list = np.load("layouts/new_large_graph/pmds.npy", allow_pickle=True)
gviz_list = np.load("layouts/new_large_graph/gviz.npy", allow_pickle=True)
with torch.no_grad():
    for idx, col in tqdm(scalability.iterrows(), total=len(scalability)):
        # if idx not in [406, 516]: continue
        torch.cuda.empty_cache()
        G = load_mtx(col['file'])
        G.remove_edges_from(nx.selfloop_edges(G))
        data = generate_data_list(G, 
                                sparse=data_config.sparse, 
                                pivot_mode=data_config.pivot,
                                init_mode=data_config.init,
                                edge_index=data_config.edge.index,
                                edge_attr=data_config.edge.attr,
                                pmds_list=pmds_list[idx],
                                gviz_list=gviz_list[idx],
                                device=device)
        batch = Batch.from_data_list([data]).to(device)
        # generator.train()
        # generator(batch)
        generator.eval()
        pred = generator(batch)
        pos = rotate(rescale(pred, batch), batch)
        gt = rotate(rescale(batch.gt_pos, batch), batch)
        stress = stressfn(pos, batch).item()
        gt_stress = stressfn(gt, batch).item()
        spc = (stress - gt_stress) / np.maximum(stress, gt_stress)
        stress_list.append(stress)
        spc_list.append(spc)

        np.save(f"/__artifacts__/gan_result/data/scalability_{idx}.npy", pos.cpu().numpy())
        graph_attr = dict(node_size=1, 
                        with_labels=False, 
                        labels=dict(zip(list(G.nodes), map(lambda n: n if type(n) is int else n[1:], list(G.nodes)))),
                        font_color="white", 
                        font_weight="bold",
                        font_size=12,
                        width=0.1)

        # gt_pos = pickle.load(open(f"/__artifacts__/data/scalability_{idx}_gt.pkl", "rb"))

        plt.figure(figsize=[12, 9])
        nx.draw(G, pos=gt.cpu().numpy(), node_color='orange', **graph_attr)
        plt.title(f"neato: large_{idx}")
        plt.axis("equal")
        plt.savefig(f"/__artifacts__/gan_result/output/{idx}_{col['name']}_{col['n']}_{spc}_nx.png", dpi=300)
        plt.show()

        plt.figure(figsize=[12, 9])
        graph_vis(G, pos.cpu().numpy(), **graph_attr)
        plt.title(f"dgd: large_{idx}, spc={spc:.2%}")
        plt.axis("equal")
        plt.savefig(f"/__artifacts__/gan_result/output/{idx}_{col['name']}_{col['n']}_{spc}_dgd.png", dpi=300)
        plt.show()