In [1]:
import os
import yaml
import pandas as pd
import numpy as np
import clip
import torch
root = '/home/yuyating/workspace/STDD'
root

'/home/yuyating/workspace/STDD'

In [2]:
def get_templates(dataset_name):
    label_file = root + f"/zs_label_db/classes_label_{dataset_name}.yml"
    with open(label_file, 'r') as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
    # classes = data['classes']  # list: ['brush hair', ...]
    # templates = data['templates']  # list: ['a video of a person {}.', ...]
    # obj_templates = data['obj_templates']
    return data
data = get_templates('ucf101')

In [3]:
def get_xprompt(dataset_name):
    xprompt_file = f"data/{dataset_name}/classes_xprompt_{dataset_name}.yml"
    with open(xprompt_file, 'r') as f:
        text = yaml.load(f, Loader=yaml.FullLoader)
    return text


In [4]:
def expand_cls_tokenized(cls_tokenized, num_prompt):
    expanded_cls_tokenized = []
    for cls in cls_tokenized:
        cur_size = cls.size(0)
        if cur_size < num_prompt:
            repeats = num_prompt // cur_size
            expanded_cls = torch.cat([cls] * repeats, dim=0)
            remaining = num_prompt % cur_size
            if remaining > 0:
                expanded_cls = torch.cat([expanded_cls, cls[:remaining]], dim=0)

        else:
            expanded_cls = cls[:num_prompt]
        expanded_cls_tokenized.append(expanded_cls)
    expanded_cls_tokenized = torch.stack(expanded_cls_tokenized, dim=0)
    return expanded_cls_tokenized


In [5]:

def expand_cls_text(cls_text_list):
    min_prompt = min([len(i) for i in cls_text_list])
    max_prompt = max([len(i) for i in cls_text_list])

    expanded_cls_text_list = [sublist + [sublist[i % len(sublist)] for i in range(max_prompt-len(sublist))] for sublist in cls_text_list]
    return expanded_cls_text_list, min_prompt, max_prompt


In [6]:
def text_prompt(data, dataset: str, num_templates: int, cls_prompt_type: str):
    text = get_xprompt(dataset)
    templates = get_templates(dataset)['templates'] # k: classes, templates, obj_templates
    classes = [i[1] for i in data.classes] # ["c1", "c2", ...]
    num_classes = len(classes)
    n_prompts = [0, 0]

    total_templates = len(templates)
    num_templates = min(num_templates, total_templates)
    templates = templates[:num_templates] # a video of a person {}.

    tokenized_dict = {}
    cls_text_dict = {}
    text_dict = {}
    xoo_dict = {} # {0: [xprompt_oo,...], 1: [xprompt_oo], ...}
    xao_dict = {} # {0: [xprompt_ao,...], 1: [xprompt_ao], ...}
    xaa_dict = {} # {0: [xprompt_aa,...], 1: [xprompt_aa], ...}
    for i, t in enumerate(text.values()):
        xoo_dict[i] = t['xprompt_oo']
        xao_dict[i] = t['xprompt_ao']
        xaa_dict[i] = t['xprompt_aa']
    for ii, txt in enumerate(templates):
        text_dict[ii] = {
            'a': [[txt.format(c)] for i, c in enumerate(classes)],
            'xoo': [],
            'xao': [],
            'xaa': []
        }
        tokenized_dict[ii] = []
        for i, c in enumerate(classes):
            ci_xoo_list = text_dict[ii]['a'][i][:] # ["a {ci}."]
            ci_xao_list = text_dict[ii]['a'][i][:]
            ci_xaa_list = text_dict[ii]['a'][i][:]
            for j, t in enumerate(xoo_dict[i]):
                ci_xoo_list.append(txt.format(f"{c}, {t}")) # ["a video of a person brush hair, where hair..."]
            text_dict[ii]['xoo'].append(ci_xoo_list)
            for j, t in enumerate(xao_dict[i]):
                ci_xao_list.append(txt.format(f"{c}, {t}"))
            text_dict[ii]['xao'].append(ci_xao_list)
            for j, t in enumerate(xaa_dict[i]):
                ci_xaa_list.append(txt.format(f"{c}, {t}"))
            text_dict[ii]['xaa'].append(ci_xaa_list)
        if cls_prompt_type == 'xoo':
            cls_text_dict[ii] = text_dict[ii]['xoo']
        elif cls_prompt_type == 'xao':
            cls_text_dict[ii] = text_dict[ii]['xao']
        elif cls_prompt_type == 'xaa':
            cls_text_dict[ii] = text_dict[ii]['xaa']
        elif cls_prompt_type == 'xmix':
            cls_text_dict[ii] = []
            for i in range(len(classes)):
                cls_text_dict[ii].append(list(set(text_dict[ii]['xoo'][i] + text_dict[ii]['xao'][i] + text_dict[ii]['xaa'][i])))
        else:
            cls_text_dict[ii] = text_dict[ii]['a']


        cls_text_dict[ii], n_prompts[0], n_prompts[1] = expand_cls_text(cls_text_dict[ii])

        for n in range(n_prompts[1]):
            tokenized_dict[ii].append(torch.cat([clip.tokenize(t[n]) for t in cls_text_dict[ii]])) # [c 77, c 77,...]

        tokenized_dict[ii] = torch.cat(tokenized_dict[ii])
        # for i in range(len(cls_text_dict[ii])):
        #     tokenized_dict[ii].append(torch.cat([clip.tokenize(t) for t in cls_text_dict[ii][i]]))


    # cls_tokenized = [torch.cat([tokenized_dict[i][j] for i in range(num_templates)], dim=0) for j in range(num_classes)]

    cls_tokenized = torch.cat([v for v in tokenized_dict.values()]) # (num_templates max_prompt num_cls) 77


    # expand and repeat cls_tokenized dim to max
    # cls_tokenized = expand_cls_tokenized(cls_tokenized, max_prompt) # c max 77
    return cls_tokenized, cls_text_dict, tokenized_dict, num_templates, n_prompts


In [7]:
def get_en_labels(en_list, en_dict):
    # classes = [i[1] for i in data.classes] # ["c1", "c2", ...]
    # en_list, en_dict = get_ASKG_entity(dataset, classes, entity_type) # ['e1', 'e2', ...], {0:['','',...], 1:[], ...}
    en_label_dict = {}
    en_label_list = []
    for c, en_li in en_dict.items():
        en_labels = []
        for i, en in enumerate(en_list):
            if en in en_li:
                en_labels.append(i)
        en_label_dict[c] = en_labels
        en_label_list.append(en_labels)
    return en_label_dict, en_label_list
