In [1]:
from PIL import Image
from transformers import ViTImageProcessor, ViTModel, CLIPImageProcessor, CLIPModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn

class Hypernetwork(nn.Module):
    def __init__(self, input_size, num_layers, layer_size, chunk_size, chunk_emb_size, num_chunks):
        super().__init__()
        self.num_chunks = num_chunks
        self.chunk_embeddings = self._generate_chunk_embeddings(chunk_emb_size, num_chunks) 
        
        input_size = input_size + chunk_emb_size
        hypernet_layers = self._prepare_layers(num_layers, layer_size, input_size, chunk_size)
        self.hypernet = nn.Sequential(*hypernet_layers)

    def forward(self, x):
        fast_weights = []
        for chunk_emb in self.chunk_embeddings:
            cat_ = torch.cat((x, chunk_emb), dim=1)
            fast_weight_chunk = self.hypernet(cat_)
            fast_weights.append(fast_weight_chunk)

        fast_weights = self._merge_layers(fast_weights)
        
        return fast_weights

    def _generate_chunk_embeddings(self, chunk_emb_size, num_chunks):
        # chunk_embs = [torch.rand((1, chunk_emb_size)) for _ in range(num_chunks)]
        chunk_embs = []
        for _ in range(num_chunks):
            chunk_emb = torch.rand((1, chunk_emb_size))
            chunk_emb.requires_grad = True
            chunk_embs.append(chunk_emb)
        return chunk_embs
        
    def _prepare_layers(self, num_layers, layer_size, input_size, chunk_size):
        input_layer = nn.Linear(in_features=input_size, out_features=layer_size)
        layers = [input_layer, nn.ReLU()]
        for _ in range(num_layers-1):
            layer = nn.Linear(layer_size, layer_size)
            layers.append(layer)
            layers.append(nn.ReLU())

        layers.append(nn.Linear(in_features=layer_size, out_features=chunk_size))
        layers.append(nn.Sigmoid())
        return layers

    def _merge_layers(self, fast_weights):
        merged_params = []
        for i in range(2):
            weights = torch.cat(fast_weights[128*i:128*(i+1)])
            bias = fast_weights[128*(i+1)]
            merged_params.append(weights)
            merged_params.append(bias)
        return merged_params

In [3]:
t = torch.rand(10, 10)

t.requires_grad = True
t

