In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchph
import torchph.nn.slayer
import gtda.diagrams

In [3]:
import numpy
import torchvision

train = torchvision.datasets.MNIST('mnist', train = True, download = False)
test = torchvision.datasets.MNIST('mnist', train = False, download = False)

train_labels = numpy.array([ item[1] for item in train ])
test_labels = numpy.array([ item[1] for item in test ])

train_labels.shape, test_labels.shape

((60000,), (10000,))

In [4]:
import numpy

train_diagrams = numpy.load("8/train/filtration_diagrams.npy")
test_diagrams = numpy.load("8/test/filtration_diagrams.npy")

train_diagrams.shape, test_diagrams.shape

((2040000, 80, 3), (340000, 72, 3))

In [5]:
scaler = gtda.diagrams.Scaler()
train_diagrams = scaler.fit_transform(train_diagrams)
test_diagrams = scaler.transform(test_diagrams)

train_diagrams.shape, test_diagrams.shape

((2040000, 80, 3), (340000, 72, 3))

In [6]:
train_diagrams = torch.tensor(train_diagrams, dtype = torch.float32)
test_diagrams = torch.tensor(test_diagrams, dtype = torch.float32)

train_diagrams.shape, test_diagrams.shape

(torch.Size([2040000, 80, 3]), torch.Size([340000, 72, 3]))

In [7]:
train_diagrams = train_diagrams.reshape((train_labels.shape[0], -1, *train_diagrams.shape[1:]))
test_diagrams = test_diagrams.reshape((test_labels.shape[0], -1, *test_diagrams.shape[1:]))

train_diagrams.shape, test_diagrams.shape

(torch.Size([60000, 34, 80, 3]), torch.Size([10000, 34, 72, 3]))

In [8]:
import tqdm
import joblib

def process(data_diagrams):
    def transform(diagram, dim):
        dim_filter = (diagram[:, 2] == dim)
        non_degenerate_filter = (diagram[:, 0] != diagram[:, 1])
        rotation = torchph.nn.slayer.UpperDiagonalThresholdedLogTransform(0.05)
        return rotation(diagram[dim_filter & non_degenerate_filter][:, 0:2])

    data = [ ]
    for i in tqdm.trange(data_diagrams.shape[1]):
        for dim in [ 0, 1 ]:
            diagrams = joblib.Parallel(n_jobs = 1)(
                joblib.delayed(transform)(item[i], dim)
                for item in data_diagrams
            )
            diagrams, non_dummy_points, _, _ = torchph.nn.slayer.prepare_batch(diagrams)
            data.append(diagrams)
            data.append(non_dummy_points)
    return data

train_diagrams = process(train_diagrams)
test_diagrams = process(test_diagrams)
 
print(len(train_diagrams), train_diagrams[0].shape, train_diagrams[1].shape)
print(len(test_diagrams), test_diagrams[0].shape, test_diagrams[1].shape)

100%|██████████| 34/34 [07:14<00:00, 12.79s/it]
100%|██████████| 34/34 [01:13<00:00,  2.16s/it]

136 torch.Size([60000, 28, 2]) torch.Size([60000, 28])
136 torch.Size([10000, 22, 2]) torch.Size([10000, 22])





