## Test Observation Wrapper

In [2]:
from gym.wrappers import TransformObservation
from nclustRL.utils.helper import transform_obs
import nclustenv

In [3]:
env = nclustenv.make('BiclusterEnv-v0')

In [4]:
env2 = TransformObservation(env, transform_obs)
obs = env.reset()['state']
obs_flat = env2.reset()['state']

In [None]:
obs.ndata

In [None]:
obs_flat.ndata

In [5]:
import dgl

def transform_obs(n, obs):

    nclusters = n

    state = obs.clone()
    ntypes = state.ntypes

    for n, axis in enumerate(ntypes):
        for i in range(nclusters):
            state.nodes[axis].data[i] = torch.randint(0, 2, (len(state.nodes(axis)),), dtype=torch.bool).to('cuda:0')

    keys = sorted(list(state.nodes[ntypes[0]].data.keys()))
    ndata = {}

    for ntype in ntypes:
        ndata[ntype] = torch.vstack(
            [state.ndata[key][ntype].float() for key in keys]
        ).transpose(0, 1).to('cuda:0')

        state.nodes[ntype].data.clear()
    state.ndata['feat'] = ndata

    return state

## Test Embedings

### Explicit

In [4]:
import dgl
import dgl.nn.pytorch as dglnn
import torch.nn as nn
from torch.nn import functional as F
from nclustRL.utils.helper import pairwise

class HeteroRelu(nn.ReLU):

    def __init__(self, inplace:bool = False):
        super(HeteroRelu, self).__init__(inplace=inplace)

    def forward(self, inputs):
        
        return {k: super(HeteroRelu, self).forward(v) for k, v in inputs.items()}

class GraphSequential(nn.Sequential):

    def __init__(self, *args):
        super(GraphSequential, self).__init__(*args)

    def forward(self, graph, feat, edge_weight=None):
        for module in self:

            if isinstance(module, dglnn.HeteroGraphConv):

                rel_names = zip(module.mods.keys(), graph.canonical_etypes)
                feat = module(
                    g=graph, 
                    inputs=feat, 
                    mod_kwargs={
                        rel: dict(edge_weight=graph.edges[canonical].data[edge_weight]) 
                        for rel, canonical in rel_names})

            else:
                feat = module(inputs=feat)

        return feat


class RGCN(nn.Module):
    def __init__(self, layers, rel_names):
        super().__init__()

        _layers = []

        for in_feats, out_feats in pairwise(layers): 

            _layers.append(dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(in_feats, out_feats)
                for rel in rel_names}, aggregate='sum'))

            _layers.append(HeteroRelu())

        self._hidden_layers = GraphSequential(*_layers)

            

    def forward(self, graph, feat, edge_weight=None):

        return self._hidden_layers(graph, feat, edge_weight)


class GraphEncoder(nn.Module):
    def __init__(self, n, conv_feats, n_classes, rel_names):
        super().__init__()

        conv_feats.insert(0, n)
        self.rgcn = RGCN(conv_feats, rel_names)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn(g, h, 'w')
        with g.local_scope():
            g.ndata['h'] = h
            hg = 0
            for ntype in h.keys():
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            
            return hg


class HeteroClassifier(nn.Module):
    def __init__(self, n, conv_feats, n_classes, rel_names):
        super().__init__()

        conv_feats.insert(0, n)

        self.rgcn = RGCN(conv_feats, rel_names)
        self.classify = nn.Linear(conv_feats[-1], n_classes)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn(g, h, 'w')
        print(h['col'])
        with g.local_scope():
            g.ndata['h'] = h
            hg = 0
            for ntype in h.keys():
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)

            return self.classify(hg)

In [5]:
import torch
from torch.nn import functional as F
from tqdm import tqdm
from dgl.dataloading import GraphDataLoader

dgl.seed(5)

def test_embedings(graphs):

    batch_size=1
    shuffle=True
    nclasses = 5
    n = 5

    # dataloader = GraphDataLoader(
    #     base,
    #     batch_size=batch_size,
    #     drop_last=False,
    #     shuffle=shuffle)

    etypes = graphs[0].etypes

    model = HeteroClassifier(n, [n*2], nclasses, etypes)
    model = model.cuda()
    opt = torch.optim.Adam(model.parameters())


    for epoch in range(20):
        with tqdm(graphs, unit="batch") as tepoch:
            for batched_graph in tepoch:

                tepoch.set_description(f"Epoch {epoch}")

                # batched_graph = transform_obs(n, batched_graph)
                labels = torch.randint(0, 4, (batch_size,)).to('cuda:0')

                logits = model(batched_graph)
                loss = F.cross_entropy(logits, labels)

                predictions = logits.argmax(dim=1, keepdim=True).squeeze()
                correct = (logits == labels).sum().item()

                opt.zero_grad()
                loss.backward()
                opt.step()

                accuracy = correct / batch_size
                tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)

In [22]:
test_embedings(graphs)

Epoch 0: 100%|███████| 10/10 [00:00<00:00, 104.07batch/s, accuracy=0, loss=1.68]


