In [1]:
from sklearn.datasets import fetch_20newsgroups
cats = ['alt.atheism', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'sci.med', 'rec.sport.hockey', 'sci.space', 'soc.religion.christian', 'talk.politics.guns']
newsgroups_train = fetch_20newsgroups(subset='train', categories=cats, remove=("headers", "footers"), )
newsgroups_test = fetch_20newsgroups(subset='test', categories=cats, remove=("headers", "footers"), )

In [2]:
len(newsgroups_train.data)

5184

In [4]:
import json
from pydantic import BaseModel, Field
from typing import Optional, List
import enum

class EducationLevels(str, enum.Enum):
    HIGH_SCHOOL = "high_school"
    BACHELORS = "bachelors"
    MASTERS = "masters"
    PHD = "phd"
    NONE = "none"

class Location(BaseModel):
    city: str
    state_or_province: str
    country: str

class FakeProfile(BaseModel):
    name: str
    occupation: str
    industry: str
    job_description: str
    education: EducationLevels
    major: Optional[str] = Field(default=None)
    location: Location

    @classmethod
    def from_json(cls, data: str):
        return cls(**json.loads(data))


class FakeProfiles(BaseModel):
    profiles: List[FakeProfile]

    @classmethod
    def from_json(cls, data: str):
        return cls(**json.loads(data))

In [5]:
profiles_data = {}
for news_group in cats:
    with open(f"../fake_profiles/{news_group.replace('.', '_')}.json", "r") as f:
        profiles_data[news_group] = FakeProfiles.from_json(f.read())

In [6]:
print(newsgroups_train.target_names)
target_to_name = {i: name for i, name in enumerate(newsgroups_train.target_names)}
target_to_name

['alt.atheism', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.sport.hockey', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns']


{0: 'alt.atheism',
 1: 'comp.windows.x',
 2: 'misc.forsale',
 3: 'rec.autos',
 4: 'rec.sport.hockey',
 5: 'sci.med',
 6: 'sci.space',
 7: 'soc.religion.christian',
 8: 'talk.politics.guns'}

In [7]:
newsgroups_train.data[0]

"Now available: xvertext 4.0 \n--------------\n\nSummary                                  \n-------\nxvertext provides you with four functions to draw strings at any angle in   \nan X window (previous versions were limited to vertical text). Rotation  \nis still achieved using XImages, but the notion of rotating a whole font\nfirst has been dropped.\n\nWhat's new?\n-----------\nI've added a cache which keeps a copy of previously rotated strings - thus\nspeeding up redraws.\n\nWhere can I get it? \n-------------------\ncomp.sources.x (soon...)\nexport.lcs.mit.edu : contrib/xvertext.4.0.shar.Z  (now)\n"

In [9]:
import re


def preprocess(x: str) -> str:
    x = x.replace("\n", " ").replace("\t", " ").replace("\r", " ")

    # Remove emails
    x = re.sub(r"\S*@\S*\s?", "", x)

    # Remove special characters
    x = re.sub(r"[^a-zA-Z0-9 ]", "", x)

    # Remove extra spaces
    x = re.sub(" +", " ", x)

    return x.lower().strip()

In [10]:
import random
profile_article_pairs = []
for profile_group, profiles in profiles_data.items():
    positive_cases = []
    negative_cases = []
    for profile in profiles.profiles:
        for (idx, article) in enumerate(newsgroups_train.data):
            article = preprocess(article)
            if len(article) <= 1000:
                if profile_group == target_to_name[newsgroups_train.target[idx]]:
                    positive_cases.append((profile, article, 1))
                else:
                    negative_cases.append((profile, article, 0))
    # Randomly sample equal amount of negative cases
    positive_cases = random.sample(positive_cases, len(positive_cases)//2)
    negative_cases = random.sample(negative_cases, len(positive_cases))
    profile_article_pairs.extend(positive_cases + negative_cases)


In [11]:
len(profile_article_pairs)

65482

In [12]:
def get_profile_text(profile: FakeProfile):
    return f"{profile.name} is a {profile.occupation} in the {profile.industry} industry. {profile.job_description} {profile.name} has a {profile.education} degree in {profile.major} from {profile.location.city}, {profile.location.state_or_province}, {profile.location.country}."

In [13]:
# Split into train, validation, and test
random.shuffle(profile_article_pairs)
train_cutoff = int(len(profile_article_pairs) * 0.8)
validation_cutoff = int(len(profile_article_pairs) * 0.9)
train_pairs = profile_article_pairs[:train_cutoff]
validation_pairs = profile_article_pairs[train_cutoff:validation_cutoff]
test_pairs = profile_article_pairs[validation_cutoff:]

In [14]:
import json

data_sets = [(train_pairs, 'train.jsonl'), (validation_pairs, 'valid.jsonl'), (test_pairs, 'test.jsonl')]

for data_set, file_name in data_sets:
    with open(file_name, 'w') as f:
        for profile, article, relevance in data_set:
            data = {
                "text": f"User Profile: {get_profile_text(profile)}. Article: {article}. Relevant: {True if relevance == 1 else False}"
            }
            f.write(json.dumps(data) + '\n')

