In [1]:
import json
import torch
import pandas as pd
import numpy as np
from copy import deepcopy
from tqdm import tqdm
from collections import defaultdict

In [33]:
lv1_key2idx = json.load(open("./lv1_key2idx.json"))
lv2_key2idx = json.load(open("./lv2_key2idx.json"))
lv3_key2idx = json.load(open("./lv3_key2idx.json"))
lv4_key2idx = json.load(open("./lv4_key2idx.json"))
lv1_label_idx = json.load(open("./lv1_label_idx.json"))
unscp_lv1 = json.load(open("./new_unscp_lv1.json"))
unscp_lv1idx2lv1idx = json.load(open("./unscp_lv1idx2lv1idx.json"))

lv1idx2lv5idx = json.load(open("./lv1idx2lv5idx.json"))
lv5_label_idx = json.load(open("./lv5_label_idx.json"))
lv5_idx2lv1_idx = json.load(open("./lv5_idx2lv1_idx.json"))
lv5_idx2lv2_idx = json.load(open("lv5_idx2lv2_idx.json"))
lv5_idx2lv3_idx = json.load(open("lv5_idx2lv3_idx.json"))
lv5_idx2lv4_idx = json.load(open("lv5_idx2lv4_idx.json"))

In [34]:
lv1_idx2key = {}
for k, v in lv1_key2idx.items():
    lv1_idx2key[v]=k
    
lv2_idx2key = {}
for k, v in lv2_key2idx.items():
    lv2_idx2key[v]=k
    
lv3_idx2key = {}
for k, v in lv3_key2idx.items():
    lv3_idx2key[v]=k
    
lv4_idx2key = {}
for k, v in lv4_key2idx.items():
    lv4_idx2key[v]=k

In [8]:
cat_emb_2048_eos = torch.load("../../../kisti_output/emb_zodal_cat_max_leng_2048_eos.pt", map_location='cpu')
cat_emb_512_eos = torch.load("../../../kisti_output/emb_zodal_cat_max_leng_512_eos.pt", map_location='cpu')
cat_emb_2048_masked = torch.load("../../../kisti_output/emb_zodal_cat_max_leng_2048_masked.pt", map_location='cpu')
cat_emb_512_masked = torch.load("../../../kisti_output/emb_zodal_cat_max_leng_512_masked.pt", map_location='cpu')

In [9]:
len(cat_emb_2048_eos), len(cat_emb_2048_masked), len(cat_emb_512_eos), len(cat_emb_512_masked)

(21829, 21829, 21829, 21829)

In [10]:
len(lv5_idx2lv2_idx), len(lv5_idx2lv3_idx), len(lv5_idx2lv4_idx)

(21829, 21829, 21829)

In [11]:
# 상위 레벨 category 각각에 속하는 Lv5 categories 를 list type으로 수집
def higher2lv5list(lv5idx2idx):
    _tmp = defaultdict(list)
    for k, v in lv5idx2idx.items():
        _tmp[v].append(int(k))
    _tmp = dict(sorted(_tmp.items(), key=lambda x: x[0]))
    return _tmp

In [12]:
lv1_idx2lv5_idx_list = higher2lv5list(lv5_idx2lv1_idx)
lv2_idx2lv5_idx_list = higher2lv5list(lv5_idx2lv2_idx)
lv3_idx2lv5_idx_list = higher2lv5list(lv5_idx2lv3_idx)
lv4_idx2lv5_idx_list = higher2lv5list(lv5_idx2lv4_idx)

### Lv 5 Embedding 을 상위 (Lv 1) 카테고리별로 군집한 후 평균 계산
---
- 상위 레벨 (Lv1)에 공통으로 속하는 Lv 5 embedding 을 평균낸 임베딩을 해당 상위 레벨을 표현하는 임베딩으로 사용.
- 상위 레벨 (Lv1)에 공통으로 속하는 Lv 5 embedding 을 tensor 로 묶어서 따로 저장.

