In [1]:
!pip install torch



In [2]:
import math
import torch
from torch import nn
import numpy as np

def torch_apply_along_axis(function, x, axis: int = 0):
    """
    Torch equivalent of numpy apply along axis. This function is slow and should be avoided
    https://discuss.pytorch.org/t/apply-a-function-along-an-axis/130440
    """
    return torch.stack([
        function(x_i) for x_i in torch.unbind(x, dim=axis)
    ], dim=axis)

def input_to_rfs_torch(xw, AB_fun, ab_fun, xis, num_rfs, dim):
    ab_coeffs = torch_apply_along_axis(ab_fun, xis, 0)
    AB_coeffs = torch_apply_along_axis(AB_fun, xis, 0)
    torch.manual_seed(0)
    gs = torch.rand(size=(num_rfs, dim))
    renorm_gs = (ab_coeffs * gs.t()).t()
    dot_products = torch.einsum('ij,j->i', renorm_gs, xw)
    squared_xw = torch.sum(xw * xw)
    correction_vector = (squared_xw / 2) * ab_coeffs * ab_coeffs
    diff_vector = dot_products - correction_vector
    return (1.0 / math.sqrt(num_rfs)) * AB_coeffs * torch.exp(diff_vector)

def input_to_rfs_torch_vectorized(xw, AB_fun, ab_fun, xis, num_rfs, dim):
    ab_coeffs = torch_apply_along_axis(ab_fun, xis, 0)
    AB_coeffs = torch_apply_along_axis(AB_fun, xis, 0)
    torch.manual_seed(0)
    gs = torch.rand(size=(num_rfs, dim))
    renorm_gs = (ab_coeffs * gs.t()).t()
    dot_products = torch.einsum('ij,jk->ik', xw, renorm_gs.t())
    squared_xw = torch.sum(torch.mul(xw, xw), dim=1)
    correction_vector = torch.outer(squared_xw / 2, torch.mul(ab_coeffs, ab_coeffs))
    diff_vector = dot_products - correction_vector
    return (1.0 / math.sqrt(num_rfs)) * AB_coeffs * torch.exp(diff_vector)

# class mynetwork(nn.Module):
#     def __init__(self, w):
#         super().__init__()
#         self.w = w
#         self.weights = input_to_rfs_torch(self.w, A_fun, a_fun, xis, num_rfs, dim)
#         self.weights = nn.Parameter(self.weights)
#         # self.bias = nn.Parameter(torch.zeros(10))

#     def forward(self, x):
#         xb = input_to_rfs_torch(x, A_fun, a_fun, xis, num_rfs, dim)
#         return xb @ self.weights.t()

In [3]:
###################### TEST
# dim = 5
# x = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0]).float()
# w = torch.Tensor([5.0, 4.0, 3.0, 2.0, 1.0]).float()
x = torch.rand((3,256))

# bias = torch.Tensor([0.0])
# groundtruth_value = torch.cos(torch.dot(x, w)+bias)
num_rfs = 64
# a_fun= lambda xi: 2.0 * math.pi * 1j * xi
# b_fun= lambda x: 1
# A_fun= lambda x: np.exp(bias)
# B_fun= lambda x: 1
a_fun = lambda x: np.sin(x)
b_fun = lambda x: np.cos(x)
A_fun = lambda x: np.sin(x)
B_fun = lambda x: np.cos(x)

xis_creator = lambda x: 1.0 / (2.0 * math.pi) * (x > 0.5) - 1.0 / (2.0 * math.pi) * (x < 0.5)
random_tosses = torch.rand(num_rfs)
xis = xis_creator(random_tosses)

