In [2]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn as nn


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1)
        self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1)
        self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1)
        self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1)
        self.l1 = nn.Linear(512*20*20, 64)

        self.b1 = nn.BatchNorm2d(128)
        self.b2 = nn.BatchNorm2d(256)
        self.b3 = nn.BatchNorm2d(512)

    def forward(self, x):
        h = F.relu(self.c0(x))
        features = F.relu(self.b1(self.c1(h)))
        h = F.relu(self.b2(self.c2(features)))
        h = F.relu(self.b3(self.c3(h)))
        encoded = self.l1(h.view(x.shape[0], -1))
        return encoded, features

class GlobalDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.c0 = nn.Conv2d(128, 64, kernel_size=3)
        self.c1 = nn.Conv2d(64, 32, kernel_size=3)
        self.l0 = nn.Linear(32 * 22 * 22 + 64, 512)
        self.l1 = nn.Linear(512, 512)
        self.l2 = nn.Linear(512, 1)

    def forward(self, y, M):
        h = F.relu(self.c0(M))
        h = self.c1(h)
        h = h.view(y.shape[0], -1)
        h = torch.cat((y, h), dim=1)
        h = F.relu(self.l0(h))
        h = F.relu(self.l1(h))
        return self.l2(h)

class LocalDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.c0 = nn.Conv2d(192, 512, kernel_size=1)
        self.c1 = nn.Conv2d(512, 512, kernel_size=1)
        self.c2 = nn.Conv2d(512, 1, kernel_size=1)

    def forward(self, x):
        h = F.relu(self.c0(x))
        h = F.relu(self.c1(h))
        return self.c2(h)

class PriorDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.l0 = nn.Linear(64, 1000)
        self.l1 = nn.Linear(1000, 200)
        self.l2 = nn.Linear(200, 1)

    def forward(self, x):
        h = F.relu(self.l0(x))
        h = F.relu(self.l1(h))
        return torch.sigmoid(self.l2(h))

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(64, 15)
        self.bn1 = nn.BatchNorm1d(15)
        self.l2 = nn.Linear(15, 10)
        self.bn2 = nn.BatchNorm1d(10)
        self.l3 = nn.Linear(10, 10)
        self.bn3 = nn.BatchNorm1d(10)

    def forward(self, x):
        encoded, _ = x[0], x[1]
        clazz = F.relu(self.bn1(self.l1(encoded)))
        clazz = F.relu(self.bn2(self.l2(clazz)))
        clazz = F.softmax(self.bn3(self.l3(clazz)), dim=1)
        return clazz

class DeepInfoAsLatent(nn.Module):
    def __init__(self, run, epoch):
        super().__init__()
        model_path = Path(r'c:/data/deepinfomax/models') / Path(str(run)) / Path('encoder' + str(epoch) + '.wgt')
        self.encoder = Encoder()
        self.encoder.load_state_dict(torch.load(str(model_path)))
        self.classifier = Classifier()

    def forward(self, x):
        z, features = self.encoder(x)
        z = z.detach()
        return self.classifier((z, features))

In [None]:
import torch
from models import Encoder, GlobalDiscriminator, LocalDiscriminator, PriorDiscriminator
from torchvision.datasets.cifar import CIFAR10
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm
from pathlib import Path
import statistics as stats
import argparse


class DeepInfoMaxLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=1.0, gamma=0.1):
        super().__init__()
        self.global_d = GlobalDiscriminator()
        self.local_d = LocalDiscriminator()
        self.prior_d = PriorDiscriminator()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

    def forward(self, y, M, M_prime):

        # see appendix 1A of https://arxiv.org/pdf/1808.06670.pdf

        y_exp = y.unsqueeze(-1).unsqueeze(-1)
        y_exp = y_exp.expand(-1, -1, 26, 26)

        y_M = torch.cat((M, y_exp), dim=1)
        y_M_prime = torch.cat((M_prime, y_exp), dim=1)

        Ej = -F.softplus(-self.local_d(y_M)).mean()
        Em = F.softplus(self.local_d(y_M_prime)).mean()
        LOCAL = (Em - Ej) * self.beta

        Ej = -F.softplus(-self.global_d(y, M)).mean()
        Em = F.softplus(self.global_d(y, M_prime)).mean()
        GLOBAL = (Em - Ej) * self.alpha

        prior = torch.rand_like(y)

        term_a = torch.log(self.prior_d(prior)).mean()
        term_b = torch.log(1.0 - self.prior_d(y)).mean()
        PRIOR = - (term_a + term_b) * self.gamma

        return LOCAL + GLOBAL + PRIOR

