In [8]:
!pip install import-ipynb
import import_ipynb

from CVAE_ImbalanceGenerator_MNIST import *



In [9]:
!pip install matplotlib
!pip install sklearn



In [10]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import DataLoader, Subset
import torch.utils.data as data_utils
from sklearn.model_selection import train_test_split

In [11]:
device = torch.device("cpu")
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 32
ZDIM = 20
NUM_EPOCHS = 40

# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)   
torch.cuda.manual_seed(SEED)

In [12]:
trainloader_collection = [trainloader_mnist_1, trainloader_mnist_2, trainloader_mnist_3]
trainset_collection = [trainset_mnist_1, trainset_mnist_2, trainset_mnist_3]
settings = [0, 1, 2] #we have three different settings 

In [13]:
class CVAE(nn.Module):
    def __init__(self, zdim):
        super().__init__()
        self._zdim = zdim
        self._in_units = 28 * 28
        hidden_units = 512
        self._encoder = nn.Sequential(
            nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
        )
        self._to_mean = nn.Linear(hidden_units, zdim)
        self._to_lnvar = nn.Linear(hidden_units, zdim)
        self._decoder = nn.Sequential(
            nn.Linear(zdim + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, self._in_units),
            nn.Sigmoid()
        )

    def encode(self, x, labels):
        in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=device)
        in_[:, :self._in_units] = x
        in_[:, self._in_units:] = labels
        h = self._encoder(in_)
        mean = self._to_mean(h)
        lnvar = self._to_lnvar(h)
        return mean, lnvar

    def decode(self, z, labels):
        in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=device)
        in_[:, :self._zdim] = z
        in_[:, self._zdim:] = labels
        return self._decoder(in_)


def to_onehot(label):
    return torch.eye(CLASS_SIZE, device=device, dtype=torch.float32)[label]

In [14]:
model_1 = CVAE(ZDIM).to(device)
model_2 = CVAE(ZDIM).to(device)
model_3 = CVAE(ZDIM).to(device)

optimizer1 = optim.Adam(model_1.parameters(), lr=1e-3)
optimizer2 = optim.Adam(model_2.parameters(), lr=1e-3)
optimizer3 = optim.Adam(model_3.parameters(), lr=1e-3)

model_collection = [model_1, model_2, model_3]
opt_collection = [optimizer1, optimizer2, optimizer3]

for setting in zip(settings):
    setting = int(''.join(map(str, setting)))

    print('-------------------')
    if setting == 0:
        print('SETTING: Half-Split Imbalance')
    elif setting == 1:
        print('SETTING: MultiMajority')
    elif setting == 2:
        print('SETTING: MultiMinority')
    print('-------------------')

    model_collection[setting].train()
    for e in range(NUM_EPOCHS):
        train_loss = 0

        for i, (images, labels) in enumerate(trainloader_collection[setting]):
            labels = to_onehot(labels)
            # Reconstruction images
            # Encode images
            x = images.view(-1, 28*28*1).to(device)
            mean, lnvar = model_collection[setting].encode(x, labels)
            std = lnvar.exp().sqrt()
            epsilon = torch.randn(ZDIM, device=device)
        
            # Decode latent variables
            z = mean + std * epsilon
            y = model_collection[setting].decode(z, labels)
        
            # Compute loss
            kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
            bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
            loss = (-1 * kld + bce).mean()

            # Update model
            opt_collection[setting].zero_grad()
            loss.backward()
            opt_collection[setting].step()
            train_loss += loss.item() * x.shape[0]

        print(f'epoch: {e + 1} epoch_loss: {train_loss/len(trainset_collection[setting])}')
    print(f'Finished training for SETTING:{setting}')

-------------------
SETTING: Half-Split Imbalance
-------------------
epoch: 1 epoch_loss: 166.58161136881512
Finished training for SETTING:0
-------------------
SETTING: MultiMajority
-------------------
epoch: 1 epoch_loss: 156.54348042429513
Finished training for SETTING:1
-------------------
SETTING: MultiMinority
-------------------
epoch: 1 epoch_loss: 178.36226130059836
Finished training for SETTING:2


In [15]:
def image_generator(NUM, label_name, setting):
    model_collection[setting].eval()
    output_container = torch.tensor((), device=device)

    for i in range(NUM):
        z = torch.randn(ZDIM, device=device).unsqueeze(dim=0)
        label = torch.tensor([label_name], device=device)
        with torch.no_grad():
            y = model_collection[setting].decode(z, to_onehot(label))
            y = y.reshape(1, 1, 28, 28)

        output_container = torch.cat((output_container, y), 0)

    return output_container

In [16]:
def image_plotter(label_name, setting):
    image = image_generator(100, label_name, setting).cpu().detach().numpy()
    n = np.random.randint(1, 100)
    image = image[n].reshape(28, 28)
    fig1, (ax1)= plt.subplots(1, sharex = True, sharey = False)
    ax1.title.set_text(f'CVAE Reconstruction of the Class: {label_name} -- Setting: {setting}')
    ax1.imshow(image, interpolation ='none', aspect = 'auto')

In [29]:
## Setting One (Half-Split Imbalance)
CVAE_trainset_setting_one = torch.cat((image_generator(4000, 0, 0), image_generator(4000, 1, 0), image_generator(4000, 2, 0), 
                                       image_generator(4000, 3, 0), image_generator(4000, 4, 0)), 0)

CVAE_train_labels_setting_one = torch.cat((torch.tensor([0]*4000), torch.tensor([1]*4000), torch.tensor([2]*4000), torch.tensor([3]*4000), torch.tensor([4]*4000)), 0)
CVAE_setting_one_dataset = TensorDataset(CVAE_trainset_setting_one, CVAE_train_labels_setting_one)

CVAE_S1_MNIST_trainloader = DataLoader(CVAE_setting_one_dataset, batch_size=16, shuffle=True) 

In [18]:
## Setting Two (Multimajority)
CVAE_setting_two_dataset = TensorDataset(image_generator(5742, 9, 1), torch.tensor([9]*5742))
CVAE_S2_MNIST_trainloader = DataLoader(CVAE_setting_two_dataset, batch_size=16, shuffle=True)

In [19]:
## Setting Three (Multiminority)
CVAE_trainset_setting_three = torch.cat((image_generator(5742, 0, 2), image_generator(5742, 1, 2), image_generator(5742, 2, 2), 
                                       image_generator(5742, 3, 2), image_generator(5742, 4, 2), image_generator(5742, 5, 2), 
                                       image_generator(5742, 6, 2), image_generator(5742, 7, 2), image_generator(5742, 8, 2)),  0)

CVAE_train_labels_setting_three = torch.cat((torch.tensor([0]*5742), torch.tensor([1]*5742), torch.tensor([2]*5742), torch.tensor([3]*5742), torch.tensor([4]*5742),
                                             torch.tensor([5]*5742), torch.tensor([6]*5742), torch.tensor([7]*5742), torch.tensor([8]*5742)), 0)

CVAE_setting_three_dataset = TensorDataset(CVAE_trainset_setting_three, CVAE_train_labels_setting_three)

CVAE_S3_MNIST_trainloader = DataLoader(CVAE_setting_three_dataset, batch_size=16, shuffle=True)