class mynetwork(nn.Module):
    def __init__(self, inp_dim):
        super().__init__()
        self.layer1 = nn.Linear(inp_dim, inp_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(inp_dim, 256)
        # print(self.linear2)
        # print(self.linear2.weight.shape)
        # self.w = self.linear2.weight
        # print("W IS", w.shape)
        self.weights = input_to_rfs_torch_vectorized(self.linear2.weight, A_fun, a_fun, xis, num_rfs, 256)
        self.weights = nn.Parameter(self.weights)
        # self.bias = nn.Parameter(torch.zeros(10))

    def forward(self, x):
        xb = input_to_rfs_torch_vectorized(x, A_fun, a_fun, xis, num_rfs, 256)
        return xb @ self.weights.t()

In [5]:
# testing vectorized version, before we run with Vit
net = mynetwork(256)
net(x)

tensor([[159.5214, 138.9836, 125.9633, 132.2636, 149.1356, 145.5340, 147.3583,
         151.3705, 148.8756, 149.8760, 142.3045, 151.4638, 138.9649, 149.8407,
         161.5794, 136.1583, 149.5099, 142.7977, 139.4826, 141.2225, 140.4835,
         135.7491, 156.8195, 143.1146, 146.2220, 147.1548, 138.6097, 139.5836,
         133.6076, 141.5825, 153.4449, 147.8783, 137.5003, 129.9559, 152.5504,
         145.4684, 142.0189, 142.1979, 146.1419, 143.6837, 150.1651, 147.3919,
         147.2217, 147.5441, 152.2361, 142.5456, 144.4930, 149.1394, 146.2385,
         140.4700, 153.3432, 146.4711, 141.0195, 142.8022, 147.6641, 137.6978,
         135.1386, 152.4140, 136.3570, 139.5619, 136.7042, 159.1589, 137.4422,
         152.3584, 142.0882, 149.2403, 145.5336, 148.4319, 139.8356, 134.5932,
         152.3497, 132.1793, 162.0865, 138.4233, 141.4406, 146.3412, 135.2060,
         145.7321, 138.4471, 146.4900, 152.2778, 156.8766, 147.7393, 134.8655,
         143.7523, 137.9660, 138.4806, 143.2652, 144

In [6]:
for n, p in net.named_parameters():
  print(n)
  if 'weights' in n:
    p.requires_grad = True
  else:
    p.requires_grad = False

weights
layer1.weight
layer1.bias
linear2.weight
linear2.bias


In [7]:
net.train()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
# real stupid test
for i in range(50):
    optimizer.zero_grad()
    l = net(x).mean()
    if i%5==0:
      print(l)
    l.backward() #quite slow
    optimizer.step()
# loss indeed goes down

tensor(117.9060, grad_fn=<MeanBackward0>)
tensor(114.9361, grad_fn=<MeanBackward0>)
tensor(111.9663, grad_fn=<MeanBackward0>)
tensor(108.9964, grad_fn=<MeanBackward0>)
tensor(106.0266, grad_fn=<MeanBackward0>)
tensor(103.0567, grad_fn=<MeanBackward0>)
tensor(100.0869, grad_fn=<MeanBackward0>)
tensor(97.1170, grad_fn=<MeanBackward0>)
tensor(94.1472, grad_fn=<MeanBackward0>)
tensor(91.1773, grad_fn=<MeanBackward0>)


In [None]:
# # test vectorized version.
# # the output does match the non-vectorized version
# x_vec = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [5.0, 4.0, 3.0, 2.0, 1.0]]).float()
# x_rfs_vec = input_to_rfs_torch_vectorized(x_vec, A_fun, a_fun, xis, num_rfs, dim)
# print('vectorized output: ', x_rfs_vec)
# x_rfs = input_to_rfs_torch(x, A_fun, a_fun, xis, num_rfs, dim)
# print('non-vectorized output: ', x_rfs)
# x_2 = torch.Tensor([5.0, 4.0, 3.0, 2.0, 1.0]).float()
# x_rfs_2 = input_to_rfs_torch(x_2, A_fun, a_fun, xis, num_rfs, dim)
# print('non-vectorized output (2nd vector): ', x_rfs_2)

In [8]:
! pip install transformers datasets evaluate wandb accelerate



In [9]:
from collections import OrderedDict
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    ImageClassifierOutput,
    MaskedImageModelingOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from transformers import ViTConfig, ViTPreTrainedModel, ViTModel
from transformers.models.vit.modeling_vit import ViTPooler
# from nnk import *

logger = logging.get_logger(__name__)

# General docstring
_CONFIG_FOR_DOC = "ViTConfig"

# Base docstring
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]

# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"


VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/vit-base-patch16-224",
    # See all ViT models at https://huggingface.co/models?filter=vit
]

VIT_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        interpolate_pos_encoding (`bool`, *optional*):
            Whether to interpolate the pre-trained position encodings.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

In [10]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class LinearViTForImageClassification(ViTPreTrainedModel):
    def __init__(self, config: ViTConfig, A_fun: callable, a_fun: callable, xis: callable, num_rfs: int) -> None:
        super().__init__(config, A_fun, a_fun, xis, num_rfs)

        self.num_labels = config.num_labels
        self.A_fun = A_fun
        self.a_fun = a_fun
        self.xis = xis
        self.num_rfs = num_rfs
        # self.vit = ViTModel(config, add_pooling_layer=False) #the og image classification model do not use the pooling layer
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.vit.pooler = Identity()
        self.pooler = ViTPooler(config) #add pooling layer
        # self.pooler = self.pooler.load_state_dict(dict1, strict=False)

        # Classifier head
        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

        self.post_init()
        # Initialize weights and apply final processing
        # get weights and bias to linearize
        self.w = self.pooler.dense.weight
        # the bias in the pooler layer is the 0 vector so going to ignore the first pass
        self.output_rfs = input_to_rfs_torch_vectorized(self.w, A_fun, a_fun, xis, num_rfs, self.w.shape[1])
        # TO CHECK: Might be issues with some gradient hooks
        self.output_rfs = nn.Parameter(self.output_rfs)
        # linearize the pooler layer


    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=ImageClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    # TODO : Precompute the x_rfs
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[tuple, ImageClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )
        print(outputs)
        first_token_tensor = outputs['last_hidden_state'][:, 0]
        x_rfs = input_to_rfs_torch_vectorized(first_token_tensor, self.A_fun, self.a_fun, self.xis, self.num_rfs, first_token_tensor.shape[1])
        sequence_output = x_rfs @ self.output_rfs.t()
        # compute the linearized pooling layer
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [11]:
config = ViTConfig.from_pretrained('google/vit-base-patch16-224-in21k')
model = LinearViTForImageClassification(config,A_fun=A_fun, a_fun=a_fun, xis=xis, num_rfs=64)

In [12]:
# model1 = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
# with torch.no_grad():
#     model.pooler.dense.weight.copy_(model1.pooler.dense.weight)

In [13]:
dict1 = torch.load('pooler_weights.pkl')
with torch.no_grad():
    model.pooler.dense.weight.copy_(dict1['pooler.dense.weight'])

In [14]:
model

LinearViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_feat

In [15]:
model.pooler.dense.weight == dict1['pooler.dense.weight']

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [16]:
from datasets import load_dataset, load_metric
from transformers import ViTImageProcessor
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
inputs = image_processor(image, return_tensors="pt")


In [17]:
model(**inputs)

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.1559,  0.0914,  0.1518,  ..., -0.3180, -0.0859, -0.0903],
         [-0.2254,  0.0864,  0.4752,  ..., -0.1781,  0.1726,  0.1334],
         [ 0.0444,  0.0677,  0.4199,  ..., -0.2576,  0.1191,  0.0130],
         ...,
         [-0.0153, -0.0396,  0.1684,  ..., -0.1672,  0.1869,  0.1025],
         [ 0.0249, -0.0382,  0.2046,  ...,  0.0517,  0.1489,  0.1320],
         [-0.1748, -0.0254,  0.2523,  ..., -0.1474,  0.1627,  0.1325]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[[ 0.1559,  0.0914,  0.1518,  ..., -0.3180, -0.0859, -0.0903],
         [-0.2254,  0.0864,  0.4752,  ..., -0.1781,  0.1726,  0.1334],
         [ 0.0444,  0.0677,  0.4199,  ..., -0.2576,  0.1191,  0.0130],
         ...,
         [-0.0153, -0.0396,  0.1684,  ..., -0.1672,  0.1869,  0.1025],
         [ 0.0249, -0.0382,  0.2046,  ...,  0.0517,  0.1489,  0.1320],
         [-0.1748, -0.0254,  0.2523,  ..., -0.1474,  0.1627,  0.1325]]],
       grad_f

