In [5]:
from model import ClipResnet
import torch
from torch import nn
import torch.nn.functional as F


model = ClipResnet()

ex = model(torch.randn(1,3,224,224).to(0),torch.randn(1,3,224,224).to(0))


In [4]:
for x in ex:
    print(x.shape)

torch.Size([1, 256, 56, 56])
torch.Size([1, 512, 28, 28])
torch.Size([1, 1024, 14, 14])
torch.Size([1, 2048, 7, 7])


In [14]:
def bilinear_upsample(x, target_shape):
    return F.interpolate(x, size = target_shape, mode = 'bilinear')

upsampled = torch.cat([bilinear_upsample(i, ex[0].shape[-1]) for i in ex], dim = 1)

In [15]:
upsampled.shape

torch.Size([1, 3840, 56, 56])

In [57]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            self._make_block(3840, 2048,3, padding=1, stride = 2),
            self._make_block(2048, 1024,3, padding=1, stride = 2),
            self._make_block(1024, 512,3,padding=1),
            self._make_block(512,1,6,padding=1)
        )

    def _make_block(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        )
    
    def forward(self, x):
        return self.decoder(x)


decoder = Decoder().to(0)
# decoder(torch.rand(1,3840,7,7).to(0)).shape
decoder(upsampled.to(torch.float32)).shape

torch.Size([1, 1, 224, 224])

In [1]:
import clip
from PIL import Image
import torch

model, preprocess = clip.load('ViT-B/16', device = f'cuda:{0}')


img = Image.open('test_image.png')
img = preprocess(img).unsqueeze(0).to(0)


In [2]:
model.encode_image(img).shape

torch.Size([1, 512])

In [3]:
model.visual(img.to(torch.float16)).shape

torch.Size([1, 512])

In [4]:
from torchsummary import summary

mod = model.visual.to(torch.float32)
mod.proj = None
summary(mod, (3,224,224));

Layer (type:depth-idx)                        Output Shape              Param #
├─Conv2d: 1-1                                 [-1, 768, 14, 14]         589,824
├─LayerNorm: 1-2                              [-1, 197, 768]            1,536
├─Transformer: 1-3                            [-1, 2, 768]              --
|    └─Sequential: 2-1                        [-1, 2, 768]              --
|    |    └─ResidualAttentionBlock: 3-1       [-1, 2, 768]              7,087,872
|    |    └─ResidualAttentionBlock: 3-2       [-1, 2, 768]              7,087,872
|    |    └─ResidualAttentionBlock: 3-3       [-1, 2, 768]              7,087,872
|    |    └─ResidualAttentionBlock: 3-4       [-1, 2, 768]              7,087,872
|    |    └─ResidualAttentionBlock: 3-5       [-1, 2, 768]              7,087,872
|    |    └─ResidualAttentionBlock: 3-6       [-1, 2, 768]              7,087,872
|    |    └─ResidualAttentionBlock: 3-7       [-1, 2, 768]              7,087,872
|    |    └─ResidualAttentionBlock: 3-

In [5]:

feats = []
def hook(module, input, output):
    feats.append(output)

# ids = [m[1].register_forward_hook(hook) for m in mod.named_modules()]

id = mod.transformer.register_forward_hook(hook)
_ = mod(torch.randn(2,3,224,224).to(0))
id.remove()
# for id in ids:
#     id.remove()

In [12]:
out = feats[0]

out.transpose(0,1)[:,1:].transpose(1,2).reshape(2,768,14,14)

torch.Size([2, 768, 14, 14])

In [12]:
import clip 
import torch

model, preprocess = clip.load('ViT-B/16', device = f'cuda:{0}')

outs = model.visual(torch.rand(1,3,224,224).to(0).to(torch.float16))

In [9]:
model.visual.forward

<function __main__.forward(self, x: torch.Tensor)>

In [8]:
outs = model.visual.forward(torch.rand(1,3,224,224).to(0).to(torch.float16))
outs.shape

TypeError: forward() missing 1 required positional argument: 'x'

In [3]:
model.encode_image(torch.rand(1,3,224,224).to(0).to(torch.float16)).shape

torch.Size([1, 512])

In [3]:
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(0)
text.shape

torch.Size([3, 77])

In [4]:
model.encode_text(clip.tokenize(["a diagram", "a dog", "a cat"]).to(0)).shape

torch.Size([3, 512])

In [5]:
feats = []
feats.append(model.encode_image(torch.rand(1,3,224,224).to(0).to(torch.float16)))
feats.append(model.encode_text(clip.tokenize(["a diagram", "a dog", "a cat"]).to(0)))

### Seeing the forward function in Zegclip

In [1]:
from models import CLIPViT
import torch
import math 

outs = CLIPViT().get_features(torch.rand(1,3,224,224).to(0).to(torch.float16),layers = [11])

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([1, 512])
torch.Size([1, 512, 14, 14])
1


In [2]:
import torch

def d3_to_d4(t):
    n, hw, c = t.size()
    if hw % 2 != 0:
        t = t[:, 1:]
    h = w = int(math.sqrt(hw))
    return t.transpose(1, 2).reshape(n, c, h, w)

def d4_to_d3(t):
    return t.flatten(-2).transpose(-1, -2)

inputs_both = (outs, torch.rand(3,512).to(0))

In [3]:
inputs = inputs_both[0][0]
cls_token = inputs_both[0][1]
text_token = inputs_both[1]

len(inputs), cls_token.shape, text_token.shape

(2, torch.Size([1, 512]), torch.Size([3, 512]))

In [5]:
inputs[1].shape

torch.Size([1, 512, 14, 14])

In [13]:
x = []
for stage_ in inputs[:1]:
    x.append(d4_to_d3(stage_) if stage_.dim() > 3 else stage_)
x.reverse()

In [14]:
def get_qs(q, cls):
    # q = [q.cls, q]
    C, dim = q.shape
    bs, _ = cls.shape
    q = q.expand(bs, -1, -1)
    q1 = torch.einsum("bd,bcd->bcd", cls, q)
    q_ = torch.concat((q1, q), dim=-1)
    return q_

get_qs(text_token, cls_token).shape

torch.Size([1, 3, 1024])

In [6]:
text_token.expand(1,-1,-1).shape

torch.Size([1, 3, 512])

In [26]:
outs[0][1].shape

torch.Size([1, 512, 14, 14])

In [1]:
from models import SegDecoder
import torch 

inputs_both = [
    [(torch.rand(1,768,14,14).to(0),torch.rand(1,512,14,14).to(0)),torch.rand(1,512).to(0)],
    torch.rand(3,512).to(0)
]

decoder = SegDecoder(224,768,[0,1],[0,1]).to(0)

decoder(inputs_both)

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([1, 3, 224, 224])
