## Import Modules

In [1]:
import json
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "5" 
os.environ['http_proxy'] = '10.106.130.4:3128'
os.environ['https_proxy'] = '10.106.130.4:3128'
import time
import socket
from pathlib import Path
import cv2
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import MultivariateNormal
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from pixel_generator.mage import models_mage
from PIL import Image
from imagenet_clstolabel import IMGNET_CLASS2LABEL
from IPython.display import display
import lpips

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def viz_torchimage(image):
    image = torch.clamp(image, 0, 1)
    image_np = image.detach().cpu().numpy().transpose([1, 2, 0])
    image_np = Image.fromarray(np.uint8(image_np*255))
    display(image_np)

<-- ----------------------------------------------------------------------- -->

# Large version inference

## 1. Load pre-trained encoder, RDM and MAGE

In [3]:
# Initialize RCG-L
class_cond = True
if class_cond:
    rdm_ckpt_path = 'final_ckpts/rdm-mocov3vitl-clscond.pth'
    rdm_cfg = 'config/rdm/mocov3vitl_simplemlp_l12_w1536_classcond.yaml'
else:
    rdm_ckpt_path = 'final_ckpts/rdm-mocov3vitl.pth'
    rdm_cfg = 'config/rdm/mocov3vitl_simplemlp_l12_w1536.yaml'

# minority guidance args
mg_kwargs = {
    'use_ms_grad': True,
    'norm_for_mg': 2.0,
    't_mid': -1.0,  # -1.0 for no early stop
    'mg_scale': 0.15,
    'p_ratio': 0.5,
    'num_mc_samples': 1,
    'mg_scale_type': 'var',
    'use_normed_grad': True,
    'use_lpips': False,
    'inter_rate': 1,
}
if mg_kwargs['use_lpips']:
    loss_lpips = lpips.LPIPS(net='alex').cuda()
    mg_kwargs['loss_lpips'] = loss_lpips
model = models_mage.mage_vit_large_patch16(mask_ratio_mu=0.75, mask_ratio_std=0.25,
                                           mask_ratio_min=0.5, mask_ratio_max=1.0,
                                           vqgan_ckpt_path='vqgan-ckpts/vqgan_jax_strongaug.ckpt',
                                           use_rep=True, rep_dim=256, rep_drop_prob=0.1,
                                           use_class_label=False,
                                           pretrained_enc_arch='mocov3_vit_large',
                                           pretrained_enc_path='pretrained_enc_ckpts/mocov3/vitl.pth.tar',
                                           pretrained_enc_proj_dim=256,
                                           pretrained_enc_withproj=True,
                                           pretrained_rdm_ckpt=rdm_ckpt_path,
                                           pretrained_rdm_cfg=rdm_cfg,
                                           mg_kwargs=mg_kwargs)
checkpoint = torch.load(os.path.join('final_ckpts/mage-l.pth'), map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=True)
model.cuda()
_ = model.eval()

Use representation as condition!
Loading model from final_ckpts/rdm-mocov3vitl-clscond.pth
RDM: Running in x0-prediction mode
DiffusionWrapper has 72.18 M params.
Keeping EMAs of 156.
Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Strict load
Restored from vqgan-ckpts/vqgan_jax_strongaug.ckpt


### 2. Image Generation

In [None]:
# torch.manual_seed(7)
# np.random.seed(7)

n_image_to_gen = 2
rdm_steps = 250
rdm_eta = 1.0
mage_temp = 11.0
mage_steps = 20
cfg = 6.0  # 6.0

if class_cond:
    for class_label in [1, 323, 985]:
        print("{}: {}".format(class_label, IMGNET_CLASS2LABEL[class_label]))
        class_label = class_label * torch.ones(1).cuda().long()
        for i in range(n_image_to_gen):
            gen_images, _ = model.gen_image(1, num_iter=mage_steps, choice_temperature=mage_temp, sampled_rep=None, rdm_steps=rdm_steps, eta=rdm_eta, cfg=cfg, class_label=class_label)
            viz_torchimage(gen_images[0])

else:
    for i in range(n_image_to_gen):
        gen_images, _ = model.gen_image(1, num_iter=mage_steps, choice_temperature=mage_temp, sampled_rep=None, rdm_steps=rdm_steps, eta=rdm_eta, cfg=cfg, class_label=None)
        viz_torchimage(gen_images[0])

In [8]:
# get train loader
torch.manual_seed(0)
np.random.seed(0)
bsz = 1
transform = transforms.Compose([
            transforms.Resize(256, interpolation=3),
            transforms.CenterCrop(256),
            transforms.ToTensor()])

dataset = datasets.ImageFolder('/root/datasets/imagenet/imagenet1k_train/', transform=transform)
print('num_data:', len(dataset))
print('num_classes:', len(dataset.classes))

num_data: 1281167
num_classes: 1000


In [10]:
label_dict = dataset.class_to_idx
f1 = ['n01498041', 'n01514859', 'n01582220', 'n01608432', 'n01616318',
          'n01687978', 'n01776313', 'n01806567', 'n01833805', 'n01882714',
          'n01910747', 'n01944390', 'n01985128', 'n02007558', 'n02071294',
          'n02085620', 'n02114855', 'n02123045', 'n02128385', 'n02129165',
          'n02129604', 'n02165456', 'n02190166', 'n02219486', 'n02226429',
          'n02279972', 'n02317335', 'n02326432', 'n02342885', 'n02363005',
          'n02391049', 'n02395406', 'n02403003', 'n02422699', 'n02442845',
          'n02444819', 'n02480855', 'n02510455', 'n02640242', 'n02672831',
          'n02687172', 'n02701002', 'n02730930', 'n02769748', 'n02782093',
          'n02787622', 'n02793495', 'n02799071', 'n02802426', 'n02814860',
          'n02840245', 'n02906734', 'n02948072', 'n02980441', 'n02999410',
          'n03014705', 'n03028079', 'n03032252', 'n03125729', 'n03160309',
          'n03179701', 'n03220513', 'n03249569', 'n03291819', 'n03384352',
          'n03388043', 'n03450230', 'n03481172', 'n03594734', 'n03594945',
          'n03627232', 'n03642806', 'n03649909', 'n03661043', 'n03676483',
          'n03724870', 'n03733281', 'n03759954', 'n03761084', 'n03773504',
          'n03804744', 'n03916031', 'n03938244', 'n04004767', 'n04026417',
          'n04090263', 'n04133789', 'n04153751', 'n04296562', 'n04330267',
          'n04371774', 'n04404412', 'n04465501', 'n04485082', 'n04507155',
          'n04536866', 'n04579432', 'n04606251', 'n07714990', 'n07745940']
imgnet100_label = [label_dict[i] for i in f1]
print(imgnet100_label)

[6, 8, 18, 21, 23, 42, 78, 85, 94, 105, 107, 113, 124, 130, 148, 151, 272, 281, 288, 291, 292, 301, 308, 310, 311, 323, 327, 331, 333, 337, 340, 341, 345, 352, 357, 360, 366, 388, 394, 401, 403, 407, 411, 414, 417, 420, 425, 429, 430, 437, 446, 462, 470, 483, 488, 492, 497, 498, 516, 525, 526, 538, 541, 549, 561, 562, 578, 587, 608, 609, 616, 620, 621, 624, 629, 643, 646, 650, 651, 657, 677, 711, 721, 742, 748, 764, 774, 783, 819, 827, 843, 851, 866, 872, 879, 889, 902, 913, 937, 949]