tensor([[0.2982, 0.4490, 0.8614, 0.7074, 0.0074, 0.5340, 0.6689, 0.1930, 0.9553,
         0.7903],
        [0.3822, 0.3196, 0.3025, 0.0499, 0.3313, 0.4167, 0.4933, 0.5373, 0.4805,
         0.0244],
        [0.3341, 0.8582, 0.7997, 0.8608, 0.8872, 0.4590, 0.6981, 0.8417, 0.1348,
         0.9034],
        [0.7349, 0.0891, 0.6473, 0.3789, 0.1330, 0.9239, 0.6139, 0.2815, 0.1439,
         0.4143],
        [0.8715, 0.0706, 0.8491, 0.0453, 0.8825, 0.2681, 0.6389, 0.8834, 0.2082,
         0.7312],
        [0.7164, 0.2481, 0.0909, 0.3928, 0.1614, 0.2715, 0.8836, 0.7845, 0.4650,
         0.5475],
        [0.3595, 0.0097, 0.7094, 0.8885, 0.3732, 0.9038, 0.3597, 0.7077, 0.0653,
         0.0441],
        [0.2502, 0.5409, 0.2824, 0.0737, 0.4117, 0.8894, 0.4722, 0.4253, 0.0520,
         0.4317],
        [0.8255, 0.4520, 0.6172, 0.5317, 0.1815, 0.3860, 0.3008, 0.3991, 0.1640,
         0.4964],
        [0.5076, 0.5384, 0.5128, 0.2812, 0.7130, 0.8440, 0.5340, 0.7121, 0.9194,
         0.7673]], requires_

In [4]:
import torch.nn.functional as F

class Linear_fw(nn.Linear):
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        self.weight.fast = self.weight
        self.bias.fast = self.bias

    def forward(self, x):
        if self.weight.fast is not None and self.bias.fast is not None:
            out = F.linear(x, self.weight.fast, self.bias.fast)
        else:
            out = F.linear(x, self.weight, self.bias)
        return out

In [5]:
class MLP_FW(nn.Module):
    def __init__(self, input_size, num_layers, layer_size, num_classes):
        super().__init__()
        layers = self._generate_layers(input_size, num_layers, layer_size, num_classes)
        
        self.net = nn.Sequential(*layers)

    def _generate_layers(self, input_size, num_layers, layer_size, num_classes):
        layers = [Linear_fw(input_size, layer_size), nn.ReLU()]
        for _ in range(num_layers-2):
            layers.append(Linear_fw(layer_size, layer_size))
            layers.append(nn.ReLU())

        layers.append(Linear_fw(layer_size, num_classes))
        layers.append(nn.ReLU())

        return layers
    
    def _update_weight(self, weight, update_value):
            weight.fast = weight * update_value

In [6]:
mlp = MLP_FW(768, 4, 128, 5)

In [7]:
input_embedding = torch.rand((1,5*768))

In [8]:
hypernet = Hypernetwork(input_size=5*768, num_layers=4, layer_size=500, chunk_size=128, chunk_emb_size=8, num_chunks=129*2)

In [9]:
updates = hypernet(input_embedding)

In [10]:
for k, weight in enumerate(list(mlp.parameters())[2:-2]):
    update_value = updates[k]
    mlp._update_weight(weight, update_value)

In [11]:
from transformers import ViTForImageClassification

In [12]:
vit_image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

img = Image.open("/home/lukasz/binary-hyper-maml/filelists/emnist/emnist/13/11.png").convert('RGB')

In [13]:
from torchvision.transforms import ToTensor
trans = ToTensor()

In [14]:
img_t = trans(img)

In [15]:
inputs_t = vit_image_processor(images=img_t, return_tensors="pt", do_rescale=False)
pixel_values_t = inputs_t.pixel_values

In [16]:
inputs = vit_image_processor(images=img, return_tensors="pt")
pixel_values = inputs.pixel_values

In [17]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

In [18]:
outptut = model(pixel_values)

In [19]:
outptut[0].shape

torch.Size([1, 1000])

In [20]:
feature_extractor = ViTModel.from_pretrained('google/vit-base-patch16-224')

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
outputs = feature_extractor(pixel_values)
sequence_output = outputs[0]

sequence_output[:, 0, :].shape

torch.Size([1, 768])

In [22]:
import json
import os
import torchvision.transforms as transforms
from PIL import Image

class TransformLoader:
    def __init__(self, image_size, 
                 normalize_param    = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]),
                 jitter_param       = dict(Brightness=0.4, Contrast=0.4, Color=0.4)):
        self.image_size = image_size
        self.normalize_param = normalize_param
        self.jitter_param = jitter_param
    
    def parse_transform(self, transform_type):
        method = getattr(transforms, transform_type)
        if transform_type=='RandomResizedCrop':
            return method(self.image_size) 
        elif transform_type=='CenterCrop':
            return method(self.image_size) 
        elif transform_type=='Resize':
            return method([int(self.image_size*1.15), int(self.image_size*1.15)])
        elif transform_type=='Normalize':
            return method(**self.normalize_param )
        else:
            return method()

    def get_composed_transform(self, aug = False):
        if aug:
            transform_list = ['RandomResizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize']
        else:
            transform_list = ['Resize','CenterCrop', 'ToTensor', 'Normalize']

        transform_funcs = [ self.parse_transform(x) for x in transform_list]
        transform = transforms.Compose(transform_funcs)
        return transform

identity = lambda x:x

class SimpleDataset:
    def __init__(self, data_file, transform, target_transform=identity):
        with open(data_file, 'r') as f:
            self.meta = json.load(f)
        self.transform = transform
        self.target_transform = target_transform


    def __getitem__(self,i):
        image_path = os.path.join(self.meta['image_names'][i])
        img = Image.open(image_path).convert('RGB')
        # img = self.transform(img)
        target = self.target_transform(self.meta['image_labels'][i])
        return img, target

    def __len__(self):
        return len(self.meta['image_names'])

In [38]:
path = "/home/lukasz/binary-hyper-maml/filelists/miniImagenet/base.json"

In [39]:
transform_loader = TransformLoader(224)
transform = transform_loader.get_composed_transform()

In [40]:
dataset = SimpleDataset(path, transform)

In [41]:
from tqdm import tqdm

In [42]:
for data in tqdm(dataset):
    img_t, target = data
    inputs_t = vit_image_processor(images=img_t, return_tensors="pt", do_rescale=False)
    pixel_values_t = inputs_t.pixel_values
    outputs = feature_extractor(pixel_values)
    sequence_output = outputs[0]

    sequence_output[:, 0, :]

  0%|          | 22/38400 [00:02<1:15:41,  8.45it/s]