In [1]:
!pip3 install transformers>=4.32.0 optimum>=1.12.0
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/  # Use cu117 if on CUDA 11.7

Looking in indexes: https://pypi.org/simple, https://huggingface.github.io/autogptq-index/whl/cu118/


In [2]:
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import os
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

## Read Data

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
BASE_DIR = Path("/content/drive/MyDrive/ITMO/RW'23")
DATA_DIR = BASE_DIR.joinpath("data")

FILE_NAME = "data_irrelative_topics.csv"

FILE = DATA_DIR.joinpath(FILE_NAME)

In [5]:
data = pd.read_csv(FILE)
data.head(10)

Unnamed: 0.1,Unnamed: 0,wordset,all_topics,unsuitable_topics,suitable_topics
0,0,islamic muslim ohio kent sandvik cheer islam a...,"['islamic', 'muslim', 'ohio', 'kent', 'sandvik...",[],"['islam', 'tube', 'sandvik', 'ohio', 'upenn', ..."
1,1,chip blend roast com favorite www http href da...,"['chip', 'blend', 'roast', 'com', 'favorite', ...","['k-cups', 'blend', 'chip']","['href', 'pod', 'com', 'favorite', 'roast', 'd..."
2,2,window display color screen manager server app...,"['window', 'display', 'color', 'screen', 'mana...","['client', 'colors', 'colormap', 'manager']","['size', 'application', 'memory', 'screen', 'm..."
3,3,nasa mission launch surface shuttle radar dete...,"['nasa', 'mission', 'launch', 'surface', 'shut...","['april', 'three']","['orbiter', 'probe', 'shuttle', 'field', 'nasa..."
4,4,church bible christ catholic faith matthew lor...,"['church', 'bible', 'christ', 'catholic', 'fai...","['matthew', 'word', 'holy']","['father', 'catholic', 'christ', 'scripture', ..."
5,5,bike firearm knife criminal carry death weapon...,"['bike', 'firearm', 'knife', 'criminal', 'carr...","['scsi', 'bike', 'smith', 'tape', 'child', 'mo...","['criminal', 'firearm', 'carry', 'handgun', 'k..."
6,6,company available care awesome search sent lic...,"['company', 'available', 'care', 'awesome', 's...","['awesome', 'care']","['licorice', 'call', 'company', 'business', 's..."
7,7,firearm bill weapon crime vote handgun public ...,"['firearm', 'bill', 'weapon', 'crime', 'vote',...","['vote', 'washington', 'news', 'kent']","['laws', 'criminal', 'federal', 'firearm', 'bi..."
8,8,clinton mike batf class bill within proper wac...,"['clinton', 'mike', 'batf', 'class', 'bill', '...",[],"['unknown', 'clinton', 'proper', 'bill', 'batf..."
9,9,almond body smooth blue magnesium cause diamon...,"['almond', 'body', 'smooth', 'blue', 'magnesiu...",[],"['blood', 'magnesium', 'rda', 'diamond', 'smoo..."


