In [None]:
!pip install pytorch-lightning-bolts

Collecting pytorch-lightning-bolts
[?25l  Downloading https://files.pythonhosted.org/packages/78/0a/be7648b4cc7c2197b108c630175a15d73ef522e6ed3b3c1ad436e4ac7f0c/pytorch_lightning_bolts-0.3.2-py3-none-any.whl (253kB)
[K     |████████████████████████████████| 256kB 19.0MB/s 
[?25hCollecting pytorch-lightning>=1.1.1
[?25l  Downloading https://files.pythonhosted.org/packages/07/0c/e2d52147ac12a77ee4e7fd7deb4b5f334cfb335af9133a0f2780c8bb9a2c/pytorch_lightning-1.2.10-py3-none-any.whl (841kB)
[K     |████████████████████████████████| 849kB 50.6MB/s 
Collecting torchmetrics>=0.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/99/dc59248df9a50349d537ffb3403c1bdc1fa69077109d46feaa0843488001/torchmetrics-0.3.1-py3-none-any.whl (271kB)
[K     |████████████████████████████████| 276kB 54.3MB/s 
Collecting fsspec[http]>=0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/e9/91/2ef649137816850fa4f4c97c6f2eabb1a79bf0aa2c8ed198e387e373455e/fsspec-2021.4.0-py3-none-an

In [None]:
import cv2
import os
import numpy as np
import random
#import cPickle as pickle
import pickle
import warnings
import argparse

random.seed(1234)
np.random.seed(1234)

train_size = 9800
test_size = 200
img_size = 75
size = 5
question_size = 18  ## 2 x (6 for one-hot vector of color), 3 for question type, 3 for question subtype
q_type_idx = 12
sub_q_type_idx = 15
"""Answer : [yes, no, rectangle, circle, r, g, b, o, k, y]"""

nb_questions = 10
dirs = './data'

colors = [
    (0,0,255),##r
    (0,255,0),##g
    (255,0,0),##b
    (0,156,255),##o
    (128,128,128),##k
    (0,255,255)##y
]


try:
    os.makedirs(dirs)
except:
    print('directory {} already exists'.format(dirs))

def center_generate(objects):
    while True:
        pas = True
        center = np.random.randint(0+size, img_size - size, 2)        
        if len(objects) > 0:
            for name,c,shape in objects:
                if ((center - c) ** 2).sum() < ((size * 2) ** 2):
                    pas = False
        if pas:
            return center



def build_dataset():
    objects = []
    img = np.ones((img_size,img_size,3)) * 255
    for color_id,color in enumerate(colors):  
        center = center_generate(objects)
        if random.random()<0.5:
            start = (center[0]-size, center[1]-size)
            end = (center[0]+size, center[1]+size)
            cv2.rectangle(img, start, end, color, -1)
            objects.append((color_id,center,'r'))
        else:
            center_ = (center[0], center[1])
            cv2.circle(img, center_, size, color, -1)
            objects.append((color_id,center,'c'))


    ternary_questions = []
    binary_questions = []
    norel_questions = []
    ternary_answers = []
    binary_answers = []
    norel_answers = []
    """Non-relational questions"""
    for _ in range(nb_questions):
        question = np.zeros((question_size))
        color = random.randint(0,5)
        question[color] = 1
        question[q_type_idx] = 1
        subtype = random.randint(0,2)
        question[subtype+sub_q_type_idx] = 1
        norel_questions.append(question)
        """Answer : [yes, no, rectangle, circle, r, g, b, o, k, y]"""
        if subtype == 0:
            """query shape->rectangle/circle"""
            if objects[color][2] == 'r':
                answer = 2
            else:
                answer = 3

        elif subtype == 1:
            """query horizontal position->yes/no"""
            if objects[color][1][0] < img_size / 2:
                answer = 0
            else:
                answer = 1

        elif subtype == 2:
            """query vertical position->yes/no"""
            if objects[color][1][1] < img_size / 2:
                answer = 0
            else:
                answer = 1
        norel_answers.append(answer)
    
    """Binary Relational questions"""
    for _ in range(nb_questions):
        question = np.zeros((question_size))
        color = random.randint(0,5)
        question[color] = 1
        question[q_type_idx+1] = 1
        subtype = random.randint(0,2)
        question[subtype+sub_q_type_idx] = 1
        binary_questions.append(question)

        if subtype == 0:
            """closest-to->rectangle/circle"""
            my_obj = objects[color][1]
            dist_list = [((my_obj - obj[1]) ** 2).sum() for obj in objects]
            dist_list[dist_list.index(0)] = 999
            closest = dist_list.index(min(dist_list))
            if objects[closest][2] == 'r':
                answer = 2
            else:
                answer = 3
                
        elif subtype == 1:
            """furthest-from->rectangle/circle"""
            my_obj = objects[color][1]
            dist_list = [((my_obj - obj[1]) ** 2).sum() for obj in objects]
            furthest = dist_list.index(max(dist_list))
            if objects[furthest][2] == 'r':
                answer = 2
            else:
                answer = 3

        elif subtype == 2:
            """count->1~6"""
            my_obj = objects[color][2]
            count = -1
            for obj in objects:
                if obj[2] == my_obj:
                    count +=1 
            answer = count+4

        binary_answers.append(answer)

    """Ternary Relational questions"""
    for _ in range(nb_questions):
        question = np.zeros((question_size))
        rnd_colors = np.random.permutation(np.arange(5))
        # 1st object
        color1 = rnd_colors[0]
        question[color1] = 1
        # 2nd object
        color2 = rnd_colors[1]
        question[6 + color2] = 1

        question[q_type_idx + 2] = 1
        
        subtype = random.randint(0, 2)

        question[subtype+sub_q_type_idx] = 1
        ternary_questions.append(question)

        # get coordiantes of object from question
        A = objects[color1][1]
        B = objects[color2][1]

        if subtype == 0:
            """between->1~4"""

            between_count = 0 
            # check is any objects lies inside the box
            for other_obj in objects:
                # skip object A and B
                if (other_obj[0] == color1) or (other_obj[0] == color2):
                    continue

                # Get x and y coordinate of third object
                other_objx = other_obj[1][0]
                other_objy = other_obj[1][1]

                if (A[0] <= other_objx <= B[0] and A[1] <= other_objy <= B[1]) or \
                   (A[0] <= other_objx <= B[0] and B[1] <= other_objy <= A[1]) or \
                   (B[0] <= other_objx <= A[0] and B[1] <= other_objy <= A[1]) or \
                   (B[0] <= other_objx <= A[0] and A[1] <= other_objy <= B[1]):
                    between_count += 1

            answer = between_count + 4
        elif subtype == 1:
            """is-on-band->yes/no"""
            
            grace_threshold = 12  # half of the size of objects
            epsilon = 1e-10  
            m = (B[1]-A[1])/((B[0]-A[0]) + epsilon ) # add epsilon to prevent dividing by zero
            c = A[1] - (m*A[0])

            answer = 1  # default answer is 'no'

            # check if any object lies on/close the line between object A and object B
            for other_obj in objects:
                # skip object A and B
                if (other_obj[0] == color1) or (other_obj[0] == color2):
                    continue

                other_obj_pos = other_obj[1]
                
                # y = mx + c
                y = (m*other_obj_pos[0]) + c
                if (y - grace_threshold)  <= other_obj_pos[1] <= (y + grace_threshold):
                    answer = 0
        elif subtype == 2:
            """count-obtuse-triangles->1~6"""

            obtuse_count = 0

            # disable warnings
            # the angle computation may fail if the points are on a line
            warnings.filterwarnings("ignore")
            for other_obj in objects:
                # skip object A and B
                if (other_obj[0] == color1) or (other_obj[0] == color2):
                    continue

                # get position of 3rd object
                C = other_obj[1]
                # edge length
                a = np.linalg.norm(B - C)
                b = np.linalg.norm(C - A)
                c = np.linalg.norm(A - B)
                # angles by law of cosine
                alpha = np.rad2deg(np.arccos((b ** 2 + c ** 2 - a ** 2) / (2 * b * c)))
                beta = np.rad2deg(np.arccos((a ** 2 + c ** 2 - b ** 2) / (2 * a * c)))
                gamma = np.rad2deg(np.arccos((a ** 2 + b ** 2 - c ** 2) / (2 * a * b)))
                max_angle = max(alpha, beta, gamma)
                if max_angle >= 90 and max_angle < 180:
                    obtuse_count += 1

            warnings.filterwarnings("default")
            answer = obtuse_count + 4

        ternary_answers.append(answer)

    ternary_relations = (ternary_questions, ternary_answers)
    binary_relations = (binary_questions, binary_answers)
    norelations = (norel_questions, norel_answers)
    
    img = img/255.
    dataset = (img, ternary_relations, binary_relations, norelations)
    return dataset


print('building test datasets...')
test_datasets = [build_dataset() for _ in range(test_size)]
print('building train datasets...')
train_datasets = [build_dataset() for _ in range(train_size)]


#img_count = 0
#cv2.imwrite(os.path.join(dirs,'{}.png'.format(img_count)), cv2.resize(train_datasets[0][0]*255, (512,512)))


print('saving datasets...')
filename = os.path.join(dirs,'sort-of-clevr.pickle')
with  open(filename, 'wb') as f:
    pickle.dump((train_datasets, test_datasets), f)
print('datasets saved at {}'.format(filename))

building test datasets...
building train datasets...
saving datasets...
datasets saved at ./data/sort-of-clevr.pickle


In [None]:
import os

import numpy as np
import torch
import pickle

with open("data/sort-of-clevr.pickle", "rb") as fp:
    train_data, test_data = pickle.load(fp)

train_images = torch.tensor([x[0] for x in train_data]).float()
train_images = train_images.transpose(3, 1)
test_images = torch.tensor([x[0] for x in test_data]).float()
test_images = test_images.transpose(3, 1)

In [None]:
train_images[0].shape

torch.Size([3, 75, 75])

input_height (int) – height of the images

enc_type (str) – option between resnet18 or resnet50

first_conv (bool) – use standard kernel_size 7, stride 2 at start or replace it with kernel_size 3, stride 1 conv

maxpool1 (bool) – use standard maxpool to reduce spatial dim of feat by a factor of 2

enc_out_dim (int) – set according to the out_channel count of encoder used (512 for resnet18, 2048 for resnet50)

kl_coeff (float) – coefficient for kl term of the loss

latent_dim (int) – dim of latent space

lr (float) – learning rate for Adam

In [None]:
from pl_bolts.models.autoencoders import VAE
import pytorch_lightning as pl

from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule

cifar_10 = CIFAR10DataModule('.')

#https://pytorch-lightning-bolts.readthedocs.io/en/latest/autoencoders.html
# model = VAE(enc_out_dim=512, latent_dim=256, input_height=32)
# trainer = pl.Trainer(gpus=1, max_epochs=30, progress_bar_refresh_rate=10)
# trainer.fit(model, cifar_10)

In [None]:
from matplotlib.pyplot import imshow, figure
import numpy as np
from torchvision.utils import make_grid
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
# figure(figsize=(8, 3), dpi=300)

# # Z COMES FROM NORMAL(0, 1)
# num_preds = 16
# p = torch.distributions.Normal(torch.zeros(256), torch.ones(256))
# z = p.rsample((num_preds,))

# # SAMPLE IMAGES
# with torch.no_grad():
#     pred = model.decoder(z.to(model.device)).cpu()

# # UNDO DATA NORMALIZATION
# normalize = cifar10_normalization()
# mean, std = np.array(normalize.mean), np.array(normalize.std)
# img = make_grid(pred).permute(1, 2, 0).numpy() * std + mean

# # PLOT IMAGES
# imshow(img);

In [None]:
train_images.shape

torch.Size([9800, 3, 75, 75])

In [None]:
# cifar_10.setup()

In [None]:
# type(next(iter(cifar_10.train_dataloader())))

In [None]:
# x,y = next(iter(cifar_10.train_dataloader()))

In [None]:
# x[0].size()

In [None]:
# y[0]

In [None]:
# next(iter(train_loader)).size()

In [None]:
from torch.utils.data import Dataset, DataLoader
class ClevrDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, pickle_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the pickle file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        with open(pickle_file, "rb") as fp:
            self.train_data, _ = pickle.load(fp)

        self.train_images = torch.tensor([x[0] for x in train_data]).float()
        self.train_images = train_images.transpose(3, 1)

        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.train_images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        sample = [train_images[idx], torch.tensor(1)]

        if self.transform:
          sample = self.transform(sample)

        return sample

In [None]:
clevr_dataset = ClevrDataset(pickle_file='data/sort-of-clevr.pickle',
                                    root_dir='.')

In [None]:
len(clevr_dataset)

9800

In [None]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils


fig = plt.figure()

# for i in range(1,10):
for i in range(4800):
    x_c,y_c = clevr_dataset[i]

    print(i, x_c.shape, y_c.shape)


  return f(*args, **kwds)


0 torch.Size([3, 75, 75]) torch.Size([])
1 torch.Size([3, 75, 75]) torch.Size([])
2 torch.Size([3, 75, 75]) torch.Size([])
3 torch.Size([3, 75, 75]) torch.Size([])
4 torch.Size([3, 75, 75]) torch.Size([])
5 torch.Size([3, 75, 75]) torch.Size([])
6 torch.Size([3, 75, 75]) torch.Size([])
7 torch.Size([3, 75, 75]) torch.Size([])
8 torch.Size([3, 75, 75]) torch.Size([])
9 torch.Size([3, 75, 75]) torch.Size([])
10 torch.Size([3, 75, 75]) torch.Size([])
11 torch.Size([3, 75, 75]) torch.Size([])
12 torch.Size([3, 75, 75]) torch.Size([])
13 torch.Size([3, 75, 75]) torch.Size([])
14 torch.Size([3, 75, 75]) torch.Size([])
15 torch.Size([3, 75, 75]) torch.Size([])
16 torch.Size([3, 75, 75]) torch.Size([])
17 torch.Size([3, 75, 75]) torch.Size([])
18 torch.Size([3, 75, 75]) torch.Size([])
19 torch.Size([3, 75, 75]) torch.Size([])
20 torch.Size([3, 75, 75]) torch.Size([])
21 torch.Size([3, 75, 75]) torch.Size([])
22 torch.Size([3, 75, 75]) torch.Size([])
23 torch.Size([3, 75, 75]) torch.Size([])
24

<Figure size 432x288 with 0 Axes>

In [None]:
clevr_dataloader = DataLoader(clevr_dataset, batch_size=4,
                        shuffle=True, num_workers=0)

In [None]:
x_n, y_n = next(iter(clevr_dataloader))
print(x_n.shape)
print(y_n.shape)

torch.Size([4, 3, 75, 75])
torch.Size([4])


In [None]:
# for i_batch, sample_batched in enumerate(clevr_dataloader):
#     print(i_batch, sample_batched.size())

In [None]:
# class SortOfClevrDataModule(pl.LightningDataModule):
#     def __init__(self):
#         super().__init__()
#     def train_dataloader(self):
#         return train_loader
#     def test_dataloader(self):
#         return test_loader

In [None]:
from pl_bolts.models.autoencoders import VAE
import pytorch_lightning as pl

#https://pytorch-lightning-bolts.readthedocs.io/en/latest/autoencoders.html
model = VAE(latent_dim=64, input_height=75, first_conv=False)
trainer = pl.Trainer(gpus=1, max_epochs=30, progress_bar_refresh_rate=10)
trainer.fit(model, clevr_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  return f(*args, **kwds)
  return f(*args, **kwds)

  | Name    | Type          | Params
------------------------------------------
0 | encoder | ResNetEncoder | 11.2 M
1 | decoder | ResNetDecoder | 7.1 M 
2 | fc_mu   | Linear        | 32.8 K
3 | fc_var  | Linear        | 32.8 K
------------------------------------------
18.3 M    Trainable params
0         Non-trainable params
18.3 M    Total params
73.149    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

  recon_loss = F.mse_loss(x_hat, x, reduction='mean')





RuntimeError: ignored