In [17]:
cat_emb_lv1_2048_eos = []
cat_emb_lv1_2048_masked = []
cat_emb_lv1_512_eos = []
cat_emb_lv1_512_masked = []
for k, vs in lv1_idx2lv5_idx_list.items():
    idx_tensor = torch.tensor(vs)
    avg_emb_2048_eos = cat_emb_2048_eos[idx_tensor].mean(dim=0).tolist()
    avg_emb_2048_masked = cat_emb_2048_masked[idx_tensor].mean(dim=0).tolist()
    avg_emb_512_eos = cat_emb_512_eos[idx_tensor].mean(dim=0).tolist()
    avg_emb_512_masked = cat_emb_512_masked[idx_tensor].mean(dim=0).tolist()
    
    cat_emb_lv1_2048_eos.append(avg_emb_2048_eos)
    cat_emb_lv1_2048_masked.append(avg_emb_2048_masked)
    cat_emb_lv1_512_eos.append(avg_emb_512_eos)
    cat_emb_lv1_512_masked.append(avg_emb_512_masked)

cat_emb_lv1_2048_eos = torch.tensor(cat_emb_lv1_2048_eos).to(dtype=torch.float16)
cat_emb_lv1_2048_masked = torch.tensor(cat_emb_lv1_2048_masked).to(dtype=torch.float16)
cat_emb_lv1_512_eos = torch.tensor(cat_emb_lv1_512_eos).to(dtype=torch.float16)
cat_emb_lv1_512_masked = torch.tensor(cat_emb_lv1_512_masked).to(dtype=torch.float16)

In [18]:
len(lv1_idx2key.keys()), len(cat_emb_lv1_2048_eos), len(cat_emb_lv1_2048_masked), len(cat_emb_lv1_512_eos), len(cat_emb_lv1_512_masked)

(54, 54, 54, 54, 54)

In [19]:
emb_unspsc_lv1_max_leng_512_eos = torch.load("../../../kisti_output/emb_unspsc_lv1_max_leng_512_eos.pt", map_location='cpu')
emb_unspsc_lv1_max_leng_512_masked = torch.load("../../../kisti_output/emb_unspsc_lv1_max_leng_512_masked.pt", map_location='cpu')
emb_unspsc_lv1_max_leng_2048_eos = torch.load("../../../kisti_output/emb_unspsc_lv1_max_leng_2048_eos.pt", map_location='cpu')
emb_unspsc_lv1_max_leng_2048_masked = torch.load("../../../kisti_output/emb_unspsc_lv1_max_leng_2048_masked.pt", map_location='cpu')

In [20]:
emb_zodal_cat_lv1_max_leng_512_eos = torch.load("../../../kisti_output/emb_zodal_cat_lv1_max_leng_512_eos.pt", map_location='cpu')
emb_zodal_cat_lv1_max_leng_512_masked = torch.load("../../../kisti_output/emb_zodal_cat_lv1_max_leng_512_masked.pt", map_location='cpu')
emb_zodal_cat_lv1_max_leng_2048_eos = torch.load("../../../kisti_output/emb_zodal_cat_lv1_max_leng_2048_eos.pt", map_location='cpu')
emb_zodal_cat_lv1_max_leng_2048_masked = torch.load("../../../kisti_output/emb_zodal_cat_lv1_max_leng_2048_masked.pt", map_location='cpu')

In [21]:
unscp_lv1idx2lv1idx = json.load(open("unscp_lv1idx2lv1idx.json"))
zodal_lv1idx2lv1idx = json.load(open("zodal_lv1idx2lv1idx.json"))

In [22]:
def cal_sim_un(a,b):
    a = a.cuda()
    b = b.cuda()
    a_norm = torch.norm(a, dim=-1)
    b_norm = torch.norm(b, dim=-1)

        
    sim_list = []
    for k, v in unscp_lv1idx2lv1idx.items():
        sim = torch.matmul(a[int(k)], b[v].T)/(a_norm[int(k)]*b_norm[v])
        sim = sim.detach().cpu()
        sim_list.append(sim)
    return sim_list