ImageClassifierOutput(loss=None, logits=tensor([[-0.0259,  0.0105]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [18]:
for n, p in model.named_parameters():
    if ('output_rfs'in n) or ('classifier' in n):
        p.requires_grad = True
    else :
        p.requires_grad = False

In [None]:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# real stupid test
for i in range(10):
    optimizer.zero_grad()
    l = model(inputs['pixel_values'], labels=torch.tensor([1]).long()).loss
    if i%1 == 0:
      print(l)
    l.backward() #quite slow
    optimizer.step()

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.1559,  0.0914,  0.1518,  ..., -0.3180, -0.0859, -0.0903],
         [-0.2254,  0.0864,  0.4752,  ..., -0.1781,  0.1726,  0.1334],
         [ 0.0444,  0.0677,  0.4199,  ..., -0.2576,  0.1191,  0.0130],
         ...,
         [-0.0153, -0.0396,  0.1684,  ..., -0.1672,  0.1869,  0.1025],
         [ 0.0249, -0.0382,  0.2046,  ...,  0.0517,  0.1489,  0.1320],
         [-0.1748, -0.0254,  0.2523,  ..., -0.1474,  0.1627,  0.1325]]]), pooler_output=tensor([[[ 0.1559,  0.0914,  0.1518,  ..., -0.3180, -0.0859, -0.0903],
         [-0.2254,  0.0864,  0.4752,  ..., -0.1781,  0.1726,  0.1334],
         [ 0.0444,  0.0677,  0.4199,  ..., -0.2576,  0.1191,  0.0130],
         ...,
         [-0.0153, -0.0396,  0.1684,  ..., -0.1672,  0.1869,  0.1025],
         [ 0.0249, -0.0382,  0.2046,  ...,  0.0517,  0.1489,  0.1320],
         [-0.1748, -0.0254,  0.2523,  ..., -0.1474,  0.1627,  0.1325]]]), hidden_states=None, attentions=None)
tensor(0.6751, gra

In [None]:
# Load the pooler weights.
# pytorch_model.bin comes from:
# https://huggingface.co/google/vit-base-patch16-224-in21k/tree/main
# need to throw the bin file into the left side panel

# weights = torch.load('/content/pytorch_model.bin')

# dict1 = OrderedDict()
# dict1['pooler.dense.weight'] = weights['pooler.dense.weight']
# dict1['pooler.dense.bias'] = weights['pooler.dense.bias']

# torch.save(dict1, '/content/pooler_weights.pkl')