In [1]:
from collections import defaultdict

In [6]:
!ls

RecSys	clean  demo_data  main.py  new


In [7]:
import sys
sys.path.append('../code')
sys.path.append('RecSys/code')  # for lightning ai

In [9]:
from torch.utils.data import Dataset
from datasets import load_dataset
import random
from prompt_templates import create_prompt_titles
from tqdm import tqdm

In [10]:
from dataclasses import dataclass

@dataclass
class Args:
    dataset = "demo"
    batch_size = 1
    num_workers = 4
    T = 4
    datafraction = 1.0
args = Args()

In [11]:
from enum import Enum

class Type(Enum):
    IDS = 0
    CATEGORIES = 1
    PUBLISHED_TIME = 2
    LAST_MODIFIED_TIME = 3

In [15]:
class EkstraBladetDataset(Dataset):

    def __init__(self, args, create_prompt, type, split="train"):
        
        # Download the dataset from huggingface
        if split == "test":
            self.behaviors = load_dataset(f'Wouter01/testbehaviors', cache_dir=f"../../testbehaviors_data")["train"]
            self.articles = load_dataset(f'Wouter01/testarticles', cache_dir=f"../../testarticles_data")["train"].to_pandas()
            self.history = load_dataset(f'Wouter01/testhistory', cache_dir=f"../../testhistory_data")["train"].to_pandas()
        else:
            self.behaviors = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'behaviors', cache_dir=f"../{args.dataset}_data")[split]
            # self.articles = load_dataset(f'Wouter01/testdemoarticles', cache_dir=f"../../{args.dataset}_data2")["train"].to_pandas()  #only use this for last modified time
            self.articles = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'articles', cache_dir=f"../{args.dataset}_data")["train"].to_pandas()
            self.history = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'history', cache_dir=f"../{args.dataset}_data")[split].to_pandas()

        # Set fast lookup for identifier keys
        self.history.set_index("user_id", inplace=True)
        self.articles.set_index("article_id", inplace=True)

        self.T = args.T  # Number of previous clicked articles to consider
        self.create_prompt = create_prompt  # Function to create a prompt from the data
        self.type = type
        self.datafraction = args.datafraction
        self.split = split

    def __len__(self):
        return int(self.datafraction*len(self.behaviors))

    def __getitem__(self, idx):
        # Every item consits of a positive and negative sample
        behavior = self.behaviors[idx]

        # Get the inview articles
        inview_articles = behavior["article_ids_inview"]

        # Get the history of the user
        user_id = behavior["user_id"]
        history = self.history.loc[user_id]

        # Get the T latest clicked articles by the user
        old_clicks = history["article_id_fixed"]
        old_clicks = old_clicks[-min(len(old_clicks), self.T):]

        # Pick random positive and negative samples
        clicked_articles = behavior["article_ids_clicked"]
        unclicked_articles = [article for article in inview_articles if article not in clicked_articles]
        pos_sample = random.choice(clicked_articles)

        # If all documents clicked just treat one of them as a negative sample
        if len(unclicked_articles) == 0:
            unclicked_articles = clicked_articles
        neg_sample = random.choice(unclicked_articles)

        if self.type == Type.IDS:
            old_clicks = old_clicks.tolist()
        if self.type == Type.CATEGORIES:
            old_clicks = [self.articles.loc[c]["category_str"] for c in old_clicks]
            pos_sample = self.articles.loc[pos_sample]["category_str"]
            neg_sample = self.articles.loc[neg_sample]["category_str"]
        elif self.type == Type.PUBLISHED_TIME:
            old_clicks = [self.articles.loc[c]["published_time"] for c in old_clicks]
            pos_sample = self.articles.loc[pos_sample]["published_time"]
            neg_sample = self.articles.loc[neg_sample]["published_time"]
        elif self.type == Type.LAST_MODIFIED_TIME:
            # make sure to comment out the line with the articles dataset above
            try:
                old_clicks = [self.articles.loc[c]["last_modified_time"] for c in old_clicks]
                pos_sample = self.articles.loc[pos_sample]["last_modified_time"]
                neg_sample = self.articles.loc[neg_sample]["last_modified_time"]
            except:
                raise ValueError("make sure to comment out the line with the articles dataset above")

        return old_clicks, pos_sample, neg_sample

