In [1]:
import os
import pathlib
import numpy as np
import torch
import torchvision.transforms as TF
from PIL import Image
from scipy import linalg
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.nn.functional import adaptive_avg_pool2d
from torch.utils.data import DataLoader, Dataset

In [2]:
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/transformer_generate')

In [3]:
from inception import InceptionV3
from utils import top_k_logits, sample
from transformer_model import Autoregressive_GPT
from vqvae import VQVAE
from fid_utils import compute_fid_by_path

In [4]:
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
block_idx = 3
num_workers = 1
inception_model_path = '/content/drive/MyDrive/Colab Notebooks/transformer_generate/pt_inception-2015-12-05-6726825d.pth'
inception_model = InceptionV3([block_idx], inception_model_path).to(device)
path1 = '/content/drive/MyDrive/data/cifar_raw/airplane'
path2 = '/content/drive/MyDrive/data/cifar_raw/airplane'
batch_size = 64

In [5]:
compute_fid_by_path(inception_model, path1, path2, batch_size, device, block_idx)



  0%|          | 0/1 [00:00<?, ?it/s]



  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


  0%|          | 0/1 [00:00<?, ?it/s]

-0.00016024861014329872

In [6]:
class Cifar10_code_Dataset(Dataset):
    def __init__(self, save_path, train=True):
        if train:
            self.bottom_feature = torch.load(save_path + '/train_bottom_result.pt', map_location=torch.device('cpu'))
            self.top_feature = torch.load(save_path + '/train_top_result.pt', map_location=torch.device('cpu'))
            self.label = torch.load(save_path + '/train_label.pt', map_location=torch.device('cpu'))
        else:
            self.top_feature = torch.load(save_path + '/train_top_result.pt', map_location=torch.device('cpu'))
            self.bottom_feature = torch.load(save_path + '/valid_bottom_result.pt', map_location=torch.device('cpu'))
            self.label = torch.load(save_path + '/valid_label.pt', map_location=torch.device('cpu'))
    def __len__(self):
        return self.label.shape[0]

    def __getitem__(self, index):
        return self.bottom_feature[index, :], self.top_feature[index, :], self.label[index]

In [7]:
save_path = '/content/drive/MyDrive/Colab Notebooks/vq-vae2/extracted_code_cifar10'
train_set = Cifar10_code_Dataset(save_path)
valid_set =  Cifar10_code_Dataset(save_path, False)
train_loader = DataLoader(train_set, batch_size=batch_size,num_workers=num_workers)
valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=num_workers)
_, top_feature_sample, _ = next(iter(train_loader))   

In [8]:
class Cifar10_code_Dataset_autoregressive(Dataset):
    def __init__(self, save_path, train=True):
        if train:
            self.bottom_feature = torch.load(save_path + '/train_bottom_result.pt', map_location=torch.device('cpu'))
            # self.top_feature = torch.load(save_path + '/train_top_result.pt', map_location=torch.device('cpu'))
            # self.label = torch.load(save_path + '/train_label.pt', map_location=torch.device('cpu'))
        else:
            self.bottom_feature = torch.load(save_path + '/valid_bottom_result.pt', map_location=torch.device('cpu'))
            # self.top_feature = torch.load(save_path + '/train_top_result.pt', map_location=torch.device('cpu'))
            # self.label = torch.load(save_path + '/valid_label.pt', map_location=torch.device('cpu'))
    def __len__(self):
        return self.bottom_feature.shape[0]

    def __getitem__(self, index):
        bottom_feature = torch.flatten(self.bottom_feature[index])
        return bottom_feature[:-1], bottom_feature[1:]

In [9]:
train_set_autoregressive = Cifar10_code_Dataset_autoregressive(save_path)
valid_set_autoregressive =  Cifar10_code_Dataset_autoregressive(save_path, False)
train_loader_autoregressive = DataLoader(train_set, batch_size=batch_size,num_workers=num_workers)
valid_loader_autoregressive = DataLoader(valid_set, batch_size=batch_size, num_workers=num_workers)

