In [9]:
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
from Clip.configs import CFG
from torch import nn
import torch
from dotenv import load_dotenv
load_dotenv()
config_info = CFG().get_config()


class TextEncoder(nn.Module):
    def __init__(self, model_name=config_info.text_encoder_model, pretrained=config_info.pretrained, trainable=config_info.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state
        # return last_hidden_state[:, self.target_token_idx, :]
    
# Assuming config_info is defined somewhere with appropriate values


# Create an instance of the TextEncoder
text_encoder = TextEncoder()

# Tokenize an example text
tokenizer = DistilBertTokenizer.from_pretrained(config_info.text_encoder_model)
text = "This is a sample text to encode."
input_ids = tokenizer.encode(text, return_tensors="pt")

# Generate attention mask
attention_mask = torch.ones(input_ids.shape)

# Obtain the output from the TextEncoder
output = text_encoder(input_ids=input_ids, attention_mask=attention_mask)

print("Input IDs:", input_ids.shape)
print("Attention Mask:", attention_mask.shape)
print("Output Shape:", output.shape)


Input IDs: torch.Size([1, 11])
Attention Mask: torch.Size([1, 11])
Output Shape: torch.Size([1, 11, 768])


In [10]:
import timm


ModuleNotFoundError: No module named 'timm'