In [129]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import random
import polars as pl
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [161]:
class EkstraBladetDataset(Dataset):

    def __init__(self, create_prompt, tokenizer, split="train", T=5):
        
        # Download the dataset from huggingface
        self.behaviors = load_dataset('Wouter01/RecSys_demo', 'behaviors', cache_dir="demo_data")[split]
        self.articles = load_dataset('Wouter01/RecSys_demo', 'articles', cache_dir="demo_data")[split].to_pandas()
        self.history = load_dataset('Wouter01/RecSys_demo', 'history', cache_dir="demo_data")[split].to_pandas()

        # Set fast lookup for identifier keys
        self.history.set_index("user_id", inplace=True)
        self.articles.set_index("article_id", inplace=True)

        self.T = T  # Number of previous clicked articles to consider
        self.create_prompt = create_prompt  # Function to create a prompt from the data
        self.tokenize = tokenizer

    def __len__(self):
        return len(self.behaviors)

    def __getitem__(self, idx):
        # Every item consits of a positive and negative sample
        behavior = self.behaviors[idx]

        # Pick random positive and negative samples
        clicked_articles = behavior["article_ids_clicked"]
        unclicked_articles = [article for article in behavior["article_ids_inview"] if article not in clicked_articles]
        pos_sample = random.choice(clicked_articles)
        neg_sample = random.choice(unclicked_articles)

        # Get the history of the user
        user_id = behavior["user_id"]
        history = self.history.loc[user_id]

        # Get the T latest clicked articles by the user
        old_clicks = history["article_id_fixed"]
        old_clicks = old_clicks[-min(len(old_clicks), self.T):]

        # Get the article information
        titles, subtitles = [], []  # last two are pos and neg samples
        for article_id in old_clicks.tolist() + [pos_sample, neg_sample]:
            article = self.articles.loc[article_id]
            titles.append(article["title"])
            subtitles.append(article["subtitle"])

        assert len(titles) == self.T + 2 and len(titles) == len(subtitles)
        
        title_pos, title_neg = titles[-2], titles[-1]
        subtitle_pos, subtitle_neg = subtitles[-2], subtitles[-1]
        titles, subtitles = titles[:-2], subtitles[:-2]

        # Create the prompts
        pos_prompt = self.create_prompt(titles, subtitles, title_pos, subtitle_pos)
        neg_prompt = self.create_prompt(titles, subtitles, title_neg, subtitle_neg)

        return (self.tokenize(pos_prompt, padding='max_length', max_length=2048, truncation=True, return_tensors='pt'), 
                self.tokenize(neg_prompt, padding='max_length', max_length=2048, truncation=True, return_tensors='pt'))
    
def create_prompt(titles, subtitles, title, subtitle):
    prompt = f"Given the following titles and subtitles of previously read articles:\n"
    for i, (t, s) in enumerate(zip(titles, subtitles)):
        prompt += f"Article {i+1}:\nTitle: {t}\nSubtitle: {s}\n\n"
    prompt += f"Is the user likely to click on an articles with title {title} and subtitle {subtitle}? (yes/no)\n"
    return prompt

    
tokenizer = AutoTokenizer.from_pretrained("google/mt5-base", model_max_length=2048)

data = EkstraBladetDataset(create_prompt, tokenizer)





In [170]:
def collate_fn(batch):
    tokenized_pos = [prompt[0] for prompt in batch]
    tokenized_neg = [prompt[1] for prompt in batch]

    pos_input_ids = torch.cat([item['input_ids'] for item in tokenized_pos], dim=0)
    pos_attention_mask = torch.cat([item['attention_mask'] for item in tokenized_pos], dim=0)

    neg_input_ids = torch.cat([item['input_ids'] for item in tokenized_neg], dim=0)
    neg_attention_mask = torch.cat([item['attention_mask'] for item in tokenized_neg], dim=0)

    return {
        'pos_input_ids': pos_input_ids,
        'pos_attention_mask': pos_attention_mask,
        'neg_input_ids': neg_input_ids,
        'neg_attention_mask': neg_attention_mask
    }

In [172]:
data_loader = DataLoader(data, batch_size=64, collate_fn=collate_fn, shuffle=True)

In [174]:
for batch in data_loader:
    print(batch['pos_input_ids'].shape, batch['neg_input_ids'].shape)
    break

