# EMNIST

In [1]:
%load_ext autoreload
%autoreload 2
# Import dependencies
import os
import importlib
import numpy as np
import matplotlib.pyplot as plt
import time
import torchvision
import seaborn as sns
import numpy as np
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch
from tqdm import tqdm
from datetime import datetime
import yaml
import h5py
from copy import deepcopy
import ambiguous.models.cvae
from ambiguous.models.cvae import *
from ambiguous.dataset.dataset import partition_dataset
import wandb
from pytorch_lightning.loggers import WandbLogger

In [2]:
TRAIN_CVAE = False
LOCAL_CKPT = False

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Seed for reproducibility
torch.manual_seed(42)
print(device)
root='/home/mila/n/nizar.islah/expectation-clamp/'

cuda


In [32]:
transform = transforms.Compose([
    transforms.ToTensor(),
    lambda x: x.rot90(1,[2,1]).flip(2)
])
dataset = datasets.EMNIST(root=root, download=False, split='byclass', train=True, transform=transform)
new_dataset = partition_dataset(dataset, range(10, 36))
train_set, val_set = torch.utils.data.random_split(new_dataset, [round(0.8*len(new_dataset)), round(0.2*len(new_dataset))])
test_set = datasets.EMNIST(root=root, download=False, split='byclass', train=False, transform=transform)
# Dataloaders
batch_size = 64
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=2, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=2)

In [6]:
x,t=next(iter(train_loader))
torchvision.utils.save_image(x, 'image.pdf', nrow=8)

['h', 'f', 'm', 'x', 'm', 'c', 'c', 'e', 'a', 't', 'r', 'p', 'k', 'c', 'o', 'a', 'c', 'i', 'a', 'n', 'r', 'n', 'o', 'w', 'i', 't', 'o', 'c', 's', 'd', 'o', 'c', 'b', 'p', 'z', 'o', 'f', 'i', 'i', 'p', 'u', 'e', 'h', 's', 'x', 'm', 'z', 'a', 'o', 'd', 'm', 'w', 'h', 'u', 'o', 'd', 's', 't', 'h', 'w', 'y', 'r', 's', 's']


In [19]:
enc_layers = [28*28, 1024, 512]
dec_layers = [512, 1024, 28*28]
latent_dim = 20
if TRAIN_CVAE:
    wandb.init(project="EMNIST_CVAE", entity="team-nizar")
    wandb_logger = WandbLogger(project="EMNIST_CVAE", log_model="all")
    model = EMNIST_CVAE(latent_dim, enc_layers, dec_layers, n_classes=26, conditional=True).to(device)
    trainer = pl.Trainer(gpus=1,logger=wandb_logger,max_epochs=100,auto_lr_find=True)
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
    torch.save(model.state_dict(), "emnist_cvae.pth")
    
else:
    if LOCAL_CKPT:
        ckpt_path = ""
        ckpt = torch.load(ckpt_path)
    else:
        run = wandb.init()
        artifact = run.use_artifact('team-nizar/EMNIST_CVAE/model-pkia9rhl:v29', type='model')
        artifact_dir = artifact.download()
        ckpt = torch.load(artifact_dir+'/model.ckpt')
        model = EMNIST_CVAE(latent_dim, enc_layers, dec_layers, n_classes=26, conditional=True).to(device)
        model.load_state_dict(ckpt['state_dict'])
    
model.eval()

EMNIST_CVAE(
  (encoder): EMNIST_Encoder(
    (MLP): Sequential(
      (L0): Linear(in_features=810, out_features=1024, bias=True)
      (A0): ReLU()
      (L1): Linear(in_features=1024, out_features=512, bias=True)
      (A1): ReLU()
    )
    (fc_mu): Linear(in_features=512, out_features=20, bias=True)
    (fc_logvar): Linear(in_features=512, out_features=20, bias=True)
  )
  (decoder): EMNIST_Decoder(
    (MLP): Sequential(
      (L0): Linear(in_features=46, out_features=512, bias=True)
      (A0): ReLU()
      (L1): Linear(in_features=512, out_features=1024, bias=True)
      (A1): ReLU()
      (L2): Linear(in_features=1024, out_features=784, bias=True)
      (sigmoid): Sigmoid()
    )
  )
)

In [20]:
def gen_imgs(model, n_samples, n_latent, n_classes=26):
    targets = torch.randint(0, n_classes, (n_samples,))
    c = torch.zeros(n_samples, n_classes).to(device)
    c[range(n_samples), targets] = 1
    z = torch.randn(n_samples, n_latent).to(device)
    rec = model.decoder(z, c).view(-1, 1, 28, 28)
    torchvision.utils.save_image(rec, "reconstruction.pdf", nrow=8)
    return

gen_imgs(model, 64, latent_dim)