In [10]:
transformer_model = Autoregressive_GPT(vocab_size=512, seq_len=8*8-1, attn_pdrop = 0.2, resid_pdrop = 0.2, embd_pdrop = 0.2).to(device)
PATH = '/content/drive/MyDrive/Colab Notebooks/vq-vae2/models/30_generating.pt'
transformer_model.load_state_dict(torch.load(PATH, map_location=device))

<All keys matched successfully>

In [11]:
vqvae_model = VQVAE().to(device)
PATH = '/content/drive/MyDrive/Colab Notebooks/transformer_generate/vqvae_300_model.pt'
vqvae_model.load_state_dict(torch.load(PATH, map_location=device))

<All keys matched successfully>

In [12]:
def generate_bottom_feature(transformer_model, train_set, n_samples):
    counts = torch.ones(512) # start counts as 1 not zero, this is called "smoothing"
    rp = torch.randperm(len(train_set))
    nest = 5000 # how many images to use for the estimation
    for i in range(nest):
        a, _ = train_set[int(rp[i])]
        t = a[0].item() # index of first token in the sequence
        counts[t] += 1
    prob = counts/counts.sum()
    start_pixel = np.random.choice(np.arange(512), size=(n_samples, 1), replace=True, p=prob)
    start_pixel = torch.from_numpy(start_pixel).to(device)
    pixels = sample(transformer_model, start_pixel, 8*8-1, temperature=1.0, sample=True, top_k=100)
    pixels = torch.reshape(pixels, (-1, 8, 8))
    return pixels

In [13]:
def reconstruct(vqvae_model, top_features, bottom_features, save_path=None):
    return vqvae_model.decode_code(top_features, bottom_features)

In [14]:
def generate_and_compute(transformer_model, vqvae_model, save_path, original_path, inception_model_path, batch_size, device, block_idx=3):
    top_features = top_feature_sample.to(torch.int).to(device)
    bottom_features = generate_bottom_feature(transformer_model, train_set_autoregressive, batch_size).to(torch.int).to(device)
    images = reconstruct(vqvae_model, top_features, bottom_features, save_path=None)
    return images
    # compute_fid_by_path(save_path, original_path, inception_model_path, batch_size, device)

In [15]:
images = generate_and_compute(transformer_model, vqvae_model, None, path1, inception_model_path, batch_size, device, block_idx=3)
print(images.shape)

torch.Size([64, 3, 32, 32])


In [28]:
save_path = '/content/drive/MyDrive/Colab Notebooks/transformer_generate'
torchvision.utils.save_image(
                images,
                save_path + f"/images/sample.png",
                nrow=8,
                normalize=True,
                range=(-1, 1),
            ) 



In [None]:
def cifar_all_loader(root, num_workers=2, batch_size=64):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
         torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ])
    train_set = torchvision.datasets.CIFAR10(root=root, train=True,
                                           download=False, transform=transform)
    test_set = torchvision.datasets.CIFAR10(root=root, train=False,
                                          download=False, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)
    return train_loader, test_loader

In [None]:
batch_size = 64
root = '/content/drive/MyDrive/data/cifar10'
num_workers = 2
test_batch_size = 4
train_loader, test_loader = cifar_all_loader(root, num_workers=2, batch_size=batch_size)
sample_size = 16

In [None]:
def reconstruct(test_loader, model, sample_size, device, save_path):
    for i, (img, _) in enumerate(test_loader):
        sample = img[:sample_size]
        sample = sample.to(device)
        with torch.no_grad():
            out, _ = model(sample)
            torchvision.utils.save_image(
                torch.cat([sample, out], 0),
                save_path + f"images/_{str(i).zfill(5)}.png",
                nrow=sample_size,
                normalize=True,
                range=(-1, 1),
            ) 
        if i >= 25:
            break