torch.Size([64, 4096]) torch.Size([64, 4096])


In [None]:
from transformers import AutoTokenizer

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained('your-model-name')

class EkstraBladetDataset(Dataset):
    def __init__(self, create_prompt, tokenizer, split="train", T=5, n_pairs=1):
        self.behaviors = load_dataset('Wouter01/RecSys_demo', 'behaviors', cache_dir="demo_data")[split]
        self.articles = load_dataset('Wouter01/RecSys_demo', 'articles', cache_dir="demo_data")[split].to_pandas()
        self.history = load_dataset('Wouter01/RecSys_demo', 'history', cache_dir="demo_data")[split].to_pandas()
        self.history.set_index("user_id", inplace=True)
        self.articles.set_index("article_id", inplace=True)
        self.T = T
        self.n_pairs = n_pairs
        self.create_prompt = create_prompt
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.behaviors)

    def __getitem__(self, idx):
        behavior = self.behaviors[idx]
        clicked_articles = behavior["article_ids_clicked"]
        unclicked_articles = [article for article in behavior["article_ids_inview"] if article not in clicked_articles]
        user_id = behavior["user_id"]
        history = self.history.loc[user_id]
        old_clicks = history["article_id_fixed"]
        old_clicks = old_clicks[-min(len(old_clicks), self.T):]

        pos_prompts, neg_prompts = [], []

        for _ in range(self.n_pairs):
            pos_sample = random.choice(clicked_articles)
            neg_sample = random.choice(unclicked_articles)
            titles, subtitles = [], []

            for article_id in old_clicks.tolist() + [pos_sample, neg_sample]:
                article = self.articles.loc[article_id]
                titles.append(article["title"])
                subtitles.append(article["subtitle"])

            assert len(titles) == self.T + 2 and len(titles) == len(subtitles)

            title_pos, title_neg = titles[-2], titles[-1]
            subtitle_pos, subtitle_neg = subtitles[-2], subtitles[-1]
            titles, subtitles = titles[:-2], subtitles[:-2]

            pos_prompt = self.create_prompt(titles, subtitles, title_pos, subtitle_pos)
            neg_prompt = self.create_prompt(titles, subtitles, title_neg, subtitle_neg)

            pos_prompts.append(pos_prompt)
            neg_prompts.append(neg_prompt)

        return pos_prompts, neg_prompts

    def tokenize_prompt(self, prompt):
        return self.tokenizer(prompt, padding='max_length', truncation=True, return_tensors='pt')

def custom_collate_fn(batch):
    pos_prompts = [item[0] for item in batch]
    neg_prompts = [item[1] for item in batch]

    pos_prompts = [prompt for sublist in pos_prompts for prompt in sublist]
    neg_prompts = [prompt for sublist in neg_prompts for prompt in sublist]

    tokenized_pos = [tokenizer(prompt, padding='max_length', truncation=True, return_tensors='pt') for prompt in pos_prompts]
    tokenized_neg = [tokenizer(prompt, padding='max_length', truncation=True, return_tensors='pt') for prompt in neg_prompts]

    pos_input_ids = torch.cat([item['input_ids'] for item in tokenized_pos], dim=0)
    pos_attention_mask = torch.cat([item['attention_mask'] for item in tokenized_pos], dim=0)

    neg_input_ids = torch.cat([item['input_ids'] for item in tokenized_neg], dim=0)
    neg_attention_mask = torch.cat([item['attention_mask'] for item in tokenized_neg], dim=0)

    return {
        'pos_input_ids': pos_input_ids,
        'pos_attention_mask': pos_attention_mask,
        'neg_input_ids': neg_input_ids,
        'neg_attention_mask': neg_attention_mask
    }

data = EkstraBladetDataset(create_prompt, tokenizer, n_pairs=5)
data_loader = DataLoader(data, batch_size=100, shuffle=True, collate_fn=custom_collate_fn)

# Example usage
for batch in data_loader:
    print(batch['pos_input_ids'].shape, batch['neg_input_ids'].shape)


In [126]:
for batch in data_loader:
    print(batch)
    break