In [23]:
# sim_avg_un = cal_sim_un(emb_unspsc_lv1_max_leng_512_eos, cat_emb_lv1_512_eos)
# sim_avg_un = torch.tensor(sim_avg_un)

In [24]:
def cal_sim_zo(a,b):
    a = a.cuda()
    b = b.cuda()
    a_norm = torch.norm(a, dim=-1)
    b_norm = torch.norm(b, dim=-1)
    
    sim_list = []
    for k, v in zodal_lv1idx2lv1idx.items():
        sim = torch.matmul(a[int(k)], b[v].T)/(a_norm[int(k)]*b_norm[v])
        sim = sim.detach().cpu()
        sim_list.append(sim)
    return sim_list

In [26]:
# sim_avg_zo = cal_sim_zo(emb_zodal_cat_lv1_max_leng_512_eos, cat_emb_lv1_512_eos)
# sim_avg_zo = torch.tensor(sim_avg_zo)

In [25]:
pro_emb_masked_2048 = torch.load("../../../kisti_output/emb_task3_pro_max_leng_2048_masked.pt")
pro_emb_masked_512 = torch.load("../../../kisti_output/emb_task3_pro_max_leng_512_masked.pt")
pro_emb_eos_2048 = torch.load("../../../kisti_output/emb_task3_pro_max_leng_2048_eos.pt")
pro_emb_eos_512 = torch.load("../../../kisti_output/emb_task3_pro_max_leng_512_eos.pt")

In [27]:
def pro_cat_sim_mat(a,b):
    a = a.cuda()
    b = b.cuda()
    a_norm = torch.norm(a, dim=-1)
    b_norm = torch.norm(b, dim=-1)

    sim_mat = []
    for i in range(len(a)):
        sim_p_pro = []
        for j in range(len(b)):
            sim = torch.matmul(a[i], b[j].T) / (a_norm[i]*b_norm[j])
            sim_p_pro.append(sim.detach().cpu())
        sim_mat.append(sim_p_pro)
    sim_mat = torch.tensor(sim_mat)
    print(sim_mat.mean())
    print(sim_mat.std())
    print(sim_mat.max())
    print(sim_mat.min())
    return sim_mat

In [28]:
pro_emb_list = [pro_emb_masked_2048, pro_emb_masked_512, pro_emb_eos_2048, pro_emb_eos_512]
zod_emb_list = [emb_zodal_cat_lv1_max_leng_2048_masked, emb_zodal_cat_lv1_max_leng_512_masked, emb_zodal_cat_lv1_max_leng_2048_eos, emb_zodal_cat_lv1_max_leng_512_eos]
uni_emb_list = [emb_unspsc_lv1_max_leng_2048_masked, emb_unspsc_lv1_max_leng_512_masked, emb_unspsc_lv1_max_leng_2048_eos, emb_unspsc_lv1_max_leng_512_eos]

In [29]:
avg_lv1_emb_list = [cat_emb_lv1_2048_masked, cat_emb_lv1_512_masked, cat_emb_lv1_2048_eos, cat_emb_lv1_512_eos]

### Lv 5 Embedding 을 상위 (Lv 2) 카테고리별로 군집한 후 평균 계산
---
- 상위 레벨 (Lv2)에 공통으로 속하는 Lv 5 embedding 을 평균낸 임베딩을 해당 상위 레벨을 표현하는 임베딩으로 사용.
- 상위 레벨 (Lv2)에 공통으로 속하는 Lv 5 embedding 을 tensor 로 묶어서 따로 저장.

In [30]:
cat_emb_lv2_2048_eos = []
cat_emb_lv2_2048_masked = []
cat_emb_lv2_512_eos = []
cat_emb_lv2_512_masked = []
for k, vs in lv2_idx2lv5_idx_list.items():
    idx_tensor = torch.tensor(vs)
    avg_emb_2048_eos = cat_emb_2048_eos[idx_tensor].mean(dim=0).tolist()
    avg_emb_2048_masked = cat_emb_2048_masked[idx_tensor].mean(dim=0).tolist()
    avg_emb_512_eos = cat_emb_512_eos[idx_tensor].mean(dim=0).tolist()
    avg_emb_512_masked = cat_emb_512_masked[idx_tensor].mean(dim=0).tolist()
    
    cat_emb_lv2_2048_eos.append(avg_emb_2048_eos)
    cat_emb_lv2_2048_masked.append(avg_emb_2048_masked)
    cat_emb_lv2_512_eos.append(avg_emb_512_eos)
    cat_emb_lv2_512_masked.append(avg_emb_512_masked)

