In [1]:
import clip
import torch
from torch import nn

In [7]:
# Hyperparameter Control:
depth_templates = ['This {} is {}'] 
obj_classes=['object']
depth_classes =['giant', 'extremely close', 'close','not in distance','a little remote', 'far','unseen'] 
bin_list=[1.00, 1.50, 2.00, 2.25, 2.50, 2.75, 3.00]
temperature=0.1
clip_vis = 'RN50'

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def zeroshot_classifier(depth_classes,obj_classes, templates, model):
    with torch.no_grad():
        zeroshot_weights = []
        for depth in depth_classes:
            for obj in obj_classes:
                texts = [template.format(obj,depth) for template in templates]  # format with class
                texts = clip.tokenize(texts).to(device) # tokenize
                class_embeddings = model.encode_text(texts)  # embed with text encoder
                class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
                class_embedding = class_embeddings.mean(dim=0)
                class_embedding /= class_embedding.norm()
                zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights


In [15]:
class CrossAttention(nn.Module):
    def __init__(self,batch_size,channel,high,width):
        super().__init__()
        self.batch_size = batch_size
        self.channel = channel
        self.high = high
        self.width = width
        self.layer = nn.MultiheadAttention(embed_dim=self.channel,num_heads=1,kdim=1024,vdim=1024,batch_first=True).to(device)
    def forward(self,x,txt):
        x = x.reshape(self.batch_size,self.channel,self.high*self.width).permute(0,2,1)
        txt = torch.stack([txt.permute(1,0)]*self.batch_size,dim=0)
        att_out = self.layer(x,txt,txt)[0]
        att_out = att_out.permute(0,2,1).reshape(self.batch_size,self.channel,self.high,self.width)
        return att_out



In [33]:
class ConvBlock(nn.Module):
    def __init__(self,in_channel,out_channel) -> None:
        super().__init__()
        self.layer =    nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=3,padding=1).to(device),
                        nn.ReLU().to(device),
                        nn.Conv2d(in_channels=out_channel,out_channels=out_channel,kernel_size=3,padding=1).to(device))
    def forward(self,x):
        x = self.layer(x)
        return x


In [34]:

class TestModel(nn.Module):
    def __init__(self,clip_vis):
        super().__init__()
        clip_model,_ = clip.load(clip_vis)
        self.model = clip_model.visual
        self.text_f = zeroshot_classifier(depth_classes, obj_classes, depth_templates, clip_model).to(torch.float32) # init text feature
        

        self.cross_att_list = nn.ModuleList([CrossAttention(1,2048,15,20),
                                             CrossAttention(1,1024,30,40),
                                             CrossAttention(1,512,60,80),
                                             CrossAttention(1,256,120,160)])
        self.up_layer = nn.ModuleList([nn.ConvTranspose2d(2048,1024,kernel_size=2,stride=2).to(device),
                                       nn.ConvTranspose2d(1024,512,kernel_size=2,stride=2).to(device),
                                       nn.ConvTranspose2d(512,256,kernel_size=2,stride=2).to(device),
                                       nn.ConvTranspose2d(256,128,kernel_size=2,stride=2).to(device)])
        self.conv_block = nn.ModuleList([ConvBlock(2048,1024),
                                         ConvBlock(1024,512),
                                         ConvBlock(512,256)])

        self.last_layer_depth = nn.Sequential(
            nn.Conv2d(128,64,kernel_size=3,stride=1,padding=1).to(device),
            nn.ReLU().to(device),
            nn.ConvTranspose2d(64,64,kernel_size=2,stride=2).to(device),
            nn.ReLU().to(device),
            nn.Conv2d(64,1,kernel_size=3,stride=1,padding=1).to(device)
        )

        self.max_depth = 10.
        self.text_f.requires_grad = False
        for param in self.model.parameters():
            param.requires_grad = False
        
    def compute_feature_map(self,x):
        def stem(x):
            x = self.model.relu1(self.model.bn1(self.model.conv1(x)))
            x = self.model.relu2(self.model.bn2(self.model.conv2(x)))
            x = self.model.relu3(self.model.bn3(self.model.conv3(x)))
            x = self.model.avgpool(x)
            return x

        x = x.type(self.model.conv1.weight.dtype)
        x = stem(x)

        feature_map1 = self.model.layer1(x)
        feature_map2 = self.model.layer2(feature_map1)
        feature_map3 = self.model.layer3(feature_map2)
        feature_map4 = self.model.layer4(feature_map3)
        return feature_map4.to(torch.float32),feature_map3.to(torch.float32),feature_map2.to(torch.float32),feature_map1.to(torch.float32)
        
    def forward(self,x):
        feature_map4,feature_map3,feature_map2,feature_map1 = self.compute_feature_map(x)
        
        attention1 = self.cross_att_list[0](feature_map4,self.text_f)
        output1 = self.up_layer[0](attention1)

        attention2 = self.cross_att_list[1](feature_map3,self.text_f)
        output2 = torch.cat((attention2,output1),dim=1)
        output2 = self.conv_block[0](output2)
        output2 = self.up_layer[1](output2)

        attention3 = self.cross_att_list[2](feature_map2,self.text_f)
        output3 = torch.cat((attention3,output2),dim=1)
        output3 = self.conv_block[1](output3)
        output3 = self.up_layer[2](output3)

        attention4 = self.cross_att_list[3](feature_map1,self.text_f)
        output4 = torch.cat((attention4,output3),dim=1)
        output4 = self.conv_block[2](output4)
        output4 = self.up_layer[3](output4)

        
        output = self.last_layer_depth(output4)
        output = torch.sigmoid(output)*self.max_depth
        return output    


In [35]:
model = TestModel("RN50")

In [37]:
for name,param in model.named_parameters():
    if param.requires_grad:
        pass

In [39]:
x = torch.rand(1,3,480,640).to(device)
x = model(x)
x.shape

torch.Size([1, 1, 480, 640])

In [5]:
import torch
import torch.nn as nn
pred = torch.randn((1,1,480,640))
print("before interpolation: ",pred.shape)
pred = nn.functional.interpolate(pred,(480,640),mode="bilinear",align_corners=True)
print("after interpolation: ",pred.shape)

before interpolation:  torch.Size([1, 1, 480, 640])
after interpolation:  torch.Size([1, 1, 480, 640])
