# Deep InfoMax representation learning for images

In [None]:
import sys
sys.path.append("../python")

In [None]:
import torch
import torchkld
import torchvision

In [None]:
import infomax

In [None]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
#device = "cpu"
print("Device: " + device)
print(f"Devices count: {torch.cuda.device_count()}")
print(f"CUDA version: {torch.version.cuda}")

In [None]:
from misc.modules import *
from misc.plots import *
from misc.training import *

In [None]:
import os
from pathlib import Path

path = Path("../../data/").resolve()
experiments_path = path / "embeddings/CIFAR10/"
#models_path = experiments_path / "models/"
#results_path = experiments_path / "resuts/"

In [None]:
config = {}

## Data

In [None]:
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, ImageNet
from torchvision.models import resnet18, resnet50

In [None]:
image_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    #torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

In [None]:
train_dataset = CIFAR10(root="./.cache", download=True, transform=image_transform)
test_dataset = CIFAR10(root="./.cache", download=True, transform=image_transform, train=False)

In [None]:
config["n_classes"] = 10

In [None]:
config["batch_size_train"] = 256
config["batch_size_test"]  = 256

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config["batch_size_train"], shuffle=True)
test_dataloader  = torch.utils.data.DataLoader(test_dataset, batch_size=config["batch_size_test"], shuffle=False)
eval_dataloader  = test_dataloader #torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=False)

## Model

In [None]:
config["embedding_dim"] = 256
config["discriminator_network_inner_dim"] = 256

config["distribution"] = "normal"
#config["distribution"] = "uniform"

In [None]:
embedder_network = torch.nn.Sequential(
    resnet18(num_classes=config["embedding_dim"]),
    (torch.nn.BatchNorm1d(config["embedding_dim"], affine=False) if config["distribution"] == "normal" else torch.nn.Sigmoid())
).to(device)
embedder_network.embedding_dim = config["embedding_dim"]

discriminator_network = BasicDenseT(config["embedding_dim"], config["embedding_dim"], inner_dim=config["discriminator_network_inner_dim"]).to(device)

In [None]:
config["input_p"]  = 1.0e-1
config["output_p"] = 5.0e-1

In [None]:
model = infomax.embeddings.Embedder(
    embedder_network,
    discriminator_network,
    infomax.channels.BoundedVarianceGaussianChannel(config["input_p"]),
    infomax.channels.BoundedVarianceGaussianChannel(config["output_p"]) if config["distribution"] == "normal" else infomax.channels.BoundedSupportGaussianChannel(config["output_p"]),
).to(device)

In [None]:
import math

config["capacity"] = config["embedding_dim"] * model.output_channel.capacity
config["min_capacity_for_classification"] = math.log(config["n_classes"])

print(f"Capacity: {config['capacity']:.2f}")
print(f"Min capacity required for class preservation: {config['min_capacity_for_classification']:.2f}")

In [None]:
config["n_epochs"] = 2001
config["embedder_network_lr"] = 1.0e-3
config["discriminator_network_lr"] = 1.0e-3

In [None]:
history = train_infomax_embedder(
    model,
    train_dataloader,
    test_dataloader,
    device,
    callback=lambda history, epoch, step, infomax_embedder, train_dataloader, test_dataloader, device: classification_callback(
        history, epoch, step, infomax_embedder, train_dataloader, test_dataloader, device,
        distribution_tests={},
        clustering_metrics={},
        classifiers={
            "logistic_regression": LogisticRegression,
        },
    ),
    embedder_network_lr=config["embedder_network_lr"],
    discriminator_network_lr=config["discriminator_network_lr"],
    distribution=config["distribution"],
    n_epochs=config["n_epochs"]
)

In [None]:
plot_embeddings(*convert_to_embeddings(embedder_network, train_dataloader, device), x_lim=(-3.0, 3.0), y_lim=(-3.0, 3.0))

In [None]:
plt.imshow(model.input_channel(next(enumerate(train_dataloader))[1][0].to(device)).cpu().numpy()[0][0])

In [None]:
from pathlib import Path

save_results(model, config, history, experiments_path)