cat_emb_lv2_2048_eos = torch.tensor(cat_emb_lv2_2048_eos).to(dtype=torch.float16)
cat_emb_lv2_2048_masked = torch.tensor(cat_emb_lv2_2048_masked).to(dtype=torch.float16)
cat_emb_lv2_512_eos = torch.tensor(cat_emb_lv2_512_eos).to(dtype=torch.float16)
cat_emb_lv2_512_masked = torch.tensor(cat_emb_lv2_512_masked).to(dtype=torch.float16)

In [35]:
len(lv2_idx2key.keys()), len(cat_emb_lv2_2048_eos), len(cat_emb_lv2_2048_masked), len(cat_emb_lv2_512_eos), len(cat_emb_lv2_512_masked)

(387, 387, 387, 387, 387)

In [36]:
avg_lv2_emb_list = [cat_emb_lv2_2048_masked, cat_emb_lv2_512_masked, cat_emb_lv2_2048_eos, cat_emb_lv2_512_eos]

### Lv 5 Embedding 을 상위 (Lv 3) 카테고리별로 군집한 후 평균 계산
---
- 상위 레벨 (Lv3)에 공통으로 속하는 Lv 5 embedding 을 평균낸 임베딩을 해당 상위 레벨을 표현하는 임베딩으로 사용.
- 상위 레벨 (Lv3)에 공통으로 속하는 Lv 5 embedding 을 tensor 로 묶어서 따로 저장.

In [37]:
cat_emb_lv3_2048_eos = []
cat_emb_lv3_2048_masked = []
cat_emb_lv3_512_eos = []
cat_emb_lv3_512_masked = []
for i, (k, vs) in enumerate(lv3_idx2key.items()):
    try:
        idx_tensor = torch.tensor(lv3_idx2lv5_idx_list[i])
        avg_emb_2048_eos = cat_emb_2048_eos[idx_tensor].mean(dim=0).tolist()
        avg_emb_2048_masked = cat_emb_2048_masked[idx_tensor].mean(dim=0).tolist()
        avg_emb_512_eos = cat_emb_512_eos[idx_tensor].mean(dim=0).tolist()
        avg_emb_512_masked = cat_emb_512_masked[idx_tensor].mean(dim=0).tolist()
    except:
        avg_emb_2048_eos = torch.zeros((4096,)).tolist()
        avg_emb_2048_masked = torch.zeros((4096,)).tolist()
        avg_emb_512_eos = torch.zeros((4096,)).tolist()
        avg_emb_512_masked = torch.zeros((4096,)).tolist()
        pass

    cat_emb_lv3_2048_eos.append(avg_emb_2048_eos)
    cat_emb_lv3_2048_masked.append(avg_emb_2048_masked)
    cat_emb_lv3_512_eos.append(avg_emb_512_eos)
    cat_emb_lv3_512_masked.append(avg_emb_512_masked)

cat_emb_lv3_2048_eos = torch.tensor(cat_emb_lv3_2048_eos).to(dtype=torch.float16)
cat_emb_lv3_2048_masked = torch.tensor(cat_emb_lv3_2048_masked).to(dtype=torch.float16)
cat_emb_lv3_512_eos = torch.tensor(cat_emb_lv3_512_eos).to(dtype=torch.float16)
cat_emb_lv3_512_masked = torch.tensor(cat_emb_lv3_512_masked).to(dtype=torch.float16)

In [38]:
len(lv3_idx2key.keys()), len(cat_emb_lv3_2048_eos), len(cat_emb_lv3_2048_masked), len(cat_emb_lv3_512_eos), len(cat_emb_lv3_512_masked)