{'pos_prompt': ["Given the following titles and subtitles of previously read articles:\nArticle 1:\nTitle: Kvinde dræbt og kørt i vandet: Nu starter retssag mod kæresten\nSubtitle: Januar sidste år blev en 24-årig kvinde fundet dræbt i en bil, der var kørt i vandet ved Amager Strandpark. Onsdag begynder retssagen mod drabsofferets 25-årige kæreste\n\nArticle 2:\nTitle: Megan Fox indrømmer: Har aldrig elsket sin krop\nSubtitle: Skuespilleren viser sin krop frem som badetøjsmodel for Sports Illustrated, men hun har ikke et ukompliceret forhold til den\n\nArticle 3:\nTitle: Slår op med kæresten\nSubtitle: Billie Eilish er nu single igen\n\nArticle 4:\nTitle: Bortført pige fundet efter seks år: Genkendt i Netflix-serie\nSubtitle: En seer af Netflix-serien 'Unsolved Mysteries' har nu medvirket til, at en forsvunden pige er blevet genforenet med sin far\n\nArticle 5:\nTitle: Modtaget til stående klapsalver i Cannes\nSubtitle: Da Johnny Depp gjorde comeback ved international filmfestival, var

In [114]:
stop = 10000
for i, j in data:
    # print(i)
    # print('------------------------')
    # print(j)
    stop -= 1
    if stop == 0:
        break

In [115]:
10000 / 6

1666.6666666666667

In [37]:
a = [1,2,3,4]
random.choice(a)

3

In [48]:
next(iter(data))

{'impression_id': 48401,
 'impression_time': datetime.datetime(2023, 5, 21, 21, 6, 50),
 'read_time': 21.0,
 'article_ids_inview': [9774516,
  9771051,
  9770028,
  9775402,
  9774461,
  9759544,
  9773947,
  9142581,
  9775331,
  9775371,
  9759966],
 'article_ids_clicked': [9759966],
 'user_id': 22779,
 'session_id': 21}

In [None]:
# pandas brrrrrrrrrrr

In [45]:
history = load_dataset('Wouter01/RecSys_demo', 'history', cache_dir="demo_data")["train"]

In [50]:
pl_history = pl.DataFrame(history.to_pandas())

In [62]:
for _ in range(100000):
    a = pl_history.filter(pl.col("user_id") == 22779)

In [52]:
a

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,list[datetime[μs]],list[f32],list[i32],list[f32]
22779,"[2023-05-17 15:50:15, 2023-05-17 15:51:08, … 2023-05-18 06:26:39]","[46.0, 100.0, … 15.0]","[9770333, 9769641, … 9770541]","[48.0, 148.0, … 7.0]"


In [55]:
pd_history = history.to_pandas()

In [56]:
pd_history.set_index("user_id", inplace=True)

In [92]:
for _ in range(1000):
    a = pd_history.loc[22779]

In [105]:
for i in a["article_id_fixed"][-5:].tolist() + [0]:
    print(i)

9769641
9769641
9770989
9770541
9770541
0


In [74]:
a["article_id_fixed"]  # duplicates in clicked articles

array([9770333, 9769641, 9769888, 9769641, 9770328, 9769641, 9769641,
       9770989, 9770541, 9770541], dtype=int32)

In [76]:
a

impression_time_fixed      [2023-05-17T15:50:15.000000, 2023-05-17T15:51:...
scroll_percentage_fixed    [46.0, 100.0, 92.0, 100.0, 100.0, 100.0, 100.0...
article_id_fixed           [9770333, 9769641, 9769888, 9769641, 9770328, ...
read_time_fixed            [48.0, 148.0, 18.0, 5.0, 83.0, 8.0, 15.0, 10.0...
Name: 22779, dtype: object

In [77]:
len(a["article_id_fixed"] )

10

In [102]:
a["article_id_fixed"].tolist() + [0]

[9770333,
 9769641,
 9769888,
 9769641,
 9770328,
 9769641,
 9769641,
 9770989,
 9770541,
 9770541,
 0]

In [95]:
for i in a["article_id_fixed"] + [0]:
    print(i)

9770333
9769641
9769888
9769641
9770328
9769641
9769641
9770989
9770541
9770541


In [82]:
articles = load_dataset('Wouter01/RecSys_demo', 'articles', cache_dir="demo_data")["train"].to_pandas()

In [83]:
articles.set_index("article_id", inplace=True)

In [89]:
for _ in range(10000):
    a= articles.loc[9770989]