# Deep InfoMax representation learning for images

In [None]:
import sys
sys.path.append("../python")

In [None]:
import torch
import torchkld
import torchvision

In [None]:
import infomax

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
#device = "cpu"
print("Device: " + device)
print(f"Devices count: {torch.cuda.device_count()}")
print(f"CUDA version: {torch.version.cuda}")

In [None]:
from misc.modules import *
from misc.plots import *
from misc.training import *

In [None]:
import os
from pathlib import Path

path = Path("../../data/").resolve()
experiments_path = path / "embeddings/CIFAR10/"
#models_path = experiments_path / "models/"
#results_path = experiments_path / "resuts/"

In [None]:
config = {}

## Data

In [None]:
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, ImageNet
from torchvision.models import resnet18, resnet50

In [None]:
image_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [None]:
#config["dataset"] = "MNIST"
config["dataset"] = "CIFAR10"
config["n_classes"] = 10

train_dataset = getattr(torchvision.datasets, config["dataset"])(root="./.cache", download=True, transform=image_transform)
test_dataset  = getattr(torchvision.datasets, config["dataset"])(root="./.cache", download=True, transform=image_transform, train=False)

In [None]:
config["batch_size_train"] = 512
config["batch_size_test"]  = 1024

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config["batch_size_train"], shuffle=True)
test_dataloader  = torch.utils.data.DataLoader(test_dataset, batch_size=config["batch_size_test"], shuffle=False)
eval_dataloader  = test_dataloader #torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=False)

## Model

In [None]:
config["distribution"] = "normal"
#config["distribution"] = "uniform"

config["embedding_dim"] = 2
normalization_layer = torch.nn.BatchNorm1d(config["embedding_dim"], affine=False) if config["distribution"] == "normal" else torch.nn.Sigmoid()

In [None]:
config["backbone"] = "resnet18"

if config["backbone"] == "convnet":
    backbone = Conv2dEmbedder(embedding_dim=config["embedding_dim"])
else:
    backbone = getattr(torchvision.models, config["backbone"])(num_classes=config["embedding_dim"]).train()
    
    if config["dataset"] in ["CIFAR10", "CIFAR100"]:
        backbone.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
        backbone.maxpool = torch.nn.Identity()

In [None]:
embedder_network = torch.nn.Sequential(
    backbone,
    normalization_layer
).to(device)
embedder_network.embedding_dim = config["embedding_dim"]

In [None]:
config["discriminator_network"] = "DenseT"
config["discriminator_network_inner_dim"] = 256
config["discriminator_network_output_dim"] = 256

_discriminator_network_factory = {
    "SeparableT": lambda: SeparableT(
        config["embedding_dim"],
        config["embedding_dim"],
        inner_dim=config["discriminator_network_inner_dim"],
        output_dim=config["discriminator_network_output_dim"],
    ).to(device),
    "DenseT": lambda: DenseT(
        config["embedding_dim"],
        config["embedding_dim"],
        inner_dim=config["discriminator_network_inner_dim"]
    ).to(device),
    "AdditiveGaussainT": lambda: AdditiveGaussainT(p=0.99).to(device)
}

discriminator_network = _discriminator_network_factory[config["discriminator_network"]]()

In [None]:
config["input_p"]  = 2.0e-1
config["output_p"] = 1.0e-1

model = infomax.embeddings.Embedder(
    embedder_network,
    discriminator_network,
    infomax.channels.BoundedVarianceGaussianChannel(config["input_p"]),
    #torchvision.transforms.Compose([
    #    torchvision.transforms.RandomResizedCrop((32, 32), scale=(0.2, 1.)),
    #    torchvision.transforms.RandomHorizontalFlip(),
    #    torchvision.transforms.RandomApply([
    #        torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
    #        #torchvision.transforms.ColorJitter(0.5, 0.5, 0.5, 0.5)  # strengthened
    #    ], p=0.8),
    #    torchvision.transforms.RandomGrayscale(p=0.2),
    #    #infomax.channels.BoundedVarianceGaussianChannel(config["input_p"]).to(device)
    #    #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #]),
    infomax.channels.BoundedVarianceGaussianChannel(config["output_p"]) if config["distribution"] == "normal" else infomax.channels.BoundedSupportUniformChannel(config["output_p"]),
).to(device)

In [None]:
import math

config["capacity"] = config["embedding_dim"] * model.output_channel.capacity
config["min_capacity_for_classification"] = math.log(config["n_classes"])

print(f"Capacity: {config['capacity']:.2f}")
print(f"Min capacity required for class preservation: {config['min_capacity_for_classification']:.2f}")

In [None]:
config["n_epochs"] = 2001
config["embedder_network_lr"] = 1.0e-3
config["discriminator_network_lr"] = 1.0e-3

config["loss"] = "InfoNCELoss"
config["marginalize"] = "product"

In [None]:
history = train_infomax_embedder(
    model,
    train_dataloader,
    test_dataloader,
    device,
    callback=lambda history, epoch, step, infomax_embedder, train_dataloader, test_dataloader, device: classification_callback(
        history, epoch, step, infomax_embedder, train_dataloader, test_dataloader, device,
        #period=20,
        #distribution_tests={},
        #clustering_metrics={},
        #classifiers={
        #    "logistic_regression": lambda: DenseClassifier(config["embedding_dim"], config["n_classes"], device).to(device),
        #    #"mlp": lambda: DenseClassifier(config["embedding_dim"], config["n_classes"], device, n_layers=3).to(device),
        #    #"knn": lambda: KNeighborsClassifier(metric='cosine'),
        #    #"mlp": lambda: MLPClassifier(alpha=1.0, max_iter=1000),
        #},
    ),
    optimizer_embedder_network=lambda params: torch.optim.Adam(params, lr=config["embedder_network_lr"]),
    optimizer_discriminator_network=lambda params: torch.optim.Adam(params, lr=config["discriminator_network_lr"]),
    loss=getattr(torchkld.loss, config["loss"])(),
    marginalize=config["marginalize"],
    distribution=config["distribution"],
    n_epochs=config["n_epochs"]
)

In [None]:
plot_embeddings(*convert_to_embeddings(embedder_network, train_dataloader, device), x_lim=(-3.0, 3.0), y_lim=(-3.0, 3.0))

In [None]:
plt.imshow(model.input_channel(next(enumerate(train_dataloader))[1][0].to(device)).cpu().numpy()[0][0])

In [None]:
from pathlib import Path

save_results(model, config, history, experiments_path)