(1891, 1891, 1891, 1891, 1891)

In [39]:
avg_lv3_emb_list = [cat_emb_lv3_2048_masked, cat_emb_lv3_512_masked, cat_emb_lv3_2048_eos, cat_emb_lv3_512_eos]

### Lv 5 Embedding 을 상위 (Lv 4) 카테고리별로 군집한 후 평균 계산
---
- 상위 레벨 (Lv4)에 공통으로 속하는 Lv 5 embedding 을 평균낸 임베딩을 해당 상위 레벨을 표현하는 임베딩으로 사용.
- 상위 레벨 (Lv4)에 공통으로 속하는 Lv 5 embedding 을 tensor 로 묶어서 따로 저장.

In [40]:
cat_emb_lv4_2048_eos = []
cat_emb_lv4_2048_masked = []
cat_emb_lv4_512_eos = []
cat_emb_lv4_512_masked = []
for i, (k, vs) in enumerate(lv4_idx2key.items()):
    try:
        idx_tensor = torch.tensor(lv4_idx2lv5_idx_list[i])
        avg_emb_2048_eos = cat_emb_2048_eos[idx_tensor].mean(dim=0).tolist()
        avg_emb_2048_masked = cat_emb_2048_masked[idx_tensor].mean(dim=0).tolist()
        avg_emb_512_eos = cat_emb_512_eos[idx_tensor].mean(dim=0).tolist()
        avg_emb_512_masked = cat_emb_512_masked[idx_tensor].mean(dim=0).tolist()
    except:
        avg_emb_2048_eos = torch.zeros((4096,)).tolist()
        avg_emb_2048_masked = torch.zeros((4096,)).tolist()
        avg_emb_512_eos = torch.zeros((4096,)).tolist()
        avg_emb_512_masked = torch.zeros((4096,)).tolist()
        pass

    cat_emb_lv4_2048_eos.append(avg_emb_2048_eos)
    cat_emb_lv4_2048_masked.append(avg_emb_2048_masked)
    cat_emb_lv4_512_eos.append(avg_emb_512_eos)
    cat_emb_lv4_512_masked.append(avg_emb_512_masked)

cat_emb_lv4_2048_eos = torch.tensor(cat_emb_lv4_2048_eos).to(dtype=torch.float16)
cat_emb_lv4_2048_masked = torch.tensor(cat_emb_lv4_2048_masked).to(dtype=torch.float16)
cat_emb_lv4_512_eos = torch.tensor(cat_emb_lv4_512_eos).to(dtype=torch.float16)
cat_emb_lv4_512_masked = torch.tensor(cat_emb_lv4_512_masked).to(dtype=torch.float16)

In [41]:
len(lv4_idx2key.keys()), len(cat_emb_lv4_2048_eos), len(cat_emb_lv4_2048_masked), len(cat_emb_lv4_512_eos), len(cat_emb_lv4_512_masked)

(10792, 10792, 10792, 10792, 10792)

In [42]:
avg_lv4_emb_list = [cat_emb_lv4_2048_masked, cat_emb_lv4_512_masked, cat_emb_lv4_2048_eos, cat_emb_lv4_512_eos]

In [46]:
lv1idx2lv2idx = defaultdict(set)
for k, vs in lv1idx2lv5idx.items():
    for v in vs:
        lv1idx2lv2idx[k].add(lv5_idx2lv2_idx[str(v)])
        
lv1idx2lv2idx = {k:list(v) for k, v in lv1idx2lv2idx.items()}

key_dump, val_dump = [], []
for k, v in lv1idx2lv2idx.items():
    key_dump.append(k)
    val_dump.extend(v)

In [54]:
lv1idx2lv3idx = defaultdict(set)
for k, vs in lv1idx2lv5idx.items():
    for v in vs:
        lv1idx2lv3idx[k].add(lv5_idx2lv3_idx[str(v)])
        
lv1idx2lv3idx = {k:list(v) for k, v in lv1idx2lv3idx.items()}

