In [14]:
from torch_geometric.loader import DataLoader, ImbalancedSampler
from torch_geometric.transforms import FaceToEdge, OneHotDegree
import torchvision.transforms as transforms
import torch.nn.functional as F

from mantra.simplicial import SimplicialDataset
from mantra.transforms import (
    TriangulationToFaceTransform,
    OrientableToClassTransform,
    DegreeTransform,
)
from validation.validate_homology import validate_betti_numbers

import torch

In [21]:
class NameToClass:
    def __init__(self):
        self.class_dict = {
            "Klein bottle": 0,
            "": 1,
            "RP^2": 2,
            "T^2": 3,
            "S^2": 4,
        }

    def __call__(self, data):
        data.y = F.one_hot(
            torch.tensor(self.class_dict[data.name]), num_classes=5
        )
        return data


tr = transforms.Compose(
    [
        TriangulationToFaceTransform(),
        FaceToEdge(remove_faces=False),
        DegreeTransform(),
        OrientableToClassTransform(),
        NameToClass(),
    ]
)

dataset = SimplicialDataset(root="./data", transform=tr)

print(
    f"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}"
)

Percentage: 0.27, 0.73


In [23]:
data = dataset[0]
data.y



tensor([0, 0, 0, 0, 1])

In [19]:
from collections import Counter

# Tally occurrences of words in a list
cnt = Counter()
for data in dataset:
    cnt[data.name] += 1

cnt

AttributeError: 'NoneType' object has no attribute 'name'

In [58]:
dataset = dataset.shuffle()

train_dataset = dataset[:-150]
test_dataset = dataset[-150:]

print(f"Number of training graphs: {len(train_dataset)}")
print(f"Number of test graphs: {len(test_dataset)}")

Number of training graphs: 562
Number of test graphs: 150


In [59]:
train_loader = DataLoader(
    train_dataset, batch_size=10
)  # ,sampler=ImbalancedSampler(train_dataset))
test_loader = DataLoader(test_dataset, batch_size=10)


for batch in train_loader:
    break

batch.y

tensor([0, 0, 0, 1, 1, 0, 0, 1, 0, 1])