In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import os
import itertools

import typing

import tqdm
import numpy
import joblib
import gtda.images
import torchvision
import skimage.color

import cvtda.utils
import cvtda.topology

def make_image(image, channel: int) -> numpy.ndarray:
    image = numpy.array(image)
    match channel:
        case 'red':
            image = image[:, :, 0]
        case 'green':
            image = image[:, :, 1]
        case 'blue':
            image = image[:, :, 2]
        case 'gray':
            image = skimage.color.rgb2gray(image)
        case _:
            raise NotImplementedError
    assert image.shape == (32, 32)
    return image

def make_diagrams(
    channel: int,
    binarizer,
    filtration,
    n_jobs: int = 1
) -> typing.Tuple[numpy.ndarray, numpy.ndarray]:
    dir = f"E:/4/{channel}/None/diagrams/{str(filtration or 'None')}"
    if os.path.exists(f"{dir}/test.npy"):
        return 1, 2
    os.makedirs(dir, exist_ok = True)
    
    train = numpy.array(
        joblib.Parallel(n_jobs = n_jobs)(
            joblib.delayed(make_image)(item[0], channel)
            for item in torchvision.datasets.CIFAR10('cifar-10', train = True, download = False)
        )
    )
    test = numpy.array(
        joblib.Parallel(n_jobs = n_jobs)(
            joblib.delayed(make_image)(item[0], channel)
            for item in torchvision.datasets.CIFAR10('cifar-10', train = False, download = False)
        )
    )

    if binarizer is not None:
        train = binarizer.fit_transform(train)
        test = binarizer.transform(test)
        
    if filtration is not None:
        train = filtration.fit_transform(train)
        test = filtration.transform(test)
    
    filtrations_to_diagrams = cvtda.topology.FiltrationsToDiagrams(verbose = False, n_jobs = n_jobs)
    train = filtrations_to_diagrams.fit_transform(train)
    test = filtrations_to_diagrams.transform(test)

    numpy.save(f"{dir}/train.npy", train)
    numpy.save(f"{dir}/test.npy", test)
    return train, test

def process(channel: int, binarizer_threshold: float) -> typing.Tuple[numpy.ndarray, numpy.ndarray]:
    centers = [ 5, 12, 18, 25 ]
    greyscale_to_filtrations = cvtda.topology.GreyscaleToFiltrations(
        n_jobs = 1,
        radial_filtration_centers = list(itertools.product(centers, centers))
    )
    diagrams = joblib.Parallel(return_as = 'generator', n_jobs = -1)(
        joblib.delayed(make_diagrams)(
            channel,
            binarizer = gtda.images.Binarizer(threshold = binarizer_threshold, n_jobs = 1),
            filtration = filtration,
            n_jobs = 1
        )
        for filtration in greyscale_to_filtrations.filtrations_
    )
    for train, test in tqdm.tqdm(diagrams, total = len(greyscale_to_filtrations.filtrations_)):
        pass

    make_diagrams(
        channel,
        binarizer = None,
        filtration = None,
        n_jobs = -1
    )

In [36]:
for channel in [ 'gray' ]:
    print(f'>>> Calculating channel {channel}')
    process(channel, 0.4)

>>> Calculating channel gray


100%|██████████| 24/24 [06:11<00:00, 15.49s/it]  
100%|██████████| 50000/50000 [00:07<00:00, 6645.74it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8253.59it/s]


In [95]:
import tqdm
import joblib
import torch
import gtda.diagrams
import torchph.nn.slayer

def process():
    def transform(diagram, dim):
        dim_filter = (diagram[:, 2] == dim)
        non_degenerate_filter = (diagram[:, 0] != diagram[:, 1])
        rotation = torchph.nn.slayer.UpperDiagonalThresholdedLogTransform(0.05)
        return rotation(diagram[dim_filter & non_degenerate_filter][:, 0:2])

    train_data = [ ]
    test_data = [ ]
    for filtration in tqdm.tqdm(os.listdir(f"E:/4/red/None/diagrams")):
        for dim in [ 0, 1 ]:
            for channel in [ 'red', "green", "blue" ]:
                dir = f"E:/4/{channel}/None/diagrams/{filtration}"
                train_diagrams = numpy.load(f"{dir}/train.npy")
                test_diagrams = numpy.load(f"{dir}/test.npy")

                scaler = gtda.diagrams.Scaler()
                train_diagrams = torch.tensor(scaler.fit_transform(train_diagrams), dtype = torch.float32)
                test_diagrams = torch.tensor(scaler.transform(test_diagrams), dtype = torch.float32)

                diagrams_train = joblib.Parallel(n_jobs = 1)(joblib.delayed(transform)(diagram, dim) for diagram in train_diagrams)
                diagrams, non_dummy_points, _, _ = torchph.nn.slayer.prepare_batch(diagrams_train)
                train_data.append(diagrams)
                train_data.append(non_dummy_points)
                
                diagrams_test = joblib.Parallel(n_jobs = 1)(joblib.delayed(transform)(diagram, dim) for diagram in test_diagrams)
                diagrams, non_dummy_points, _, _ = torchph.nn.slayer.prepare_batch(diagrams_test)
                test_data.append(diagrams)
                test_data.append(non_dummy_points)

    return train_data, test_data

train_diagrams, test_diagrams = process()
 
print(len(train_diagrams), train_diagrams[0].shape, train_diagrams[1].shape)
print(len(test_diagrams), test_diagrams[0].shape, test_diagrams[1].shape)

100%|██████████| 25/25 [18:03<00:00, 43.34s/it]

