## this code extract POS from the summary json file, which generates a new json file

In [None]:
import os
os.chdir('..')
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
nltk.download('stopwords')
import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt 
import torch
import ast
import re
from collections import defaultdict
from vllm import LLM, SamplingParams
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
import spacy
nlp = spacy.load('en_core_web_lg')

In [None]:
def remove_dummy(adjs):
    dummy = set(['other', 'most', 'some', 'any', 'several', 'many', 'few', 'all', 'each', 'every', 'another', 'both', 'either', 'neither', 'such', 'more', 'less', 'a few', 'a lot', 'several', 'many', 'much', 'little', 'most', 'none', 'no one', 'somebody', 'someone', 'something', 'somewhere', 'used', 'audio', 'best', 'due', 'recorded', 'most', 'various', 'video', 'meant', 'easy', '737-800', 'personal', 'external', 'overall',
    'sound', 'mobile', 'designed', 'well-defined', 'detailed', 'suitable', 'small', 'third', 'second', 'fourth', 'fifth', 'first', 'related', 'different', 'actual', 'kitchen', '*', '2-3', 'everyday', 'common'])
    return adjs.difference(dummy)

## set up paramters and in/out file path

In [None]:
topk = 10
target_name = 'ast-esc50'
# target_name = 'beats-esc50'
# target_name = 'beats-esc50-unfreeze'

in_json = f'summaries/calibration_{target_name}_esc50_esc50_top5.json'
out_json = f'summaries/calibration_{target_name}_esc50_esc50_top5_processed.json'

with open(in_json, 'r') as f:
    data = json.load(f)

## extract all adjectives

In [None]:
all_adjs, new_data = [], {}
for k, item in tqdm(data.items()):
    sentence = ''
    for point in item['highly']:
        sentence = sentence + ' ' + point

    words = word_tokenize(sentence.lower())
    pos_tags = nltk.pos_tag(words)

    # rule-based filtering
    adjs = []
    for i in range(len(pos_tags)):
        if pos_tags[i][1] in ['JJ', 'JJR', 'JJS', 'VBN']:
            previous = pos_tags[i-4:i]
            previous = [k[0] for k in previous]
            if 'no' not in previous and 'not' not in previous:
                adjs.append(pos_tags[i][0])

    adjs = set([a for a in adjs])
    adjs = list(remove_dummy(adjs))
    item['adj_after_rbf'] = adjs
    new_data.update({k: item})
    all_adjs.extend(adjs)

## apply the LLM to determine which are acoustic adjectives

In [None]:
all_adjs = list(set(all_adjs))
prompt_template =  "<s>[INST] <<SYS>>\n\
<</SYS>>\n\n \
Can the adjective '{}' be used to describe the tone, emotion, or acoustic features of audio, music, or any other form of sound?\n \
Answer(yes or no):\n\
Reason:\
[/INST]"

dataset = defaultdict(list)
for word in all_adjs:
	dataset["word"].append(word)
	dataset["text"].append(prompt_template.format(word))
dataset = Dataset.from_dict(dataset)
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=False, pin_memory=True,num_workers=16)

# load a hf LLM. Change to your path.
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
sampling_params = SamplingParams(top_p=1, temperature=1, max_tokens=128)

prompts = [prompt_template.format(w) for w in all_adjs]
outputs = llm.generate(prompts, sampling_params)

results, ref_dict = [], {}
for idx, output in enumerate(outputs):
	prompt = output.prompt
	word = dataset[idx]["word"]
	generated_text = output.outputs[0].text
	# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
	results.append({"word": word, "response": generated_text})

	ans = True if 'yes' in generated_text.lower() else False
	ref_dict[word] = ans

## combine filtered results to data

In [None]:
new_new_data = {}
for k, item in tqdm(new_data.items()):
    adjs = item['adj_after_rbf']
    valid_adjs = []
    for adj in adjs:
        adj = adj.strip('\'')
        if adj in ref_dict.keys() and ref_dict[adj]:
            valid_adjs.append(adj)

    item['adj_after_rbf_llmf'] = valid_adjs
    new_new_data.update({k: item})

In [None]:
new_new_new_data = {}
for k, item in tqdm(new_new_data.items()):
    sentence = ''
    for point in item['highly']:
        sentence = sentence + ' ' + point
    doc = nlp(sentence.strip())
    caption_len = len([token.text for token in doc])

    words = word_tokenize(sentence.lower())
    pos_tags = nltk.pos_tag(words)

    verbs, preps, nouns = [], [], []
    for pos_tag in pos_tags:
        if pos_tag[1] in ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']: # verb
            if pos_tag[0] not in stopwords.words('english'):
                verbs.append(pos_tag[0])
        if pos_tag[1] in ['IN']: # preposition. Do not use stopwords to filter
            preps.append(pos_tag[0])
        if pos_tag[1] in ['NN', 'NNS', 'NNP', 'NNPS']: # noun
            if pos_tag[0] not in stopwords.words('english'):
                nouns.append(pos_tag[0])

    verbs =list(set([a for a in verbs]))
    preps = list(set([a for a in preps]))
    nouns = list(set([a for a in nouns]))
    item['verbs'] = verbs
    item['preps'] = preps
    item['nouns'] = nouns
    item['caption_len'] = caption_len
    new_new_new_data.update({k: item})



## dump processed data to out file

In [None]:
with open(out_json, 'w') as f: 
	json.dump(new_new_new_data, f, indent=4)

## plot adjective distribution

In [None]:
counter = defaultdict(int)
for line in new_new_new_data.values(): 
	for w in line['adj_after_rbf_llmf']:
		counter[w] += 1

# fix empirical error
counter['high-quality'] += counter['high']
del counter['high']
counter['low-quality'] += counter['low']
del counter['low']

data = {k: v for k, v in sorted(counter.items(), key=lambda x: x[1], reverse=True)}
x = list(data.keys())[:topk]
x.reverse()
y = list(data.values())[:topk]
y.reverse()


colors = plt.cm.viridis(np.linspace(0, 1, len(y)))
plt.figure(figsize=(25, 32))
plt.barh(x, y, color=colors, edgecolor='none')
plt.xticks(fontsize=50, rotation=30, ha='right') 
plt.yticks(fontsize=60)
plt.subplots_adjust(left=0.25, right=0.95, top=0.98)
plt.show()
plt.savefig(f'adjective_count-{target_name}.jpg', format='jpg', dpi=1000)