key_dump, val_dump = [], []
for k, v in lv1idx2lv3idx.items():
    key_dump.append(k)
    val_dump.extend(v)

In [56]:
lv1idx2lv4idx = defaultdict(set)
for k, vs in lv1idx2lv5idx.items():
    for v in vs:
        lv1idx2lv4idx[k].add(lv5_idx2lv4_idx[str(v)])
        
lv1idx2lv4idx = {k:list(v) for k, v in lv1idx2lv4idx.items()}

key_dump, val_dump = [], []
for k, v in lv1idx2lv4idx.items():
    key_dump.append(k)
    val_dump.extend(v)

In [58]:
lv2idx2lv1idx, lv3idx2lv1idx, lv4idx2lv1idx = {}, {}, {}
for k, vs in lv1idx2lv2idx.items():
    for v in vs:
        lv2idx2lv1idx[v]=int(k)
        
for k, vs in lv1idx2lv3idx.items():
    for v in vs:
        lv3idx2lv1idx[v]=int(k)
        
for k, vs in lv1idx2lv4idx.items():
    for v in vs:
        lv4idx2lv1idx[v]=int(k)

In [59]:
def emb_mixup_lv1(avg_cat, emb_lv1, lam):
    mixed_emb = (1-lam)*avg_cat + lam*emb_lv1
    assert mixed_emb.shape == avg_cat.shape
    return mixed_emb

def emb_mixup_lv234(avg_cat, emb_lv1, lam, lvidx2lv1idx):
    mixed_emb = []
    for i in range(len(avg_cat)):
        try:
            emb = (1-lam)*avg_cat[i] + lam*emb_lv1[lvidx2lv1idx[i]]
            mixed_emb.append(emb.tolist())
        except:
            mixed_emb.append(avg_cat[i].tolist())
            pass
        
    mixed_emb = torch.tensor(mixed_emb)
    assert mixed_emb.shape == avg_cat.shape
    return mixed_emb

In [60]:
emb_method = ["masked", "masked", "eos", "eos"]
lenght = [2048, 512, 2048, 512]

### 각 상위 레벨 (Lv 2, 3, 4) 평균 embedding과 국문 조달청 Lv 1 Embedding Mixup

In [69]:
lv = 2
for i in range(4):
    for l in [0.1, 0.5]:
        mixed_emb_name = f"mix_w_zod_{int(l*100)}_emb_lv_{lv}_max_leng_{lenght[i]}_{emb_method[i]}"
        mixed_embs = emb_mixup_lv234(avg_lv2_emb_list[i], zod_emb_list[i], l, lv2idx2lv1idx)
        mixed_embs = mixed_embs.to(dtype=torch.float16)
        torch.save(mixed_embs, f'../../../kisti_output/{mixed_emb_name}.pt')

In [70]:
lv = 3
for i in range(4):
    for l in [0.1, 0.5]:
        mixed_emb_name = f"mix_w_zod_{int(l*100)}_emb_lv_{lv}_max_leng_{lenght[i]}_{emb_method[i]}"
        mixed_embs = emb_mixup_lv234(avg_lv3_emb_list[i], zod_emb_list[i], l, lv3idx2lv1idx)
        mixed_embs = mixed_embs.to(dtype=torch.float16)
        torch.save(mixed_embs, f'../../../kisti_output/{mixed_emb_name}.pt')

In [71]:
lv = 4
for i in range(4):
    for l in [0.1, 0.5]:
        mixed_emb_name = f"mix_w_zod_{int(l*100)}_emb_lv_{lv}_max_leng_{lenght[i]}_{emb_method[i]}"
        mixed_embs = emb_mixup_lv234(avg_lv4_emb_list[i], zod_emb_list[i], l, lv4idx2lv1idx)
        # print(mixed_embs.dtype)
        mixed_embs = mixed_embs.to(dtype=torch.float16)
        torch.save(mixed_embs, f'../../../kisti_output/{mixed_emb_name}.pt')

torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32


### 각 상위 레벨 (Lv 2, 3, 4) 평균 embedding과 UNSPSC Lv 1 Embedding Mixup