tensor([[0.0000, 0.5451, 0.0000, 1.2836, 0.0000, 0.0000, 0.0000, 0.0000, 1.7132,
         0.0000],
        [0.0903, 0.7741, 0.0000, 0.7793, 0.0000, 1.3730, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.5813, 0.0000, 3.5214, 0.9626, 0.0000, 1.2651, 1.6148, 2.7820, 0.0000,
         1.0708],
        [0.0000, 0.4411, 0.0000, 0.4172, 0.1244, 0.0000, 0.5478, 0.0000, 0.5815,
         0.0630],
        [0.9421, 0.0000, 1.8847, 0.0000, 0.0793, 0.2830, 0.8410, 2.1577, 0.0000,
         0.6675],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.2982, 0.0000, 0.0000, 0.2387, 0.4370,
         0.0000],
        [0.6426, 0.0000, 1.9349, 2.3578, 0.0000, 0.9353, 2.2055, 0.8362, 0.0000,
         0.9592],
        [0.0000, 0.6578, 0.0000, 0.0000, 0.4252, 0.0000, 0.0000, 0.0000, 0.4283,
         0.0000],
        [0.0000, 0.2671, 0.0000, 0.0000, 0.3966, 0.3184, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.1446, 0.0000, 0.0000, 0.4191, 0.0000, 0.0000, 0.0000, 0.4445,
         0.0000],
        [0

Epoch 1: 100%|███████| 10/10 [00:00<00:00, 107.98batch/s, accuracy=0, loss=1.48]


tensor([[0.0000, 0.5748, 0.0000, 1.3011, 0.0000, 0.0000, 0.0000, 0.0000, 1.7433,
         0.0000],
        [0.0868, 0.8066, 0.0000, 0.7856, 0.0000, 1.3835, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.5821, 0.0000, 3.4950, 0.9522, 0.0000, 1.2890, 1.5819, 2.8010, 0.0000,
         1.0632],
        [0.0000, 0.4471, 0.0000, 0.4168, 0.1298, 0.0000, 0.5394, 0.0000, 0.5896,
         0.0594],
        [0.9431, 0.0000, 1.8668, 0.0000, 0.0750, 0.2934, 0.8191, 2.1710, 0.0000,
         0.6642],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3069, 0.0000, 0.0000, 0.2423, 0.4420,
         0.0000],
        [0.6439, 0.0000, 1.9162, 2.3539, 0.0000, 0.9449, 2.1766, 0.8506, 0.0000,
         0.9477],
        [0.0000, 0.6748, 0.0000, 0.0000, 0.4399, 0.0000, 0.0000, 0.0000, 0.4450,
         0.0000],
        [0.0000, 0.2897, 0.0000, 0.0000, 0.4171, 0.3299, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.1689, 0.0000, 0.0000, 0.4357, 0.0000, 0.0000, 0.0000, 0.4647,
         0.0000],
        [0

Epoch 2:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.97]

tensor([[0.0000e+00, 5.7571e-01, 0.0000e+00, 1.2938e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.7697e+00, 0.0000e+00],
        [9.2844e-02, 8.0906e-01, 0.0000e+00, 7.8053e-01, 0.0000e+00, 1.3600e+00,
         6.1334e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.5419e+00, 0.0000e+00, 3.4605e+00, 9.5395e-01, 0.0000e+00, 1.3121e+00,
         1.5441e+00, 2.8246e+00, 0.0000e+00, 1.0304e+00],
        [0.0000e+00, 4.4499e-01, 0.0000e+00, 4.0728e-01, 1.3510e-01, 0.0000e+00,
         5.3436e-01, 0.0000e+00, 5.9596e-01, 5.3222e-02],
        [9.1542e-01, 0.0000e+00, 1.8430e+00, 0.0000e+00, 7.9743e-02, 3.1101e-01,
         7.9347e-01, 2.1876e+00, 0.0000e+00, 6.4217e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.1328e-01, 0.0000e+00,
         0.0000e+00, 2.4623e-01, 4.4702e-01, 0.0000e+00],
        [6.1698e-01, 0.0000e+00, 1.8934e+00, 2.3446e+00, 0.0000e+00, 9.4926e-01,
         2.1495e+00, 8.6643e-01, 0.0000e+00, 9.2179e-01],
        [0.0000e+00, 6.7618

Epoch 2: 100%|███████| 10/10 [00:00<00:00, 123.56batch/s, accuracy=0, loss=1.86]


tensor([[0.4544, 0.0000, 1.2457, 0.0000, 0.4017, 0.0000, 0.0000, 1.2276, 0.0218,
         0.0000],
        [1.2215, 0.8133, 0.2727, 0.2253, 0.3151, 0.1347, 0.9934, 0.0000, 0.0000,
         0.3006],
        [1.4112, 0.0000, 2.0099, 1.4137, 0.0000, 1.8318, 1.7613, 1.0789, 0.0000,
         1.0062],
        [0.0000, 0.0000, 0.0000, 0.4612, 0.0000, 0.6829, 0.0000, 0.0000, 0.7291,
         0.0000],
        [0.8372, 0.0000, 2.3556, 0.0000, 0.0000, 1.0642, 0.0000, 1.6608, 0.0000,
         0.2464],
        [0.0000, 0.5772, 0.0000, 0.0000, 0.0000, 0.3792, 0.0000, 0.0000, 0.3955,
         0.0000],
        [1.0190, 0.2439, 0.6875, 0.0000, 0.0443, 0.8716, 0.1489, 0.1917, 0.0000,
         0.1578],
        [0.8277, 0.0000, 1.5838, 0.7144, 0.0000, 0.0918, 1.2197, 1.3900, 0.0000,
         0.5630],
        [0.5388, 0.5781, 0.2006, 0.6306, 0.0000, 1.1697, 0.4644, 0.0000, 0.0000,
         0.2307],
        [1.5553, 0.4895, 0.3625, 0.0000, 1.5434, 0.0000, 0.0000, 1.0115, 0.0000,
         0.0000]], device='c

Epoch 3:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.73]

tensor([[0.0000, 0.5718, 0.0000, 1.2738, 0.0000, 0.0000, 0.0000, 0.0000, 1.7750,
         0.0000],
        [0.1022, 0.8057, 0.0000, 0.7643, 0.0000, 1.3330, 0.0040, 0.0000, 0.0000,
         0.0000],
        [1.4848, 0.0000, 3.4505, 0.9486, 0.0000, 1.3173, 1.5344, 2.8055, 0.0000,
         1.0023],
        [0.0000, 0.4438, 0.0000, 0.3952, 0.1361, 0.0000, 0.5347, 0.0000, 0.5988,
         0.0473],
        [0.8759, 0.0000, 1.8351, 0.0000, 0.0854, 0.3185, 0.7866, 2.1741, 0.0000,
         0.6209],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3149, 0.0000, 0.0000, 0.2428, 0.4480,
         0.0000],
        [0.5798, 0.0000, 1.8870, 2.3236, 0.0000, 0.9460, 2.1449, 0.8550, 0.0000,
         0.9018],
        [0.0000, 0.6749, 0.0000, 0.0000, 0.4427, 0.0000, 0.0000, 0.0000, 0.4609,
         0.0000],
        [0.0000, 0.2912, 0.0000, 0.0000, 0.4191, 0.3026, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.1652, 0.0000, 0.0000, 0.4394, 0.0000, 0.0000, 0.0000, 0.4855,
         0.0000],
        [0

Epoch 3: 100%|███████| 10/10 [00:00<00:00, 115.09batch/s, accuracy=0, loss=1.82]


tensor([[0.5386, 0.0000, 1.7454, 1.6114, 0.0000, 1.2699, 1.7705, 1.3878, 0.0000,
         1.0186],
        [0.0000, 1.0312, 0.0000, 0.0000, 0.3330, 0.0975, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.5571, 0.0000, 0.0000, 0.0000, 0.6310, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.6917, 0.0000, 0.0000, 0.2464, 0.0000, 0.0000, 0.0000, 1.4040,
         0.0000],
        [0.0000, 2.5499, 0.0000, 0.0213, 0.0000, 0.1432, 0.0000, 0.0000, 0.6889,
         0.0000],
        [0.0000, 0.0000, 1.0480, 0.6964, 0.0000, 1.3297, 0.0000, 0.1129, 0.0000,
         0.0992],
        [0.0000, 1.0382, 0.0000, 0.0000, 0.5386, 0.0000, 0.0000, 0.0000, 0.6239,
         0.0000],
        [0.0000, 1.5921, 0.0000, 0.2284, 0.1168, 0.0000, 0.7524, 0.0000, 0.0737,
         0.2757],
        [0.5786, 0.3371, 0.3125, 0.2544, 0.0000, 1.1834, 0.2502, 0.0000, 0.0000,
         0.2361],
        [0.0000, 0.9730, 0.0000, 0.7731, 0.0000, 0.4135, 0.0000, 0.0000, 1.1384,
         0.0000],
        [0

Epoch 4:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.74]

tensor([[0.0000, 0.5859, 0.0000, 1.2626, 0.0000, 0.0000, 0.0000, 0.0000, 1.7885,
         0.0000],
        [0.1070, 0.8137, 0.0000, 0.7540, 0.0000, 1.3221, 0.0101, 0.0000, 0.0000,
         0.0000],
        [1.4377, 0.0000, 3.4162, 0.9282, 0.0000, 1.3054, 1.5179, 2.7937, 0.0000,
         0.9603],
        [0.0000, 0.4496, 0.0000, 0.3878, 0.1394, 0.0000, 0.5349, 0.0000, 0.6005,
         0.0437],
        [0.8440, 0.0000, 1.8113, 0.0000, 0.0862, 0.3132, 0.7737, 2.1657, 0.0000,
         0.5921],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3190, 0.0000, 0.0000, 0.2401, 0.4471,
         0.0000],
        [0.5493, 0.0000, 1.8654, 2.2974, 0.0000, 0.9407, 2.1379, 0.8473, 0.0000,
         0.8744],
        [0.0000, 0.6817, 0.0000, 0.0000, 0.4485, 0.0000, 0.0000, 0.0000, 0.4680,
         0.0000],
        [0.0000, 0.2983, 0.0000, 0.0000, 0.4246, 0.2926, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.1776, 0.0000, 0.0000, 0.4448, 0.0000, 0.0000, 0.0000, 0.4999,
         0.0000],
        [0

Epoch 4:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.57]

tensor([[0.5145, 0.0000, 1.7280, 1.5904, 0.0000, 1.2603, 1.7609, 1.3820, 0.0000,
         0.9977],
        [0.0000, 1.0522, 0.0000, 0.0000, 0.3422, 0.0972, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.5827, 0.0000, 0.0000, 0.0000, 0.6312, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.7213, 0.0000, 0.0000, 0.2584, 0.0000, 0.0000, 0.0000, 1.4267,
         0.0000],
        [0.0000, 2.5979, 0.0000, 0.0344, 0.0000, 0.1568, 0.0000, 0.0000, 0.7310,
         0.0000],
        [0.0000, 0.0000, 1.0432, 0.6856, 0.0000, 1.3189, 0.0000, 0.1121, 0.0000,
         0.0927],
        [0.0000, 1.0579, 0.0000, 0.0000, 0.5482, 0.0000, 0.0000, 0.0000, 0.6391,
         0.0000],
        [0.0000, 1.6154, 0.0000, 0.2303, 0.1216, 0.0000, 0.7554, 0.0000, 0.0904,
         0.2818],
        [0.5752, 0.3445, 0.3120, 0.2495, 0.0000, 1.1766, 0.2499, 0.0000, 0.0000,
         0.2325],
        [0.0000, 1.0072, 0.0000, 0.7756, 0.0000, 0.4198, 0.0000, 0.0000, 1.1677,
         0.0000],
        [0

Epoch 4: 100%|████████| 10/10 [00:00<00:00, 124.61batch/s, accuracy=0, loss=1.8]


tensor([[0.0000, 1.8372, 0.0000, 0.0000, 0.2505, 0.0000, 0.0000, 0.0000, 1.4325,
         0.0000],
        [0.5811, 1.1713, 0.0000, 0.0000, 0.8909, 0.2045, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2553, 0.0000, 0.4716, 1.5197, 0.0000, 0.0677, 1.5311, 0.0000, 0.0981,
         0.4494],
        [0.0000, 0.1220, 0.0000, 0.4109, 0.0000, 0.0000, 0.0000, 0.0000, 0.7276,
         0.0000],
        [0.0000, 1.1594, 0.0000, 1.0596, 0.0000, 0.1496, 0.0000, 0.0000, 0.9830,
         0.0000],
        [1.5868, 0.0000, 1.6465, 0.1716, 0.3859, 0.1234, 1.4652, 1.6777, 0.0000,
         0.6668],
        [0.0000, 1.6765, 0.0000, 0.0000, 0.6797, 0.0000, 0.0000, 0.0000, 1.4519,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.3214, 0.0000, 0.0000, 0.8253, 0.9818,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0621, 0.0000, 0.0000, 0.0565, 0.2197,
         0.0000],
        [0.0000, 0.6842, 0.0000, 0.0000, 0.4551, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0

Epoch 5:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.77]

tensor([[0.0000, 0.6116, 0.0000, 1.2515, 0.0000, 0.0000, 0.0000, 0.0000, 1.7878,
         0.0000],
        [0.1107, 0.8314, 0.0000, 0.7449, 0.0000, 1.3159, 0.0111, 0.0000, 0.0000,
         0.0000],
        [1.4010, 0.0000, 3.4024, 0.9088, 0.0000, 1.2870, 1.5167, 2.7605, 0.0000,
         0.9384],
        [0.0000, 0.4594, 0.0000, 0.3804, 0.1427, 0.0000, 0.5373, 0.0000, 0.5976,
         0.0429],
        [0.8192, 0.0000, 1.8016, 0.0000, 0.0882, 0.3034, 0.7726, 2.1433, 0.0000,
         0.5776],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3240, 0.0000, 0.0000, 0.2351, 0.4452,
         0.0000],
        [0.5256, 0.0000, 1.8567, 2.2726, 0.0000, 0.9347, 2.1406, 0.8256, 0.0000,
         0.8610],
        [0.0000, 0.6950, 0.0000, 0.0000, 0.4546, 0.0000, 0.0000, 0.0000, 0.4679,
         0.0000],
        [0.0000, 0.3136, 0.0000, 0.0000, 0.4313, 0.2836, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.2009, 0.0000, 0.0000, 0.4507, 0.0000, 0.0000, 0.0000, 0.5024,
         0.0000],
        [0

Epoch 5:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.76]

tensor([[0.0000, 0.7868, 0.0000, 0.2185, 0.0000, 0.0000, 0.0000, 0.0000, 1.2939,
         0.0000],
        [0.0000, 3.3562, 0.0000, 0.9143, 0.0000, 0.4680, 0.0000, 0.0000, 1.0766,
         0.0000],
        [2.1106, 0.0000, 3.0871, 0.0000, 1.5160, 0.0000, 0.1870, 3.2593, 0.0000,
         0.0830],
        [0.5846, 0.0000, 3.8837, 0.7469, 0.0000, 0.0000, 0.9152, 3.3150, 0.0000,
         0.3776],
        [0.0000, 1.7798, 0.0000, 0.8575, 0.0000, 0.0000, 1.0642, 0.0000, 0.5510,
         0.3204],
        [1.2584, 0.1922, 0.3006, 0.0000, 0.7586, 0.0000, 1.1966, 0.9348, 0.0000,
         0.5072],
        [0.3583, 0.0000, 1.5398, 0.5820, 0.0000, 0.0000, 0.9210, 1.4415, 0.0357,
         0.3138],
        [0.0000, 0.1838, 0.0000, 1.2686, 0.0000, 0.1687, 1.7118, 0.0000, 0.0000,
         0.7791],
        [0.0000, 0.3262, 0.0000, 0.0000, 0.1020, 0.0000, 0.0000, 0.0000, 0.1251,
         0.0000],
        [1.1166, 0.3242, 0.8419, 0.0000, 1.0829, 0.0000, 0.0000, 0.4329, 0.0000,
         0.0000]], device='c

Epoch 5: 100%|███████| 10/10 [00:00<00:00, 108.23batch/s, accuracy=0, loss=1.63]


tensor([[1.4767, 0.0000, 1.2297, 0.2689, 0.3215, 0.1591, 1.3508, 0.9686, 0.0000,
         0.5094],
        [0.0000, 0.5010, 0.0000, 0.0000, 0.0698, 0.1370, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6719, 0.6325, 0.0000, 0.0000, 0.8319, 0.0334, 0.0000, 0.1096, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0250, 0.1591, 0.2012, 0.0000, 0.3494, 0.1383, 0.3354,
         0.0000],
        [0.0000, 0.5969, 0.0000, 0.1054, 0.0000, 0.0000, 0.0000, 0.0000, 0.8250,
         0.0000],
        [0.0000, 0.0000, 1.0700, 0.0820, 0.0000, 0.3376, 0.0000, 0.7296, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5033, 0.0000, 0.0000, 0.0592, 0.0000, 0.7700, 0.1181,
         0.0000],
        [0.0000, 0.0000, 0.1493, 0.4055, 0.0000, 0.1724, 0.0000, 0.1595, 0.4193,
         0.0000],
        [0.0000, 1.9553, 0.0000, 0.0000, 0.0000, 0.1046, 0.0000, 0.0000, 0.6717,
         0.0000],
        [0.2022, 0.4247, 0.0000, 0.0000, 0.8975, 0.0000, 0.0000, 0.2324, 0.0000,
         0.0000],
        [0

Epoch 6:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.84]

tensor([[0.0000, 0.6321, 0.0000, 1.2352, 0.0000, 0.0000, 0.0000, 0.0000, 1.7865,
         0.0000],
        [0.1164, 0.8481, 0.0000, 0.7318, 0.0000, 1.3143, 0.0099, 0.0000, 0.0000,
         0.0000],
        [1.3591, 0.0000, 3.3940, 0.8933, 0.0000, 1.2594, 1.5289, 2.7186, 0.0000,
         0.9169],
        [0.0000, 0.4672, 0.0000, 0.3707, 0.1445, 0.0000, 0.5418, 0.0000, 0.5976,
         0.0393],
        [0.7906, 0.0000, 1.7947, 0.0000, 0.0910, 0.2878, 0.7813, 2.1149, 0.0000,
         0.5622],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3275, 0.0000, 0.0000, 0.2292, 0.4467,
         0.0000],
        [0.4979, 0.0000, 1.8511, 2.2470, 0.0000, 0.9221, 2.1522, 0.7994, 0.0000,
         0.8460],
        [0.0000, 0.7057, 0.0000, 0.0000, 0.4572, 0.0000, 0.0000, 0.0000, 0.4682,
         0.0000],
        [0.0000, 0.3270, 0.0000, 0.0000, 0.4348, 0.2805, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.2206, 0.0000, 0.0000, 0.4533, 0.0000, 0.0000, 0.0000, 0.5016,
         0.0000],
        [0

Epoch 6:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.74]

tensor([[0.4722, 0.0000, 1.7123, 1.5518, 0.0000, 1.2451, 1.7689, 1.3451, 0.0000,
         0.9731],
        [0.0000, 1.1024, 0.0000, 0.0000, 0.3488, 0.1022, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.6452, 0.0000, 0.0000, 0.0000, 0.6390, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.7881, 0.0000, 0.0000, 0.2652, 0.0000, 0.0000, 0.0000, 1.4270,
         0.0000],
        [0.0000, 2.7077, 0.0000, 0.0149, 0.0000, 0.1968, 0.0000, 0.0000, 0.7252,
         0.0000],
        [0.0000, 0.0000, 1.0423, 0.6633, 0.0000, 1.3027, 0.0000, 0.1049, 0.0000,
         0.0883],
        [0.0000, 1.1019, 0.0000, 0.0000, 0.5573, 0.0000, 0.0000, 0.0000, 0.6399,
         0.0000],
        [0.0000, 1.6687, 0.0000, 0.2115, 0.1214, 0.0000, 0.7581, 0.0000, 0.0857,
         0.2927],
        [0.5720, 0.3667, 0.3163, 0.2337, 0.0000, 1.1690, 0.2469, 0.0000, 0.0000,
         0.2338],
        [0.0000, 1.0851, 0.0000, 0.7485, 0.0000, 0.4428, 0.0000, 0.0000, 1.1649,
         0.0000],
        [0

Epoch 6:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.75]

tensor([[1.4626, 0.0000, 1.2154, 0.2676, 0.3314, 0.1525, 1.3435, 0.9661, 0.0000,
         0.4950],
        [0.0000, 0.5163, 0.0000, 0.0000, 0.0727, 0.1401, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6730, 0.6461, 0.0000, 0.0000, 0.8399, 0.0332, 0.0000, 0.1135, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0177, 0.1550, 0.2077, 0.0000, 0.3485, 0.1366, 0.3385,
         0.0000],
        [0.0000, 0.6167, 0.0000, 0.1018, 0.0000, 0.0000, 0.0000, 0.0000, 0.8278,
         0.0000],
        [0.0000, 0.0000, 1.0601, 0.0785, 0.0000, 0.3334, 0.0000, 0.7285, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.4966, 0.0000, 0.0000, 0.0571, 0.0000, 0.7688, 0.1215,
         0.0000],
        [0.0000, 0.0000, 0.1449, 0.3963, 0.0000, 0.1739, 0.0000, 0.1586, 0.4222,
         0.0000],
        [0.0000, 2.0039, 0.0000, 0.0000, 0.0000, 0.1217, 0.0000, 0.0000, 0.6702,
         0.0000],
        [0.2008, 0.4326, 0.0000, 0.0000, 0.9055, 0.0000, 0.0000, 0.2330, 0.0000,
         0.0000],
        [0

Epoch 6: 100%|███████| 10/10 [00:00<00:00, 115.31batch/s, accuracy=0, loss=1.76]


tensor([[0.0000e+00, 4.2598e-01, 0.0000e+00, 1.5584e+00, 0.0000e+00, 2.4381e-01,
         1.3644e-01, 0.0000e+00, 1.1912e+00, 1.8471e-02],
        [0.0000e+00, 0.0000e+00, 1.3451e+00, 3.1920e-01, 1.5676e-03, 0.0000e+00,
         5.6861e-02, 9.5596e-01, 4.7990e-01, 0.0000e+00],
        [4.5057e-01, 5.8694e-01, 0.0000e+00, 1.2387e+00, 0.0000e+00, 2.9319e-01,
         1.7717e+00, 0.0000e+00, 0.0000e+00, 6.6650e-01],
        [0.0000e+00, 7.7033e-01, 0.0000e+00, 0.0000e+00, 1.2721e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 3.3631e-01, 0.0000e+00],
        [0.0000e+00, 4.5323e-01, 0.0000e+00, 0.0000e+00, 7.3623e-02, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 3.6932e-01, 0.0000e+00],
        [8.7978e-01, 0.0000e+00, 3.1597e+00, 2.5249e-01, 4.7471e-01, 0.0000e+00,
         1.5936e+00, 3.8477e+00, 0.0000e+00, 7.1318e-01],
        [0.0000e+00, 0.0000e+00, 1.2793e+00, 1.2751e+00, 0.0000e+00, 2.5485e-01,
         7.3113e-01, 9.2291e-01, 2.8259e-01, 3.1752e-01],
        [0.0000e+00, 9.8409

Epoch 7:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.73]

tensor([[0.0000e+00, 6.6013e-01, 0.0000e+00, 1.2263e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.7904e+00, 0.0000e+00],
        [1.2408e-01, 8.7317e-01, 0.0000e+00, 7.2763e-01, 0.0000e+00, 1.3187e+00,
         2.6610e-03, 0.0000e+00, 0.0000e+00, 8.4285e-04],
        [1.3322e+00, 0.0000e+00, 3.3628e+00, 8.9019e-01, 0.0000e+00, 1.2500e+00,
         1.5165e+00, 2.7219e+00, 0.0000e+00, 8.8684e-01],
        [0.0000e+00, 4.7272e-01, 0.0000e+00, 3.6337e-01, 1.4772e-01, 0.0000e+00,
         5.4209e-01, 0.0000e+00, 6.0017e-01, 3.6130e-02],
        [7.7122e-01, 0.0000e+00, 1.7734e+00, 0.0000e+00, 1.0194e-01, 2.8163e-01,
         7.7697e-01, 2.1154e+00, 0.0000e+00, 5.4108e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.3350e-01, 0.0000e+00,
         0.0000e+00, 2.2869e-01, 4.5014e-01, 0.0000e+00],
        [4.8151e-01, 0.0000e+00, 1.8316e+00, 2.2376e+00, 0.0000e+00, 9.1805e-01,
         2.1408e+00, 8.0178e-01, 0.0000e+00, 8.2695e-01],
        [0.0000e+00, 7.2057

Epoch 7:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.63]

tensor([[0.0000, 0.8241, 0.0000, 0.1962, 0.0000, 0.0000, 0.0000, 0.0000, 1.2949,
         0.0000],
        [0.0000, 3.4805, 0.0000, 0.8760, 0.0000, 0.5132, 0.0000, 0.0000, 1.0599,
         0.0000],
        [2.0325, 0.0000, 3.0443, 0.0000, 1.5512, 0.0000, 0.1842, 3.2271, 0.0000,
         0.0221],
        [0.4937, 0.0000, 3.8316, 0.7401, 0.0000, 0.0000, 0.9186, 3.2717, 0.0000,
         0.3055],
        [0.0000, 1.8304, 0.0000, 0.8247, 0.0000, 0.0000, 1.0702, 0.0000, 0.5452,
         0.3376],
        [1.2294, 0.1796, 0.2879, 0.0000, 0.7708, 0.0000, 1.1980, 0.9234, 0.0000,
         0.4862],
        [0.3103, 0.0000, 1.5131, 0.5693, 0.0045, 0.0000, 0.9239, 1.4193, 0.0467,
         0.2759],
        [0.0000, 0.1951, 0.0000, 1.2311, 0.0000, 0.1744, 1.7196, 0.0000, 0.0000,
         0.7741],
        [0.0000, 0.3462, 0.0000, 0.0000, 0.1069, 0.0000, 0.0000, 0.0000, 0.1261,
         0.0000],
        [1.0884, 0.3110, 0.8256, 0.0000, 1.1038, 0.0000, 0.0000, 0.4268, 0.0000,
         0.0000]], device='c

Epoch 7:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.59]

tensor([[1.2938, 0.0000, 0.9894, 0.2284, 0.2204, 0.3350, 1.2219, 0.8242, 0.0000,
         0.5070],
        [0.0000, 0.0000, 0.0000, 1.7193, 0.0000, 0.7717, 0.1323, 0.0000, 0.7987,
         0.0791],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1967, 1.0264,
         0.0000],
        [0.0000, 0.2095, 0.0000, 0.6970, 0.0000, 0.0000, 0.8147, 0.0000, 0.3582,
         0.1489],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.1612, 0.0000, 0.0000, 0.0000, 0.8345,
         0.0000],
        [0.3391, 0.2352, 0.0242, 0.2398, 0.0423, 0.0348, 0.6048, 0.0264, 0.0000,
         0.2054],
        [0.4004, 0.0000, 2.3042, 0.5258, 0.0000, 0.7773, 0.2492, 1.4051, 0.0000,
         0.1748],
        [0.0000, 0.0000, 0.0000, 0.7524, 0.0000, 0.4126, 0.0000, 0.0000, 0.4245,
         0.0000],
        [0.0000, 0.0000, 0.9383, 0.0000, 0.0000, 0.0000, 0.0000, 0.5665, 0.3474,
         0.0000],
        [0.8399, 0.0000, 1.2855, 0.4838, 0.1263, 0.0000, 0.7819, 0.5377, 0.0000,
         0.0848],
        [0

Epoch 7: 100%|███████| 10/10 [00:00<00:00, 111.85batch/s, accuracy=0, loss=1.59]
Epoch 8:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.45]

tensor([[0.0000, 0.6816, 0.0000, 1.2112, 0.0000, 0.0000, 0.0000, 0.0000, 1.7797,
         0.0000],
        [0.1322, 0.8907, 0.0000, 0.7173, 0.0000, 1.3160, 0.0000, 0.0000, 0.0000,
         0.0092],
        [1.2954, 0.0000, 3.3784, 0.8780, 0.0000, 1.2213, 1.5276, 2.6978, 0.0000,
         0.8821],
        [0.0000, 0.4771, 0.0000, 0.3532, 0.1476, 0.0000, 0.5457, 0.0000, 0.5988,
         0.0356],
        [0.7452, 0.0000, 1.7838, 0.0000, 0.1029, 0.2652, 0.7867, 2.0977, 0.0000,
         0.5359],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.2249, 0.4495,
         0.0000],
        [0.4580, 0.0000, 1.8419, 2.2155, 0.0000, 0.8958, 2.1488, 0.7867, 0.0000,
         0.8269],
        [0.0000, 0.7320, 0.0000, 0.0000, 0.4634, 0.0000, 0.0000, 0.0000, 0.4662,
         0.0000],
        [0.0000, 0.3679, 0.0000, 0.0000, 0.4393, 0.2895, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.2645, 0.0000, 0.0000, 0.4546, 0.0000, 0.0000, 0.0000, 0.4943,
         0.0000],
        [0

Epoch 8:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.44]

tensor([[0.0000e+00, 8.4577e-01, 0.0000e+00, 1.8527e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.2826e+00, 0.0000e+00],
        [0.0000e+00, 3.5480e+00, 0.0000e+00, 8.5986e-01, 0.0000e+00, 5.4083e-01,
         0.0000e+00, 0.0000e+00, 1.0289e+00, 0.0000e+00],
        [1.9951e+00, 0.0000e+00, 3.0678e+00, 0.0000e+00, 1.5520e+00, 0.0000e+00,
         1.9904e-01, 3.1943e+00, 0.0000e+00, 1.7786e-02],
        [4.5036e-01, 0.0000e+00, 3.8609e+00, 7.2929e-01, 0.0000e+00, 0.0000e+00,
         9.4000e-01, 3.2299e+00, 0.0000e+00, 3.0442e-01],
        [0.0000e+00, 1.8569e+00, 0.0000e+00, 8.0813e-01, 0.0000e+00, 0.0000e+00,
         1.0706e+00, 0.0000e+00, 5.3389e-01, 3.4203e-01],
        [1.2153e+00, 1.7116e-01, 2.9724e-01, 0.0000e+00, 7.7160e-01, 0.0000e+00,
         1.2064e+00, 9.1105e-01, 0.0000e+00, 4.8702e-01],
        [2.8720e-01, 0.0000e+00, 1.5285e+00, 5.5844e-01, 3.4593e-03, 0.0000e+00,
         9.3662e-01, 1.3975e+00, 5.7000e-02, 2.7662e-01],
        [0.0000e+00, 2.0065

Epoch 8: 100%|███████| 10/10 [00:00<00:00, 114.64batch/s, accuracy=0, loss=1.58]

tensor([[1.4204, 0.0000, 1.2307, 0.2544, 0.3358, 0.1297, 1.3546, 0.9450, 0.0000,
         0.4880],
        [0.0000, 0.5386, 0.0000, 0.0000, 0.0736, 0.1408, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6742, 0.6671, 0.0000, 0.0000, 0.8431, 0.0399, 0.0000, 0.1219, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0239, 0.1400, 0.2073, 0.0000, 0.3570, 0.1262, 0.3395,
         0.0000],
        [0.0000, 0.6462, 0.0000, 0.0894, 0.0000, 0.0000, 0.0000, 0.0000, 0.8115,
         0.0000],
        [0.0000, 0.0000, 1.0690, 0.0652, 0.0000, 0.3173, 0.0000, 0.7155, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5024, 0.0000, 0.0000, 0.0500, 0.0000, 0.7591, 0.1198,
         0.0000],
        [0.0000, 0.0000, 0.1496, 0.3734, 0.0000, 0.1646, 0.0000, 0.1515, 0.4174,
         0.0000],
        [0.0000, 2.0815, 0.0000, 0.0000, 0.0000, 0.1463, 0.0000, 0.0000, 0.6356,
         0.0000],
        [0.1963, 0.4460, 0.0000, 0.0000, 0.9059, 0.0000, 0.0000, 0.2333, 0.0000,
         0.0000],
        [0


Epoch 9:   0%|                  | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.6]

tensor([[0.0000, 0.7055, 0.0000, 1.1946, 0.0000, 0.0000, 0.0000, 0.0000, 1.7551,
         0.0000],
        [0.1388, 0.9091, 0.0000, 0.7065, 0.0000, 1.3133, 0.0000, 0.0000, 0.0000,
         0.0156],
        [1.2568, 0.0000, 3.4161, 0.8627, 0.0000, 1.1987, 1.5492, 2.6614, 0.0000,
         0.8832],
        [0.0000, 0.4807, 0.0000, 0.3415, 0.1445, 0.0000, 0.5518, 0.0000, 0.5928,
         0.0409],
        [0.7189, 0.0000, 1.8103, 0.0000, 0.0985, 0.2532, 0.8030, 2.0717, 0.0000,
         0.5353],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3292, 0.0000, 0.0000, 0.2188, 0.4445,
         0.0000],
        [0.4332, 0.0000, 1.8670, 2.1886, 0.0000, 0.8737, 2.1664, 0.7632, 0.0000,
         0.8361],
        [0.0000, 0.7443, 0.0000, 0.0000, 0.4614, 0.0000, 0.0000, 0.0000, 0.4544,
         0.0000],
        [0.0000, 0.3905, 0.0000, 0.0000, 0.4396, 0.2998, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.2863, 0.0000, 0.0000, 0.4538, 0.0000, 0.0000, 0.0000, 0.4781,
         0.0000],
        [0

Epoch 9:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.43]

tensor([[0.3127, 0.0000, 1.2187, 0.0000, 0.4250, 0.0000, 0.0000, 1.1395, 0.0203,
         0.0000],
        [1.1563, 0.8249, 0.2668, 0.1814, 0.3290, 0.1089, 0.9945, 0.0000, 0.0000,
         0.2761],
        [1.2586, 0.0000, 1.9987, 1.2977, 0.0000, 1.7406, 1.7657, 1.0030, 0.0000,
         0.9446],
        [0.0000, 0.0000, 0.0000, 0.3934, 0.0000, 0.6507, 0.0000, 0.0000, 0.7149,
         0.0000],
        [0.6775, 0.0000, 2.3336, 0.0000, 0.0000, 0.9753, 0.0000, 1.5717, 0.0000,
         0.1674],
        [0.0000, 0.7212, 0.0000, 0.0000, 0.0000, 0.3759, 0.0000, 0.0000, 0.3855,
         0.0000],
        [0.9725, 0.2649, 0.6867, 0.0000, 0.0563, 0.8232, 0.1404, 0.1830, 0.0000,
         0.1447],
        [0.6482, 0.0000, 1.5566, 0.6378, 0.0000, 0.0399, 1.2352, 1.2810, 0.0000,
         0.4709],
        [0.5365, 0.6423, 0.2093, 0.5596, 0.0000, 1.1263, 0.4601, 0.0000, 0.0000,
         0.2477],
        [1.4816, 0.4814, 0.3498, 0.0000, 1.5792, 0.0000, 0.0000, 0.9801, 0.0000,
         0.0000]], device='c

Epoch 9:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.48]

tensor([[0.1927, 2.0537, 0.0000, 0.0847, 0.2955, 0.0000, 0.8266, 0.0000, 0.0000,
         0.2888],
        [0.8605, 0.4403, 0.6581, 0.0000, 0.6032, 0.0000, 0.3661, 0.0045, 0.0179,
         0.0000],
        [0.0000, 0.9320, 0.0000, 1.6970, 0.0000, 0.0000, 1.5479, 0.0000, 0.6781,
         0.5109],
        [0.0000, 0.2570, 0.0000, 0.3988, 0.0000, 0.0000, 0.0000, 0.0000, 1.0288,
         0.0000],
        [0.0000, 2.7433, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0133,
         0.0000],
        [0.0000, 0.0000, 0.2853, 1.1973, 0.0000, 0.4244, 0.2928, 0.0000, 0.4631,
         0.1249],
        [0.4806, 0.0000, 2.4912, 0.0000, 0.7510, 0.0000, 0.0000, 2.4471, 0.0542,
         0.0000],
        [0.2940, 0.0000, 1.9701, 0.9586, 0.0000, 0.0000, 0.9334, 1.1217, 0.0116,
         0.2142],
        [0.0000, 0.0000, 0.6412, 0.3801, 0.0000, 0.0000, 0.1854, 0.2229, 0.2633,
         0.0000],
        [0.0000, 0.0000, 0.4649, 0.0000, 0.1957, 0.0000, 0.0000, 0.5554, 0.1353,
         0.0000],
        [0

Epoch 9:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.61]

tensor([[1.2648, 0.0000, 1.0200, 0.2109, 0.2201, 0.3190, 1.2400, 0.7963, 0.0000,
         0.5176],
        [0.0000, 0.0000, 0.0000, 1.6750, 0.0000, 0.7617, 0.1371, 0.0000, 0.7686,
         0.0970],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1881, 1.0036,
         0.0000],
        [0.0000, 0.2084, 0.0000, 0.6713, 0.0000, 0.0000, 0.8283, 0.0000, 0.3542,
         0.1614],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.1544, 0.0000, 0.0000, 0.0000, 0.8176,
         0.0000],
        [0.3278, 0.2384, 0.0359, 0.2216, 0.0409, 0.0265, 0.6153, 0.0161, 0.0000,
         0.2152],
        [0.3621, 0.0000, 2.3496, 0.5047, 0.0000, 0.7556, 0.2685, 1.3629, 0.0000,
         0.1816],
        [0.0000, 0.0000, 0.0000, 0.7261, 0.0000, 0.4055, 0.0000, 0.0000, 0.4037,
         0.0000],
        [0.0000, 0.0000, 0.9592, 0.0000, 0.0000, 0.0000, 0.0000, 0.5474, 0.3394,
         0.0000],
        [0.8067, 0.0000, 1.3197, 0.4653, 0.1232, 0.0000, 0.7994, 0.5079, 0.0000,
         0.0949],
        [0

Epoch 9: 100%|███████| 10/10 [00:00<00:00, 113.14batch/s, accuracy=0, loss=1.42]


tensor([[0.0000, 0.4992, 0.0000, 1.4947, 0.0000, 0.2387, 0.1413, 0.0000, 1.1457,
         0.0477],
        [0.0000, 0.0000, 1.3870, 0.2960, 0.0000, 0.0000, 0.0772, 0.9103, 0.4931,
         0.0000],
        [0.4337, 0.5917, 0.0000, 1.1871, 0.0000, 0.2711, 1.7912, 0.0000, 0.0000,
         0.6887],
        [0.0000, 0.8313, 0.0000, 0.0000, 0.1258, 0.0000, 0.0000, 0.0000, 0.2989,
         0.0000],
        [0.0000, 0.4958, 0.0000, 0.0000, 0.0724, 0.0000, 0.0000, 0.0000, 0.3388,
         0.0000],
        [0.7430, 0.0000, 3.2696, 0.2226, 0.4690, 0.0000, 1.6494, 3.7214, 0.0000,
         0.6856],
        [0.0000, 0.0000, 1.3234, 1.2226, 0.0000, 0.2237, 0.7583, 0.8737, 0.2972,
         0.3222],
        [0.0000, 1.0222, 0.0000, 0.0000, 0.5922, 0.0000, 0.0000, 0.0000, 0.7099,
         0.0000],
        [0.0000, 0.5840, 0.0000, 0.5039, 0.0000, 0.0000, 0.4367, 0.0000, 0.5759,
         0.0666],
        [0.1154, 1.1495, 0.0000, 0.9774, 0.0000, 0.4570, 1.0580, 0.0000, 0.0000,
         0.3800],
        [0

Epoch 10:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.85]

tensor([[0.0000, 0.7330, 0.0000, 1.1816, 0.0000, 0.0000, 0.0000, 0.0000, 1.7306,
         0.0000],
        [0.1405, 0.9317, 0.0000, 0.6952, 0.0000, 1.3020, 0.0000, 0.0000, 0.0000,
         0.0254],
        [1.2388, 0.0000, 3.4527, 0.8469, 0.0000, 1.2009, 1.5541, 2.6315, 0.0000,
         0.8904],
        [0.0000, 0.4856, 0.0000, 0.3344, 0.1432, 0.0000, 0.5549, 0.0000, 0.5891,
         0.0462],
        [0.7069, 0.0000, 1.8349, 0.0000, 0.1029, 0.2585, 0.8048, 2.0512, 0.0000,
         0.5383],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3299, 0.0000, 0.0000, 0.2149, 0.4413,
         0.0000],
        [0.4212, 0.0000, 1.8915, 2.1655, 0.0000, 0.8682, 2.1754, 0.7431, 0.0000,
         0.8496],
        [0.0000, 0.7590, 0.0000, 0.0000, 0.4612, 0.0000, 0.0000, 0.0000, 0.4419,
         0.0000],
        [0.0000, 0.4146, 0.0000, 0.0000, 0.4411, 0.2991, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.3103, 0.0000, 0.0000, 0.4522, 0.0000, 0.0000, 0.0000, 0.4633,
         0.0000],
        [0

Epoch 10:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.57]

tensor([[0.1956, 2.0773, 0.0000, 0.0816, 0.2907, 0.0000, 0.8340, 0.0000, 0.0000,
         0.2983],
        [0.8536, 0.4349, 0.6662, 0.0000, 0.6070, 0.0000, 0.3648, 0.0032, 0.0147,
         0.0000],
        [0.0000, 0.9469, 0.0000, 1.6838, 0.0000, 0.0000, 1.5565, 0.0000, 0.6752,
         0.5236],
        [0.0000, 0.2725, 0.0000, 0.3925, 0.0000, 0.0000, 0.0000, 0.0000, 1.0179,
         0.0000],
        [0.0000, 2.7946, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9910,
         0.0000],
        [0.0000, 0.0000, 0.2917, 1.1844, 0.0000, 0.4206, 0.2961, 0.0000, 0.4598,
         0.1312],
        [0.4675, 0.0000, 2.5152, 0.0000, 0.7644, 0.0000, 0.0000, 2.4348, 0.0641,
         0.0000],
        [0.2834, 0.0000, 1.9895, 0.9467, 0.0000, 0.0000, 0.9293, 1.1109, 0.0174,
         0.2169],
        [0.0000, 0.0000, 0.6500, 0.3727, 0.0000, 0.0000, 0.1847, 0.2186, 0.2616,
         0.0000],
        [0.0000, 0.0000, 0.4721, 0.0000, 0.2002, 0.0000, 0.0000, 0.5519, 0.1338,
         0.0000],
        [0

Epoch 10:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.56]

tensor([[0.4050, 0.0000, 1.7616, 1.4774, 0.0000, 1.2017, 1.7934, 1.2887, 0.0000,
         0.9807],
        [0.0000, 1.2101, 0.0000, 0.0000, 0.3487, 0.1200, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.7722, 0.0000, 0.0000, 0.0000, 0.6463, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.9213, 0.0000, 0.0000, 0.2564, 0.0000, 0.0000, 0.0000, 1.3562,
         0.0000],
        [0.0000, 2.9188, 0.0000, 0.0000, 0.0000, 0.2293, 0.0000, 0.0000, 0.6387,
         0.0000],
        [0.0000, 0.0000, 1.0552, 0.6220, 0.0000, 1.2765, 0.0000, 0.0965, 0.0000,
         0.1008],
        [0.0000, 1.1862, 0.0000, 0.0000, 0.5541, 0.0000, 0.0000, 0.0000, 0.6001,
         0.0000],
        [0.0000, 1.7540, 0.0000, 0.1736, 0.1135, 0.0000, 0.7700, 0.0000, 0.0583,
         0.3232],
        [0.5705, 0.4104, 0.3183, 0.2040, 0.0000, 1.1522, 0.2486, 0.0000, 0.0000,
         0.2533],
        [0.0000, 1.2469, 0.0000, 0.6970, 0.0000, 0.4544, 0.0000, 0.0000, 1.0873,
         0.0000],
        [0

Epoch 10: 100%|███████| 10/10 [00:00<00:00, 98.73batch/s, accuracy=0, loss=1.73]


tensor([[0.3911, 0.0000, 0.8203, 0.7361, 0.1624, 0.0000, 1.9418, 1.4888, 0.0000,
         0.7757],
        [0.0000, 0.0000, 0.0000, 0.4885, 0.0000, 1.2865, 0.0000, 0.0000, 0.1212,
         0.0000],
        [0.0000, 0.5615, 0.0000, 0.3560, 0.0000, 0.1167, 0.0000, 0.0000, 1.0353,
         0.0000],
        [0.2251, 0.0000, 1.2695, 0.8411, 0.0000, 0.2928, 0.7333, 0.5010, 0.0000,
         0.2152],
        [0.0000, 0.2635, 0.0000, 0.0000, 0.1018, 0.0000, 0.0000, 0.0000, 0.2855,
         0.0000],
        [0.8543, 0.0000, 3.8023, 0.1930, 0.0064, 0.1078, 0.7708, 3.2493, 0.0000,
         0.4023],
        [1.0804, 0.6990, 0.0000, 0.0061, 0.1161, 0.7210, 1.0543, 0.1659, 0.0000,
         0.6093],
        [0.0000, 0.3982, 0.0000, 0.0000, 0.0000, 1.2658, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.0107, 0.0000, 0.0000, 0.0288, 0.0000, 0.0000, 0.0000, 0.7847,
         0.0000],
        [0.0000, 0.6177, 0.0000, 0.7222, 0.0000, 0.6672, 0.3719, 0.0000, 0.0000,
         0.2516],
        [0

Epoch 11:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.58]

tensor([[0.0000, 0.7556, 0.0000, 1.1812, 0.0000, 0.0000, 0.0000, 0.0000, 1.7393,
         0.0000],
        [0.1376, 0.9500, 0.0000, 0.6926, 0.0000, 1.2967, 0.0116, 0.0000, 0.0000,
         0.0279],
        [1.2300, 0.0000, 3.4414, 0.8354, 0.0000, 1.1958, 1.5158, 2.6522, 0.0000,
         0.8783],
        [0.0000, 0.4910, 0.0000, 0.3325, 0.1473, 0.0000, 0.5537, 0.0000, 0.5932,
         0.0465],
        [0.7019, 0.0000, 1.8276, 0.0000, 0.1172, 0.2571, 0.7767, 2.0648, 0.0000,
         0.5309],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3386, 0.0000, 0.0000, 0.2178, 0.4444,
         0.0000],
        [0.4142, 0.0000, 1.8849, 2.1545, 0.0000, 0.8650, 2.1556, 0.7555, 0.0000,
         0.8432],
        [0.0000, 0.7704, 0.0000, 0.0000, 0.4683, 0.0000, 0.0000, 0.0000, 0.4462,
         0.0000],
        [0.0000, 0.4316, 0.0000, 0.0000, 0.4478, 0.2953, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.3295, 0.0000, 0.0000, 0.4547, 0.0000, 0.0000, 0.0000, 0.4706,
         0.0000],
        [0

Epoch 11:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.52]

tensor([[0.1946, 2.0980, 0.0000, 0.0830, 0.2868, 0.0000, 0.8469, 0.0000, 0.0000,
         0.3034],
        [0.8474, 0.4325, 0.6580, 0.0000, 0.6219, 0.0000, 0.3552, 0.0136, 0.0118,
         0.0000],
        [0.0000, 0.9612, 0.0000, 1.6808, 0.0000, 0.0000, 1.5655, 0.0000, 0.6834,
         0.5279],
        [0.0000, 0.2849, 0.0000, 0.3918, 0.0000, 0.0000, 0.0000, 0.0000, 1.0258,
         0.0000],
        [0.0000, 2.8357, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0106,
         0.0000],
        [0.0000, 0.0000, 0.2910, 1.1795, 0.0000, 0.4200, 0.2942, 0.0000, 0.4628,
         0.1312],
        [0.4618, 0.0000, 2.4980, 0.0000, 0.7932, 0.0000, 0.0000, 2.4557, 0.0541,
         0.0000],
        [0.2773, 0.0000, 1.9773, 0.9375, 0.0000, 0.0000, 0.9058, 1.1260, 0.0098,
         0.2058],
        [0.0000, 0.0000, 0.6446, 0.3682, 0.0000, 0.0000, 0.1756, 0.2255, 0.2610,
         0.0000],
        [0.0000, 0.0000, 0.4678, 0.0000, 0.2109, 0.0000, 0.0000, 0.5573, 0.1347,
         0.0000],
        [0

Epoch 11:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.61]

tensor([[0.0000e+00, 9.0339e-01, 0.0000e+00, 1.7322e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.2682e+00, 0.0000e+00],
        [0.0000e+00, 3.7367e+00, 0.0000e+00, 8.4633e-01, 0.0000e+00, 5.3655e-01,
         0.0000e+00, 0.0000e+00, 9.8851e-01, 0.0000e+00],
        [1.9283e+00, 0.0000e+00, 3.1084e+00, 0.0000e+00, 1.5988e+00, 0.0000e+00,
         1.5494e-01, 3.1713e+00, 0.0000e+00, 0.0000e+00],
        [3.7800e-01, 0.0000e+00, 3.9195e+00, 6.8948e-01, 0.0000e+00, 0.0000e+00,
         9.0044e-01, 3.1888e+00, 0.0000e+00, 2.7637e-01],
        [0.0000e+00, 1.9274e+00, 0.0000e+00, 7.8602e-01, 0.0000e+00, 0.0000e+00,
         1.0966e+00, 0.0000e+00, 5.2683e-01, 3.7092e-01],
        [1.1887e+00, 1.4542e-01, 3.1267e-01, 0.0000e+00, 7.8632e-01, 0.0000e+00,
         1.1992e+00, 9.0090e-01, 0.0000e+00, 4.8592e-01],
        [2.4788e-01, 0.0000e+00, 1.5593e+00, 5.2907e-01, 2.1235e-02, 0.0000e+00,
         9.2124e-01, 1.3760e+00, 6.9572e-02, 2.6879e-01],
        [0.0000e+00, 2.1175

Epoch 11:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.55]

tensor([[1.3907, 0.0000, 1.2388, 0.2338, 0.3595, 0.1261, 1.3376, 0.9429, 0.0000,
         0.4819],
        [0.0000, 0.5778, 0.0000, 0.0000, 0.0781, 0.1382, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6716, 0.7032, 0.0000, 0.0000, 0.8594, 0.0406, 0.0000, 0.1284, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0281, 0.1260, 0.2218, 0.0000, 0.3501, 0.1252, 0.3410,
         0.0000],
        [0.0000, 0.6982, 0.0000, 0.0815, 0.0000, 0.0000, 0.0000, 0.0000, 0.7999,
         0.0000],
        [0.0000, 0.0000, 1.0773, 0.0477, 0.0000, 0.3148, 0.0000, 0.7097, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5093, 0.0000, 0.0000, 0.0515, 0.0000, 0.7522, 0.1190,
         0.0000],
        [0.0000, 0.0000, 0.1568, 0.3543, 0.0000, 0.1633, 0.0000, 0.1435, 0.4163,
         0.0000],
        [0.0000, 2.2054, 0.0000, 0.0000, 0.0000, 0.1440, 0.0000, 0.0000, 0.6218,
         0.0000],
        [0.1929, 0.4669, 0.0000, 0.0000, 0.9228, 0.0000, 0.0000, 0.2359, 0.0000,
         0.0000],
        [0

Epoch 11: 100%|██████| 10/10 [00:00<00:00, 106.74batch/s, accuracy=0, loss=1.54]


tensor([[3.8132e-01, 0.0000e+00, 8.0412e-01, 7.2892e-01, 1.8128e-01, 0.0000e+00,
         1.9258e+00, 1.4942e+00, 1.1993e-03, 7.5903e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.8359e-01, 0.0000e+00, 1.2923e+00,
         0.0000e+00, 0.0000e+00, 1.2934e-01, 0.0000e+00],
        [0.0000e+00, 5.8712e-01, 0.0000e+00, 3.5452e-01, 0.0000e+00, 1.2146e-01,
         0.0000e+00, 0.0000e+00, 1.0454e+00, 0.0000e+00],
        [2.1989e-01, 0.0000e+00, 1.2586e+00, 8.3553e-01, 0.0000e+00, 2.9689e-01,
         7.2436e-01, 5.0612e-01, 0.0000e+00, 2.0702e-01],
        [0.0000e+00, 2.7557e-01, 0.0000e+00, 0.0000e+00, 1.0642e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 2.8899e-01, 0.0000e+00],
        [8.3969e-01, 0.0000e+00, 3.7721e+00, 1.8477e-01, 3.9269e-02, 1.0565e-01,
         7.4257e-01, 3.2616e+00, 0.0000e+00, 3.7735e-01],
        [1.0780e+00, 7.0683e-01, 0.0000e+00, 2.6236e-03, 1.1791e-01, 7.2571e-01,
         1.0553e+00, 1.6523e-01, 0.0000e+00, 6.0805e-01],
        [0.0000e+00, 4.2339

Epoch 12:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.56]

tensor([[0.0000, 0.7834, 0.0000, 1.1787, 0.0000, 0.0000, 0.0000, 0.0000, 1.7485,
         0.0000],
        [0.1402, 0.9770, 0.0000, 0.6894, 0.0000, 1.3086, 0.0193, 0.0000, 0.0000,
         0.0353],
        [1.2181, 0.0000, 3.4183, 0.8259, 0.0000, 1.1998, 1.4982, 2.6572, 0.0000,
         0.8604],
        [0.0000, 0.4965, 0.0000, 0.3300, 0.1537, 0.0000, 0.5512, 0.0000, 0.5988,
         0.0433],
        [0.6933, 0.0000, 1.8117, 0.0000, 0.1362, 0.2567, 0.7639, 2.0678, 0.0000,
         0.5173],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3482, 0.0000, 0.0000, 0.2195, 0.4493,
         0.0000],
        [0.4065, 0.0000, 1.8702, 2.1445, 0.0000, 0.8724, 2.1449, 0.7591, 0.0000,
         0.8316],
        [0.0000, 0.7860, 0.0000, 0.0000, 0.4733, 0.0000, 0.0000, 0.0000, 0.4502,
         0.0000],
        [0.0000, 0.4524, 0.0000, 0.0000, 0.4470, 0.2987, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.3521, 0.0000, 0.0000, 0.4537, 0.0000, 0.0000, 0.0000, 0.4770,
         0.0000],
        [0

Epoch 12:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.7]

tensor([[0.2942, 0.0000, 1.2120, 0.0000, 0.4673, 0.0000, 0.0000, 1.1483, 0.0213,
         0.0000],
        [1.1425, 0.8359, 0.2631, 0.1700, 0.3448, 0.1080, 0.9891, 0.0000, 0.0000,
         0.2753],
        [1.2360, 0.0000, 1.9994, 1.2640, 0.0000, 1.7412, 1.7485, 1.0013, 0.0000,
         0.9434],
        [0.0000, 0.0000, 0.0000, 0.3798, 0.0000, 0.6489, 0.0000, 0.0000, 0.7106,
         0.0000],
        [0.6567, 0.0000, 2.3306, 0.0000, 0.0000, 0.9742, 0.0000, 1.5755, 0.0000,
         0.1528],
        [0.0000, 0.8053, 0.0000, 0.0000, 0.0000, 0.3746, 0.0000, 0.0000, 0.3816,
         0.0000],
        [0.9623, 0.2844, 0.6861, 0.0000, 0.0681, 0.8219, 0.1389, 0.1869, 0.0000,
         0.1469],
        [0.6245, 0.0000, 1.5516, 0.6146, 0.0000, 0.0411, 1.1991, 1.2856, 0.0000,
         0.4547],
        [0.5317, 0.6857, 0.2124, 0.5430, 0.0000, 1.1249, 0.4749, 0.0000, 0.0000,
         0.2607],
        [1.4684, 0.4791, 0.3425, 0.0000, 1.6210, 0.0000, 0.0000, 0.9921, 0.0000,
         0.0000]], device='c

Epoch 12:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.53]

tensor([[0.3957, 0.0000, 1.7395, 1.4598, 0.0000, 1.2118, 1.7675, 1.2963, 0.0000,
         0.9644],
        [0.0000, 1.2621, 0.0000, 0.0000, 0.3441, 0.1259, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.8400, 0.0000, 0.0000, 0.0000, 0.6590, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.9889, 0.0000, 0.0000, 0.2495, 0.0000, 0.0000, 0.0000, 1.3807,
         0.0000],
        [0.0000, 3.0183, 0.0000, 0.0000, 0.0000, 0.2417, 0.0000, 0.0000, 0.6696,
         0.0000],
        [0.0000, 0.0000, 1.0505, 0.6138, 0.0000, 1.2895, 0.0000, 0.0977, 0.0000,
         0.1010],
        [0.0000, 1.2254, 0.0000, 0.0000, 0.5576, 0.0000, 0.0000, 0.0000, 0.6179,
         0.0000],
        [0.0000, 1.7968, 0.0000, 0.1693, 0.1046, 0.0000, 0.7893, 0.0000, 0.0734,
         0.3320],
        [0.5687, 0.4385, 0.3192, 0.1982, 0.0000, 1.1653, 0.2544, 0.0000, 0.0000,
         0.2567],
        [0.0000, 1.3233, 0.0000, 0.6948, 0.0000, 0.4679, 0.0000, 0.0000, 1.1137,
         0.0000],
        [0

Epoch 12: 100%|██████| 10/10 [00:00<00:00, 123.84batch/s, accuracy=0, loss=1.52]


tensor([[1.2483, 0.0000, 1.0041, 0.1969, 0.2475, 0.3251, 1.2134, 0.8083, 0.0000,
         0.5034],
        [0.0000, 0.0000, 0.0000, 1.6551, 0.0000, 0.7717, 0.1604, 0.0000, 0.7822,
         0.1137],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1894, 1.0179,
         0.0000],
        [0.0000, 0.2149, 0.0000, 0.6586, 0.0000, 0.0000, 0.8193, 0.0000, 0.3615,
         0.1559],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.1694, 0.0000, 0.0000, 0.0000, 0.8358,
         0.0000],
        [0.3205, 0.2496, 0.0304, 0.2112, 0.0552, 0.0314, 0.6082, 0.0205, 0.0000,
         0.2117],
        [0.3445, 0.0000, 2.3278, 0.4868, 0.0000, 0.7620, 0.2267, 1.3782, 0.0000,
         0.1612],
        [0.0000, 0.0000, 0.0000, 0.7139, 0.0000, 0.4123, 0.0000, 0.0000, 0.4118,
         0.0000],
        [0.0000, 0.0000, 0.9483, 0.0000, 0.0000, 0.0000, 0.0000, 0.5567, 0.3425,
         0.0000],
        [0.7885, 0.0000, 1.2986, 0.4534, 0.1593, 0.0000, 0.7662, 0.5274, 0.0000,
         0.0749],
        [0

Epoch 13:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.82]

tensor([[0.0000, 0.8084, 0.0000, 1.1776, 0.0000, 0.0000, 0.0000, 0.0000, 1.7557,
         0.0000],
        [0.1435, 1.0030, 0.0000, 0.6893, 0.0000, 1.3191, 0.0200, 0.0000, 0.0000,
         0.0358],
        [1.2132, 0.0000, 3.4152, 0.8189, 0.0000, 1.2185, 1.4908, 2.6491, 0.0000,
         0.8621],
        [0.0000, 0.5007, 0.0000, 0.3241, 0.1590, 0.0000, 0.5527, 0.0000, 0.6016,
         0.0421],
        [0.6891, 0.0000, 1.8102, 0.0000, 0.1397, 0.2656, 0.7595, 2.0611, 0.0000,
         0.5181],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3541, 0.0000, 0.0000, 0.2177, 0.4513,
         0.0000],
        [0.4031, 0.0000, 1.8674, 2.1348, 0.0000, 0.8870, 2.1415, 0.7558, 0.0000,
         0.8315],
        [0.0000, 0.8008, 0.0000, 0.0000, 0.4800, 0.0000, 0.0000, 0.0000, 0.4547,
         0.0000],
        [0.0000, 0.4748, 0.0000, 0.0000, 0.4529, 0.3018, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.3736, 0.0000, 0.0000, 0.4602, 0.0000, 0.0000, 0.0000, 0.4807,
         0.0000],
        [0

Epoch 13:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.52]

tensor([[0.1986, 2.1409, 0.0000, 0.0744, 0.2884, 0.0000, 0.8568, 0.0000, 0.0000,
         0.3038],
        [0.8390, 0.4354, 0.6497, 0.0000, 0.6430, 0.0000, 0.3471, 0.0183, 0.0172,
         0.0000],
        [0.0000, 0.9858, 0.0000, 1.6621, 0.0000, 0.0000, 1.5719, 0.0000, 0.6909,
         0.5247],
        [0.0000, 0.3123, 0.0000, 0.3871, 0.0000, 0.0000, 0.0000, 0.0000, 1.0358,
         0.0000],
        [0.0000, 2.9265, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0235,
         0.0000],
        [0.0000, 0.0000, 0.2880, 1.1671, 0.0000, 0.4329, 0.2926, 0.0000, 0.4672,
         0.1291],
        [0.4417, 0.0000, 2.4798, 0.0000, 0.8281, 0.0000, 0.0000, 2.4528, 0.0555,
         0.0000],
        [0.2631, 0.0000, 1.9630, 0.9255, 0.0000, 0.0000, 0.8912, 1.1252, 0.0114,
         0.1971],
        [0.0000, 0.0000, 0.6380, 0.3613, 0.0000, 0.0000, 0.1691, 0.2260, 0.2656,
         0.0000],
        [0.0000, 0.0000, 0.4630, 0.0000, 0.2252, 0.0000, 0.0000, 0.5558, 0.1384,
         0.0000],
        [0

Epoch 13:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.75]

tensor([[0.3905, 0.0000, 1.7442, 1.4468, 0.0000, 1.2209, 1.7688, 1.2898, 0.0000,
         0.9685],
        [0.0000, 1.2841, 0.0000, 0.0000, 0.3495, 0.1238, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.8670, 0.0000, 0.0000, 0.0000, 0.6570, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 2.0157, 0.0000, 0.0000, 0.2582, 0.0000, 0.0000, 0.0000, 1.3875,
         0.0000],
        [0.0000, 3.0579, 0.0000, 0.0000, 0.0000, 0.2329, 0.0000, 0.0000, 0.6713,
         0.0000],
        [0.0000, 0.0000, 1.0503, 0.6107, 0.0000, 1.2968, 0.0000, 0.0967, 0.0000,
         0.1018],
        [0.0000, 1.2416, 0.0000, 0.0000, 0.5640, 0.0000, 0.0000, 0.0000, 0.6198,
         0.0000],
        [0.0000, 1.8131, 0.0000, 0.1666, 0.1102, 0.0000, 0.7890, 0.0000, 0.0735,
         0.3264],
        [0.5689, 0.4498, 0.3180, 0.1966, 0.0000, 1.1688, 0.2542, 0.0000, 0.0000,
         0.2559],
        [0.0000, 1.3541, 0.0000, 0.6980, 0.0000, 0.4653, 0.0000, 0.0000, 1.1162,
         0.0000],
        [0

Epoch 13:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.44]

tensor([[0.0000, 2.2374, 0.0000, 0.0000, 0.2292, 0.0000, 0.0000, 0.0000, 1.3804,
         0.0000],
        [0.6270, 1.3617, 0.0000, 0.0000, 0.9155, 0.2265, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1792, 0.0000, 0.4740, 1.4074, 0.0000, 0.0469, 1.5411, 0.0000, 0.1097,
         0.4367],
        [0.0000, 0.2245, 0.0000, 0.3402, 0.0000, 0.0000, 0.0178, 0.0000, 0.7201,
         0.0000],
        [0.0000, 1.4366, 0.0000, 0.9667, 0.0000, 0.1877, 0.0000, 0.0000, 0.9490,
         0.0000],
        [1.4351, 0.0000, 1.6489, 0.1121, 0.4484, 0.0824, 1.4500, 1.5952, 0.0000,
         0.6077],
        [0.0000, 1.9325, 0.0000, 0.0000, 0.6839, 0.0000, 0.0000, 0.0000, 1.4280,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.3991, 0.0000, 0.0000, 0.7862, 1.0027,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0886, 0.0000, 0.0000, 0.0452, 0.2185,
         0.0000],
        [0.0000, 0.8339, 0.0000, 0.0000, 0.4782, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0

Epoch 13:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.41]

tensor([[1.2431, 0.0000, 1.0066, 0.1895, 0.2522, 0.3273, 1.2135, 0.8056, 0.0000,
         0.5066],
        [0.0000, 0.0000, 0.0000, 1.6492, 0.0000, 0.7730, 0.1609, 0.0000, 0.7840,
         0.1083],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1877, 1.0174,
         0.0000],
        [0.0000, 0.2164, 0.0000, 0.6520, 0.0000, 0.0000, 0.8190, 0.0000, 0.3633,
         0.1557],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.1726, 0.0000, 0.0000, 0.0000, 0.8346,
         0.0000],
        [0.3180, 0.2523, 0.0313, 0.2064, 0.0589, 0.0317, 0.6082, 0.0192, 0.0000,
         0.2116],
        [0.3394, 0.0000, 2.3314, 0.4773, 0.0000, 0.7697, 0.2269, 1.3746, 0.0000,
         0.1684],
        [0.0000, 0.0000, 0.0000, 0.7097, 0.0000, 0.4135, 0.0000, 0.0000, 0.4138,
         0.0000],
        [0.0000, 0.0000, 0.9498, 0.0000, 0.0000, 0.0000, 0.0000, 0.5547, 0.3425,
         0.0000],
        [0.7832, 0.0000, 1.3010, 0.4446, 0.1660, 0.0000, 0.7662, 0.5258, 0.0000,
         0.0801],
        [0

Epoch 13: 100%|███████| 10/10 [00:00<00:00, 132.57batch/s, accuracy=0, loss=1.7]


tensor([[0.0000, 0.5837, 0.0000, 1.4768, 0.0000, 0.2466, 0.1769, 0.0000, 1.1747,
         0.0594],
        [0.0000, 0.0000, 1.3651, 0.2754, 0.0493, 0.0000, 0.0336, 0.9287, 0.4926,
         0.0000],
        [0.4224, 0.6087, 0.0000, 1.1609, 0.0000, 0.2798, 1.7881, 0.0000, 0.0000,
         0.6829],
        [0.0000, 0.9074, 0.0000, 0.0000, 0.1317, 0.0000, 0.0000, 0.0000, 0.3176,
         0.0000],
        [0.0000, 0.5538, 0.0000, 0.0000, 0.0869, 0.0000, 0.0000, 0.0000, 0.3540,
         0.0000],
        [0.6966, 0.0000, 3.2243, 0.1788, 0.5657, 0.0000, 1.5391, 3.7497, 0.0000,
         0.6472],
        [0.0000, 0.0000, 1.3089, 1.1887, 0.0000, 0.2362, 0.7238, 0.8808, 0.2955,
         0.3107],
        [0.0000, 1.0726, 0.0000, 0.0000, 0.6164, 0.0000, 0.0000, 0.0000, 0.7309,
         0.0000],
        [0.0000, 0.6227, 0.0000, 0.4912, 0.0000, 0.0000, 0.4434, 0.0000, 0.5936,
         0.0662],
        [0.1131, 1.2099, 0.0000, 0.9601, 0.0000, 0.4648, 1.0751, 0.0000, 0.0000,
         0.3820],
        [0

Epoch 14:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.59]

tensor([[0.0000, 0.8202, 0.0000, 1.1747, 0.0000, 0.0000, 0.0000, 0.0000, 1.7718,
         0.0000],
        [0.1423, 1.0122, 0.0000, 0.6845, 0.0000, 1.3123, 0.0253, 0.0000, 0.0000,
         0.0313],
        [1.1998, 0.0000, 3.4144, 0.8021, 0.0000, 1.2212, 1.4841, 2.6558, 0.0000,
         0.8653],
        [0.0000, 0.5040, 0.0000, 0.3223, 0.1650, 0.0000, 0.5492, 0.0000, 0.6059,
         0.0388],
        [0.6804, 0.0000, 1.8100, 0.0000, 0.1483, 0.2689, 0.7536, 2.0651, 0.0000,
         0.5200],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3610, 0.0000, 0.0000, 0.2187, 0.4534,
         0.0000],
        [0.3936, 0.0000, 1.8664, 2.1189, 0.0000, 0.8864, 2.1359, 0.7608, 0.0000,
         0.8295],
        [0.0000, 0.8076, 0.0000, 0.0000, 0.4855, 0.0000, 0.0000, 0.0000, 0.4626,
         0.0000],
        [0.0000, 0.4848, 0.0000, 0.0000, 0.4530, 0.2991, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.3844, 0.0000, 0.0000, 0.4624, 0.0000, 0.0000, 0.0000, 0.4911,
         0.0000],
        [0

Epoch 14:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.36]

tensor([[1.9875e-01, 2.1502e+00, 0.0000e+00, 7.8934e-02, 2.8697e-01, 0.0000e+00,
         8.5924e-01, 0.0000e+00, 1.0094e-02, 2.9700e-01],
        [8.3562e-01, 4.3486e-01, 6.4292e-01, 0.0000e+00, 6.5792e-01, 0.0000e+00,
         3.3929e-01, 2.8149e-02, 1.8011e-02, 0.0000e+00],
        [0.0000e+00, 9.9180e-01, 0.0000e+00, 1.6614e+00, 0.0000e+00, 0.0000e+00,
         1.5700e+00, 0.0000e+00, 7.0131e-01, 5.1568e-01],
        [0.0000e+00, 3.1825e-01, 0.0000e+00, 3.8593e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.0488e+00, 0.0000e+00],
        [0.0000e+00, 2.9449e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.0532e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.8787e-01, 1.1608e+00, 0.0000e+00, 4.3118e-01,
         2.8992e-01, 0.0000e+00, 4.7201e-01, 1.2507e-01],
        [4.3401e-01, 0.0000e+00, 2.4684e+00, 0.0000e+00, 8.5248e-01, 0.0000e+00,
         0.0000e+00, 2.4740e+00, 4.0144e-02, 0.0000e+00],
        [2.5593e-01, 0.0000

Epoch 14:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.56]

tensor([[0.0000, 2.2525, 0.0000, 0.0000, 0.2313, 0.0000, 0.0000, 0.0000, 1.4130,
         0.0000],
        [0.6299, 1.3694, 0.0000, 0.0000, 0.9214, 0.2216, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1739, 0.0000, 0.4704, 1.4021, 0.0000, 0.0450, 1.5323, 0.0000, 0.1073,
         0.4303],
        [0.0000, 0.2281, 0.0000, 0.3376, 0.0000, 0.0000, 0.0190, 0.0000, 0.7298,
         0.0000],
        [0.0000, 1.4460, 0.0000, 0.9655, 0.0000, 0.1772, 0.0000, 0.0000, 0.9699,
         0.0000],
        [1.4250, 0.0000, 1.6426, 0.1071, 0.4575, 0.0862, 1.4353, 1.6119, 0.0000,
         0.6019],
        [0.0000, 1.9423, 0.0000, 0.0000, 0.6900, 0.0000, 0.0000, 0.0000, 1.4493,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.4157, 0.0000, 0.0000, 0.7939, 1.0040,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0944, 0.0000, 0.0000, 0.0471, 0.2235,
         0.0000],
        [0.0000, 0.8401, 0.0000, 0.0000, 0.4842, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0

Epoch 14:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.42]

tensor([[1.2362, 0.0000, 1.0030, 0.1862, 0.2577, 0.3274, 1.2041, 0.8154, 0.0000,
         0.5015],
        [0.0000, 0.0000, 0.0000, 1.6423, 0.0000, 0.7650, 0.1642, 0.0000, 0.7964,
         0.1036],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1884, 1.0267,
         0.0000],
        [0.0000, 0.2167, 0.0000, 0.6483, 0.0000, 0.0000, 0.8146, 0.0000, 0.3648,
         0.1516],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.1772, 0.0000, 0.0000, 0.0000, 0.8406,
         0.0000],
        [0.3149, 0.2533, 0.0298, 0.2036, 0.0630, 0.0295, 0.6043, 0.0226, 0.0000,
         0.2078],
        [0.3313, 0.0000, 2.3290, 0.4706, 0.0000, 0.7720, 0.2157, 1.3881, 0.0000,
         0.1641],
        [0.0000, 0.0000, 0.0000, 0.7053, 0.0000, 0.4087, 0.0000, 0.0000, 0.4226,
         0.0000],
        [0.0000, 0.0000, 0.9501, 0.0000, 0.0000, 0.0000, 0.0000, 0.5604, 0.3466,
         0.0000],
        [0.7771, 0.0000, 1.2977, 0.4413, 0.1743, 0.0000, 0.7557, 0.5374, 0.0000,
         0.0763],
        [0

Epoch 14: 100%|██████| 10/10 [00:00<00:00, 137.50batch/s, accuracy=0, loss=1.52]


tensor([[1.3731, 0.0000, 1.2315, 0.2195, 0.3824, 0.1342, 1.3190, 0.9534, 0.0000,
         0.4773],
        [0.0000, 0.6086, 0.0000, 0.0000, 0.0884, 0.1390, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6729, 0.7337, 0.0000, 0.0000, 0.8788, 0.0416, 0.0000, 0.1258, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0257, 0.1152, 0.2418, 0.0000, 0.3419, 0.1298, 0.3464,
         0.0000],
        [0.0000, 0.7375, 0.0000, 0.0790, 0.0000, 0.0000, 0.0000, 0.0000, 0.8229,
         0.0000],
        [0.0000, 0.0000, 1.0749, 0.0364, 0.0000, 0.3259, 0.0000, 0.7133, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5098, 0.0000, 0.0000, 0.0592, 0.0000, 0.7511, 0.1228,
         0.0000],
        [0.0000, 0.0000, 0.1572, 0.3405, 0.0000, 0.1693, 0.0000, 0.1422, 0.4236,
         0.0000],
        [0.0000, 2.2934, 0.0000, 0.0000, 0.0000, 0.1326, 0.0000, 0.0000, 0.6531,
         0.0000],
        [0.1902, 0.4834, 0.0000, 0.0000, 0.9455, 0.0000, 0.0000, 0.2340, 0.0000,
         0.0000],
        [0

Epoch 15:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.47]

tensor([[0.0000, 0.8341, 0.0000, 1.1694, 0.0000, 0.0000, 0.0000, 0.0000, 1.7816,
         0.0000],
        [0.1481, 1.0236, 0.0000, 0.6837, 0.0000, 1.3073, 0.0233, 0.0000, 0.0000,
         0.0272],
        [1.1860, 0.0000, 3.4239, 0.7962, 0.0000, 1.2254, 1.4690, 2.6711, 0.0000,
         0.8640],
        [0.0000, 0.5059, 0.0000, 0.3169, 0.1674, 0.0000, 0.5504, 0.0000, 0.6056,
         0.0397],
        [0.6694, 0.0000, 1.8176, 0.0000, 0.1479, 0.2720, 0.7453, 2.0747, 0.0000,
         0.5199],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3647, 0.0000, 0.0000, 0.2196, 0.4543,
         0.0000],
        [0.3847, 0.0000, 1.8714, 2.1103, 0.0000, 0.8837, 2.1261, 0.7710, 0.0000,
         0.8284],
        [0.0000, 0.8160, 0.0000, 0.0000, 0.4929, 0.0000, 0.0000, 0.0000, 0.4676,
         0.0000],
        [0.0000, 0.4988, 0.0000, 0.0000, 0.4625, 0.3000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.3973, 0.0000, 0.0000, 0.4696, 0.0000, 0.0000, 0.0000, 0.4961,
         0.0000],
        [0

Epoch 15:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.44]

tensor([[0.2023, 2.1634, 0.0000, 0.0735, 0.2920, 0.0000, 0.8624, 0.0000, 0.0059,
         0.2977],
        [0.8319, 0.4309, 0.6502, 0.0000, 0.6619, 0.0000, 0.3365, 0.0293, 0.0198,
         0.0000],
        [0.0000, 0.9984, 0.0000, 1.6481, 0.0000, 0.0000, 1.5743, 0.0000, 0.6966,
         0.5185],
        [0.0000, 0.3273, 0.0000, 0.3799, 0.0000, 0.0000, 0.0000, 0.0000, 1.0502,
         0.0000],
        [0.0000, 2.9759, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0531,
         0.0000],
        [0.0000, 0.0000, 0.2904, 1.1511, 0.0000, 0.4270, 0.2907, 0.0000, 0.4721,
         0.1257],
        [0.4188, 0.0000, 2.4924, 0.0000, 0.8500, 0.0000, 0.0000, 2.4758, 0.0423,
         0.0000],
        [0.2453, 0.0000, 1.9698, 0.9085, 0.0000, 0.0000, 0.8744, 1.1434, 0.0035,
         0.1981],
        [0.0000, 0.0000, 0.6409, 0.3513, 0.0000, 0.0000, 0.1619, 0.2335, 0.2676,
         0.0000],
        [0.0000, 0.0000, 0.4667, 0.0000, 0.2358, 0.0000, 0.0000, 0.5605, 0.1404,
         0.0000],
        [0

Epoch 15:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.63]

tensor([[0.0000, 2.2933, 0.0000, 0.0000, 0.2403, 0.0000, 0.0000, 0.0000, 1.4107,
         0.0000],
        [0.6334, 1.3891, 0.0000, 0.0000, 0.9296, 0.2271, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1695, 0.0000, 0.4769, 1.3911, 0.0000, 0.0394, 1.5351, 0.0000, 0.1061,
         0.4352],
        [0.0000, 0.2379, 0.0000, 0.3307, 0.0000, 0.0000, 0.0223, 0.0000, 0.7272,
         0.0000],
        [0.0000, 1.4718, 0.0000, 0.9579, 0.0000, 0.1735, 0.0000, 0.0000, 0.9668,
         0.0000],
        [1.4162, 0.0000, 1.6569, 0.1009, 0.4563, 0.0840, 1.4343, 1.6094, 0.0000,
         0.6063],
        [0.0000, 1.9680, 0.0000, 0.0000, 0.6959, 0.0000, 0.0000, 0.0000, 1.4461,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.4175, 0.0000, 0.0000, 0.7899, 1.0026,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0962, 0.0000, 0.0000, 0.0460, 0.2226,
         0.0000],
        [0.0000, 0.8554, 0.0000, 0.0000, 0.4904, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0

Epoch 15:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.68]

tensor([[1.3685, 0.0000, 1.2398, 0.2121, 0.3840, 0.1322, 1.3193, 0.9511, 0.0000,
         0.4804],
        [0.0000, 0.6205, 0.0000, 0.0000, 0.0928, 0.1414, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6737, 0.7456, 0.0000, 0.0000, 0.8862, 0.0474, 0.0000, 0.1257, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0309, 0.1091, 0.2433, 0.0000, 0.3437, 0.1277, 0.3449,
         0.0000],
        [0.0000, 0.7529, 0.0000, 0.0758, 0.0000, 0.0000, 0.0000, 0.0000, 0.8223,
         0.0000],
        [0.0000, 0.0000, 1.0808, 0.0301, 0.0000, 0.3271, 0.0000, 0.7113, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5143, 0.0000, 0.0000, 0.0610, 0.0000, 0.7490, 0.1206,
         0.0000],
        [0.0000, 0.0000, 0.1600, 0.3340, 0.0000, 0.1687, 0.0000, 0.1408, 0.4213,
         0.0000],
        [0.0000, 2.3320, 0.0000, 0.0000, 0.0000, 0.1356, 0.0000, 0.0000, 0.6541,
         0.0000],
        [0.1897, 0.4913, 0.0000, 0.0000, 0.9504, 0.0000, 0.0000, 0.2326, 0.0000,
         0.0000],
        [0

Epoch 15:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.68]

tensor([[0.0000, 0.6163, 0.0000, 1.4621, 0.0000, 0.2344, 0.1895, 0.0000, 1.1846,
         0.0584],
        [0.0000, 0.0000, 1.3793, 0.2600, 0.0548, 0.0000, 0.0265, 0.9346, 0.4874,
         0.0000],
        [0.4163, 0.6109, 0.0000, 1.1465, 0.0000, 0.2702, 1.7883, 0.0000, 0.0000,
         0.6828],
        [0.0000, 0.9376, 0.0000, 0.0000, 0.1429, 0.0000, 0.0000, 0.0000, 0.3282,
         0.0000],
        [0.0000, 0.5756, 0.0000, 0.0000, 0.0985, 0.0000, 0.0000, 0.0000, 0.3629,
         0.0000],
        [0.6614, 0.0000, 3.2564, 0.1516, 0.5625, 0.0000, 1.5157, 3.7660, 0.0000,
         0.6526],
        [0.0000, 0.0000, 1.3202, 1.1664, 0.0000, 0.2318, 0.7179, 0.8867, 0.2890,
         0.3107],
        [0.0000, 1.0927, 0.0000, 0.0000, 0.6293, 0.0000, 0.0000, 0.0000, 0.7373,
         0.0000],
        [0.0000, 0.6370, 0.0000, 0.4816, 0.0000, 0.0000, 0.4482, 0.0000, 0.5967,
         0.0675],
        [0.1142, 1.2290, 0.0000, 0.9508, 0.0000, 0.4567, 1.0796, 0.0000, 0.0000,
         0.3803],
        [0

Epoch 15:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.7]

tensor([[0.3495, 0.0000, 0.8168, 0.6879, 0.1940, 0.0000, 1.9067, 1.4980, 0.0000,
         0.7622],
        [0.0000, 0.0000, 0.0000, 0.4718, 0.0000, 1.3000, 0.0000, 0.0000, 0.1435,
         0.0000],
        [0.0000, 0.6689, 0.0000, 0.3492, 0.0000, 0.1202, 0.0000, 0.0000, 1.0739,
         0.0000],
        [0.2035, 0.0000, 1.2628, 0.8060, 0.0000, 0.3052, 0.7129, 0.5118, 0.0000,
         0.2071],
        [0.0000, 0.3112, 0.0000, 0.0000, 0.1263, 0.0000, 0.0000, 0.0000, 0.3039,
         0.0000],
        [0.7949, 0.0000, 3.7925, 0.1417, 0.0594, 0.1300, 0.7008, 3.2724, 0.0000,
         0.3868],
        [1.0707, 0.7284, 0.0000, 0.0000, 0.1267, 0.7266, 1.0538, 0.1643, 0.0000,
         0.6017],
        [0.0000, 0.5073, 0.0000, 0.0000, 0.0000, 1.2794, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.1143, 0.0000, 0.0000, 0.0392, 0.0000, 0.0000, 0.0000, 0.8205,
         0.0000],
        [0.0000, 0.6882, 0.0000, 0.7022, 0.0000, 0.6710, 0.3890, 0.0000, 0.0111,
         0.2485],
        [0

Epoch 15: 100%|████████| 10/10 [00:00<00:00, 93.80batch/s, accuracy=0, loss=1.5]
Epoch 16:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.53]

tensor([[0.0000, 0.8572, 0.0000, 1.1634, 0.0000, 0.0000, 0.0000, 0.0000, 1.7830,
         0.0000],
        [0.1488, 1.0423, 0.0000, 0.6803, 0.0000, 1.3144, 0.0289, 0.0000, 0.0000,
         0.0252],
        [1.1800, 0.0000, 3.4319, 0.7793, 0.0000, 1.2233, 1.4660, 2.6667, 0.0000,
         0.8642],
        [0.0000, 0.5114, 0.0000, 0.3119, 0.1699, 0.0000, 0.5532, 0.0000, 0.6055,
         0.0427],
        [0.6650, 0.0000, 1.8238, 0.0000, 0.1467, 0.2694, 0.7413, 2.0709, 0.0000,
         0.5210],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3673, 0.0000, 0.0000, 0.2176, 0.4524,
         0.0000],
        [0.3804, 0.0000, 1.8764, 2.0914, 0.0000, 0.8797, 2.1300, 0.7684, 0.0000,
         0.8305],
        [0.0000, 0.8294, 0.0000, 0.0000, 0.4993, 0.0000, 0.0000, 0.0000, 0.4684,
         0.0000],
        [0.0000, 0.5199, 0.0000, 0.0000, 0.4690, 0.3095, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.4196, 0.0000, 0.0000, 0.4755, 0.0000, 0.0000, 0.0000, 0.4993,
         0.0000],
        [0

Epoch 16:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.42]

tensor([[0.2749, 0.0000, 1.2236, 0.0000, 0.4907, 0.0000, 0.0000, 1.1516, 0.0119,
         0.0000],
        [1.1329, 0.8443, 0.2622, 0.1550, 0.3659, 0.1072, 0.9860, 0.0000, 0.0000,
         0.2735],
        [1.2154, 0.0000, 2.0004, 1.2265, 0.0000, 1.7547, 1.7348, 1.0080, 0.0000,
         0.9384],
        [0.0000, 0.0000, 0.0000, 0.3700, 0.0000, 0.6581, 0.0000, 0.0000, 0.7312,
         0.0000],
        [0.6377, 0.0000, 2.3383, 0.0000, 0.0000, 0.9968, 0.0000, 1.5800, 0.0000,
         0.1553],
        [0.0000, 0.8901, 0.0000, 0.0000, 0.0000, 0.3771, 0.0000, 0.0000, 0.4098,
         0.0000],
        [0.9574, 0.3039, 0.6829, 0.0000, 0.0835, 0.8310, 0.1352, 0.1881, 0.0000,
         0.1428],
        [0.5985, 0.0000, 1.5627, 0.5788, 0.0000, 0.0484, 1.1822, 1.2916, 0.0000,
         0.4592],
        [0.5323, 0.7256, 0.2047, 0.5302, 0.0000, 1.1293, 0.4784, 0.0000, 0.0000,
         0.2520],
        [1.4587, 0.4831, 0.3467, 0.0000, 1.6600, 0.0000, 0.0000, 0.9906, 0.0000,
         0.0000]], device='c

Epoch 16:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.72]

tensor([[0.0000, 0.9760, 0.0000, 0.1617, 0.0000, 0.0000, 0.0000, 0.0000, 1.2973,
         0.0000],
        [0.0000, 3.9857, 0.0000, 0.8549, 0.0000, 0.5222, 0.0000, 0.0000, 1.0580,
         0.0000],
        [1.8776, 0.0000, 3.1066, 0.0000, 1.6758, 0.0000, 0.0956, 3.1904, 0.0000,
         0.0000],
        [0.3180, 0.0000, 3.9195, 0.6275, 0.0000, 0.0000, 0.8354, 3.2114, 0.0000,
         0.2736],
        [0.0000, 2.0183, 0.0000, 0.7677, 0.0000, 0.0000, 1.1215, 0.0000, 0.5535,
         0.3636],
        [1.1651, 0.1153, 0.3103, 0.0000, 0.8209, 0.0000, 1.1826, 0.9052, 0.0000,
         0.4806],
        [0.2141, 0.0000, 1.5589, 0.4877, 0.0536, 0.0000, 0.8903, 1.3872, 0.0513,
         0.2648],
        [0.0000, 0.2206, 0.0000, 1.1355, 0.0000, 0.1579, 1.7424, 0.0000, 0.0040,
         0.7896],
        [0.0000, 0.4352, 0.0000, 0.0000, 0.1334, 0.0000, 0.0000, 0.0000, 0.1249,
         0.0000],
        [1.0278, 0.2608, 0.8378, 0.0000, 1.1905, 0.0000, 0.0000, 0.4333, 0.0000,
         0.0000]], device='c

Epoch 16:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.42]

tensor([[0.0000, 2.3383, 0.0000, 0.0000, 0.2511, 0.0000, 0.0000, 0.0000, 1.4226,
         0.0000],
        [0.6334, 1.4117, 0.0000, 0.0000, 0.9390, 0.2315, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1689, 0.0000, 0.4794, 1.3800, 0.0000, 0.0396, 1.5396, 0.0000, 0.1040,
         0.4375],
        [0.0000, 0.2491, 0.0000, 0.3287, 0.0000, 0.0000, 0.0239, 0.0000, 0.7289,
         0.0000],
        [0.0000, 1.5004, 0.0000, 0.9587, 0.0000, 0.1730, 0.0000, 0.0000, 0.9758,
         0.0000],
        [1.4158, 0.0000, 1.6612, 0.0906, 0.4568, 0.0865, 1.4328, 1.6054, 0.0000,
         0.6112],
        [0.0000, 1.9969, 0.0000, 0.0000, 0.7047, 0.0000, 0.0000, 0.0000, 1.4536,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.4238, 0.0000, 0.0000, 0.7856, 0.9965,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0998, 0.0000, 0.0000, 0.0446, 0.2217,
         0.0000],
        [0.0000, 0.8729, 0.0000, 0.0000, 0.4979, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0

Epoch 16:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.49]

tensor([[1.3676, 0.0000, 1.2455, 0.2043, 0.3836, 0.1375, 1.3214, 0.9458, 0.0000,
         0.4851],
        [0.0000, 0.6314, 0.0000, 0.0000, 0.0985, 0.1433, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6741, 0.7573, 0.0000, 0.0000, 0.8950, 0.0512, 0.0000, 0.1235, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0345, 0.1037, 0.2457, 0.0000, 0.3459, 0.1249, 0.3425,
         0.0000],
        [0.0000, 0.7666, 0.0000, 0.0743, 0.0000, 0.0000, 0.0000, 0.0000, 0.8267,
         0.0000],
        [0.0000, 0.0000, 1.0833, 0.0255, 0.0000, 0.3350, 0.0000, 0.7089, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5153, 0.0000, 0.0000, 0.0670, 0.0000, 0.7478, 0.1173,
         0.0000],
        [0.0000, 0.0000, 0.1599, 0.3298, 0.0000, 0.1725, 0.0000, 0.1409, 0.4196,
         0.0000],
        [0.0000, 2.3639, 0.0000, 0.0000, 0.0000, 0.1297, 0.0000, 0.0000, 0.6692,
         0.0000],
        [0.1892, 0.4989, 0.0000, 0.0000, 0.9579, 0.0000, 0.0000, 0.2306, 0.0000,
         0.0000],
        [0

Epoch 16: 100%|██████| 10/10 [00:00<00:00, 100.75batch/s, accuracy=0, loss=1.67]
Epoch 17:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.53]

tensor([[0.0000, 0.8735, 0.0000, 1.1593, 0.0000, 0.0000, 0.0000, 0.0000, 1.7892,
         0.0000],
        [0.1508, 1.0573, 0.0000, 0.6779, 0.0000, 1.3162, 0.0311, 0.0000, 0.0000,
         0.0189],
        [1.1781, 0.0000, 3.4487, 0.7697, 0.0000, 1.2458, 1.4662, 2.6546, 0.0000,
         0.8750],
        [0.0000, 0.5149, 0.0000, 0.3064, 0.1732, 0.0000, 0.5570, 0.0000, 0.6053,
         0.0458],
        [0.6629, 0.0000, 1.8360, 0.0000, 0.1415, 0.2826, 0.7409, 2.0626, 0.0000,
         0.5295],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3713, 0.0000, 0.0000, 0.2149, 0.4509,
         0.0000],
        [0.3792, 0.0000, 1.8864, 2.0790, 0.0000, 0.8924, 2.1346, 0.7616, 0.0000,
         0.8381],
        [0.0000, 0.8400, 0.0000, 0.0000, 0.5074, 0.0000, 0.0000, 0.0000, 0.4719,
         0.0000],
        [0.0000, 0.5362, 0.0000, 0.0000, 0.4810, 0.3101, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.4350, 0.0000, 0.0000, 0.4865, 0.0000, 0.0000, 0.0000, 0.5060,
         0.0000],
        [0

Epoch 17:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.55]

tensor([[0.2022, 2.2002, 0.0000, 0.0719, 0.3051, 0.0000, 0.8707, 0.0000, 0.0210,
         0.2961],
        [0.8318, 0.4275, 0.6633, 0.0000, 0.6699, 0.0000, 0.3456, 0.0181, 0.0142,
         0.0000],
        [0.0000, 1.0197, 0.0000, 1.6331, 0.0000, 0.0000, 1.5856, 0.0000, 0.7030,
         0.5208],
        [0.0000, 0.3506, 0.0000, 0.3735, 0.0000, 0.0000, 0.0000, 0.0000, 1.0547,
         0.0000],
        [0.0000, 3.0521, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0838,
         0.0000],
        [0.0000, 0.0000, 0.2920, 1.1364, 0.0000, 0.4330, 0.2945, 0.0000, 0.4681,
         0.1267],
        [0.4141, 0.0000, 2.5224, 0.0000, 0.8477, 0.0000, 0.0000, 2.4564, 0.0127,
         0.0000],
        [0.2419, 0.0000, 1.9892, 0.8860, 0.0000, 0.0000, 0.8777, 1.1309, 0.0000,
         0.2097],
        [0.0000, 0.0000, 0.6504, 0.3387, 0.0000, 0.0000, 0.1661, 0.2266, 0.2600,
         0.0000],
        [0.0000, 0.0000, 0.4747, 0.0000, 0.2419, 0.0000, 0.0000, 0.5544, 0.1343,
         0.0000],
        [0

Epoch 17:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.61]

tensor([[0.0000, 2.3612, 0.0000, 0.0000, 0.2707, 0.0000, 0.0000, 0.0000, 1.4414,
         0.0000],
        [0.6342, 1.4230, 0.0000, 0.0000, 0.9507, 0.2230, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1667, 0.0000, 0.4848, 1.3730, 0.0000, 0.0450, 1.5434, 0.0000, 0.0983,
         0.4424],
        [0.0000, 0.2555, 0.0000, 0.3239, 0.0000, 0.0000, 0.0263, 0.0000, 0.7313,
         0.0000],
        [0.0000, 1.5160, 0.0000, 0.9531, 0.0000, 0.1622, 0.0000, 0.0000, 0.9854,
         0.0000],
        [1.4111, 0.0000, 1.6735, 0.0858, 0.4524, 0.0985, 1.4315, 1.6003, 0.0000,
         0.6155],
        [0.0000, 2.0121, 0.0000, 0.0000, 0.7197, 0.0000, 0.0000, 0.0000, 1.4646,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.4323, 0.0000, 0.0000, 0.7809, 0.9937,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.1042, 0.0000, 0.0000, 0.0429, 0.2220,
         0.0000],
        [0.0000, 0.8824, 0.0000, 0.0000, 0.5072, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0

Epoch 17:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.4]

tensor([[1.3641, 0.0000, 1.2531, 0.2006, 0.3835, 0.1452, 1.3203, 0.9437, 0.0000,
         0.4871],
        [0.0000, 0.6390, 0.0000, 0.0000, 0.1023, 0.1389, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6733, 0.7642, 0.0000, 0.0000, 0.9021, 0.0454, 0.0000, 0.1208, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0394, 0.1008, 0.2489, 0.0000, 0.3471, 0.1231, 0.3403,
         0.0000],
        [0.0000, 0.7769, 0.0000, 0.0720, 0.0000, 0.0000, 0.0000, 0.0000, 0.8294,
         0.0000],
        [0.0000, 0.0000, 1.0895, 0.0221, 0.0000, 0.3393, 0.0000, 0.7066, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5200, 0.0000, 0.0000, 0.0694, 0.0000, 0.7461, 0.1165,
         0.0000],
        [0.0000, 0.0000, 0.1626, 0.3252, 0.0000, 0.1744, 0.0000, 0.1403, 0.4189,
         0.0000],
        [0.0000, 2.3862, 0.0000, 0.0000, 0.0000, 0.1130, 0.0000, 0.0000, 0.6779,
         0.0000],
        [0.1884, 0.5039, 0.0000, 0.0000, 0.9657, 0.0000, 0.0000, 0.2286, 0.0000,
         0.0000],
        [0

Epoch 17: 100%|██████| 10/10 [00:00<00:00, 108.98batch/s, accuracy=0, loss=1.67]


tensor([[3.4370e-01, 0.0000e+00, 8.3660e-01, 6.7272e-01, 1.9311e-01, 0.0000e+00,
         1.9029e+00, 1.4933e+00, 0.0000e+00, 7.7759e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.6263e-01, 0.0000e+00, 1.2943e+00,
         0.0000e+00, 0.0000e+00, 1.4965e-01, 0.0000e+00],
        [0.0000e+00, 7.0299e-01, 0.0000e+00, 3.4351e-01, 0.0000e+00, 1.0524e-01,
         0.0000e+00, 0.0000e+00, 1.0823e+00, 0.0000e+00],
        [2.0060e-01, 0.0000e+00, 1.2769e+00, 7.9384e-01, 0.0000e+00, 3.1983e-01,
         7.1395e-01, 5.0556e-01, 0.0000e+00, 2.1288e-01],
        [0.0000e+00, 3.2679e-01, 0.0000e+00, 0.0000e+00, 1.3593e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 3.0593e-01, 0.0000e+00],
        [7.8822e-01, 0.0000e+00, 3.8323e+00, 1.2977e-01, 5.6906e-02, 1.7674e-01,
         6.8147e-01, 3.2568e+00, 0.0000e+00, 3.9857e-01],
        [1.0679e+00, 7.3748e-01, 0.0000e+00, 0.0000e+00, 1.2985e-01, 7.2732e-01,
         1.0570e+00, 1.6435e-01, 0.0000e+00, 6.0154e-01],
        [0.0000e+00, 5.4098

Epoch 18:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.47]

tensor([[0.0000, 0.8877, 0.0000, 1.1541, 0.0000, 0.0000, 0.0000, 0.0000, 1.7881,
         0.0000],
        [0.1492, 1.0672, 0.0000, 0.6708, 0.0000, 1.3042, 0.0400, 0.0000, 0.0000,
         0.0197],
        [1.1723, 0.0000, 3.4623, 0.7637, 0.0000, 1.2649, 1.4504, 2.6572, 0.0000,
         0.8692],
        [0.0000, 0.5196, 0.0000, 0.3035, 0.1755, 0.0000, 0.5606, 0.0000, 0.6039,
         0.0511],
        [0.6591, 0.0000, 1.8458, 0.0000, 0.1495, 0.2965, 0.7285, 2.0639, 0.0000,
         0.5260],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3768, 0.0000, 0.0000, 0.2137, 0.4499,
         0.0000],
        [0.3751, 0.0000, 1.8947, 2.0704, 0.0000, 0.9041, 2.1315, 0.7644, 0.0000,
         0.8403],
        [0.0000, 0.8480, 0.0000, 0.0000, 0.5110, 0.0000, 0.0000, 0.0000, 0.4715,
         0.0000],
        [0.0000, 0.5464, 0.0000, 0.0000, 0.4864, 0.2980, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.4475, 0.0000, 0.0000, 0.4889, 0.0000, 0.0000, 0.0000, 0.5060,
         0.0000],
        [0

Epoch 18:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.42]

tensor([[0.2012, 2.2127, 0.0000, 0.0683, 0.3014, 0.0000, 0.8813, 0.0000, 0.0183,
         0.3041],
        [0.8302, 0.4267, 0.6696, 0.0000, 0.6739, 0.0000, 0.3481, 0.0162, 0.0122,
         0.0000],
        [0.0000, 1.0284, 0.0000, 1.6247, 0.0000, 0.0000, 1.5960, 0.0000, 0.7002,
         0.5321],
        [0.0000, 0.3585, 0.0000, 0.3690, 0.0000, 0.0000, 0.0000, 0.0000, 1.0506,
         0.0000],
        [0.0000, 3.0759, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0771,
         0.0000],
        [0.0000, 0.0000, 0.2975, 1.1282, 0.0000, 0.4319, 0.2962, 0.0000, 0.4661,
         0.1303],
        [0.4113, 0.0000, 2.5431, 0.0000, 0.8631, 0.0000, 0.0000, 2.4530, 0.0143,
         0.0000],
        [0.2386, 0.0000, 2.0043, 0.8801, 0.0000, 0.0000, 0.8725, 1.1293, 0.0000,
         0.2110],
        [0.0000, 0.0000, 0.6579, 0.3344, 0.0000, 0.0000, 0.1657, 0.2248, 0.2585,
         0.0000],
        [0.0000, 0.0000, 0.4815, 0.0000, 0.2475, 0.0000, 0.0000, 0.5520, 0.1332,
         0.0000],
        [0

Epoch 18:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.7]

tensor([[0.0000e+00, 9.9649e-01, 0.0000e+00, 1.5379e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.2972e+00, 0.0000e+00],
        [0.0000e+00, 4.0461e+00, 0.0000e+00, 8.4148e-01, 0.0000e+00, 4.6394e-01,
         0.0000e+00, 0.0000e+00, 1.0693e+00, 0.0000e+00],
        [1.8654e+00, 0.0000e+00, 3.1606e+00, 0.0000e+00, 1.6846e+00, 0.0000e+00,
         7.9752e-02, 3.1661e+00, 0.0000e+00, 0.0000e+00],
        [3.0287e-01, 0.0000e+00, 3.9843e+00, 6.1538e-01, 0.0000e+00, 0.0000e+00,
         8.1960e-01, 3.1877e+00, 0.0000e+00, 2.8625e-01],
        [0.0000e+00, 2.0429e+00, 0.0000e+00, 7.5428e-01, 0.0000e+00, 0.0000e+00,
         1.1380e+00, 0.0000e+00, 5.5497e-01, 3.7359e-01],
        [1.1584e+00, 1.1078e-01, 3.2567e-01, 0.0000e+00, 8.2555e-01, 0.0000e+00,
         1.1839e+00, 8.9924e-01, 0.0000e+00, 4.8956e-01],
        [2.0529e-01, 0.0000e+00, 1.5918e+00, 4.7679e-01, 5.5050e-02, 0.0000e+00,
         8.8714e-01, 1.3753e+00, 4.1462e-02, 2.7610e-01],
        [0.0000e+00, 2.2533

Epoch 18:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.4]

tensor([[0.0000, 2.3777, 0.0000, 0.0000, 0.2628, 0.0000, 0.0000, 0.0000, 1.4316,
         0.0000],
        [0.6369, 1.4316, 0.0000, 0.0000, 0.9500, 0.2129, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1615, 0.0000, 0.4920, 1.3627, 0.0000, 0.0475, 1.5505, 0.0000, 0.0985,
         0.4493],
        [0.0000, 0.2601, 0.0000, 0.3178, 0.0000, 0.0000, 0.0291, 0.0000, 0.7283,
         0.0000],
        [0.0000, 1.5266, 0.0000, 0.9429, 0.0000, 0.1455, 0.0000, 0.0000, 0.9796,
         0.0000],
        [1.4016, 0.0000, 1.6876, 0.0817, 0.4601, 0.1097, 1.4313, 1.5959, 0.0000,
         0.6166],
        [0.0000, 2.0235, 0.0000, 0.0000, 0.7156, 0.0000, 0.0000, 0.0000, 1.4583,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.4389, 0.0000, 0.0000, 0.7772, 0.9919,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.1059, 0.0000, 0.0000, 0.0404, 0.2198,
         0.0000],
        [0.0000, 0.8893, 0.0000, 0.0000, 0.5066, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0

Epoch 18:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.39]

tensor([[1.3538, 0.0000, 1.2639, 0.1952, 0.3862, 0.1520, 1.3255, 0.9393, 0.0000,
         0.4900],
        [0.0000, 0.6426, 0.0000, 0.0000, 0.1025, 0.1328, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6750, 0.7684, 0.0000, 0.0000, 0.9041, 0.0416, 0.0000, 0.1187, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0452, 0.0961, 0.2505, 0.0000, 0.3502, 0.1207, 0.3390,
         0.0000],
        [0.0000, 0.7810, 0.0000, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000, 0.8267,
         0.0000],
        [0.0000, 0.0000, 1.1003, 0.0158, 0.0000, 0.3406, 0.0000, 0.7013, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5288, 0.0000, 0.0000, 0.0690, 0.0000, 0.7414, 0.1139,
         0.0000],
        [0.0000, 0.0000, 0.1694, 0.3173, 0.0000, 0.1709, 0.0000, 0.1372, 0.4165,
         0.0000],
        [0.0000, 2.3941, 0.0000, 0.0000, 0.0000, 0.0916, 0.0000, 0.0000, 0.6779,
         0.0000],
        [0.1876, 0.5067, 0.0000, 0.0000, 0.9682, 0.0000, 0.0000, 0.2269, 0.0000,
         0.0000],
        [0

Epoch 18:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.43]

tensor([[0.0000, 0.6480, 0.0000, 1.4347, 0.0000, 0.2058, 0.2049, 0.0000, 1.1941,
         0.0614],
        [0.0000, 0.0000, 1.4156, 0.2441, 0.0594, 0.0000, 0.0267, 0.9183, 0.4687,
         0.0000],
        [0.4086, 0.6157, 0.0000, 1.1205, 0.0000, 0.2676, 1.8049, 0.0000, 0.0000,
         0.6934],
        [0.0000, 0.9695, 0.0000, 0.0000, 0.1568, 0.0000, 0.0000, 0.0000, 0.3370,
         0.0000],
        [0.0000, 0.6002, 0.0000, 0.0000, 0.1104, 0.0000, 0.0000, 0.0000, 0.3668,
         0.0000],
        [0.6175, 0.0000, 3.3346, 0.1291, 0.5639, 0.0000, 1.4954, 3.7394, 0.0000,
         0.6692],
        [0.0000, 0.0000, 1.3506, 1.1375, 0.0000, 0.2538, 0.7176, 0.8771, 0.2720,
         0.3193],
        [0.0000, 1.1153, 0.0000, 0.0000, 0.6437, 0.0000, 0.0000, 0.0000, 0.7420,
         0.0000],
        [0.0000, 0.6529, 0.0000, 0.4651, 0.0000, 0.0000, 0.4601, 0.0000, 0.5995,
         0.0745],
        [0.1151, 1.2501, 0.0000, 0.9289, 0.0000, 0.4408, 1.0991, 0.0000, 0.0000,
         0.3869],
        [0

Epoch 18: 100%|██████| 10/10 [00:00<00:00, 105.86batch/s, accuracy=0, loss=1.61]
Epoch 19:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.45]

tensor([[0.0000, 0.8884, 0.0000, 1.1419, 0.0000, 0.0000, 0.0000, 0.0000, 1.7885,
         0.0000],
        [0.1567, 1.0699, 0.0000, 0.6609, 0.0000, 1.2914, 0.0464, 0.0000, 0.0000,
         0.0269],
        [1.1457, 0.0000, 3.4931, 0.7528, 0.0000, 1.2755, 1.4554, 2.6455, 0.0000,
         0.8715],
        [0.0000, 0.5193, 0.0000, 0.2970, 0.1763, 0.0000, 0.5645, 0.0000, 0.6046,
         0.0536],
        [0.6398, 0.0000, 1.8674, 0.0000, 0.1500, 0.3041, 0.7298, 2.0561, 0.0000,
         0.5247],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3782, 0.0000, 0.0000, 0.2115, 0.4486,
         0.0000],
        [0.3580, 0.0000, 1.9133, 2.0526, 0.0000, 0.9055, 2.1414, 0.7581, 0.0000,
         0.8480],
        [0.0000, 0.8500, 0.0000, 0.0000, 0.5126, 0.0000, 0.0000, 0.0000, 0.4720,
         0.0000],
        [0.0000, 0.5503, 0.0000, 0.0000, 0.4889, 0.2910, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.4490, 0.0000, 0.0000, 0.4911, 0.0000, 0.0000, 0.0000, 0.5097,
         0.0000],
        [0

Epoch 19:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.49]

tensor([[0.2090, 2.2110, 0.0000, 0.0621, 0.3043, 0.0000, 0.8854, 0.0000, 0.0287,
         0.3055],
        [0.8225, 0.4287, 0.6757, 0.0000, 0.6729, 0.0000, 0.3549, 0.0128, 0.0085,
         0.0000],
        [0.0000, 1.0253, 0.0000, 1.6094, 0.0000, 0.0000, 1.6024, 0.0000, 0.7061,
         0.5365],
        [0.0000, 0.3585, 0.0000, 0.3614, 0.0000, 0.0000, 0.0000, 0.0000, 1.0519,
         0.0000],
        [0.0000, 3.0745, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0924,
         0.0000],
        [0.0000, 0.0000, 0.3048, 1.1160, 0.0000, 0.4247, 0.2997, 0.0000, 0.4637,
         0.1343],
        [0.3852, 0.0000, 2.5718, 0.0000, 0.8624, 0.0000, 0.0000, 2.4431, 0.0000,
         0.0000],
        [0.2198, 0.0000, 2.0250, 0.8707, 0.0000, 0.0000, 0.8784, 1.1221, 0.0000,
         0.2173],
        [0.0000, 0.0000, 0.6671, 0.3276, 0.0000, 0.0000, 0.1696, 0.2211, 0.2541,
         0.0000],
        [0.0000, 0.0000, 0.4903, 0.0000, 0.2484, 0.0000, 0.0000, 0.5482, 0.1297,
         0.0000],
        [0

Epoch 19:   0%|                 | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.5]

tensor([[0.3462, 0.0000, 1.7927, 1.3794, 0.0000, 1.2369, 1.7558, 1.2914, 0.0000,
         0.9769],
        [0.0000, 1.3708, 0.0000, 0.0000, 0.3797, 0.0947, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.9674, 0.0000, 0.0000, 0.0000, 0.6143, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 2.1222, 0.0000, 0.0000, 0.2919, 0.0000, 0.0000, 0.0000, 1.4362,
         0.0000],
        [0.0000, 3.2133, 0.0000, 0.0000, 0.0000, 0.1451, 0.0000, 0.0000, 0.7483,
         0.0000],
        [0.0000, 0.0000, 1.0635, 0.5713, 0.0000, 1.2949, 0.0000, 0.0899, 0.0000,
         0.0976],
        [0.0000, 1.3104, 0.0000, 0.0000, 0.5967, 0.0000, 0.0000, 0.0000, 0.6487,
         0.0000],
        [0.0000, 1.8787, 0.0000, 0.1493, 0.1237, 0.0000, 0.8145, 0.0000, 0.1035,
         0.3280],
        [0.5683, 0.4871, 0.3103, 0.1714, 0.0000, 1.1549, 0.2654, 0.0000, 0.0000,
         0.2504],
        [0.0000, 1.4722, 0.0000, 0.6730, 0.0000, 0.4001, 0.0000, 0.0000, 1.1706,
         0.0000],
        [0

Epoch 19:   0%|                | 0/10 [00:00<?, ?batch/s, accuracy=0, loss=1.48]

tensor([[1.3496, 0.0000, 1.2782, 0.1906, 0.3893, 0.1524, 1.3287, 0.9351, 0.0000,
         0.5006],
        [0.0000, 0.6486, 0.0000, 0.0000, 0.1027, 0.1304, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6764, 0.7752, 0.0000, 0.0000, 0.9048, 0.0432, 0.0000, 0.1167, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0517, 0.0925, 0.2540, 0.0000, 0.3510, 0.1193, 0.3374,
         0.0000],
        [0.0000, 0.7879, 0.0000, 0.0620, 0.0000, 0.0000, 0.0000, 0.0000, 0.8274,
         0.0000],
        [0.0000, 0.0000, 1.1129, 0.0127, 0.0000, 0.3391, 0.0000, 0.6962, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.5389, 0.0000, 0.0000, 0.0671, 0.0000, 0.7373, 0.1101,
         0.0000],
        [0.0000, 0.0000, 0.1766, 0.3128, 0.0000, 0.1659, 0.0000, 0.1347, 0.4134,
         0.0000],
        [0.0000, 2.4093, 0.0000, 0.0000, 0.0000, 0.0832, 0.0000, 0.0000, 0.6847,
         0.0000],
        [0.1875, 0.5106, 0.0000, 0.0000, 0.9717, 0.0000, 0.0000, 0.2261, 0.0000,
         0.0000],
        [0

Epoch 19: 100%|██████| 10/10 [00:00<00:00, 114.33batch/s, accuracy=0, loss=1.66]

tensor([[0.0000, 0.6578, 0.0000, 1.4239, 0.0000, 0.1956, 0.2065, 0.0000, 1.1939,
         0.0629],
        [0.0000, 0.0000, 1.4322, 0.2422, 0.0656, 0.0000, 0.0253, 0.9136, 0.4639,
         0.0000],
        [0.4079, 0.6183, 0.0000, 1.1113, 0.0000, 0.2614, 1.8082, 0.0000, 0.0000,
         0.7027],
        [0.0000, 0.9801, 0.0000, 0.0000, 0.1551, 0.0000, 0.0000, 0.0000, 0.3378,
         0.0000],
        [0.0000, 0.6085, 0.0000, 0.0000, 0.1091, 0.0000, 0.0000, 0.0000, 0.3667,
         0.0000],
        [0.6088, 0.0000, 3.3774, 0.1321, 0.5819, 0.0000, 1.4907, 3.7292, 0.0000,
         0.6818],
        [0.0000, 0.0000, 1.3686, 1.1314, 0.0000, 0.2486, 0.7172, 0.8724, 0.2669,
         0.3284],
        [0.0000, 1.1226, 0.0000, 0.0000, 0.6457, 0.0000, 0.0000, 0.0000, 0.7428,
         0.0000],
        [0.0000, 0.6580, 0.0000, 0.4594, 0.0000, 0.0000, 0.4610, 0.0000, 0.5997,
         0.0778],
        [0.1161, 1.2584, 0.0000, 0.9194, 0.0000, 0.4353, 1.1030, 0.0000, 0.0000,
         0.3936],
        [0




### Built-In

In [8]:
import torch
from torch.nn import functional as F
from tqdm import tqdm
from dgl.dataloading import GraphDataLoader
from nclustRL.models import HeteroClassifier

dgl.seed(5)

def test_embedings(graphs):

    batch_size=1
    shuffle=True
    nclasses = 5
    n = 5

    # dataloader = GraphDataLoader(
    #     base,
    #     batch_size=batch_size,
    #     drop_last=False,
    #     shuffle=shuffle)

    etypes = graphs[0].etypes

    model = HeteroClassifier(n, [n*2], nclasses, etypes)
    model = model.cuda()
    opt = torch.optim.Adam(model.parameters())


    for epoch in range(20):
        with tqdm(graphs, unit="batch") as tepoch:
            for batched_graph in tepoch:

                tepoch.set_description(f"Epoch {epoch}")

                # batched_graph = transform_obs(n, batched_graph)
                labels = torch.randint(0, 4, (batch_size,)).to('cuda:0')

                logits = model(batched_graph)
                loss = F.cross_entropy(logits, labels)

                predictions = logits.argmax(dim=1, keepdim=True).squeeze()
                correct = (logits == labels).sum().item()

                opt.zero_grad()
                loss.backward()
                opt.step()

                accuracy = correct / batch_size
                tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)

In [9]:
test_embedings(graphs)

Epoch 0: 100%|███████| 10/10 [00:00<00:00, 155.50batch/s, accuracy=0, loss=1.48]
Epoch 1: 100%|███████| 10/10 [00:00<00:00, 192.21batch/s, accuracy=0, loss=1.48]
Epoch 2: 100%|███████| 10/10 [00:00<00:00, 203.02batch/s, accuracy=0, loss=1.47]
Epoch 3: 100%|███████| 10/10 [00:00<00:00, 160.07batch/s, accuracy=0, loss=1.47]
Epoch 4: 100%|███████| 10/10 [00:00<00:00, 169.52batch/s, accuracy=0, loss=1.97]
Epoch 5: 100%|███████| 10/10 [00:00<00:00, 195.97batch/s, accuracy=0, loss=1.96]
Epoch 6: 100%|███████| 10/10 [00:00<00:00, 227.20batch/s, accuracy=0, loss=1.85]
Epoch 7: 100%|███████| 10/10 [00:00<00:00, 229.85batch/s, accuracy=0, loss=1.95]
Epoch 8: 100%|███████| 10/10 [00:00<00:00, 233.45batch/s, accuracy=0, loss=1.94]
Epoch 9: 100%|███████| 10/10 [00:00<00:00, 202.54batch/s, accuracy=0, loss=1.46]
Epoch 10: 100%|██████| 10/10 [00:00<00:00, 208.05batch/s, accuracy=0, loss=1.84]
Epoch 11: 100%|██████| 10/10 [00:00<00:00, 183.09batch/s, accuracy=0, loss=1.46]
Epoch 12: 100%|██████| 10/10

## Generate Graphs

### Explicit

In [23]:
    import torch as th
    import dgl
    def loader(cls, module=None):

        return getattr(module, cls) if isinstance(cls, str) else cls
    
    def dense_to_dgl(x, device, cuda=0, nclusters=1, clust_init='zeros', duplicate=True):

        # set (u,v)
        clust_init = loader(th, clust_init)

        tensor = th.tensor([[i, j, elem] for i, row in enumerate(x) for j, elem in enumerate(row)]).T

        if duplicate:

            graph_data = {
                ('row', 'elem', 'col'): (tensor[0].int(), tensor[1].int()),
                ('col', 'elem', 'row'): (tensor[1].int().detach().clone(), tensor[2].int().detach().clone()),
                }

            # create graph
            G = dgl.heterograph(graph_data)

            # set weights
            G.edges[('row', 'elem', 'col')].data['w'] = tensor[2].to(th.float64)
            G.edges[('col', 'elem', 'row')].data['w'] = tensor[2].to(th.float64)

        else:

            graph_data = {
                ('row', 'elem', 'col'): (tensor[0].int(), tensor[1].int()),
                }

            # create graph
            G = dgl.heterograph(graph_data)

            # set weights
            # G.edges[('row', 'elem', 'col')].data['w'] = tensor[2].float()

            G.edata['w'] = tensor[2].to(th.float64)

        # set cluster members

        for n, axis in enumerate(['row', 'col']):
            for i in range(nclusters):
                G.nodes[axis].data[i] = th.randint(0, 2, (x.shape[n],), dtype=torch.bool)

        ndata = {}
        ntypes = G.ntypes
        keys = sorted(list(G.nodes[ntypes[0]].data.keys()))

        for ntype in ntypes:
            ndata[ntype] = torch.vstack(
                [G.ndata[key][ntype].to(th.float64) for key in keys]
            ).transpose(0, 1)

            G.nodes[ntype].data.clear()

        G.ndata['feat'] = ndata

        if device == 'gpu':
            G = G.to('cuda:{}'.format(cuda))

        return G

In [24]:
import nclustenv
import torch
env = nclustenv.make('BiclusterEnv-v0', **dict(shape=[[100, 10], [110, 15]], clusters=[5,5]))

graphs_dup = []
graphs2 = []
for i in range(10):
    env.reset()
    X = env.state._generator.X
    graphs_dup.append(dense_to_dgl(X, device='gpu', nclusters=5))
    graphs2.append(dense_to_dgl(X, device='gpu', nclusters=5, duplicate=False))

In [25]:
graphs_dup[0]

Graph(num_nodes={'col': 13, 'row': 102},
      num_edges={('col', 'elem', 'row'): 1326, ('row', 'elem', 'col'): 1326},
      metagraph=[('col', 'row', 'elem'), ('row', 'col', 'elem')])

In [26]:
graphs2[0]

Graph(num_nodes={'col': 13, 'row': 102},
      num_edges={('row', 'elem', 'col'): 1326},
      metagraph=[('row', 'col', 'elem')])

In [27]:
graphs2[0].edges[('row', 'elem', 'col')].data['w']

tensor([ 5.1700,  4.7800,  2.2900,  ...,  4.4100,  8.2300, -9.9300],
       device='cuda:0', dtype=torch.float64)

### Bult-In

In [19]:
import nclustenv
import torch
from gym.wrappers import TransformObservation
from nclustRL.utils.helper import transform_obs

def randint(size, dtype):
    return torch.randint(low=0, high=2, size=[size], dtype=dtype)

env = nclustenv.make('BiclusterEnv-v0', **dict(n=5, shape=[[100, 10], [110, 15]], clusters=[5,5], clust_init=randint))
env = TransformObservation(env, transform_obs)

graphs = [env.reset()['state'] for _ in range(10)]

In [20]:
graphs[0].edges[('row', 'elem', 'col')].data['w']

tensor([-1.3200, -9.7900, -3.0600,  ...,  8.3300, -2.6600,  5.3300],
       device='cuda:0')

In [21]:
graphs[0].ndata['feat']

{'col': tensor([[0., 1., 1., 0., 0.],
         [0., 1., 0., 0., 0.],
         [1., 0., 1., 0., 1.],
         [1., 0., 1., 1., 1.],
         [1., 0., 0., 0., 0.],
         [1., 0., 1., 1., 1.],
         [0., 1., 0., 1., 0.],
         [1., 1., 0., 0., 1.],
         [1., 1., 1., 1., 0.],
         [1., 1., 0., 0., 1.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 0., 1.]], device='cuda:0'),
 'row': tensor([[0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0.],
         [1., 0., 1., 1., 1.],
         [1., 0., 0., 1., 1.],
         [0., 1., 0., 0., 0.],
         [1., 0., 1., 0., 0.],
         [1., 0., 0., 1., 0.],
         [0., 1., 0., 1., 1.],
         [0., 0., 0., 0., 1.],
         [0., 1., 0., 1., 1.],
         [0., 0., 0., 1., 0.],
         [1., 1., 0., 1., 1.],
         [1., 0., 0., 1., 0.],
         [0., 1., 1., 0., 1.],
         [0., 0., 1., 0., 1.],
         [1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1.],
         [1., 0., 0., 1., 1.],
      

In [11]:
env.observation_space['state'].n

5

In [12]:
env.action_space[0].n

4

In [25]:
np.prod(env.action_space[1].shape)

12

In [24]:
import numpy as np