In [99]:
def train(data, backoff=False, generalize=False, penalize_negative=0.0):
    model = defaultdict(lambda: defaultdict(int))
    random.seed(42)
    for old_clicks, pos_sample, neg_sample in tqdm(data):
        history_length = len(old_clicks)
        for length in range(history_length, 0, -1):
            if generalize:
                # create a pattern eg:
                #  [1,2,4,2,1,4] --> [1,2,3,2,1,3]
                #  [7,5,4,5,7,5] --> [1,2,3,2,1,3]
                p2i = {}
                i = 1
                for p in old_clicks[-length:]:
                    if p not in p2i:
                        p2i[p] = i
                        i += 1
                old_clicks = tuple([p2i[p] for p in old_clicks[-length:]])
            else:
                old_clicks = tuple(old_clicks[-length:])
            model[old_clicks][pos_sample] += 1
            model[old_clicks][neg_sample] -= penalize_negative
            if not backoff:
                break
    return model

def evaluate(model, data, backoff=False, generalize=False):
    random.seed(42) 
    correct = tie = lose = total = 0
    for old_clicks, pos_sample, neg_sample in tqdm(data):
        history_length = len(old_clicks)
        for length in range(history_length, 0, -1):
            if generalize:
                p2i = {}
                i = 1
                for p in old_clicks[-length:]:
                    if p not in p2i:
                        p2i[p] = i
                        i += 1
                old_clicks = tuple([p2i[p] for p in old_clicks[-length:]])
            else:
                old_clicks = tuple(old_clicks[-length:])
            if model[old_clicks][pos_sample] > model[old_clicks][neg_sample]:
                correct += 1
            elif model[old_clicks][pos_sample] < model[old_clicks][neg_sample]:
                lose += 1
            else:  # Tie
                if backoff:
                    if length == 1:
                        tie += 1
                    continue
                tie += 1
                break
            break
        total += 1
    return correct, tie, lose, total

In [104]:
for type in (Type.IDS, Type.CATEGORIES, Type.PUBLISHED_TIME):
    train_data = EkstraBladetDataset(args, create_prompt_titles, type, split="train")
    val_data = EkstraBladetDataset(args, create_prompt_titles, type, split="validation")
    for backoff in (False, True):
        for generalize in (False, True):
            for penalize_negative in (0.0, 0.1, 0.5):
                print(f"Type: {type}, Backoff: {backoff}, Generalize: {generalize}, Penalize negative: {penalize_negative}")
                model = train(train_data, backoff, generalize, penalize_negative)
                correct, tie, lose, total = evaluate(model, val_data, backoff, generalize)
                print(f"Validation: Correct: {correct}, Tie: {tie}, Lose: {lose}, Total: {total}, Accuracy: {(correct+0.5*tie)/total}")

Type: Type.IDS, Backoff: False, Generalize: False, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:03<00:00, 6345.59it/s]
100%|██████████| 25356/25356 [00:03<00:00, 6469.58it/s]


Validation: Correct: 0, Tie: 25356, Lose: 0, Total: 25356, Accuracy: 0.5
Type: Type.IDS, Backoff: False, Generalize: False, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:03<00:00, 6547.83it/s]
100%|██████████| 25356/25356 [00:03<00:00, 6392.66it/s]


Validation: Correct: 0, Tie: 25356, Lose: 0, Total: 25356, Accuracy: 0.5
Type: Type.IDS, Backoff: False, Generalize: False, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:03<00:00, 6567.72it/s]
100%|██████████| 25356/25356 [00:03<00:00, 6386.70it/s]


Validation: Correct: 0, Tie: 25356, Lose: 0, Total: 25356, Accuracy: 0.5
Type: Type.IDS, Backoff: False, Generalize: True, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:03<00:00, 6510.16it/s]
100%|██████████| 25356/25356 [00:03<00:00, 6411.45it/s]


Validation: Correct: 749, Tie: 22556, Lose: 2051, Total: 25356, Accuracy: 0.47432560340747754
Type: Type.IDS, Backoff: False, Generalize: True, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:03<00:00, 6542.35it/s]
100%|██████████| 25356/25356 [00:04<00:00, 6256.89it/s]


