## Setups

In [9]:
import os
import pandas as pd
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from transformers import AutoTokenizer, AutoModel, AutoConfig

from scipy.spatial.distance import cosine
from sklearn.neighbors import NearestNeighbors

In [10]:
INPUT_DIR = '/root/autodl-tmp/data/k12/cv_split_new/train/fold_3'
OUTPUT_PATH = os.path.join(INPUT_DIR, 'sample')

TOPIC_DIR = os.path.join(INPUT_DIR, 'topics.csv')
CONTENT_DIR = os.path.join(INPUT_DIR, 'content.csv')
# CORR_DIR = os.path.join(INPUT_DIR, 'sample_submission.csv')
CORR_DIR = os.path.join(INPUT_DIR, 'correlations.csv')

MODEL_DIR = '/root/autodl-nas/model/f3r2/checkpoint-9360'
TOKENIZER_DIR = '/root/autodl-nas/model/sentence-transformers/all-MiniLM-L6-v2_new_r1.1'

N_NEIGHBOR = 50

## Data Preparation

In [11]:
def get_topic_field(d):
    title = list(filter(lambda x: pd.notna(x), d['title_level']))
    title = ' of '.join(title[-1::-1])
    title = 'No information' if title=='' else title
    title = '[TITLE] ' + title + '. '
    description = d['description'] if pd.notna(d['description']) else 'No information'
    description = '[DESCRIPTION]' + description + '. '
    field = title + description
    return field

def get_content_field(d):
    title = d['title']
    title = 'No information' if pd.isna(title) else title
    title = '[TITLE] ' + title + '. '
    description = d['description'] if pd.notna(d['description']) else 'No information'
    description = '[DESCRIPTION]' + description + '. '
    kind = '[' + d['kind'] + '] '
    field = kind + title + description
    return field

In [12]:
class DataPreparation:
    
    def __init__(self, topic_path, content_path, submission_path):
        self.topic = pd.read_csv(topic_path)
        self.content = pd.read_csv(content_path)
        self.corr = pd.read_csv(submission_path)
        self.match_dict = None
    
    def prepare_topic(self):
        df_level = self._get_level_features(self.topic)
        self.topic = self.topic.merge(df_level, on='id', how='inner')
        self.topic['field'] = self.topic.apply(lambda x: get_topic_field(x), axis=1)
        return self.topic
    
    def prepare_content(self):
        self.content['field'] = self.content.apply(lambda x: get_content_field(x), axis=1)
        return self.content
    
    def prepare_language_match(self):
        topic = self.topic[['id', 'language']].merge(self.corr, left_on='id', right_on='topic_id', how='right')[['id', 'language']]
        match_dict = {}
        for language in topic['language'].unique():
            match_dict[language] = (topic.query('language==@language')[['id']], self.content.query('language==@language')[['id']])
        self.match_dict = match_dict
        return match_dict
    
    
    def _get_level_features(self, df_topic, level_cols=['title']):
        cols = list(set(level_cols + ['id', 'parent', 'level', 'has_content']))
        df_hier = df_topic[cols]
        
        highest_level = df_hier['level'].max()
        print(f'Highest Level: {highest_level}')

        df_level = df_hier.query('level == 0').copy(deep=True)
        level_list = list()
        for col in level_cols:
            df_level[f'{col}_level'] = df_level[f'{col}'].apply(lambda x: [x])

        for i in tqdm(range(highest_level + 1)):
            level_list.append(df_level[df_level['has_content']])
            df_level_high = df_hier.query('level == @i+1')
            df_level = df_level_high.merge(df_level, left_on='parent', right_on='id', suffixes=['', '_parent'], how='inner')
            for col in level_cols:
                df_level[f'{col}_level'] = df_level[f'{col}_level'] + df_level[f'{col}'].apply(lambda x: [x])
            for col in df_level.columns:
                if col.endswith('_parent'):
                    df_level.drop(columns=col, inplace=True)
        df = pd.concat(level_list).reset_index(drop=True)
        return df[set(['id'] + [f'{col}_level' for col in level_cols])]
    
    def prepare(self):
        self.prepare_topic()
        self.prepare_content()
        self.prepare_language_match()

In [13]:
%%time
dp = DataPreparation(TOPIC_DIR, CONTENT_DIR, CORR_DIR)
dp.prepare()

Highest Level: 10


  0%|          | 0/11 [00:00<?, ?it/s]

  return df[set(['id'] + [f'{col}_level' for col in level_cols])]


CPU times: user 13.6 s, sys: 1.31 s, total: 14.9 s
Wall time: 14.9 s


