In [None]:
from ray.rllib.agents import ppo
import ray
import torch as th
from ray.tune.logger import pretty_print
import nclustRL
from nclustRL.utils.helper import transform_obs
from ray.tune import tune, grid_search
from gym.wrappers import TransformObservation
import nclustenv

In [None]:
try:
    ray.init()
except RuntimeError:
    passgrid_searchdd

In [None]:
env = nclustenv.make('BiclusterEnv-v0')

In [None]:
env2 = TransformObservation(env, transform_obs)
obs = env.reset()['state']
obs_flat = env2.reset()['state']

In [None]:
obs.ndata

In [None]:
obs_flat.ndata

In [5]:
import dgl

def transform_obs(n, obs):

    nclusters = n

    state = obs.clone()
    ntypes = state.ntypes

    for n, axis in enumerate(ntypes):
        for i in range(nclusters):
            state.nodes[axis].data[i] = torch.randint(0, 2, (len(state.nodes(axis)),), dtype=torch.bool).to('cuda:0')

    keys = sorted(list(state.nodes[ntypes[0]].data.keys()))
    ndata = {}

    for ntype in ntypes:
        ndata[ntype] = torch.vstack(
            [state.ndata[key][ntype].float() for key in keys]
        ).transpose(0, 1).to('cuda:0')

        state.nodes[ntype].data.clear()
    state.ndata['feat'] = ndata

    return state

In [134]:
import dgl
import dgl.nn.pytorch as dglnn
import torch.nn as nn
from torch.nn import functional as F
from nclustRL.utils.helper import pairwise

class HeteroRelu(nn.ReLU):

    def __init__(self, inplace:bool = False):
        super(HeteroRelu, self).__init__(inplace=inplace)

    def forward(self, inputs):
        
        return {k: super(HeteroRelu, self).forward(v) for k, v in inputs.items()}

class GraphSequential(nn.Sequential):

    def __init__(self, *args):
        super(GraphSequential, self).__init__(*args)

    def forward(self, graph, feat, edge_weight=None):
        for module in self:

            if isinstance(module, dglnn.HeteroGraphConv):

                rel_names = zip(module.mods.keys(), graph.canonical_etypes)
                feat = module(
                    g=graph, 
                    inputs=feat, 
                    mod_kwargs={
                        rel: dict(edge_weight=graph.edges[canonical].data[edge_weight]) 
                        for rel, canonical in rel_names})

            else:
                feat = module(inputs=feat)

        return feat


class RGCN(nn.Module):
    def __init__(self, layers, rel_names):
        super().__init__()

        _layers = []

        for in_feats, out_feats in pairwise(layers): 

            _layers.append(dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(in_feats, out_feats)
                for rel in rel_names}, aggregate='sum'))

            _layers.append(HeteroRelu())

        self._hidden_layers = GraphSequential(*_layers)

            

    def forward(self, graph, feat, edge_weight=None):

        return self._hidden_layers(graph, feat, edge_weight)


class GraphEncoder(nn.Module):
    def __init__(self, n, conv_feats, n_classes, rel_names):
        super().__init__()

        conv_feats.insert(0, n)
        self.rgcn = RGCN(conv_feats, rel_names)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn(g, h, 'w')
        with g.local_scope():
            g.ndata['h'] = h
            hg = 0
            for ntype in h.keys():
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            
            return hg


class HeteroClassifier(nn.Module):
    def __init__(self, n, conv_feats, n_classes, rel_names):
        super().__init__()

        conv_feats.insert(0, n)

        self.rgcn = RGCN(conv_feats, rel_names)
        self.classify = nn.Linear(conv_feats[-1], n_classes)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn(g, h, 'w')
        with g.local_scope():
            g.ndata['h'] = h
            hg = 0
            for ntype in h.keys():
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)

            return self.classify(hg)

In [135]:
import torch
from torch.nn import functional as F
from tqdm import tqdm
from dgl.dataloading import GraphDataLoader

dgl.seed(5)

def test_embedings(graphs):

    batch_size=1
    shuffle=True
    nclasses = 5
    n = 5

    # dataloader = GraphDataLoader(
    #     base,
    #     batch_size=batch_size,
    #     drop_last=False,
    #     shuffle=shuffle)

    etypes = graphs[0].etypes

    model = HeteroClassifier(n, [n*2], nclasses, etypes)
    model = model.cuda()
    opt = torch.optim.Adam(model.parameters())


    for epoch in range(20):
        with tqdm(graphs, unit="batch") as tepoch:
            for batched_graph in tepoch:

                tepoch.set_description(f"Epoch {epoch}")

                # batched_graph = transform_obs(n, batched_graph)
                labels = torch.randint(0, 4, (batch_size,)).to('cuda:0')

                logits = model(batched_graph)
                loss = F.cross_entropy(logits, labels)

                predictions = logits.argmax(dim=1, keepdim=True).squeeze()
                correct = (logits == labels).sum().item()

                opt.zero_grad()
                loss.backward()
                opt.step()

                accuracy = correct / batch_size
                tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)