Validation: Correct: 2345, Tie: 21134, Lose: 1877, Total: 25356, Accuracy: 0.5092285849503077
Type: Type.IDS, Backoff: False, Generalize: True, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:03<00:00, 6450.67it/s]
100%|██████████| 25356/25356 [00:04<00:00, 6237.15it/s]


Validation: Correct: 2762, Tie: 21267, Lose: 1327, Total: 25356, Accuracy: 0.5282970500078876
Type: Type.IDS, Backoff: True, Generalize: False, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:03<00:00, 6425.49it/s]
100%|██████████| 25356/25356 [00:04<00:00, 6254.36it/s]


Validation: Correct: 5, Tie: 25343, Lose: 8, Total: 25356, Accuracy: 0.4999408424041647
Type: Type.IDS, Backoff: True, Generalize: False, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:03<00:00, 6420.64it/s]
100%|██████████| 25356/25356 [00:03<00:00, 6408.39it/s]


Validation: Correct: 9, Tie: 25338, Lose: 9, Total: 25356, Accuracy: 0.5
Type: Type.IDS, Backoff: True, Generalize: False, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:03<00:00, 6529.35it/s]
100%|██████████| 25356/25356 [00:04<00:00, 5815.95it/s]


Validation: Correct: 9, Tie: 25338, Lose: 9, Total: 25356, Accuracy: 0.5
Type: Type.IDS, Backoff: True, Generalize: True, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:03<00:00, 6342.13it/s]
100%|██████████| 25356/25356 [00:03<00:00, 6354.40it/s]


Validation: Correct: 867, Tie: 21890, Lose: 2599, Total: 25356, Accuracy: 0.4658463480044171
Type: Type.IDS, Backoff: True, Generalize: True, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:03<00:00, 6315.92it/s]
100%|██████████| 25356/25356 [00:03<00:00, 6398.59it/s]


Validation: Correct: 2758, Tie: 20494, Lose: 2104, Total: 25356, Accuracy: 0.5128963558920966
Type: Type.IDS, Backoff: True, Generalize: True, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:03<00:00, 6452.77it/s]
100%|██████████| 25356/25356 [00:04<00:00, 6192.07it/s]


Validation: Correct: 3245, Tie: 20601, Lose: 1510, Total: 25356, Accuracy: 0.5342128095914183
Type: Type.CATEGORIES, Backoff: False, Generalize: False, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:13<00:00, 1890.98it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1835.22it/s]


Validation: Correct: 5797, Tie: 14551, Lose: 5008, Total: 25356, Accuracy: 0.5155584477046853
Type: Type.CATEGORIES, Backoff: False, Generalize: False, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:13<00:00, 1885.19it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1822.32it/s]


Validation: Correct: 6201, Tie: 13801, Lose: 5354, Total: 25356, Accuracy: 0.5167021612241679
Type: Type.CATEGORIES, Backoff: False, Generalize: False, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:13<00:00, 1857.76it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1838.41it/s]


Validation: Correct: 5993, Tie: 14221, Lose: 5142, Total: 25356, Accuracy: 0.516781038018615
Type: Type.CATEGORIES, Backoff: False, Generalize: True, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:13<00:00, 1773.56it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1845.78it/s]


Validation: Correct: 10528, Tie: 4922, Lose: 9906, Total: 25356, Accuracy: 0.5122653415365199
Type: Type.CATEGORIES, Backoff: False, Generalize: True, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:13<00:00, 1826.39it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1864.29it/s]


Validation: Correct: 10529, Tie: 4921, Lose: 9906, Total: 25356, Accuracy: 0.5122850607351317
Type: Type.CATEGORIES, Backoff: False, Generalize: True, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:13<00:00, 1853.72it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1908.70it/s]


Validation: Correct: 10524, Tie: 5196, Lose: 9636, Total: 25356, Accuracy: 0.5175106483672504
Type: Type.CATEGORIES, Backoff: True, Generalize: False, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:13<00:00, 1851.94it/s]
100%|██████████| 25356/25356 [00:14<00:00, 1809.03it/s]


