I've wanted a comparison between a machine learning model specificaly taught on the dataset (see file tensorflow_keras_categorization) and labeling based on general language model which is not fine-tuned at all for memes. After many fruitless attempts with TenserFlow itself I've decided to use HuggingFace as a proxy, because it has support some support for [zero-shot classification](https://huggingface.co/course/chapter1/3?fw=pt#zeroshot-classification).

I've encountered a significant problem - the speed of classification is dependent on the number of labels. With 300 labels (meme templates) it takes 12 seconds per single prediction on my CPU, and more than 8 seconds per prediction even on a powerful GPU in Google Colab.

I am evaluating results in a separate notebook (huggingface_zero_shot_eval_results), but in general - the accuracy is dependent on the language model use. The correct meme template is identified in 7-8% of cases. If we were to look not only on the candidate with the highest probability, but on top 5 candidates from the prediction, the correct template is in there between 10-13% times.

- roberta-large-mnli
    - top 1: 8.11%
    - top 5:  10.81%

- typeform_mobilebert-uncased-mnli
    - top 1: 7.25%
    - top 5: 12.75%


In [2]:
# !pip install tqdm pandas transformers
# !nvidia-smi

In [3]:
# from google.colab import files

# uploaded = files.upload()
# uploaded

In [2]:
import pandas as pd
import json
import random


In [3]:
dataset = pd.read_csv('all.csv')
dataset = dataset[["MemeName", "Text"]]
dataset.head(5)

Unnamed: 0,MemeName,Text
0,Y U No,Forever alone guy <sep> y u no get cat
1,Y U No,TEAMMATES <sep> Y U NO REVIVE ME?
2,Y U No,GIRLS <sep> Y U SO COMPLICATED ??!
3,Y U No,I 'like' all your pics <sep> Y U No have sex w...
4,Y U No,girls <sep> y u no stop making duck faces?!


In [4]:
labels = list(set(dataset.MemeName.tolist()))
labels

['skeptical black kid',
 "you don't say meme",
 'High Expectations Asian Fath',
 'Sudden Realization Ralph',
 'Awkward Seal',
 'Challenge Accepted 2',
 'Art Student Owl',
 'journalist',
 'Kermit the frog',
 'Douchebag Roommate',
 'Bear Grylls Loneliness',
 'Retail Robin',
 'teacher',
 'Blank Black',
 'Chuck Norris ',
 'Grumpy cat good',
 'I can haz',
 'kill yourself guy',
 "Don't you, Squidward?",
 'Mafia Baby',
 'skyrim stan',
 'Bad Factman',
 'Funes 20 aos',
 'african children dancing',
 'One Does Not Simply',
 'Really Stoned Guy',
 'Oprah You get a',
 'That would be great',
 'your country needs you',
 'Black Kid',
 'Trollface',
 "Timmy turner's dad IF I HAD ",
 'say what one more time',
 'bender blackjack and hookers',
 'Laughing Girls ',
 'Anchorman Birthday',
 'Angry School Boy',
 'Advice Dog',
 'In Soviet Russia',
 'burning house girl',
 'Roleplaying Rabbit',
 'Success Kid',
 'obama beer',
 'So I got that going on for m',
 'Matrix Morpheus',
 'Uncle Dolan',
 'mean girls',
 'kill 

In [10]:
from transformers import pipeline, AutoModelForSequenceClassification
from tqdm import tqdm

# Note: number of labels greatly affects the speed

# model_name = "roberta-large-mnli" # CPU: 20 labels ~ 1it/5.5s; 300 labels ~ 1it/87s
model_name = "typeform/mobilebert-uncased-mnli" # CPU - 1it/12s; GPU 1it/8s; CPU - 20 labels: 1it/0.8s

classifier = pipeline("zero-shot-classification", model_name) 

In [13]:
# import torch
# torch.cuda.is_available()

In [14]:
results = {}

In [18]:

texts = dataset.Text.tolist()
random.shuffle(texts)

for text in tqdm(texts[:200]):
    if len(text.split()) > 500:
        print("Too many words: ", text) # https://stackoverflow.com/questions/65023526/runtimeerror-the-size-of-tensor-a-4000-must-match-the-size-of-tensor-b-512
        continue

    prediction = classifier(
        text,
        candidate_labels=labels,
        max_length=200
    )
    results[prediction["sequence"]] = prediction

100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [41:24<00:00, 12.42s/it]


In [16]:
# results

In [19]:

filename = model_name.replace("/", "_") + ".out.json"
# filename = "mobilebert-uncased-mnli.out5.json"

with open(filename, "w", encoding="utf8") as f:
    json.dump(results, f)