In [6]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2363 entries, 0 to 2362
Data columns (total 5 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   Unnamed: 0         2363 non-null   int64 
 1   wordset            2363 non-null   object
 2   all_topics         2363 non-null   object
 3   unsuitable_topics  2363 non-null   object
 4   suitable_topics    2363 non-null   object
dtypes: int64(1), object(4)
memory usage: 92.4+ KB


##Llama-13B

In [7]:
!!nvidia-smi

['Sat Jan 13 02:22:45 2024       ',
 '+---------------------------------------------------------------------------------------+',
 '| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |',
 '|-----------------------------------------+----------------------+----------------------+',
 '| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |',
 '| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |',
 '|                                         |                      |               MIG M. |',
 '|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |',
 '| N/A   62C    P8              10W /  70W |      0MiB / 15360MiB |      0%      Default |',
 '|                                         |                      |                  N/A |',
 '+-----------------------------------------+----------------------+----------------------+',
 '                      

In [8]:
model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
# To use a different branch, change revision
# For example: revision="main"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


##Inference

In [9]:
def generate_answer(prompt, model, device, temp, n_token):
    encoded = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    model_input = encoded
    model_input = model_input.to(device)
    generated_ids = model.generate(
        **model_input, do_sample=True,
        max_new_tokens=n_token,
        temperature=temp,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id,
        num_beams=4
        )
    decoded = tokenizer.batch_decode(generated_ids)

    return decoded

In [10]:
def extract_substring(input_string):
    index = input_string.find("[/INST]")
    if index != -1:
        return input_string[index + len("[/INST]"):]
    else:
        return ""

In [11]:
def few_shot_pipeline(model, system_prompt='', instruction='', sample=('', ''), device='cpu', temp=0.1, n_token=100):
    # Construct prompt
    prompt = f"""<s>[INST]<<SYS>> {system_prompt} <</SYS>> {instruction} TEXT: {sample[0]} TOPICS: {sample[1]} ANSWER: [/INST]"""
    # Generate answer

    answer = generate_answer(prompt, model, device, temp, n_token)
    # Extract the result
    # print(answer[0])
    result = extract_substring(answer[0]).strip()
    return result.replace("\n", "")


####Type I

In [12]:
prompt = """<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic. Provide only topics and nothing more
TEXT: nasa mission launch surface shuttle radar detector atmosphere field orbiter probe april ozone three pasadena
TOPICS: nasa, mission, launch, surface, shuttle, radar, detector, atmosphere, field, orbiter, probe, april, ozone, three, pasadena
ANSWER:[/INST]""".strip()


answer = generate_answer(prompt, model, 'cuda', temp=0.1, n_token=100)

print(answer[0].strip())




<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic. Provide only topics and nothing more
TEXT: nasa mission launch surface shuttle radar detector atmosphere field orbiter probe april ozone three pasadena
TOPICS: nasa, mission, launch, surface, shuttle, radar, detector, atmosphere, field, orbiter, probe, april, ozone, three, pasadena
ANSWER:[/INST]  Sure, I'd be happy to help!

Main Topic: NASA Mission

Words that do not relate to the main topic:

* three
* pasadena

The remaining words all relate to the main topic of the NASA mission:

* launch
* surface
* shuttle
* radar
* detector
* atmosphere
* field
* orbiter
* probe
* ozone</s>


In [37]:
prompt = """<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic. Provide only topics and nothing more
TEXT: window display color screen manager server application visual mode memory size client colors colormap default
TOPICS: window, display, color, screen, manager, server, application, visual, mode, memory, size, client, colors, colormap, default
ANSWER:[/INST]""".strip()

answer = generate_answer(prompt, model, 'cuda', temp=0.1, n_token=200)

print(answer[0].strip())


<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic. Provide only topics and nothing more
TEXT: window display color screen manager server application visual mode memory size client colors colormap default
TOPICS: window, display, color, screen, manager, server, application, visual, mode, memory, size, client, colors, colormap, default
ANSWER:[/INST]  Sure, I'd be happy to help!

Main Topic: Display and Color Management

Words that do not relate to the main topic:

* manager
* server
* application
* visual
* mode
* memory
* size
* client

The main topic of the text is display and color management, specifically related to the color screen and colormap. The other words in the list do not directly relate to this topic.</s>


In [29]:
prompt = """<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic. Provide only topics and nothing more
TEXT: islamic muslim ohio kent sandvik cheer islam apple newton upenn magnus tube activity rice private
TOPICS: islamic, muslim, ohio, kent, sandvik, cheer, islam, apple, newton, upenn, magnus, tube, activity, rice, private
ANSWER:[/INST]""".strip()

answer = generate_answer(prompt, model, 'cuda', temp=0.1, n_token=100)

print(answer[0].strip())


<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic. Provide only topics and nothing more
TEXT: islamic muslim ohio kent sandvik cheer islam apple newton upenn magnus tube activity rice private
TOPICS: islamic, muslim, ohio, kent, sandvik, cheer, islam, apple, newton, upenn, magnus, tube, activity, rice, private
ANSWER:[/INST]  Sure, I'd be happy to help!

Main Topic: Islam

Words that do not relate to the main topic:

* cheer
* apple
* newton
* upenn
* magnus
* tube
* activity
* rice
* private

Note that these words do not have any direct connection to the topic of Islam.</s>


In [27]:
prompt = """<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic.
TEXT: chip blend roast com favorite www http href dark keurig k-cups brew bold rich pod
TOPICS: chip, blend, roast, com, favorite, www, http, href, dark, keurig, k-cups, brew, bold, rich, pod
ANSWER:[/INST]""".strip()

answer = generate_answer(prompt, model, 'cuda', temp=0.1, n_token=200)

print(answer[0].strip())

<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic.
TEXT: chip blend roast com favorite www http href dark keurig k-cups brew bold rich pod
TOPICS: chip, blend, roast, com, favorite, www, http, href, dark, keurig, k-cups, brew, bold, rich, pod
ANSWER:[/INST]  Sure, I'd be happy to help!

After analyzing the text, the main topic appears to be "coffee" or "brewing coffee." The text mentions "chip blend roast," "favorite," "www," "href," "dark," "keurig," "k-cups," "brew," "bold," and "rich," all of which are related to coffee.

The words that do not relate to the main topic are:

* "com" (short for "commercial")
* "www" (World Wide Web)
* "href" (Hypertext Reference)

These words are not relevant to the topic of coffee or brewing coffee, and can be removed from the text.</s>


In [31]:
prompt = """<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic.
TEXT: popcorn baby pill he's movie pump kernel puff gerber popper earth's fussy theater hour pocket
TOPICS: popcorn, baby, pill, he's, movie, pump, kernel, puff, gerber, popper, earths, fussy, theater, hour, pocket
ANSWER:[/INST]""".strip()

answer = generate_answer(prompt, model, 'cuda', temp=0.1, n_token=200)

print(answer[0].strip(), end="")

<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>
You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic.
TEXT: popcorn baby pill he's movie pump kernel puff gerber popper earth's fussy theater hour pocket
TOPICS: popcorn, baby, pill, he's, movie, pump, kernel, puff, gerber, popper, earths, fussy, theater, hour, pocket
ANSWER:[/INST]  Sure, I'd be happy to help!

After analyzing the text, the main topic is clearly "popcorn." The text mentions "popcorn" multiple times and provides details about different types of popcorn, such as "baby pill" and "kernel puff."

The words that do not relate to the main topic of "popcorn" are:

* "baby"
* "pill"
* "he's"
* "movie"
* "pump"
* "kernel"
* "puff"
* "gerber"
* "popper"
* "earths"
* "fussy"
* "theater"
* "hour"
* "pocket"

These words do not contribute to the main topic of "popcorn" and can be removed from the 

In [38]:
answer[0].replace("\n", "")


"<s>[INST] <<SYS>> You are the assistant for topic modeling <<SYS>>You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic. Provide only topics and nothing moreTEXT: window display color screen manager server application visual mode memory size client colors colormap defaultTOPICS: window, display, color, screen, manager, server, application, visual, mode, memory, size, client, colors, colormap, defaultANSWER:[/INST]  Sure, I'd be happy to help!Main Topic: Display and Color ManagementWords that do not relate to the main topic:* manager* server* application* visual* mode* memory* size* clientThe main topic of the text is display and color management, specifically related to the color screen and colormap. The other words in the list do not directly relate to this topic.</s>"

In [13]:
data["all_topics"] = [", ".join(data.iloc[i]["all_topics"][2:-2].split("', '")) for i in range(len(data))]

In [14]:
test_data = data.sample(frac=1, random_state=42)

In [15]:
test_data.head(10)

Unnamed: 0.1,Unnamed: 0,wordset,all_topics,unsuitable_topics,suitable_topics
1138,1138,scsi disk controller card floppy port fast spe...,"scsi, disk, controller, card, floppy, port, fa...","['jumper', 'board', 'floppy', 'mode', 'bios']","['disk', 'scsi', 'port', 'transfer', 'interfac..."
1628,1628,size baby quickly daughter thanks night start ...,"size, baby, quickly, daughter, thanks, night, ...",['since'],"['month', 'size', 'move', 'baby', 'daughter', ..."
1606,1606,snack bar cooky cereal butter fruit peanut nut...,"snack, bar, cooky, cereal, butter, fruit, pean...",['butter'],"['cracker', 'fruit', 'bar', 'cereal', 'strawbe..."
1977,1977,crime criminal kill death self public murder c...,"crime, criminal, kill, death, self, public, mu...","['situation', 'indiv']","['criminal', 'firearm', 'self', 'public', 'com..."
1526,1526,sweet bar crunchy granola yogurt raisin flake ...,"sweet, bar, crunchy, granola, yogurt, raisin, ...",['cluster'],"['raisin', 'wafer', 'bar', 'flake', 'granola',..."
218,218,organic natural real expensive high syrup whea...,"organic, natural, real, expensive, high, syrup...","['pleased', 'corn', 'wheat', 'maple', 'organic...",[]
1398,1398,ground wire connect circuit panel outlet wirin...,"ground, wire, connect, circuit, panel, outlet,...","['math', 'swap', 'electrical', 'duke', 'screw'...","['outlet', 'circuit', 'neutral', 'panel', 'cab..."
252,252,bible church christ faith matthew catholic scr...,"truth, spirit, word, faith, christ, church, sc...","['scripture', 'word']","['father', 'matthew', 'holy', 'heaven', 'catho..."
1922,1922,wire ground outlet wiring neutral math panel c...,"wire, ground, outlet, wiring, neutral, math, p...","['outlet', 'math', 'neutral', 'circuit', 'pane...",['ground']
643,643,review smell bad know i'll since people disapp...,"review, smell, bad, know', ""i'll"", 'since, peo...",['seal'],"['open', 'smell', 'disappointed', 'agree', 'si..."


In [16]:
torch.cuda.empty_cache()

In [17]:
SYSTEM_PROMPT = "You are an assistant for topic modeling."
INSTRUCTION = "You will receive a TEXT and TOPICS. You must analyze the TEXT, identify the main topic of the TEXT, and select all words from the TOPICS that do not relate to the main topic."

SAMPLE = (test_data.iloc[2]["wordset"], test_data.iloc[2]["all_topics"])
print(SAMPLE)

res = few_shot_pipeline(model, SYSTEM_PROMPT, INSTRUCTION, SAMPLE, device="cuda", n_token=200)

res.strip()

('snack bar cooky cereal butter fruit peanut nut cookie cracker almond oatmeal granola strawberry crunch', 'snack, bar, cooky, cereal, butter, fruit, peanut, nut, cookie, cracker, almond, oatmeal, granola, strawberry, crunch')


'Sure, I\'d be happy to help!After analyzing the text, the main topic is clearly "snacks." The text mentions several types of snacks, including cookies, cereal, and granola.Here are the words from the topics that do not relate to the main topic of "snacks":* bar (not related to snacks)* butter (not related to snacks)* fruit (not related to snacks)* peanut (not related to snacks)* nut (not related to snacks)* almond (not related to snacks)* oatmeal (not related to snacks)The only word from the topics that is related to the main topic of "snacks" is "cookie."Therefore, the answer is:[MAIN TOPIC] = snacks[RELATED WOR'

In [19]:
test_data.iloc[482]["wordset"]

'window display image server application graphics code user screen color motif mouse running format font'

In [None]:
# Create file
RES_FILE_NAME = "data_irrelative_topics_llama_13b.csv"
RES_FILE = DATA_DIR.joinpath(RES_FILE_NAME)
# Remove old file
# if os.path.exists(RES_FILE):
#     os.remove(RES_FILE)

for i in tqdm(range(483, len(test_data))):
# for i in tqdm(range(6)):
    torch.cuda.empty_cache()
    SAMPLE = (test_data.iloc[i]["wordset"], test_data.iloc[i]["all_topics"])
    res_true = test_data.iloc[i]["unsuitable_topics"]
    res_model = few_shot_pipeline(
        model,
        INSTRUCTION,
        SAMPLE,
        device="cuda",
        temp=0.3,
        n_token=100
    )
    with open(RES_FILE, 'a') as file:
      file.write(f"{SAMPLE[0]};{SAMPLE[1]};{res_true};{res_model}\n")


 15%|█▌        | 289/1880 [1:13:31<6:45:16, 15.28s/it]