In [72]:
lv = 2
for i in range(4):
    for l in [0.1, 0.5]:
        mixed_emb_name = f"mix_w_uns_{int(l*100)}_emb_lv_{lv}_max_leng_{lenght[i]}_{emb_method[i]}"
        mixed_embs = emb_mixup_lv234(avg_lv2_emb_list[i], uni_emb_list[i], l, lv2idx2lv1idx)
        mixed_embs = mixed_embs.to(dtype=torch.float16)
        torch.save(mixed_embs, f'../../../kisti_output/{mixed_emb_name}.pt')
        
lv = 3
for i in range(4):
    for l in [0.1, 0.5]:
        mixed_emb_name = f"mix_w_uns_{int(l*100)}_emb_lv_{lv}_max_leng_{lenght[i]}_{emb_method[i]}"
        mixed_embs = emb_mixup_lv234(avg_lv3_emb_list[i], uni_emb_list[i], l, lv3idx2lv1idx)
        mixed_embs = mixed_embs.to(dtype=torch.float16)
        torch.save(mixed_embs, f'../../../kisti_output/{mixed_emb_name}.pt')
        
lv = 4
for i in range(4):
    for l in [0.1, 0.5]:
        mixed_emb_name = f"mix_w_uns_{int(l*100)}_emb_lv_{lv}_max_leng_{lenght[i]}_{emb_method[i]}"
        mixed_embs = emb_mixup_lv234(avg_lv4_emb_list[i], uni_emb_list[i], l, lv4idx2lv1idx)
        mixed_embs = mixed_embs.to(dtype=torch.float16)
        torch.save(mixed_embs, f'../../../kisti_output/{mixed_emb_name}.pt')

In [73]:
torch.save(cat_emb_lv2_2048_eos, '../../../kisti_output/avg_emb_zodal_cat_lv_2_max_leng_2048_eos.pt')
torch.save(cat_emb_lv2_2048_masked, '../../../kisti_output/avg_emb_zodal_cat_lv_2_max_leng_2048_masked.pt')
torch.save(cat_emb_lv2_512_eos, '../../../kisti_output/avg_emb_zodal_cat_lv_2_max_leng_512_eos.pt')
torch.save(cat_emb_lv2_512_masked, '../../../kisti_output/avg_emb_zodal_cat_lv_2_max_leng_512_masked.pt')

In [74]:
torch.save(cat_emb_lv3_2048_eos, '../../../kisti_output/avg_emb_zodal_cat_lv_3_max_leng_2048_eos.pt')
torch.save(cat_emb_lv3_2048_masked, '../../../kisti_output/avg_emb_zodal_cat_lv_3_max_leng_2048_masked.pt')
torch.save(cat_emb_lv3_512_eos, '../../../kisti_output/avg_emb_zodal_cat_lv_3_max_leng_512_eos.pt')
torch.save(cat_emb_lv3_512_masked, '../../../kisti_output/avg_emb_zodal_cat_lv_3_max_leng_512_masked.pt')

In [75]:
torch.save(cat_emb_lv4_2048_eos, '../../../kisti_output/avg_emb_zodal_cat_lv_4_max_leng_2048_eos.pt')
torch.save(cat_emb_lv4_2048_masked, '../../../kisti_output/avg_emb_zodal_cat_lv_4_max_leng_2048_masked.pt')
torch.save(cat_emb_lv4_512_eos, '../../../kisti_output/avg_emb_zodal_cat_lv_4_max_leng_512_eos.pt')
torch.save(cat_emb_lv4_512_masked, '../../../kisti_output/avg_emb_zodal_cat_lv_4_max_leng_512_masked.pt')

In [76]:
with open("./lv1idx2lv2idx.json", "w") as f:
    json.dump(lv1idx2lv2idx, f)
    
with open("./lv1idx2lv3idx.json", "w") as f:
    json.dump(lv1idx2lv2idx, f)

with open("./lv1idx2lv4idx.json", "w") as f:
    json.dump(lv1idx2lv2idx, f)