In [1]:
import logging

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision

from cycler import cycler
from PIL import Image
from torchvision import datasets, transforms

logging.getLogger().setLevel(logging.INFO)

In [None]:
class MLP(nn.Module):
    # layer_sizes[0] is the dimension of the input
    # layer_sizes[-1] is the dimension of the output
    def __init__(self, layer_sizes, final_relu=False):
        super().__init__()
        layer_list = []
        layer_sizes = [int(x) for x in layer_sizes]
        num_layers = len(layer_sizes) - 1
        final_relu_layer = num_layers if final_relu else num_layers - 1
        for i in range(len(layer_sizes) - 1):
            input_size = layer_sizes[i]
            curr_size = layer_sizes[i + 1]
            if i < final_relu_layer:
                layer_list.append(nn.ReLU(inplace=False))
            layer_list.append(nn.Linear(input_size, curr_size))
        self.net = nn.Sequential(*layer_list)
        self.last_linear = self.net[-1]

    def forward(self, x):
        return self.net(x)

In [None]:
from copy import copy
import os
import re
from torchvision.datasets import CelebA
from torch.utils.data import default_collate

class CelebAPositives(CelebA):
    def __getitem__(self, index: int):
        X = Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

        target = []
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            elif 'augmented positives' in t:
                target.append(index)
            else:
                # TODO: refactor with utils.verify_str_arg
                raise ValueError(f'Target type "{t}" is not recognized.')

        if target:
            target = tuple(target) if len(target) > 1 else target[0]

            if self.target_transform is not None:
                target = self.target_transform(target)
        else:
            target = None

        for t in self.target_type:
            if 'augmented positives' in t:
                assert self.transform is not None
                num_augs = re.search('\d+', t)
                assert num_augs
                assert num_augs.group(0)
                num_augs = int(num_augs.group(0))
                X_list = []
                target_list = []
                for _ in range(num_augs):
                    X_list.append(self.transform(X))
                    target_list.append(copy(target))
                return X_list, target_list

        if self.transform is not None:
            X = self.transform(X)
        return [X], [target]

def unwrap_collate(lst):
    return default_collate(sum(lst))

In [2]:
!mamba install pytorch-metric-learning -c conda-forge -c pytorch -y


                  __    __    __    __
                 /  \  /  \  /  \  /  \
                /    \/    \/    \/    \
███████████████/  /██/  /██/  /██/  /████████████████████████
              /  / \   / \   / \   / \  \____
             /  /   \_/   \_/   \_/   \    o \__,
            / _/                       \_____/  `
            |/
        ███╗   ███╗ █████╗ ███╗   ███╗██████╗  █████╗
        ████╗ ████║██╔══██╗████╗ ████║██╔══██╗██╔══██╗
        ██╔████╔██║███████║██╔████╔██║██████╔╝███████║
        ██║╚██╔╝██║██╔══██║██║╚██╔╝██║██╔══██╗██╔══██║
        ██║ ╚═╝ ██║██║  ██║██║ ╚═╝ ██║██████╔╝██║  ██║
        ╚═╝     ╚═╝╚═╝  ╚═╝╚═╝     ╚═╝╚═════╝ ╚═╝  ╚═╝

        mamba (0.25.0) supported by @QuantStack

        GitHub:  https://github.com/mamba-org/mamba
        Twitter: https://twitter.com/QuantStack

█████████████████████████████████████████████████████████████


Looking for: ['pytorch-metric-learning']

conda-forge/osx-arm64                       

In [None]:
import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as logging_presets
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
import umap

In [None]:
record_keeper, _, _ = logging_presets.get_record_keeper(
    "contrastive_resnet_logs", "example_tensorboard"
)
hooks = logging_presets.get_hook_container(record_keeper)
dataset_dict = {"val": CelebAPositives(
    root='/home/anspiridonov/TopologicalRegularization/data/celeba',
    split='valid',
    target_type=['augmented positives 1'],
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
    ]),
)}
model_folder = "example_saved_models"


def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname, *args):
    logging.info(
        "UMAP plot for the {} split and label set {}".format(split_name, keyname)
    )
    label_set = np.unique(labels)
    num_classes = len(label_set)
    plt.figure(figsize=(20, 15))
    plt.gca().set_prop_cycle(
        cycler(
            "color", [plt.cm.nipy_spectral(i) for i in np.linspace(0, 0.9, num_classes)]
        )
    )
    for i in range(num_classes):
        idx = labels == label_set[i]
        plt.plot(umap_embeddings[idx, 0], umap_embeddings[idx, 1], ".", markersize=1)
    plt.show()


# Create the tester
tester = testers.GlobalEmbeddingSpaceTester(
    end_of_testing_hook=hooks.end_of_testing_hook,
    visualizer=umap.UMAP(),
    visualizer_hook=visualizer_hook,
    dataloader_num_workers=2,
    accuracy_calculator=AccuracyCalculator(k="max_bin_count"),
)

end_of_epoch_hook = hooks.end_of_epoch_hook(
    tester, dataset_dict, model_folder, test_interval=1, patience=1
)

### just metric learning

#### augmantation-based metric learning

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set trunk model and replace the softmax layer with an identity function
trunk = torchvision.models.resnet18(pretrained=True)
trunk_output_size = trunk.fc.in_features
trunk.fc = nn.Identity()
trunk = trunk.to(device)

# Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
embedder = MLP([trunk_output_size, 64]).to(device)

# Set optimizers
trunk_optimizer = torch.optim.AdamW(trunk.parameters(), lr=0.00001, weight_decay=0.0001)
embedder_optimizer = torch.optim.AdamW(
    embedder.parameters(), lr=0.0001, weight_decay=0.0001
)

### metric learning with classifier

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set trunk model and replace the softmax layer with an identity function
trunk = torchvision.models.resnet18(pretrained=True)
trunk_output_size = trunk.fc.in_features
trunk.fc = nn.Identity()
trunk = trunk.to(device)

# Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
embedder = MLP([trunk_output_size, 64]).to(device)

# Set optimizers
trunk_optimizer = torch.optim.AdamW(trunk.parameters(), lr=0.00001, weight_decay=0.0001)
embedder_optimizer = torch.optim.AdamW(
    embedder.parameters(), lr=0.0001, weight_decay=0.0001
)

In [None]:
models = {"trunk": trunk, "embedder": embedder}
optimizers = {
    "trunk_optimizer": trunk_optimizer,
    "embedder_optimizer": embedder_optimizer,
}
loss = losses.TripletMarginLoss(margin=0.1)
loss_funcs = {"metric_loss": loss}

In [None]:
train_dataset = CelebAPositives(
    root='/home/anspiridonov/TopologicalRegularization/data/celeba',
    split='valid',
    target_type=['augmented positives 1'],
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ColorJitter(brightness=0.5, hue=0.3, saturation=0.1),
        transforms.GaussianBlur(kernel_size=5, sigma=(0.01, 0.5)),
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.0)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
    ])
)

trainer = trainers.MetricLossOnly(
    models,
    optimizers,
    64,
    loss_funcs,
    train_dataset,
    dataloader_num_workers=2,
    end_of_iteration_hook=hooks.end_of_iteration_hook,
    end_of_epoch_hook=end_of_epoch_hook,
)

### just classifier

### autoencoder-like pretraining