Validation: Correct: 10869, Tie: 5024, Lose: 9463, Total: 25356, Accuracy: 0.5277251932481464
Type: Type.CATEGORIES, Backoff: True, Generalize: False, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:13<00:00, 1864.06it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1820.16it/s]


Validation: Correct: 10839, Tie: 5016, Lose: 9501, Total: 25356, Accuracy: 0.5263842877425462
Type: Type.CATEGORIES, Backoff: True, Generalize: False, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:13<00:00, 1887.02it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1823.59it/s]


Validation: Correct: 10806, Tie: 5103, Lose: 9447, Total: 25356, Accuracy: 0.5267983909133933
Type: Type.CATEGORIES, Backoff: True, Generalize: True, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:14<00:00, 1703.25it/s]
100%|██████████| 25356/25356 [00:18<00:00, 1361.21it/s]


Validation: Correct: 10530, Tie: 4919, Lose: 9907, Total: 25356, Accuracy: 0.5122850607351317
Type: Type.CATEGORIES, Backoff: True, Generalize: True, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:12<00:00, 1967.17it/s]
100%|██████████| 25356/25356 [00:10<00:00, 2348.37it/s]


Validation: Correct: 10529, Tie: 4919, Lose: 9908, Total: 25356, Accuracy: 0.5122456223379082
Type: Type.CATEGORIES, Backoff: True, Generalize: True, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:10<00:00, 2272.77it/s]
100%|██████████| 25356/25356 [00:10<00:00, 2364.01it/s]


Validation: Correct: 10665, Tie: 4919, Lose: 9772, Total: 25356, Accuracy: 0.5176092443603092
Type: Type.PUBLISHED_TIME, Backoff: False, Generalize: False, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:12<00:00, 1999.63it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1918.12it/s]


Validation: Correct: 0, Tie: 25356, Lose: 0, Total: 25356, Accuracy: 0.5
Type: Type.PUBLISHED_TIME, Backoff: False, Generalize: False, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:12<00:00, 2003.60it/s]
100%|██████████| 25356/25356 [00:12<00:00, 1997.94it/s]


Validation: Correct: 0, Tie: 25356, Lose: 0, Total: 25356, Accuracy: 0.5
Type: Type.PUBLISHED_TIME, Backoff: False, Generalize: False, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:11<00:00, 2209.70it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1891.80it/s]


Validation: Correct: 0, Tie: 25356, Lose: 0, Total: 25356, Accuracy: 0.5
Type: Type.PUBLISHED_TIME, Backoff: False, Generalize: True, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:11<00:00, 2173.08it/s]
100%|██████████| 25356/25356 [00:12<00:00, 2080.43it/s]


Validation: Correct: 749, Tie: 22556, Lose: 2051, Total: 25356, Accuracy: 0.47432560340747754
Type: Type.PUBLISHED_TIME, Backoff: False, Generalize: True, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:11<00:00, 2075.33it/s]
100%|██████████| 25356/25356 [00:11<00:00, 2182.14it/s]


Validation: Correct: 2345, Tie: 21134, Lose: 1877, Total: 25356, Accuracy: 0.5092285849503077
Type: Type.PUBLISHED_TIME, Backoff: False, Generalize: True, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:11<00:00, 2103.81it/s]
100%|██████████| 25356/25356 [00:12<00:00, 1990.07it/s]


Validation: Correct: 2762, Tie: 21267, Lose: 1327, Total: 25356, Accuracy: 0.5282970500078876
Type: Type.PUBLISHED_TIME, Backoff: True, Generalize: False, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:11<00:00, 2176.37it/s]
100%|██████████| 25356/25356 [00:12<00:00, 2088.95it/s]


Validation: Correct: 5, Tie: 25343, Lose: 8, Total: 25356, Accuracy: 0.4999408424041647
Type: Type.PUBLISHED_TIME, Backoff: True, Generalize: False, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:11<00:00, 2207.94it/s]
100%|██████████| 25356/25356 [00:12<00:00, 2094.42it/s]


Validation: Correct: 9, Tie: 25338, Lose: 9, Total: 25356, Accuracy: 0.5
Type: Type.PUBLISHED_TIME, Backoff: True, Generalize: False, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:11<00:00, 2107.19it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1912.14it/s]


