In [1]:
import numpy as np
import pickle as pkl
from PIL import Image as img
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
with open('Data/Bin/test', 'rb') as f:
    test = pkl.load(f, encoding='bytes')
with open('Data/Bin/train', 'rb') as f:
    train = pkl.load(f, encoding='bytes')
# train

In [3]:
train

{b'filenames': [b'bos_taurus_s_000507.png',
  b'stegosaurus_s_000125.png',
  b'mcintosh_s_000643.png',
  b'altar_boy_s_001435.png',
  b'cichlid_s_000031.png',
  b'phone_s_002161.png',
  b'car_train_s_000043.png',
  b'beaker_s_000604.png',
  b'fog_s_000397.png',
  b'rogue_elephant_s_000421.png',
  b'computer_keyboard_s_000757.png',
  b'willow_tree_s_000645.png',
  b'sunflower_s_000549.png',
  b'palace_s_000759.png',
  b'adriatic_s_001782.png',
  b'computer_keyboard_s_001277.png',
  b'bike_s_000682.png',
  b'wolf_pup_s_001323.png',
  b'squirrel_s_002467.png',
  b'sea_s_000678.png',
  b'shrew_s_002233.png',
  b'pine_tree_s_000087.png',
  b'rose_s_000373.png',
  b'surveillance_system_s_000769.png',
  b'pine_s_001533.png',
  b'table_s_000897.png',
  b'opossum_s_001237.png',
  b'quercus_alba_s_000257.png',
  b'leopard_s_000414.png',
  b'possum_s_002195.png',
  b'bike_s_000127.png',
  b'balmoral_castle_s_000361.png',
  b'acer_saccharinum_s_000646.png',
  b'lapin_s_000916.png',
  b'chimp_s_001

In [3]:
scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

In [21]:
def pack(np_image):
    np_image = np_image[:, :, :, np.newaxis]
    tmp = np.concatenate((np_image[0], np_image[1]), axis=2)
    return np.concatenate((tmp, np_image[2]), axis=2)

def prepr(np_image):
    tmp_im = img.fromarray(pack(np_image))
    return normalize(to_tensor(scaler(tmp_im))).to(device)

In [5]:
vgg = torchvision.models.vgg16(pretrained=False)
vgg_extractor = nn.Sequential(
    vgg.features,
    vgg.avgpool,
    nn.Flatten(),
    nn.Linear(512 * 49, 2048),
    nn.ReLU(),
    nn.Linear(2048, 2048),
    nn.ReLU(),
    nn.Linear(2048, 80)
).to(device)

In [6]:
vgg_extractor.load_state_dict(torch.load('../VGG16Extractor', map_location=device))

<All keys matched successfully>

In [7]:
vgg_extractor = nn.Sequential(*list(vgg_extractor.children())[:-2]).to(device)
vgg = 0
torch.cuda.empty_cache()

In [3]:
vgg_extractor = torch.load('../VGG16Backbone', map_location=device) # on gpu
# torch.save(vgg_extractor, 'VGG16Backbone')



In [6]:
# imgs_torch = torch.zeros((600, 3, 224, 224), dtype=torch.float32)
features = torch.zeros((600, 2048), dtype=torch.float32)
for class_num in range(100):
    torch.cuda.empty_cache()
    idx = [train[b'fine_labels'][i] == class_num for i in range(len(train[b'fine_labels']))]
    imgs = train[b'data'][idx]    
    
    idx = [test[b'fine_labels'][i] == class_num for i in range(len(test[b'fine_labels']))]
    imgs = np.concatenate((imgs, test[b'data'][idx]), axis=0)
    imgs = imgs.reshape(-1, 3, 32, 32)
    i = 0
    for im in imgs:
        features[i] = vgg_extractor(prepr(im).reshape(1, 3, 224, 224)).detach().cpu()
        i += 1  

    with open('Data/PickledClasses/' + str(class_num), 'wb') as f:
        pkl.dump(features, f)
    print('Class {} processed.'.format(class_num))
features



Class 0 processed.
Class 1 processed.
Class 2 processed.
Class 3 processed.
Class 4 processed.
Class 5 processed.
Class 6 processed.
Class 7 processed.
Class 8 processed.
Class 9 processed.
Class 10 processed.
Class 11 processed.
Class 12 processed.
Class 13 processed.
Class 14 processed.
Class 15 processed.
Class 16 processed.
Class 17 processed.
Class 18 processed.
Class 19 processed.
Class 20 processed.
Class 21 processed.
Class 22 processed.
Class 23 processed.
Class 24 processed.
Class 25 processed.
Class 26 processed.
Class 27 processed.
Class 28 processed.
Class 29 processed.
Class 30 processed.
Class 31 processed.
Class 32 processed.
Class 33 processed.
Class 34 processed.
Class 35 processed.
Class 36 processed.
Class 37 processed.
Class 38 processed.
Class 39 processed.
Class 40 processed.
Class 41 processed.
Class 42 processed.
Class 43 processed.
Class 44 processed.
Class 45 processed.
Class 46 processed.
Class 47 processed.
Class 48 processed.
Class 49 processed.
Class 50 p

tensor([[ 6.0646,  3.5600,  0.4621,  ...,  4.9076, -3.2856,  1.7092],
        [-2.5282,  4.3624,  5.6451,  ...,  4.9651, -2.3543, -0.0673],
        [ 3.9683,  5.0149, -1.8166,  ...,  1.0168, -2.9394, -0.3736],
        ...,
        [12.1863,  1.9852, -0.7208,  ...,  3.8403, -3.9264, -0.0975],
        [-0.4379, -0.3765,  5.6470,  ...,  4.7252, -4.5871,  3.0658],
        [ 3.7452,  5.7615, -1.7176,  ..., -1.0261, -2.7995,  2.6570]])

In [10]:
# imgs_torch = torch.zeros((600, 3, 224, 224), dtype=torch.float32)
for class_num in range(100):
    idx = [train[b'fine_labels'][i] == class_num for i in range(len(train[b'fine_labels']))]
    imgs = train[b'data'][idx]    
    
    idx = [test[b'fine_labels'][i] == class_num for i in range(len(test[b'fine_labels']))]
    imgs = np.concatenate((imgs, test[b'data'][idx]), axis=0)
    imgs = imgs.reshape(-1, 3, 32, 32)  

    with open('Data/PickledIms/' + str(class_num), 'wb') as f:
        pkl.dump(imgs, f)
    print('Class {} processed.'.format(class_num))
features

Class 0 processed.
Class 1 processed.
Class 2 processed.
Class 3 processed.
Class 4 processed.
Class 5 processed.
Class 6 processed.
Class 7 processed.
Class 8 processed.
Class 9 processed.
Class 10 processed.
Class 11 processed.
Class 12 processed.
Class 13 processed.
Class 14 processed.
Class 15 processed.
Class 16 processed.
Class 17 processed.
Class 18 processed.
Class 19 processed.
Class 20 processed.
Class 21 processed.
Class 22 processed.
Class 23 processed.
Class 24 processed.
Class 25 processed.
Class 26 processed.
Class 27 processed.
Class 28 processed.
Class 29 processed.
Class 30 processed.
Class 31 processed.
Class 32 processed.
Class 33 processed.
Class 34 processed.
Class 35 processed.
Class 36 processed.
Class 37 processed.
Class 38 processed.
Class 39 processed.
Class 40 processed.
Class 41 processed.
Class 42 processed.
Class 43 processed.
Class 44 processed.
Class 45 processed.
Class 46 processed.
Class 47 processed.
Class 48 processed.
Class 49 processed.
Class 50 p

NameError: name 'features' is not defined

In [4]:
class AugGenerator():
    def __init__(self, num_samples):
        self.num_samples = num_samples
        self.idx = np.random.permutation(num_samples) % 6
        self.rotate = transforms.RandomRotation(30)
        self.flip = transforms.RandomHorizontalFlip(1)
        self.noise = transforms.GaussianBlur(3)
        self.perspective = transforms.RandomPerspective(p=1)
        self.affine = transforms.RandomAffine(20, (0.2, 0.2))
        self.jitter = transforms.ColorJitter((0.8, 1), (0.8, 1), (0.8, 1))
        self.resize = transforms.Resize((224, 224))
        
    def reshuffle(self):
        self.idx = np.random.permutation(self.num_samples) % 6
        
    def aug(self, image, i):
        im_torch = self.resize(torch.tensor(image))
        if self.idx[i] == 0:
            return self.rotate(im_torch)
        elif self.idx[i] == 1:
            return self.flip(im_torch)
        elif self.idx[i] == 2:
            return self.noise(im_torch)
        elif self.idx[i] == 3:
            return self.perspective(im_torch)
        elif self.idx[i] == 4:
            return self.affine(im_torch)
        elif self.idx[i] == 5:
            return self.jitter(im_torch)

In [7]:
def create_episode(num_shots):
    episode = np.zeros((1, 5, num_shots, 3, 32, 32), dtype=np.uint8)

    for i in range(5):
        with open('Data/PickledIms/' + str(95 + i), 'rb') as f:
            data = pkl.load(f)

        shot_numbers = np.random.randint(0, 600, size=num_shots)
        episode[0][i][:num_shots] = data[shot_numbers]
        
    return episode

episode = create_episode(5)
episode.shape

(1, 5, 5, 3, 32, 32)

In [None]:
num_samples = 1024
num_shots = 5
batch_size = 16
G3 = AugGenerator(num_samples)

X_aug_im = torch.zeros((5, num_samples, 3, 224, 224), device=device, dtype=torch.float32)
X_aug = torch.zeros((5, num_samples, 2048), device=device, dtype=torch.float32)
for class_num in range(5):
    for i in range(num_samples):
        X_aug_im[class_num][i] = G3.aug(episode[0][class_num][i % num_shots], i)
    for j in range(num_samples // batch_size):
        X_aug[class_num, j * batch_size:(j + 1) * batch_size] = vgg_extractor(X_aug_im[class_num,
                                                                                       j * batch_size:(j + 1) * batch_size])
    print('Class {} processed.'.format(class_num))

In [5]:
num_samples = 1024
num_shots = 5
batch_size = 16
G3 = AugGenerator(num_samples)

X_aug_im = torch.zeros((5, num_samples, 3, 224, 224), device=device, dtype=torch.uint8)
X_aug = torch.zeros((5, num_samples, 2048), device=device, dtype=torch.float32)

In [16]:
G3.rotate(torch.tensor(episode[0][0][0], dtype=torch.uint8))

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8)

In [None]:
X_aug = vgg_extractor(X_aug_im[0, :100].reshape(-1, 3, 224, 224)) # detach !
X_aug

In [36]:
X_aug.shape

torch.Size([10, 2048])

In [12]:
episode

array([[[[[[110, 110, 108, ..., 111, 109, 111],
           [112, 112, 112, ..., 122, 122, 118],
           [114, 113, 115, ..., 126, 127, 125],
           ...,
           [114, 115, 114, ..., 104, 110, 113],
           [112, 103, 100, ..., 116, 117, 116],
           [113, 112, 117, ..., 114, 115, 112]],

          [[143, 143, 143, ..., 146, 144, 147],
           [146, 146, 147, ..., 156, 155, 154],
           [150, 147, 150, ..., 160, 158, 159],
           ...,
           [143, 143, 145, ..., 133, 139, 145],
           [140, 130, 125, ..., 147, 147, 146],
           [143, 142, 146, ..., 146, 145, 143]],

          [[163, 163, 163, ..., 169, 165, 169],
           [167, 167, 169, ..., 178, 177, 177],
           [172, 168, 174, ..., 182, 179, 180],
           ...,
           [160, 158, 162, ..., 150, 159, 165],
           [156, 142, 134, ..., 166, 165, 165],
           [158, 159, 162, ..., 169, 165, 162]]],


         [[[ 74,  77,  86, ...,  62,  55,  57],
           [ 79,  87,  89, ..., 