300 torch.Size([50000, 40, 2]) torch.Size([50000, 40])
300 torch.Size([10000, 38, 2]) torch.Size([10000, 38])





In [96]:
import numpy
import torchvision
import torchvision.transforms.v2

train = torchvision.datasets.CIFAR10('cifar-10', train = True, download = False)
test = torchvision.datasets.CIFAR10('cifar-10', train = False, download = False)

train_labels = numpy.array([ item[1] for item in train ])
test_labels = numpy.array([ item[1] for item in test ])

train_labels.shape, test_labels.shape

((50000,), (10000,))

In [112]:
import torch

device = torch.device('cuda')

class SLayerMultiChannel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.n_elements = 128
        self.n_channels = 3
        self.slayers = [
            torchph.nn.slayer.SLayerExponential(self.n_elements).to(device)
            for _ in range(self.n_channels)
        ]
        
    def forward(self, args):
        features = [ ]
        for i in range(0, len(args), 2):
            slayer_args = (args[i].to(device), args[i + 1].to(device), args[i].shape[1], len(args[i]))
            features.append(self.slayers[i // 2](slayer_args))
        return torch.stack(features, dim = 1)
    


class FiltrationBranch(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
            SLayerMultiChannel(), torch.nn.BatchNorm1d(3), torch.nn.ReLU(),
            torch.nn.Conv1d(3, 64, kernel_size = 1, stride = 1), torch.nn.BatchNorm1d(64), torch.nn.ReLU(),
            torch.nn.Conv1d(64, 32, kernel_size = 1, stride = 1), torch.nn.BatchNorm1d(32), torch.nn.ReLU(),
            torch.nn.Conv1d(32, 16, kernel_size = 1, stride = 1), torch.nn.BatchNorm1d(16), torch.nn.ReLU(),
            torch.nn.Conv1d(16, 8, kernel_size = 1, stride = 1), torch.nn.BatchNorm1d(8), torch.nn.ReLU(),
            torch.nn.Conv1d(8, 4, kernel_size = 1, stride = 1), torch.nn.BatchNorm1d(4), torch.nn.ReLU(),
            torch.nn.Flatten(), torch.nn.Linear(128 * 4, 128), torch.nn.BatchNorm1d(128), torch.nn.ReLU()
        ).to(device)

        
        self.model = torch.nn.Sequential(
            SLayerMultiChannel(), torch.nn.Flatten()
        ).to(device)
        
    def forward(self, args):
        return self.model(args)


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.num_branches = 50
        self.branches = [
            FiltrationBranch().to(device)
            for _ in range(self.num_branches)
        ]

        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.4), torch.nn.Linear(self.num_branches * 128 * 3, 256), torch.nn.BatchNorm1d(256), torch.nn.ReLU(),
            torch.nn.Dropout(0.3), torch.nn.Linear(256, 128), torch.nn.BatchNorm1d(128), torch.nn.ReLU(),
            torch.nn.Dropout(0.2), torch.nn.Linear(128, 64), torch.nn.BatchNorm1d(64), torch.nn.ReLU(),
            torch.nn.Dropout(0.1), torch.nn.Linear(64, 32), torch.nn.BatchNorm1d(32), torch.nn.ReLU(),
            torch.nn.Linear(32, 10), torch.nn.Softmax(dim = 1)
        ).to(device)
    
    def forward(self, args):
        features = [ ]
        for i in range(0, len(args), 6):
            features.append(self.branches[i // 6](args[i:i + 6]))
        return self.classifier(torch.cat(features, dim = 1))

In [113]:
import sklearn.metrics
import cvtda.utils
cvtda.utils.set_random_seed(42)

train_dl = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(train_labels, device = device, dtype = torch.long),
        *train_diagrams
    ),
    batch_size = 64,
    shuffle = True
)

test_dl = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(test_labels, device = device, dtype = torch.long),
        *test_diagrams
    ),
    batch_size = 64,
    shuffle = False
)

classifier = Net()

optimizer = torch.optim.AdamW(
    params = classifier.parameters(),
    lr = 1e-3
)

for _ in range(100):
    sum_loss = 0

    classifier.train()
    for (y, *args) in tqdm.tqdm(train_dl):
        optimizer.zero_grad()
        pred = classifier(args)

        loss = torch.nn.functional.cross_entropy(pred, y, reduction = 'sum')
        loss.backward()
        optimizer.step()
        sum_loss += loss.item()
    postfix = { 'loss': sum_loss }

    classifier.eval()
    with torch.no_grad():
        preds = [ ]
        real = [ ]
        for (y, *args) in tqdm.tqdm(test_dl):
            proba = classifier(args)
            preds.extend(torch.argmax(proba, axis = 1).cpu().detach())
            real.extend(y.cpu())

        postfix['val_acc'] = sklearn.metrics.accuracy_score(preds, real)

    print(postfix)

100%|██████████| 782/782 [06:11<00:00,  2.10it/s]
100%|██████████| 157/157 [00:22<00:00,  7.04it/s]


{'loss': 106884.3097820282, 'val_acc': 0.3993}


100%|██████████| 782/782 [06:13<00:00,  2.09it/s]
100%|██████████| 157/157 [00:23<00:00,  6.72it/s]


{'loss': 103433.46039962769, 'val_acc': 0.4275}


100%|██████████| 782/782 [06:08<00:00,  2.12it/s]
100%|██████████| 157/157 [00:21<00:00,  7.21it/s]


{'loss': 102691.93363571167, 'val_acc': 0.4386}


  6%|▌         | 46/782 [00:22<06:01,  2.04it/s]


KeyboardInterrupt: 