In [83]:
from PIL import Image
from transformers import ViTImageProcessor, ViTModel, CLIPImageProcessor, CLIPModel

In [106]:
import torch
import torch.nn as nn

class Hypernetwork(nn.Module):
    def __init__(self, input_size, num_layers, layer_size, chunk_size, chunk_emb_size, num_chunks):
        super().__init__()
        self.num_chunks = num_chunks
        self.chunk_embeddings = self._generate_chunk_embeddings(chunk_emb_size, num_chunks) 
        
        input_size = input_size + chunk_emb_size
        hypernet_layers = self._prepare_layers(num_layers, layer_size, input_size, chunk_size)
        self.hypernet = nn.Sequential(*hypernet_layers)

    def forward(self, x):
        fast_weights = []
        for chunk_emb in self.chunk_embeddings:
            cat_ = torch.cat((x, chunk_emb), dim=1)
            fast_weight_chunk = self.hypernet(cat_)
            fast_weights.append(fast_weight_chunk)

        fast_weights = self._merge_layers(fast_weights)
        
        return fast_weights

    def _generate_chunk_embeddings(self, chunk_emb_size, num_chunks):
        # chunk_embs = [torch.rand((1, chunk_emb_size)) for _ in range(num_chunks)]
        chunk_embs = []
        for _ in range(num_chunks):
            chunk_emb = torch.rand((1, chunk_emb_size))
            chunk_emb.requires_grad = True
            chunk_embs.append(chunk_emb)
        return chunk_embs
        
    def _prepare_layers(self, num_layers, layer_size, input_size, chunk_size):
        input_layer = nn.Linear(in_features=input_size, out_features=layer_size)
        layers = [input_layer, nn.ReLU()]
        for _ in range(num_layers-1):
            layer = nn.Linear(layer_size, layer_size)
            layers.append(layer)
            layers.append(nn.ReLU())

        layers.append(nn.Linear(in_features=layer_size, out_features=chunk_size))
        layers.append(nn.Sigmoid())
        return layers

    def _merge_layers(self, fast_weights):
        merged_params = []
        for i in range(2):
            weights = torch.cat(fast_weights[128*i:128*(i+1)])
            bias = fast_weights[128*(i+1)]
            merged_params.append(weights)
            merged_params.append(bias)
        return merged_params

In [124]:
t = torch.rand(10, 10)

t.requires_grad = True
t

tensor([[0.2701, 0.4831, 0.0678, 0.8158, 0.6586, 0.2374, 0.9372, 0.5849, 0.3332,
         0.5475],
        [0.4155, 0.8660, 0.7901, 0.0117, 0.3512, 0.1540, 0.2608, 0.2336, 0.9584,
         0.8533],
        [0.7711, 0.0750, 0.7570, 0.2116, 0.8449, 0.3310, 0.5233, 0.4935, 0.5692,
         0.8548],
        [0.2410, 0.4033, 0.0346, 0.5791, 0.8742, 0.3350, 0.2784, 0.8638, 0.5333,
         0.1336],
        [0.7549, 0.4227, 0.3205, 0.7453, 0.3750, 0.1564, 0.5172, 0.9376, 0.4762,
         0.6465],
        [0.2722, 0.3496, 0.5941, 0.3617, 0.7450, 0.5976, 0.6280, 0.6096, 0.0921,
         0.8017],
        [0.0909, 0.3052, 0.9453, 0.5750, 0.3455, 0.1266, 0.5642, 0.8431, 0.8943,
         0.2658],
        [0.8484, 0.7801, 0.4151, 0.6903, 0.1282, 0.4697, 0.8956, 0.7785, 0.4286,
         0.6014],
        [0.7970, 0.4235, 0.7062, 0.7118, 0.9265, 0.7303, 0.1134, 0.7578, 0.4059,
         0.0086],
        [0.0134, 0.3952, 0.3379, 0.1137, 0.3951, 0.2595, 0.6568, 0.4998, 0.4328,
         0.0871]], requires_

In [107]:
import torch.nn.functional as F

class Linear_fw(nn.Linear):
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        self.weight.fast = self.weight
        self.bias.fast = self.bias

    def forward(self, x):
        if self.weight.fast is not None and self.bias.fast is not None:
            out = F.linear(x, self.weight.fast, self.bias.fast)
        else:
            out = F.linear(x, self.weight, self.bias)
        return out

In [108]:
class MLP_FW(nn.Module):
    def __init__(self, input_size, num_layers, layer_size, num_classes):
        super().__init__()
        layers = self._generate_layers(input_size, num_layers, layer_size, num_classes)
        
        self.net = nn.Sequential(*layers)

    def _generate_layers(self, input_size, num_layers, layer_size, num_classes):
        layers = [Linear_fw(input_size, layer_size), nn.ReLU()]
        for _ in range(num_layers-2):
            layers.append(Linear_fw(layer_size, layer_size))
            layers.append(nn.ReLU())

        layers.append(Linear_fw(layer_size, num_classes))
        layers.append(nn.ReLU())

        return layers
    
    def _update_weight(self, weight, update_value):
            weight.fast = weight * update_value

In [109]:
mlp = MLP_FW(768, 4, 128, 5)

In [110]:
input_embedding = torch.rand((1,5*768))

In [111]:
hypernet = Hypernetwork(input_size=5*768, num_layers=4, layer_size=500, chunk_size=128, chunk_emb_size=8, num_chunks=129*2)

In [112]:
updates = hypernet(input_embedding)

In [120]:
for k, weight in enumerate(list(mlp.parameters())[2:-2]):
    update_value = updates[k]
    mlp._update_weight(weight, update_value)

In [62]:
from transformers import ViTForImageClassification

In [10]:
vit_image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

img = Image.open("/home/lukasz/binary-hyper-maml/filelists/emnist/emnist/13/11.png").convert('RGB')

In [125]:
from torchvision.transforms import ToTensor
trans = ToTensor()

In [126]:
img_t = trans(img)

In [133]:
inputs_t = vit_image_processor(images=img_t, return_tensors="pt", do_rescale=False)
pixel_values_t = inputs_t.pixel_values

In [130]:
inputs = vit_image_processor(images=img, return_tensors="pt")
pixel_values = inputs.pixel_values

In [14]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

In [15]:
outptut = model(pixel_values)

In [16]:
outptut[0].shape

torch.Size([1, 1000])

In [17]:
feature_extractor = ViTModel.from_pretrained('google/vit-base-patch16-224')

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
outputs = feature_extractor(pixel_values)
sequence_output = outputs[0]

sequence_output[:, 0, :].shape

torch.Size([1, 768])