Validation: Correct: 9, Tie: 25338, Lose: 9, Total: 25356, Accuracy: 0.5
Type: Type.PUBLISHED_TIME, Backoff: True, Generalize: True, Penalize negative: 0.0


100%|██████████| 24724/24724 [00:12<00:00, 1905.91it/s]
100%|██████████| 25356/25356 [00:13<00:00, 1849.05it/s]


Validation: Correct: 867, Tie: 21890, Lose: 2599, Total: 25356, Accuracy: 0.4658463480044171
Type: Type.PUBLISHED_TIME, Backoff: True, Generalize: True, Penalize negative: 0.1


100%|██████████| 24724/24724 [00:12<00:00, 1959.98it/s]
100%|██████████| 25356/25356 [00:11<00:00, 2218.76it/s]


Validation: Correct: 2758, Tie: 20494, Lose: 2104, Total: 25356, Accuracy: 0.5128963558920966
Type: Type.PUBLISHED_TIME, Backoff: True, Generalize: True, Penalize negative: 0.5


100%|██████████| 24724/24724 [00:11<00:00, 2192.94it/s]
100%|██████████| 25356/25356 [00:11<00:00, 2134.11it/s]

Validation: Correct: 3245, Tie: 20601, Lose: 1510, Total: 25356, Accuracy: 0.5342128095914183





In [119]:
# investigate just predicting the latest
val_data = EkstraBladetDataset(args, create_prompt_titles, Type.PUBLISHED_TIME, split="validation")
total = win = tie = lose = 0
for old_clicks, pos_sample, neg_sample in val_data:

    if pos_sample == None or neg_sample == None:
        tie += 1
        continue

    if pos_sample > neg_sample:
        win += 1
    elif pos_sample == neg_sample:
        tie += 1  # very unlikely
    else:
        lose += 1
    total += 1
win/total, tie/total, lose/total

(0.49767313456381135, 0.00015775358889414735, 0.5021691118472945)

In [125]:
# closest publish time to latest clicked article
val_data = EkstraBladetDataset(args, create_prompt_titles, Type.PUBLISHED_TIME, split="validation")
total = win = tie = lose = 0
for old_clicks, pos_sample, neg_sample in val_data:
    latest = old_clicks[-1]
    if latest == None or pos_sample == None or neg_sample == None:
        tie += 1
        continue

    if abs(latest - pos_sample) < abs(latest - neg_sample):
        win += 1
    else:
        lose += 1
    total += 1
win/total, tie/total, lose/total

(0.6505758005994636, 0.0, 0.3494241994005364)

In [151]:
# last modified time negatively correlated with click
val_data = EkstraBladetDataset(args, create_prompt_titles, Type.LAST_MODIFIED_TIME, split="train")
total = win = tie = lose = 0
for old_clicks, pos_sample, neg_sample in val_data:
    latest = old_clicks[-1]
    if latest == None or pos_sample == None or neg_sample == None:
        tie += 1
        continue
    
    if pos_sample > neg_sample:
        win += 1
    # if abs(latest - pos_sample) < abs(latest - neg_sample):
    #     win += 1
    else:
        lose += 1
    total += 1
win/total, tie/total, lose/total

(0.3686701181038667, 0.0, 0.6313298818961333)

In [None]:
# another simple policy is to recommend the most clicked article
# count the clicks for each article in the dataset and use that

In [130]:
# predict the article that is most inview
class Args:
    dataset = "demo"
    batch_size = 1
    num_workers = 4
    T = 100000
    datafraction = 1.0
args = Args()