In [136]:
test_embedings(graphs)

Epoch 0: 100%|██████████| 10/10 [00:00<00:00, 210.81batch/s, accuracy=0, loss=1.65]
Epoch 1: 100%|██████████| 10/10 [00:00<00:00, 219.78batch/s, accuracy=0, loss=1.36]
Epoch 2: 100%|██████████| 10/10 [00:00<00:00, 152.42batch/s, accuracy=0, loss=1.49]
Epoch 3: 100%|██████████| 10/10 [00:00<00:00, 181.30batch/s, accuracy=0, loss=1.57]
Epoch 4: 100%|██████████| 10/10 [00:00<00:00, 114.72batch/s, accuracy=0, loss=1.48]
Epoch 5: 100%|██████████| 10/10 [00:00<00:00, 131.34batch/s, accuracy=0, loss=1.32]
Epoch 6: 100%|██████████| 10/10 [00:00<00:00, 115.82batch/s, accuracy=0, loss=1.74]
Epoch 7: 100%|██████████| 10/10 [00:00<00:00, 142.62batch/s, accuracy=0, loss=1.69]
Epoch 8: 100%|██████████| 10/10 [00:00<00:00, 198.99batch/s, accuracy=0, loss=1.33]
Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 186.54batch/s, accuracy=0, loss=1.63]
Epoch 10: 100%|██████████| 10/10 [00:00<00:00, 192.31batch/s, accuracy=0, loss=1.62]
Epoch 11: 100%|██████████| 10/10 [00:00<00:00, 220.83batch/s, accuracy=0, l

In [None]:
tensor([[0.0000, 1.4752, 3.0812, 1.3285, 0.4708, 0.0000, 0.9958, 0.0000, 2.8396,
         2.7786]], device='cuda:0', grad_fn=<AddBackward0>)

In [97]:
    import torch as th
    import dgl
    def loader(cls, module=None):

        return getattr(module, cls) if isinstance(cls, str) else cls
    
    def dense_to_dgl(x, device, cuda=0, nclusters=1, clust_init='zeros', duplicate=True):

        # set (u,v)
        clust_init = loader(th, clust_init)

        tensor = th.tensor([[i, j, elem] for i, row in enumerate(x) for j, elem in enumerate(row)]).T

        if duplicate:

            graph_data = {
                ('row', 'elem', 'col'): (tensor[0].int(), tensor[1].int()),
                ('col', 'elem', 'row'): (tensor[1].int().detach().clone(), tensor[2].int().detach().clone()),
                }

            # create graph
            G = dgl.heterograph(graph_data)

            # set weights
            G.edges[('row', 'elem', 'col')].data['w'] = tensor[2].float()
            G.edges[('col', 'elem', 'row')].data['w'] = tensor[2].float()

        else:

            graph_data = {
                ('row', 'elem', 'col'): (tensor[0].int(), tensor[1].int()),
                }

            # create graph
            G = dgl.heterograph(graph_data)

            # set weights
            G.edges[('row', 'elem', 'col')].data['w'] = tensor[2].float()

        # set cluster members

        for n, axis in enumerate(['row', 'col']):
            for i in range(nclusters):
                G.nodes[axis].data[i] = th.randint(0, 2, (x.shape[n],), dtype=torch.bool)

        ndata = {}
        ntypes = G.ntypes
        keys = sorted(list(G.nodes[ntypes[0]].data.keys()))

        for ntype in ntypes:
            ndata[ntype] = torch.vstack(
                [G.ndata[key][ntype].float() for key in keys]
            ).transpose(0, 1)

            G.nodes[ntype].data.clear()

        G.ndata['feat'] = ndata

        if device == 'gpu':
            G = G.to('cuda:{}'.format(cuda))

        return G

In [130]:
import nclustenv
import torch
env = nclustenv.make('BiclusterEnv-v0', **dict(shape=[[100, 10], [110, 15]], clusters=[5,5]))

graphs_dup = []
graphs = []
for i in range(10):
    env.reset()
    X = env.state._generator.X
    graphs_dup.append(dense_to_dgl(X, device='gpu', nclusters=5))
    graphs.append(dense_to_dgl(X, device='gpu', nclusters=5, duplicate=False))

In [132]:
graphs_dup[0]

Graph(num_nodes={'col': 14, 'row': 104},
      num_edges={('col', 'elem', 'row'): 1456, ('row', 'elem', 'col'): 1456},
      metagraph=[('col', 'row', 'elem'), ('row', 'col', 'elem')])

In [133]:
graphs[0]

Graph(num_nodes={'col': 14, 'row': 104},
      num_edges={('row', 'elem', 'col'): 1456},
      metagraph=[('row', 'col', 'elem')])