if __name__ == '__main__':

    #parser = argparse.ArgumentParser(description='DeepInfomax pytorch')
    #parser.add_argument('--batch_size', default=64, type=int, help='batch_size')
    #args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 64

    # image size 3, 32, 32 batch size must be an even numbershuffle must be True
    cifar_10_train_dt = CIFAR10(r'c:\data\tv',  download=True, transform=ToTensor())
    cifar_10_train_l  = DataLoader(cifar_10_train_dt, batch_size=batch_size, shuffle=True, drop_last=True,pin_memory=torch.cuda.is_available())

    encoder = Encoder().to(device)
    loss_fn = DeepInfoMaxLoss().to(device)
    optim = Adam(encoder.parameters(), lr=1e-4)
    loss_optim = Adam(loss_fn.parameters(), lr=1e-4)

    epoch_restart = 860
    root = Path(r'c:\data\deepinfomax\models\run5')

#     if epoch_restart is not None and root is not None:
#         enc_file = root / Path('encoder' + str(epoch_restart) + '.wgt')
#         loss_file = root / Path('loss' + str(epoch_restart) + '.wgt')
#         encoder.load_state_dict(torch.load(str(enc_file)))
#         loss_fn.load_state_dict(torch.load(str(loss_file)))

    for epoch in range(epoch_restart + 1, 1000):
        batch = tqdm(cifar_10_train_l, total=len(cifar_10_train_dt) // batch_size)
        train_loss = []
        for x, target in batch:
            x = x.to(device)

            optim.zero_grad()
            loss_optim.zero_grad()
            y, M = encoder(x)
            # rotate images to create pairs for comparison
            M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
            loss = loss_fn(y, M, M_prime)
            train_loss.append(loss.item())
            batch.set_description(str(epoch) + ' Loss: ' + str(stats.mean(train_loss[-20:])))
            loss.backward()
            optim.step()
            loss_optim.step()

        if epoch % 10 == 0:
            root = Path(r'c:\data\deepinfomax\models\run5')
            enc_file = root / Path('encoder' + str(epoch) + '.wgt')
            loss_file = root / Path('loss' + str(epoch) + '.wgt')
            enc_file.parent.mkdir(parents=True, exist_ok=True)
            torch.save(encoder.state_dict(), str(enc_file))
            torch.save(loss_fn.state_dict(), str(loss_file))

Files already downloaded and verified


861 Loss: 1.079176139831543: 100%|█████████████████████████████████████████████████████████| 781/781 [02:30<00:00,  5.92it/s]
862 Loss: 0.9797142416238784: 100%|████████████████████████████████████████████████████████| 781/781 [02:12<00:00,  5.83it/s]
863 Loss: 0.9199020564556122: 100%|████████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.96it/s]
864 Loss: 0.8502464711666107: 100%|████████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.88it/s]
865 Loss: 0.7808394730091095: 100%|████████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.92it/s]
866 Loss: 0.7401991724967957: 100%|████████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.92it/s]
867 Loss: 0.7200962513685226: 100%|████████████████████████████████████████████████████████| 781/781 [02:12<00:00,  5.95it/s]
868 Loss: 0.652477577328682: 100%|█████████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.

926 Loss: 0.21468398347496986: 100%|███████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.92it/s]
927 Loss: 0.20089171305298806: 100%|███████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.95it/s]
928 Loss: 0.22822452411055566: 100%|███████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.89it/s]
929 Loss: 0.2273866802453995: 100%|████████████████████████████████████████████████████████| 781/781 [02:11<00:00,  5.93it/s]
930 Loss: 0.23043190017342569:  64%|███████████████████████████████████▏                   | 500/781 [01:24<00:47,  5.88it/s]