In [9]:
device = torch.device('cuda')

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.n_elements = 128
        self.num_diagrams = 68
        
        self.slayers = [
            torchph.nn.slayer.SLayerExponential(self.n_elements).to(device)
            for _ in range(self.num_diagrams)
        ]

        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.4), torch.nn.Linear(self.num_diagrams * self.n_elements, 256), torch.nn.BatchNorm1d(256), torch.nn.ReLU(),
            torch.nn.Dropout(0.3), torch.nn.Linear(256, 128), torch.nn.BatchNorm1d(128), torch.nn.ReLU(),
            torch.nn.Dropout(0.2), torch.nn.Linear(128, 64), torch.nn.BatchNorm1d(64), torch.nn.ReLU(),
            torch.nn.Dropout(0.1), torch.nn.Linear(64, 32), torch.nn.BatchNorm1d(32), torch.nn.ReLU(),
            torch.nn.Linear(32, 10), torch.nn.Softmax(dim = 1)
        ).to(device)
    
    def forward(self, *args):
        features = [ ]
        for i in range(0, len(args), 2):
            slayer_args = (args[i].to(device), args[i + 1].to(device), args[i].shape[1], len(args[i]))
            features.append(self.slayers[i // 2](slayer_args))
        return self.classifier(torch.cat(features, dim = 1))

In [10]:
import sklearn.metrics
import cvtda.utils
cvtda.utils.set_random_seed(42)

train_dl = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(train_labels, device = device, dtype = torch.long),
        *train_diagrams
    ),
    batch_size = 128,
    shuffle = True
)

test_dl = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(test_labels, device = device, dtype = torch.long),
        *test_diagrams
    ),
    batch_size = 128,
    shuffle = False
)

classifier = Net()

optimizer = torch.optim.AdamW(
    params = classifier.parameters(),
    lr = 1e-4
)

for _ in range(100):
    sum_loss = 0

    classifier.train()
    for (y, *args) in tqdm.tqdm(train_dl):
        optimizer.zero_grad()
        pred = classifier(*args)

        loss = torch.nn.functional.cross_entropy(pred, y, reduction = 'sum')
        loss.backward()
        optimizer.step()
        sum_loss += loss.item()
    postfix = { 'loss': sum_loss }

    classifier.eval()
    with torch.no_grad():
        preds = [ ]
        real = [ ]
        for (y, *args) in tqdm.tqdm(test_dl):
            proba = classifier(*args)
            preds.extend(torch.argmax(proba, axis = 1).cpu().detach())
            real.extend(y.cpu())

        postfix['val_acc'] = sklearn.metrics.accuracy_score(preds, real)

    print(postfix)

100%|██████████| 469/469 [01:47<00:00,  4.34it/s]
100%|██████████| 79/79 [00:05<00:00, 13.94it/s]


{'loss': 124345.12350463867, 'val_acc': 0.7756}


100%|██████████| 469/469 [01:47<00:00,  4.36it/s]
100%|██████████| 79/79 [00:05<00:00, 14.05it/s]


{'loss': 110337.82974243164, 'val_acc': 0.825}


100%|██████████| 469/469 [01:48<00:00,  4.34it/s]
100%|██████████| 79/79 [00:05<00:00, 14.11it/s]


{'loss': 103256.84013366699, 'val_acc': 0.9256}


100%|██████████| 469/469 [01:46<00:00,  4.41it/s]
100%|██████████| 79/79 [00:05<00:00, 13.89it/s]


{'loss': 98267.06907653809, 'val_acc': 0.9384}


100%|██████████| 469/469 [01:46<00:00,  4.39it/s]
100%|██████████| 79/79 [00:05<00:00, 13.59it/s]


{'loss': 95823.8133392334, 'val_acc': 0.9328}


100%|██████████| 469/469 [01:43<00:00,  4.52it/s]
100%|██████████| 79/79 [00:05<00:00, 14.52it/s]


{'loss': 94702.98522949219, 'val_acc': 0.9433}


100%|██████████| 469/469 [01:43<00:00,  4.54it/s]
100%|██████████| 79/79 [00:05<00:00, 14.40it/s]


{'loss': 93953.74647521973, 'val_acc': 0.9464}


100%|██████████| 469/469 [01:41<00:00,  4.61it/s]
100%|██████████| 79/79 [00:05<00:00, 14.72it/s]


{'loss': 93445.61747741699, 'val_acc': 0.9468}


100%|██████████| 469/469 [01:43<00:00,  4.52it/s]
100%|██████████| 79/79 [00:05<00:00, 14.21it/s]


{'loss': 93155.09976196289, 'val_acc': 0.9477}


 98%|█████████▊| 460/469 [01:43<00:02,  4.44it/s]


KeyboardInterrupt: 