In [25]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np

## config

In [26]:
debug = True
image_path = "" # to set
captions_path = ""# to set
batch_size = 8
num_workers = 0
lr = 1e-3
weight_decay = 1e-3
patience = 2
factor = 0.5
epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
text_embedding = 768
text_tokenizer = "distilbert-base-uncased"
max_length = 200

pretrained = False # for both image encoder and text encoder
trainable = False # for both image encoder and text encoder
temperature = 1.0

# image size
size = 224

# for projection head; used for both image and text encoders
num_projection_layers = 1
projection_dim = 256 
dropout = 0.1

## TextEncoder

In [48]:
from transformers import DistilBertModel, DistilBertConfig

class TextEncoder(nn.Module):
    def __init__(self, model_name=text_encoder_model, pretrained=pretrained, trainable=trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for param in self.model.parameters():
            param.requires_grad = trainable # False

        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

projection head implementation

In [29]:
class QuickGELU(nn.Module): # instead of GELU
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)
        

class QuickLayerNorm(nn.LayerNorm): # instead of LayerNorm
    '''for fp16'''
    def forward(self, x: torch.Tensor):
        original_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(original_type)
    

class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim=projection_dim, dropout=dropout):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = QuickGELU() #nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = QuickLayerNorm(projection_dim)#nn.LayerNorm(projection_dim)

    def forward(self, x: torch.Tensor):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

## ImageEncoder

In [32]:
class ImageEncoder(nn.Module):
    def __init__(self, model_name=model_name, pretrained=pretrained, trainable=trainable):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained, num_classes=0, global_pool="avg")
        for param in self.model.parameters():
            param.requires_grad = trainable # False

    def forward(self, x: torch.Tensor):
        return self.model(x)


In [34]:
im_m = ImageEncoder()
x  = torch.randn(1, 3, 224, 224)
im_m(x).shape

torch.Size([1, 2048])

In [54]:
text_m = TextEncoder()
y_ids = torch.randint(5, 300, size=(8, 25))
attention_mask = torch.ones(8, 25)
text_m(y_ids, attention_mask).shape

torch.Size([8, 25, 768])

# CBL

## ClassificationLayer

In [75]:
class ConceptClassifier(nn.Module):
    def __init__(self,   embedding_dim, projection_dim, num_classes: int):
        super().__init__()
        self.text_encoder_model = DistilBertModel.from_pretrained(text_encoder_model)
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.classifier_layer = nn.Linear(projection_dim, num_classes) # , sparse=True
    
    def forward(self, input_ids, attention_mask):
        hidden_states = self.text_encoder_model(input_ids=input_ids, attention_mask=attention_mask)[0]
        pooled_output = hidden_states[:, 0]
        projection_output = self.projection(pooled_output)
        logits = self.classifier_layer(projection_output)
        return logits

In [77]:
'''test'''
cls_model = ConceptClassifier(768, 256, num_classes=10)
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
attention_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
logits = cls_model(input_ids, attention_mask)
print(logits.shape) # should output torch.Size([2, 10])

torch.Size([2, 10])


## CLIP

initialize our loss functions

In [95]:
def cross_entropy_loss(preds, targets, reduction="none"): # take care of loss normalization by ourselves
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

In [120]:
class CLIP_model(nn.Module):
    def __init__(self, temperature=temperature, image_embedding=image_embedding, text_embedding=text_embedding):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature
        #self.classify = ConceptClassifier() #

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) #

    def forward(self, batch):
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])

        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)
        # at this step normalization is not necessary because of layer norm in ProjectionHead
        # this part will be like in a CLIP trustworthy
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * (image_embeddings @ text_embeddings.T) / self.temperature
        logits_per_text = logits_per_image.T
        #images_similarity = image_embeddings @ image_embeddings.T
        #texts_similarity = text_embeddings @ text_embeddings.T
        #targets = 
        return logits_per_image, logits_per_text

In [117]:
# можно пока его поробовать
class ConceptClassifierTest(nn.Module):
    def __init__(self, embedding_dim, projection_dim, num_classes: int):
        super().__init__()
        self.text_encoder_model = DistilBertModel.from_pretrained(text_encoder_model)
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.classifier_layer = nn.Embedding(num_classes, projection_dim, sparse=True)
    
    def forward(self, input_ids, attention_mask):
        hidden_states = self.text_encoder_model(input_ids=input_ids, attention_mask=attention_mask)[0]
        pooled_output = hidden_states[:, 0]
        projection_output = self.projection(pooled_output)
        logits = self.classifier_layer.weight.matmul(projection_output.T)
        return logits
    
'''test'''
cls_model = ConceptClassifierTest(768, 256, num_classes=10)
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
attention_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
logits = cls_model(input_ids, attention_mask)
print(logits.shape) # should output torch.Size([2, 10])

torch.Size([10, 2])


## CBM

In [None]:
def load():

def get_target_model(target_name, device):
    target_name = target_name[5:]
    model, preprocess = clip.load(target_name, device=device)
    target_model = lambda x: model.image_encoder(x).float()

In [None]:
# https://github.com/Trustworthy-ML-Lab/Label-free-CBM/blob/main/cbm.py#L37
class CBM(nn.Module):
    def __init__(self, backbone_name, W_c, W_g, b_g, proj_mean, proj_std, device="cuda"):
        super().__init__()
        model, _ = get_target_model(backbone_name, device) # дописать
        self.proj_layer = nn.Linear(in_features=W_c.shape[1], out_features=W_c.shape[0], bias=False).to(device)
        self.proj_layer.load_state_dict({"weight":W_c})

        self.proj_mean = proj_mean
        self.proj_std = proj_std

        self.final = nn.Linear(in_features=W_g.shape[1], out_features=W_g.shape[0], bias=False).to(device)
        self.final.load_state_dict({"weight":W_g, "bias":b_g})
        self.concepts = None

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.proj_layer(x)
        proj_c = (x - self.proj_mean) / self.proj_std
        x = self.final(proj_c)
        return x, proj_c
    # разобрать load_state_dict, разобрать и написать get_target model и load и как их реализовать в моем случае