In [None]:
import sys

import pyrootutils
import torch
import yaml
from pytorch_lightning.utilities.deepspeed import (
    convert_zero_checkpoint_to_fp32_state_dict,
)

# import mplfonts

root = pyrootutils.setup_root(__vsc_ipynb_file__, dotenv=True, pythonpath=True)
sys.path.append("/data/deeplearning-all-in-one/src")

In [None]:
# mplfonts.use_font()

In [None]:
labels = open("/data/clean_raw_text/district_label_map.txt", "r").readlines()
labels = [x.strip("\n").split(" ")[0] for x in labels]

In [None]:
model_path = (
    "/data/deeplearning-all-in-one/logs/train/runs/2022-08-03/18-17-12/checkpoints/epoch_009.ckpt"
)
save_path = (
    "/data/deeplearning-all-in-one/logs/train/runs/2022-08-03/18-17-12/checkpoints/model.ckpt"
)

In [None]:
device = "cuda:4" if torch.cuda.is_available() else "cpu"

In [None]:
convert_zero_checkpoint_to_fp32_state_dict(model_path, save_path)

In [None]:
from src.models.dist_classify import DistClassify

In [None]:
model = DistClassify.load_from_checkpoint(
    save_path,
    hparams=yaml.safe_load(
        open(
            "/data/deeplearning-all-in-one/logs/train/runs/2022-08-03/18-17-12/.hydra/config.yaml",
            "r",
        )
    ),
)

In [None]:
model = model.to(device)
model.eval()

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext", cache_dir="~/.cache")

In [None]:
import pandas as pd

In [None]:
data = pd.read_json("/data/clean_raw_text/all_data_cleaned.json", lines=True)

In [None]:
zwjl = data[data.src.eq("政务交流")]

In [None]:
len(zwjl)

In [None]:
model(
    tokenizer(
        zwjl.sample(32).text.values.tolist(),
        max_length=200,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    ).to(device)
).softmax(dim=-1)

In [None]:
from tqdm import tqdm

all_logits = []
with torch.no_grad():
    for i in tqdm(range(0, len(zwjl), 100), desc="generating pseudo-labels"):
        texts = zwjl.iloc[i : i + 100].text.values.tolist()
        logits = model(
            tokenizer(
                texts, max_length=200, padding="max_length", truncation=True, return_tensors="pt"
            ).to(device)
        ).softmax(dim=-1)
        all_logits.append(logits.cpu().numpy())

In [None]:
import numpy as np

In [None]:
logits = np.concatenate(all_logits, axis=0)

In [None]:
zwjl["logits"] = logits.tolist()

In [None]:
zwjl["pseudo_label"] = zwjl.logits.apply(lambda x: np.array(x).argmax())

In [None]:
zwjl.groupby(by="pseudo_label").apply(len).plot.pie(labels=labels, figsize=(6, 6))

In [None]:
jx = data[data.src.eq("江西")]

In [None]:
len(jx)

In [None]:
all_logits = []
with torch.no_grad():
    for i in tqdm(range(0, len(jx), 100), desc="generating pseudo-labels"):
        texts = jx.iloc[i : i + 100].text.values.tolist()
        logits = model(
            tokenizer(
                texts, max_length=200, padding="max_length", truncation=True, return_tensors="pt"
            ).to(device)
        ).softmax(dim=-1)
        all_logits.append(logits.cpu().numpy())

In [None]:
logits = np.concatenate(all_logits, axis=0)

In [None]:
jx["logits"] = logits.tolist()

In [None]:
jx["pseudo_label"] = jx.logits.apply(lambda x: np.array(x).argmax())

In [None]:
jx.groupby(by="pseudo_label").apply(len).plot.pie(labels=labels, figsize=(6, 6))

In [None]:
zwjl["label"] = zwjl.pseudo_label.apply(lambda x: labels[x])

In [None]:
jx["label"] = jx.pseudo_label.apply(lambda x: labels[x])

In [None]:
jx[["text", "label"]].sample(10)

In [None]:
street_data = pd.read_json("/data/clean_raw_text/dataset/street_test.json", lines=True)

In [None]:
street_data.img.apply(lambda x: len(x) > 0).sum()

In [None]:
all_agu_data = pd.concat([zwjl, jx], axis=0)

In [None]:
all_agu_data[["text", "label"]].sample(10)

In [None]:
all_agu_data["label"] = all_agu_data.logits

In [None]:
train_data = pd.read_json("/data/clean_raw_text/dataset/district_train.json", lines=True)

In [None]:
mix_data = pd.concat([train_data, all_agu_data], axis=0)

In [None]:
mix_data

In [None]:
mix_data[["text", "label"]].to_json(
    "/data/clean_raw_text/dataset/district_train_with_pseudo_label.json",
    lines=True,
    orient="records",
    force_ascii=False,
)