In [131]:
class Count(Dataset):

    def __init__(self, args, create_prompt, type, split="train"):
        
        # Download the dataset from huggingface
        if split == "test":
            self.behaviors = load_dataset(f'Wouter01/testbehaviors', cache_dir=f"../../testbehaviors_data")["train"]
            self.articles = load_dataset(f'Wouter01/testarticles', cache_dir=f"../../testarticles_data")["train"].to_pandas()
            self.history = load_dataset(f'Wouter01/testhistory', cache_dir=f"../../testhistory_data")["train"].to_pandas()
        else:
            self.behaviors = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'behaviors', cache_dir=f"../../{args.dataset}_data")[split]
            self.articles = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'articles', cache_dir=f"../../{args.dataset}_data")["train"].to_pandas()
            self.history = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'history', cache_dir=f"../../{args.dataset}_data")[split].to_pandas()

        # Set fast lookup for identifier keys
        self.history.set_index("user_id", inplace=True)
        self.articles.set_index("article_id", inplace=True)

        self.T = args.T  # Number of previous clicked articles to consider
        self.create_prompt = create_prompt  # Function to create a prompt from the data
        self.type = type
        self.datafraction = args.datafraction
        self.split = split

    def __len__(self):
        return int(self.datafraction*len(self.behaviors))

    def __getitem__(self, idx):
        # Every item consits of a positive and negative sample
        behavior = self.behaviors[idx]

        # Get the inview articles
        inview_articles = behavior["article_ids_inview"]

        return inview_articles

In [141]:
model = defaultdict(int)

train_data = Count(args, create_prompt_titles, Type.IDS, split="validation")
for inview_articles in tqdm(train_data):
    for article in inview_articles:
        model[article] += 1
sum(model.values())

100%|██████████| 25356/25356 [00:01<00:00, 14054.75it/s]


304915

In [142]:
# same for last modified time, have to upload this first
val_data = EkstraBladetDataset(args, create_prompt_titles, Type.IDS, split="validation")
total = win = tie = lose = 0
for _, pos_sample, neg_sample in val_data:
    if model[pos_sample] > model[neg_sample]:
        win += 1
    elif model[pos_sample] == model[neg_sample]:
        tie += 1
    else:
        lose += 1
    total += 1
win/total, tie/total, lose/total

(0.5298548666982174, 0.0020507966556239155, 0.4680943366461587)

In [143]:
class Count(Dataset):

    def __init__(self, args, create_prompt, type, split="train"):
        
        # Download the dataset from huggingface
        if split == "test":
            self.behaviors = load_dataset(f'Wouter01/testbehaviors', cache_dir=f"../../testbehaviors_data")["train"]
            self.articles = load_dataset(f'Wouter01/testarticles', cache_dir=f"../../testarticles_data")["train"].to_pandas()
            self.history = load_dataset(f'Wouter01/testhistory', cache_dir=f"../../testhistory_data")["train"].to_pandas()
        else:
            self.behaviors = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'behaviors', cache_dir=f"../../{args.dataset}_data")[split]
            self.articles = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'articles', cache_dir=f"../../{args.dataset}_data")["train"].to_pandas()
            self.history = load_dataset(f'Wouter01/RecSys_{args.dataset}', 'history', cache_dir=f"../../{args.dataset}_data")[split].to_pandas()

        # Set fast lookup for identifier keys
        self.history.set_index("user_id", inplace=True)
        self.articles.set_index("article_id", inplace=True)

        self.T = args.T  # Number of previous clicked articles to consider
        self.create_prompt = create_prompt  # Function to create a prompt from the data
        self.type = type
        self.datafraction = args.datafraction
        self.split = split

    def __len__(self):
        return int(self.datafraction*len(self.behaviors))

    def __getitem__(self, idx):
        # Every item consits of a positive and negative sample
        behavior = self.behaviors[idx]

        # Get the clicked articles
        article_ids_clicked = behavior["article_ids_clicked"]

        return article_ids_clicked

In [144]:
model = defaultdict(int)

train_data = Count(args, create_prompt_titles, Type.IDS, split="validation")
for inview_articles in tqdm(train_data):
    for article in inview_articles:
        model[article] += 1
sum(model.values())

100%|██████████| 25356/25356 [00:02<00:00, 10813.43it/s]


25505

In [145]:
# same for last modified time, have to upload this first
val_data = EkstraBladetDataset(args, create_prompt_titles, Type.IDS, split="validation")
total = win = tie = lose = 0
for _, pos_sample, neg_sample in val_data:
    if model[pos_sample] > model[neg_sample]:
        win += 1
    elif model[pos_sample] == model[neg_sample]:
        tie += 1
    else:
        lose += 1
    total += 1
win/total, tie/total, lose/total

(0.6527843508439817, 0.017392333175579745, 0.32982331598043857)