In [2]:
import os
import gc
import psutil
from pathlib import Path

import pandas as pd
import numpy as np
import random
pd.set_option('display.max_rows', 100)
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_cosine_schedule_with_warmup, DataCollatorWithPadding

device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

In [3]:
class PATH:
    input_dir = '/root/autodl-nas/data/k12/cv_data/fold_0'
    output_dir = '/root/autodl-nas/data/k12/out'
    cv_dir = '/root/autodl-nas/data/k12/cv_data'
    content_dir = os.path.join(input_dir, 'content.csv')
    correlation_dir = os.path.join(input_dir, 'correlations.csv')
    submission_dir = os.path.join(input_dir, 'sample_submission.csv')
    topic_dir = os.path.join(input_dir, 'topics.csv')
    
class CFG:
    fold = 0

## Text Field Features

In [4]:
df_content = pd.read_parquet(os.path.join(PATH.output_dir, 'content_field.pqt'))
df_topic = pd.read_parquet(os.path.join(PATH.output_dir, 'topic_field.pqt'))

In [5]:
df_content

Unnamed: 0,id,field
0,c_00002381196d,[video] [TITLE] Sumar números de varios dígito...
1,c_000087304a9e,[video] [TITLE] Trovare i fattori di un numero...
2,c_0000ad142ddb,[video] [TITLE] Sumar curvas de demanda. [DESC...
3,c_0000c03adc8d,[document] [TITLE] Nado de aproximação. [DESCR...
4,c_00016694ea2a,[document] [TITLE] geometry-m3-topic-a-overvie...
...,...,...
154042,c_fffcbdd4de8b,[html5] [TITLE] 2. 12: Diffusion. [DESCRIPTION...
154043,c_fffe15a2d069,[video] [TITLE] Sommare facendo gruppi da 10. ...
154044,c_fffed7b0d13a,[video] [TITLE] Introdução à subtração. [DESCR...
154045,c_ffff04ba7ac7,[video] [TITLE] SA of a Cone. [DESCRIPTION]No ...


In [6]:
correlations = pd.read_csv(PATH.correlation_dir)

In [7]:
correlations['content_id'] = correlations['content_ids'].apply(lambda x: x.split())
correlations = correlations.explode('content_id').drop(columns='content_ids').reset_index(drop=True)

In [8]:
correlations

Unnamed: 0,topic_id,content_id
0,t_0008a1bd84ba,c_7ff92a954a3d
1,t_0008a1bd84ba,c_8790b074383e
2,t_000d1fb3f2f5,c_07f1d0eec4b2
3,t_000d1fb3f2f5,c_15a6fb858696
4,t_000d1fb3f2f5,c_175e9db3fc44
...,...,...
246471,t_fff830472691,c_61fb63326e5d
246472,t_fff830472691,c_8f224e321c87
246473,t_fffbe1d5d43c,c_46f852a49c08
246474,t_fffbe1d5d43c,c_6659207b25d5


In [9]:
df_train = correlations.merge(df_topic, left_on='topic_id', right_on='id').drop(columns='id')
df_train = df_train.merge(df_content, left_on='content_id', right_on='id', suffixes=['_topic', '_content']).drop(columns='id')
df_train

Unnamed: 0,topic_id,content_id,field_topic,field_content
0,t_0008a1bd84ba,c_7ff92a954a3d,[TITLE] 12. 20: Bird Reproduction of 12: Verte...,[html5] [TITLE] 12. 20: Bird Reproduction. [DE...
1,t_0008a1bd84ba,c_8790b074383e,[TITLE] 12. 20: Bird Reproduction of 12: Verte...,[video] [TITLE] Astounding Mating Dance Birds ...
2,t_000d1fb3f2f5,c_07f1d0eec4b2,[TITLE] 2.1.2 - Logarithms of 2.1 - Exponents ...,[video] [TITLE] Proof of the logarithm change ...
3,t_b1b5bcc80a6a,c_07f1d0eec4b2,[TITLE] Change of base formula for logarithms ...,[video] [TITLE] Proof of the logarithm change ...
4,t_b6cd7dbc622c,c_07f1d0eec4b2,[TITLE] The change of base formula for logarit...,[video] [TITLE] Proof of the logarithm change ...
...,...,...,...,...
246471,t_fff1047917af,c_e6b95de6962f,[TITLE] ਅਧਿਆਪਕਾਂ ਲਈ of ਦੇਖੋ ਅਤੇ ਕਰੋ of 3-6 yea...,[video] [TITLE] ਰੋਲ ਪਲੇ (ਨਾਟਕ). [DESCRIPTION]s...
246472,t_fff1047917af,c_f59987cf8a75,[TITLE] ਅਧਿਆਪਕਾਂ ਲਈ of ਦੇਖੋ ਅਤੇ ਕਰੋ of 3-6 yea...,[video] [TITLE] ਤਾੜੀ ਵਜਾਉਣਾ. [DESCRIPTION]sour...
246473,t_fff1047917af,c_fc1eca95e2f3,[TITLE] ਅਧਿਆਪਕਾਂ ਲਈ of ਦੇਖੋ ਅਤੇ ਕਰੋ of 3-6 yea...,[video] [TITLE] ਹੋਮ ਵਿਜ਼ਿਟ. [DESCRIPTION]source...
246474,t_fff5d93d4dc2,c_79903740e1e8,[TITLE] Discriminação de preços of Decisões de...,[video] [TITLE] Discriminação de preços. [DESC...


In [10]:
df_train.to_parquet(os.path.join(PATH.output_dir, 'retrieval', f'retrieval_{CFG.fold}.pqt'))

In [14]:
df_train[['field_topic', 'field_content']].to_parquet(os.path.join(PATH.output_dir, 'retrieval', f'retrieval_field_{CFG.fold}.pqt'))
df_train[['field_topic', 'field_content']].to_csv(os.path.join(PATH.output_dir, 'retrieval', f'retrieval_field_{CFG.fold}.csv'), index=None)


In [15]:
pd.read_csv(os.path.join(PATH.output_dir, 'retrieval', f'retrieval_field_{CFG.fold}.csv'))

Unnamed: 0,field_topic,field_content
0,[TITLE] 12. 20: Bird Reproduction of 12: Verte...,[html5] [TITLE] 12. 20: Bird Reproduction. [DE...
1,[TITLE] 12. 20: Bird Reproduction of 12: Verte...,[video] [TITLE] Astounding Mating Dance Birds ...
2,[TITLE] 2.1.2 - Logarithms of 2.1 - Exponents ...,[video] [TITLE] Proof of the logarithm change ...
3,[TITLE] Change of base formula for logarithms ...,[video] [TITLE] Proof of the logarithm change ...
4,[TITLE] The change of base formula for logarit...,[video] [TITLE] Proof of the logarithm change ...
...,...,...
246471,[TITLE] ਅਧਿਆਪਕਾਂ ਲਈ of ਦੇਖੋ ਅਤੇ ਕਰੋ of 3-6 yea...,[video] [TITLE] ਰੋਲ ਪਲੇ (ਨਾਟਕ). [DESCRIPTION]s...
246472,[TITLE] ਅਧਿਆਪਕਾਂ ਲਈ of ਦੇਖੋ ਅਤੇ ਕਰੋ of 3-6 yea...,[video] [TITLE] ਤਾੜੀ ਵਜਾਉਣਾ. [DESCRIPTION]sour...
246473,[TITLE] ਅਧਿਆਪਕਾਂ ਲਈ of ਦੇਖੋ ਅਤੇ ਕਰੋ of 3-6 yea...,[video] [TITLE] ਹੋਮ ਵਿਜ਼ਿਟ. [DESCRIPTION]source...
246474,[TITLE] Discriminação de preços of Decisões de...,[video] [TITLE] Discriminação de preços. [DESC...