In [14]:
class PlainDataset(Dataset):

    def __init__(self, df, tokenizer, label_name="") -> None:
        super().__init__()
        self.data = df[label_name].tolist()
        self.tokenizer = tokenizer

    def __getitem__(self, index):
        text = self.data[index]
        inputs = self.tokenizer(
                text, 
                add_special_tokens = True,
                truncation='longest_first',
                max_length = 192,
                padding = 'max_length',
                return_attention_mask = True,
                return_tensors = 'pt',
        )
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        return inputs

    def __len__(self):
        return len(self.data)

## Retrival

In [15]:
class Retrieval():
    
    def __init__(self, model_path, tokenizer_path, dp):
        self.model = AutoModel.from_pretrained(model_path).cuda()
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        
        self.topic = dp.topic
        self.content = dp.content
        self.corr = dp.corr
        self.topic_content_match = dp.match_dict
    
    def convert2embed(self, df, label_name='field'):
        embed: list = []
        dataset = PlainDataset(df, tokenizer=self.tokenizer, label_name=label_name)
        dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=32)
        for batch in dataloader:
            batch = {k: v.cuda() for k, v in batch.items()}
            with torch.no_grad():
                embeddings = self.model(**batch, output_hidden_states=True, return_dict=True).pooler_output
                embed.append(embeddings.cpu().clone().detach().numpy())
        embed = np.concatenate(embed, axis=0)
        return embed
    
    def get_embed(self):
        for lang in self.topic_content_match.keys():
            topic_, content_ = self.topic_content_match[lang]
            topic_ = topic_[['id']].merge(self.topic[['id', 'field']], on='id', how='left')
            content_ = content_[['id']].merge(self.content[['id', 'field']], on='id', how='left')
            
            topic_path = os.path.join(OUTPUT_PATH, f"topic_{lang}.pqt")
            content_path = os.path.join(OUTPUT_PATH, f"content_{lang}.pqt")
            topic_.to_parquet(topic_path)
            content_.to_parquet(content_path)
            
            for t in ["content", "topic"]:
                path = os.path.join(OUTPUT_PATH, f"{t}_{lang}.pqt")
                df = pd.read_parquet(path)
                embed = self.convert2embed(df, label_name="field")
                np.save(path.replace(".pqt", ".npy"), embed)
                
    def inference(self):
        recall_amount = 0
        recall_amount_total = 0
        recall_num = 0
        recall_total = {}
        
        df_pred_list = []
        for lang in self.topic_content_match.keys():
            # global df_pred, df_correlations
            content_path = os.path.join(OUTPUT_PATH, f"content_{lang}.npy")
            topics_path = os.path.join(OUTPUT_PATH, f"topic_{lang}.npy")
            content_array = np.load(content_path)
            topics_array = np.load(topics_path)
            
            model = NearestNeighbors(n_neighbors=N_NEIGHBOR, metric="cosine")
            model.fit(content_array)
            d, r = model.kneighbors(topics_array)
            df_content = pd.read_parquet(content_path.replace(".npy", ".pqt"))
            df_topics = pd.read_parquet(topics_path.replace(".npy", ".pqt"))
            df_correlations = self.corr

            pred = {"topic_id": [], "content_ids": []}
            for i in range(len(df_topics)):
                r_t = r[i]
                tmp = []
                for c in r_t:
                    content_id = df_content.iloc[c]["id"]
                    tmp.append(content_id)
                topics_id = df_topics.iloc[i]["id"]
                pred["topic_id"].append(topics_id)
                pred["content_ids"].append(tmp)

            df_pred = pd.DataFrame(pred).astype({"topic_id": str})
            df_pred_list.append(df_pred)
        df_pred = pd.concat(df_pred_list)
        self.df_pred = df_pred
        self.df_pred['content_ids'] = self.df_pred.apply(lambda x: ' '.join(x['content_ids']), axis=1)
        
    def save_pred(self, path='submission.csv'):
        self.df_pred.to_csv(path, index=None)
                
    def retrieval(self):
        self.convert2embed(self.topic[['id', 'field']])
        self.get_embed()
        self.inference()
        self.save_pred()
        
    def sample_negative(self):
        self.convert2embed(self.topic[['id', 'field']])
        self.get_embed()
        self.inference()
        self.save_pred(os.path.join(OUTPUT_PATH, f'f2r2_top{N_NEIGHBOR}.csv'))

In [16]:
%%time
s1 = Retrieval(MODEL_DIR, TOKENIZER_DIR, dp)
s1.sample_negative()

Some weights of the model checkpoint at /root/autodl-nas/model/f3r2/checkpoint-9360 were not used when initializing BertModel: ['lm_head.transform.dense.bias', 'mlp.dense.bias', 'mlp.dense.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.transform.LayerNorm.bias', 'lm_head.transform.LayerNorm.weight', 'lm_head.transform.dense.weight', 'lm_head.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at /root/autodl-nas/model/f3r2/checkpoint-9360 and are newly initialized: ['bert.pooler.dense.bias', '

KeyboardInterrupt: 