# Setup

In [2]:
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig

# Load the Models

In [11]:
class CFG1:
    sup_model = "../../input/stage-1-all-minilm-l6-v2/all-MiniLM-L6-v2-exp_fold0_epochs10"
    sup_model_tuned = "../../input/sentence-transformers-all-minilm-l6-v2-fold0-42/sentence-transformers-all-MiniLM-L6-v2_fold0_42.pth"
    sup_tokenizer = AutoTokenizer.from_pretrained(sup_model + '/tokenizer')
    pooling = "mean"
    batch_size = 120
    gradient_checkpointing = False
    add_with_best_prob = False
    
class CFG2:
    sup_model = "../../input/paraphrasemultilingualmpnetbasev2-origin2/paraphrasemultilingualmpnetbasev2-origin"
    sup_model_tuned = "../../input/paraphrase-multilingual-mpnet-base-v2-reranker/model-paraphrase-multilingual-mpnet-base-v2-tuned_0.4747.pth"
    sup_tokenizer = AutoTokenizer.from_pretrained(sup_model + '/tokenizer')
    pooling = "mean"
    batch_size = 120
    gradient_checkpointing = False
    add_with_best_prob = True  

HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '../../input/paraphrasemultilingualmpnetbasev2-origin2/paraphrasemultilingualmpnetbasev2-origin/tokenizer'. Use `repo_type` argument if needed.

In [3]:
class custom_model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.config = AutoConfig.from_pretrained(cfg.sup_model + '/config', output_hidden_states = True)
        self.config.hidden_dropout = 0.0
        self.config.hidden_dropout_prob = 0.0
        self.config.attention_dropout = 0.0
        self.config.attention_probs_dropout_prob = 0.0
        self.model = AutoModel.from_pretrained(cfg.sup_model + '/model', config = self.config)
        #self.pool = MeanPooling()
        if self.cfg.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()
        if CFG.pooling == 'mean' or CFG.pooling == "ConcatPool":
            self.pool = MeanPooling()
        elif CFG.pooling == 'max':
            self.pool = MaxPooling()
        elif CFG.pooling == 'min':
            self.pool = MinPooling()
        elif CFG.pooling == 'attention':
            self.pool = AttentionPooling(self.config.hidden_size)
        elif CFG.pooling == "WLP":
            self.pool = WeightedLayerPooling(self.config.num_hidden_layers, layer_start=6)
        
        if CFG.pooling == "ConcatPool":
            self.fc = nn.Linear(self.config.hidden_size*4, 1)  
        else:
            self.fc = nn.Linear(self.config.hidden_size, 1)
        #self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    def feature(self, inputs):
        outputs = self.model(**inputs)
        
        if CFG.pooling == "WLP":
            last_hidden_state = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
            tmp = {
                'all_layer_embeddings': last_hidden_state.hidden_states
            }
            feature = self.pool(tmp)['token_embeddings'][:, 0]
            
        elif CFG.pooling == "ConcatPool":
            last_hidden_state = torch.stack(self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).hidden_states)

            p1 = self.pool(last_hidden_state[-1], inputs['attention_mask'])
            p2 = self.pool(last_hidden_state[-2], inputs['attention_mask'])
            p3 = self.pool(last_hidden_state[-3], inputs['attention_mask'])
            p4 = self.pool(last_hidden_state[-4], inputs['attention_mask'])

            feature = torch.cat(
                (p1, p2, p3, p4),-1
            )
               
        else:
            last_hidden_state = outputs.last_hidden_state
            feature = self.pool(last_hidden_state, inputs['attention_mask'])
        
        #last_hidden_state = outputs.last_hidden_state
        #feature = self.pool(last_hidden_state, inputs['attention_mask'])
        return feature
    def forward(self, inputs):
        feature = self.feature(inputs)
        output = self.fc(feature)
        return output

# Inference

# Save the Embeddings