In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm
from fastbm25 import fastbm25
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

## Load data

Load the H&M data and generate item descriptors as a new article feature.

In [3]:
def article_id_str_to_int(series):
    return series.astype('int32')

In [4]:
BASE_PATH = '../../../data/'

transactions_original = pd.read_parquet(BASE_PATH + 'parquet/transactions_train.parquet')
customers_original = pd.read_parquet(BASE_PATH + 'parquet/customers.parquet')
articles_original = pd.read_csv(BASE_PATH + 'original/articles.csv')

transactions = transactions_original.copy()
customers = customers_original.copy()
articles = articles_original.copy()
articles.article_id = article_id_str_to_int(articles.article_id)

In [5]:
# add descriptor feature
articles['descriptor'] = articles.prod_name + ' (' + articles.colour_group_name + ' ' + articles.graphical_appearance_name + ' ' + articles.product_type_name + ')'
articles.descriptor

0                          Strap top (Black Solid Vest top)
1                          Strap top (White Solid Vest top)
2                 Strap top (1) (Off White Stripe Vest top)
3                       OP T-shirt (Idro) (Black Solid Bra)
4                       OP T-shirt (Idro) (White Solid Bra)
                                ...                        
105537    5pk regular Placement1 (Black Placement print ...
105538             SPORT Malaga tank (Black Solid Vest top)
105539                  Cartwheel dress (Black Solid Dress)
105540             CLAIRE HAIR CLAW (Black Solid Hair clip)
105541                 Lounge dress (Off White Solid Dress)
Name: descriptor, Length: 105542, dtype: object

# Fine tuning

In [8]:
def fine_tune_gpt2(train_file, output_dir):
    # Load GPT-2 model and tokenizer
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    model.to(mps_device)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # Load training dataset
    train_dataset = TextDataset(
        tokenizer=tokenizer, file_path=train_file, block_size=128
    )
    # Create data collator for language modeling
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    # Set training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=5,
        per_device_train_batch_size=4,
        save_steps=10_000,
    )
    # Train the model
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )
    trainer.train()
    
    # Save the fine-tuned model
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

## Create examples for fine tuning

Get 1000 random users with at least 3 purchases. For each, generate a prompt and expected response based on the item descriptors.

In [464]:
a = transactions.drop_duplicates(['customer_id', 'article_id']).groupby('customer_id').article_id.count()
b = a[a >= 2].index
sample_customers = np.random.choice(b, 1000, replace=False)
sample_transactions = transactions[transactions["customer_id"].isin(sample_customers)]

In [499]:
def generate_prompt(history):
    last_item = history[-1]
    previous_items = history[-6:-1]
    return f"Previously, a customer has bought the following items: <{'>, <'.join(previous_items)}>. In the future, this customer will want to buy <{last_item}><|endoftext|>"

a = pd.merge(sample_transactions, articles[['article_id', 'descriptor']])
prompts = a.groupby('customer_id').descriptor.apply(list).apply(generate_prompt)
with open(BASE_PATH + 'LLM/prompts5.txt','w') as file:
    for prompt in prompts.values:
        file.write(prompt + "\n" )

In [None]:
fine_tune_gpt2(BASE_PATH + "LLM/prompts5.txt", BASE_PATH + "LLM/output5")

# Evaluation

## Create testing data

Get 1000 random users with at least 3 purchases and who were not used for fine tuning. For each, generate a prompt and record the expected response in an array.

In [501]:
a = transactions.drop_duplicates(['customer_id', 'article_id']).groupby('customer_id').article_id.count()
b = a[a>2].index
test_customers = np.random.choice(list(set(b) - set(sample_customers)), 1000, replace=False)
test_transactions = transactions[transactions["customer_id"].isin(test_customers)]

In [522]:
def generate_tests_prompt(history):
    previous_items = history[-6:-1]
    return f"Previously, a customer has bought the following items: <{'>, <'.join(previous_items)}>. In the future, this customer will want to buy <"

a = pd.merge(test_transactions, articles[['article_id', 'descriptor']])
test_prompts = a.groupby('customer_id').descriptor.apply(list).apply(generate_tests_prompt).values
test_history = a.groupby('customer_id').descriptor.apply(list).values

test_truth = a.groupby('customer_id').article_id.last().values

## Inference

In [510]:
MODEL_PATH = BASE_PATH + "LLM/output5"

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained(MODEL_PATH, pad_token_id=tokenizer.eos_token_id)

In [515]:
def generate(p):
    prompt = tokenizer.encode(p, return_tensors="pt", truncation=True, max_length=999)
    results = model.generate(prompt, max_new_tokens=100, num_return_sequences=3, num_beams=5)
    for q in results:
        query = tokenizer.decode(q[len(prompt[0]):], skip_special_tokens=True)[:-1]
        yield query

In [539]:
def tokenize(string):
    return re.sub(r"[^0-9a-zA-Z ]","",string.lower()).split(' ')
    
corpus = list(articles.descriptor.apply(tokenize).values)
bm25 = fastbm25(corpus)

In [541]:
success = set()
for i, (p, t) in enumerate(zip(test_history, test_truth)):
    print(i)
    for q in p[:-1]:
        a = bm25.top_k_sentence(tokenize(q), k=10)
        if t in {articles.iloc[idx].article_id for (_, idx, _) in a}:
            success.add(i)
            print('YESS', i)

0
1
2
3
4
YESS 4
5
6
7
8
YESS 8
9
10
11
12
13
YESS 13
14
15
16
17
18
19
YESS 19
20
YESS 20
21
22
23
YESS 23
24
YESS 24
25
YESS 25
26
27
28
YESS 28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
YESS 42
43
44
45
46
47
48
49
YESS 49
50
51
52
YESS 52
YESS 52
YESS 52
53
54
55
56
57
58
59
60
YESS 60
61
62
63
64
65
66
67
68
69
YESS 69
70
71
72
73
74
75
YESS 75
YESS 75
76
YESS 76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
YESS 95
96
97
98
99
100
101
102
103
104
105
106
107
108
YESS 108
YESS 108
109
110
YESS 110
111
112
113
114
115
116
YESS 116
YESS 116
117
118
119
120
YESS 120
121
122
123
YESS 123
124
YESS 124
YESS 124
YESS 124
125
126
YESS 126
YESS 126
YESS 126
YESS 126
127
128
129
YESS 129
130
131
132
133
134
135
YESS 135
136
YESS 136
137
138
YESS 138
139
140
141
142
YESS 142
YESS 142
143
144
YESS 144
145
YESS 145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
YESS 165
YESS 165
166
167
168
169
170
171
172
YESS 172
173
YESS 173
YESS 173
YESS 173
174


In [11]:
# How often was the last purchased item retrieved?
len(success)/1000