In [None]:
# Inference dataset (change)
inference_dataset = 'data/inference/tv2022_fedgov_asr_0904.csv'
text_field = 'google_asr_text'
output_file = 'data/inference/tv2022_fedgov_asr_0904_output.csv'

# Variable label file (don't change)
label_file = 'data/issue_labels_65.txt'

# Model files (usually don't change)
model_pytorch_model = 'models/multilabel_trf_v1/pytorch_model.bin'
model_config = 'models/multilabel_trf_v1/config.json'

# Connect to GDrive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.21.3-py3-none-any.whl (4.7 MB)
[K     |████████████████████████████████| 4.7 MB 13.9 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 70.8 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.9.1-py3-none-any.whl (120 kB)
[K     |████████████████████████████████| 120 kB 72.7 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.9.1 tokenizers-0.12.1 transformers-4.21.3


In [None]:
from tqdm import tqdm
import os
import shutil

import numpy as np
import pandas as pd
import torch
import transformers

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

#----
# Make model dir if it doesn't already exist
os.makedirs('models', exist_ok=True)
# Copy the model files to the model dir
shutil.copyfile(model_pytorch_model, 'models/pytorch_model.bin')
shutil.copyfile(model_config, 'models/config.json')

# Copy inference data (here test set)
shutil.copyfile(inference_dataset, './inference_dataset.csv')

#----
# Load the trained model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained('models')
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

#----
# Load the inference dataset
df = pd.read_csv('./inference_dataset.csv')
df = df.dropna(subset = [text_field]) # remove NAs
df = df[df[text_field] != '_error'] # remove errors
df = df.reset_index(drop = True)

# Load the variable labels
with open('data/issue_labels_65.txt', 'r') as reader:
  labels = reader.read().split('\n')
# They were created like this:
# df = pd.read_csv('data/issues_tv_fb_18_20.csv')
# with open('data/issue_labels_65.txt', 'w') as writer:
#   for i in df.columns[2:].tolist():
#     writer.write(i + '\n')

#----
# Inference

# Batch the text Series (batch size 16)
texts = df[text_field].to_list()
batch_size = 16
list_df = [texts[i:i+batch_size] for i in range(0,len(texts),batch_size)]

# Use the tokenizer to encode the Series in batches
batched_examples = []
for text_chunk in list_df:
  batched_examples.append(tokenizer.batch_encode_plus(text_chunk, truncation=True, padding=True, return_tensors="pt"))

# Batch inference
# For inference, calculating the gradients is unnecessary
# with torch.no_grad(): turns them off, which is faster (seems about 10x faster on CPU, and 2x faster on GPU or so)
outputs_list = []
for encoded_chunk in tqdm(batched_examples):
  with torch.no_grad():
    outputs_list.append(model(**encoded_chunk))

# Convert to 1s and 0s
preds_l = []
for output in outputs_list:
  preds = output.logits.sigmoid().numpy() > 0.5
  preds = preds.astype(int)
  preds_l.append(preds)
outputs = np.vstack(preds_l)

# Convert to pd DataFrame and save
df_preds = pd.DataFrame(outputs)
df_preds.columns = labels[:-1]
df_results = pd.concat([df, df_preds], axis = 1)
df_results.to_csv(output_file, index = False)

Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/455k [00:00<?, ?B/s]

100%|██████████| 269/269 [06:59<00:00,  1.56s/it]
