# Import

In [None]:
import numpy as np
import pandas as pd
import sklearn as sk
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from tqdm import tqdm


from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


In [None]:
!pip install transformers
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m44.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m80.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.12.1 tokenizers-0.13.2 transformers-4.26.1


In [None]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data

## https://www.kaggle.com/datasets/nikdavis/steam-store-games

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
folder = "/content/drive/MyDrive/ML472Project"
data_folder = "/content/drive/MyDrive/ML472Project/Data/"


Mounted at /content/drive


In [None]:
!ls "/content/drive/MyDrive/ML472Project/Data/"

steam.csv		    steam_media_data.csv	 steamspy_tag_data.csv
steam_description_data.csv  steam_requirements_data.csv  steam_support_info.csv


In [None]:
steam_media_df = pd.read_csv(data_folder+"steam_media_data.csv")
steam_media_df.head()

Unnamed: 0,steam_appid,header_image,screenshots,background,movies
0,10,https://steamcdn-a.akamaihd.net/steam/apps/10/...,"[{'id': 0, 'path_thumbnail': 'https://steamcdn...",https://steamcdn-a.akamaihd.net/steam/apps/10/...,
1,20,https://steamcdn-a.akamaihd.net/steam/apps/20/...,"[{'id': 0, 'path_thumbnail': 'https://steamcdn...",https://steamcdn-a.akamaihd.net/steam/apps/20/...,
2,30,https://steamcdn-a.akamaihd.net/steam/apps/30/...,"[{'id': 0, 'path_thumbnail': 'https://steamcdn...",https://steamcdn-a.akamaihd.net/steam/apps/30/...,
3,40,https://steamcdn-a.akamaihd.net/steam/apps/40/...,"[{'id': 0, 'path_thumbnail': 'https://steamcdn...",https://steamcdn-a.akamaihd.net/steam/apps/40/...,
4,50,https://steamcdn-a.akamaihd.net/steam/apps/50/...,"[{'id': 0, 'path_thumbnail': 'https://steamcdn...",https://steamcdn-a.akamaihd.net/steam/apps/50/...,


In [None]:
steam_media_df.loc[97,['header_image']].values

array(['https://steamcdn-a.akamaihd.net/steam/apps/3520/header.jpg?t=1447351348'],
      dtype=object)

## Metadata

In [None]:
steam_df = pd.read_csv(data_folder +"steam.csv")
steam_df.head()

Unnamed: 0,appid,name,release_date,english,developer,publisher,platforms,required_age,categories,genres,steamspy_tags,achievements,positive_ratings,negative_ratings,average_playtime,median_playtime,owners,price
0,10,Counter-Strike,2000-11-01,1,Valve,Valve,windows;mac;linux,0,Multi-player;Online Multi-Player;Local Multi-P...,Action,Action;FPS;Multiplayer,0,124534,3339,17612,317,10000000-20000000,7.19
1,20,Team Fortress Classic,1999-04-01,1,Valve,Valve,windows;mac;linux,0,Multi-player;Online Multi-Player;Local Multi-P...,Action,Action;FPS;Multiplayer,0,3318,633,277,62,5000000-10000000,3.99
2,30,Day of Defeat,2003-05-01,1,Valve,Valve,windows;mac;linux,0,Multi-player;Valve Anti-Cheat enabled,Action,FPS;World War II;Multiplayer,0,3416,398,187,34,5000000-10000000,3.99
3,40,Deathmatch Classic,2001-06-01,1,Valve,Valve,windows;mac;linux,0,Multi-player;Online Multi-Player;Local Multi-P...,Action,Action;FPS;Multiplayer,0,1273,267,258,184,5000000-10000000,3.99
4,50,Half-Life: Opposing Force,1999-11-01,1,Gearbox Software,Valve,windows;mac;linux,0,Single-player;Multi-player;Valve Anti-Cheat en...,Action,FPS;Action;Sci-fi,0,5250,288,624,415,5000000-10000000,3.99


In [None]:
def unique(list1):
    unique_list = []

    for x in list1:
        if x not in unique_list:
            unique_list.append(x)
    return unique_list

genres = []

for idx in steam_df.index:
  genres.extend(steam_df.loc[idx,['genres']].values[0].split(';'))
  genres = unique(genres)

print(genres)
print(len(genres))

['Action', 'Free to Play', 'Strategy', 'Adventure', 'Indie', 'RPG', 'Animation & Modeling', 'Video Production', 'Casual', 'Simulation', 'Racing', 'Violent', 'Massively Multiplayer', 'Nudity', 'Sports', 'Early Access', 'Gore', 'Utilities', 'Design & Illustration', 'Web Publishing', 'Education', 'Software Training', 'Sexual Content', 'Audio Production', 'Game Development', 'Photo Editing', 'Accounting', 'Documentary', 'Tutorial']
29


In [None]:
# # one hot encoding

def one_hot(df,column,attibutes):
  df[attibutes] = 0
  for idx in tqdm(df.index):
    for att in attibutes:
      own_attributes = df.loc[idx,[column]].values[0].split(';')
      for own_att in own_attributes:
        if(own_att == att):
          df.loc[idx,[att]] = 1

  return df

In [None]:
df = one_hot(steam_df,'genres',genres)

100%|██████████| 27075/27075 [10:04<00:00, 44.79it/s]


In [None]:
df.loc[70,:]

appid                                                                2720
name                                                ThreadSpace: Hyperbol
release_date                                                   2007-07-12
english                                                                 1
developer                                                 Iocaine Studios
publisher                                                           Atari
platforms                                                         windows
required_age                                                            0
categories               Single-player;Multi-player;Includes level editor
genres                                              Action;Indie;Strategy
steamspy_tags                                       Strategy;Action;Indie
achievements                                                            0
positive_ratings                                                       31
negative_ratings                      

In [None]:
steam_csv_float_metrics_cols = [
    # "name", # is text handle
    "appid",
    "required_age",
    "achievements",
    "positive_ratings",
    "negative_ratings",
    "average_playtime",
    "median_playtime",
    "price",
]


steam_csv_float_metrics_df = steam_df[steam_csv_float_metrics_cols]
steam_csv_float_metrics_df = steam_csv_float_metrics_df.rename(columns={"appid":"steam_appid"})
steam_csv_float_metrics_df

Unnamed: 0,steam_appid,required_age,achievements,positive_ratings,negative_ratings,average_playtime,median_playtime,price
0,10,0,0,124534,3339,17612,317,7.19
1,20,0,0,3318,633,277,62,3.99
2,30,0,0,3416,398,187,34,3.99
3,40,0,0,1273,267,258,184,3.99
4,50,0,0,5250,288,624,415,3.99
...,...,...,...,...,...,...,...,...
27070,1065230,0,7,3,0,0,0,2.09
27071,1065570,0,0,8,1,0,0,1.69
27072,1065650,0,24,0,1,0,0,3.99
27073,1066700,0,0,2,0,0,0,5.19


# Game desc

In [None]:
steam_description_data_df = pd.read_csv(data_folder +"steam_description_data.csv")
steam_description_data_df.head()

Unnamed: 0,steam_appid,detailed_description,about_the_game,short_description
0,10,Play the world's number 1 online action game. ...,Play the world's number 1 online action game. ...,Play the world's number 1 online action game. ...
1,20,One of the most popular online action games of...,One of the most popular online action games of...,One of the most popular online action games of...
2,30,Enlist in an intense brand of Axis vs. Allied ...,Enlist in an intense brand of Axis vs. Allied ...,Enlist in an intense brand of Axis vs. Allied ...
3,40,Enjoy fast-paced multiplayer gaming with Death...,Enjoy fast-paced multiplayer gaming with Death...,Enjoy fast-paced multiplayer gaming with Death...
4,50,Return to the Black Mesa Research Facility as ...,Return to the Black Mesa Research Facility as ...,Return to the Black Mesa Research Facility as ...


In [None]:
detailed_desc_len = 0
for detailed_desc in steam_description_data_df["detailed_description"]:
    detailed_desc_len += len(str(detailed_desc))

avg_detailed_desc_len = detailed_desc_len / len(steam_description_data_df)
avg_detailed_desc_len

1634.7170556815688

In [None]:
detailed_desc_len = 0
for detailed_desc in steam_description_data_df["about_the_game"]:
    detailed_desc_len += len(str(detailed_desc))

avg_detailed_desc_len = detailed_desc_len / len(steam_description_data_df)
avg_detailed_desc_len

1568.1582278481012

In [None]:
detailed_desc_len = 0
for detailed_desc in steam_description_data_df["short_description"]:
    detailed_desc_len += len(str(detailed_desc))

avg_detailed_desc_len = detailed_desc_len / len(steam_description_data_df)
avg_detailed_desc_len

202.82468720275116

In [None]:
avg_tag_count = 0
for detailed_desc in steam_df["steamspy_tags"]:
    avg_tag_count += len(str(detailed_desc).split(";"))

avg_tag_count = avg_tag_count / len(steam_df)
avg_tag_count

2.880960295475531

In [None]:
steamspy_tag_data_df = pd.read_csv(data_folder +"steamspy_tag_data.csv")
steamspy_tag_data_df.head()

Unnamed: 0,appid,1980s,1990s,2.5d,2d,2d_fighter,360_video,3d,3d_platformer,3d_vision,...,warhammer_40k,web_publishing,werewolves,western,word_game,world_war_i,world_war_ii,wrestling,zombies,e_sports
0,10,144,564,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,550
1,20,0,71,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,30,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,5,122,0,0,0
3,40,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,50,0,77,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [None]:
steamspy_tag_data_df.columns

Index(['appid', '1980s', '1990s', '2.5d', '2d', '2d_fighter', '360_video',
       '3d', '3d_platformer', '3d_vision',
       ...
       'warhammer_40k', 'web_publishing', 'werewolves', 'western', 'word_game',
       'world_war_i', 'world_war_ii', 'wrestling', 'zombies', 'e_sports'],
      dtype='object', length=372)

In [None]:
steamspy_tag_data_df = steamspy_tag_data_df.fillna(0)

In [None]:
total_tag_count = steamspy_tag_data_df.astype(bool).sum(axis=0)
total_tag_count

appid           29022
1980s             130
1990s             176
2.5d              141
2d               3276
                ...  
world_war_i        52
world_war_ii      263
wrestling          16
zombies           720
e_sports           93
Length: 372, dtype: int64

In [None]:
total_tag_weight = steamspy_tag_data_df.sum(axis=0)
total_tag_weight

appid           17275568821
1980s                  5335
1990s                  4847
2.5d                   3996
2d                   105843
                   ...     
world_war_i            2671
world_war_ii          22106
wrestling               478
zombies               96190
e_sports              16671
Length: 372, dtype: int64

In [None]:
steamspy_tag_data_df_sum_sorted = steamspy_tag_data_df.reindex(total_tag_weight.sort_values(ascending=False).index, axis=1)
steamspy_tag_data_df_sum_sorted

Unnamed: 0,appid,action,indie,adventure,multiplayer,singleplayer,casual,rpg,strategy,open_world,...,steam_machine,snowboarding,cycling,bmx,atv,skiing,foreign,hardware,skating,feature_film
0,10,2681,0,0,1659,0,0,0,329,0,...,0,0,0,0,0,0,0,0,0,0
1,20,208,0,15,172,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,30,99,0,0,115,16,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,40,85,0,0,58,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,50,211,0,87,0,148,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29017,1065230,0,21,21,0,0,21,0,0,0,...,0,0,0,0,0,0,0,0,0,0
29018,1065570,21,21,20,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
29019,1065650,21,21,0,0,0,21,0,0,0,...,0,0,0,0,0,0,0,0,0,0
29020,1066700,0,21,20,0,0,21,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [None]:
 steamspy_tag_data_df_count_sorted = steamspy_tag_data_df.reindex(total_tag_count.sort_values(ascending=False).index, axis=1)
 steamspy_tag_data_df_count_sorted

Unnamed: 0,appid,indie,action,adventure,casual,singleplayer,strategy,simulation,rpg,early_access,...,cycling,bmx,snowboarding,atv,foreign,skating,skiing,jet,hardware,feature_film
0,10,0,2681,0,0,0,329,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,20,0,208,15,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,30,0,99,0,0,16,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,40,0,85,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,50,0,211,87,0,148,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29017,1065230,21,0,21,21,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
29018,1065570,21,21,20,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
29019,1065650,21,21,0,21,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
29020,1066700,21,0,20,21,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [None]:
steamspy_tag_data_df_sum_sorted.columns[:80]

Index(['appid', 'action', 'indie', 'adventure', 'multiplayer', 'singleplayer',
       'casual', 'rpg', 'strategy', 'open_world', 'simulation', 'free_to_play',
       'survival', 'first_person', 'fps', 'co_op', 'shooter', 'atmospheric',
       'great_soundtrack', 'sandbox', 'story_rich', 'early_access', 'horror',
       'sci_fi', 'funny', 'third_person', '2d', 'difficult', 'zombies',
       'online_co_op', 'puzzle', 'anime', 'massively_multiplayer', 'crafting',
       'fantasy', 'exploration', 'comedy', 'building', 'gore',
       'pixel_graphics', 'platformer', 'female_protagonist', 'tactical',
       'nudity', 'racing', 'violent', 'stealth', 'space', 'pvp', 'moddable',
       'team_based', 'third_person_shooter', 'post_apocalyptic', 'classic',
       'survival_horror', 'sexual_content', 'visual_novel', 'sports', 'mature',
       'war', 'realistic', 'turn_based', 'cute', 'competitive', 'action_rpg',
       'psychological_horror', 'vr', 'point_&_click', 'retro', 'replay_value',
       'd

In [None]:
steamspy_tag_data_df_count_sorted.columns[:80]

Index(['appid', 'indie', 'action', 'adventure', 'casual', 'singleplayer',
       'strategy', 'simulation', 'rpg', 'early_access', 'puzzle', '2d',
       'great_soundtrack', 'multiplayer', 'atmospheric', 'vr', 'difficult',
       'story_rich', 'free_to_play', 'anime', 'horror', 'platformer',
       'pixel_graphics', 'violent', 'female_protagonist', 'shooter', 'sci_fi',
       'funny', 'gore', 'first_person', 'fantasy', 'open_world', 'retro',
       'arcade', 'co_op', 'sports', 'fps', 'survival', 'nudity',
       'visual_novel', 'family_friendly', 'comedy', 'point_&_click', 'racing',
       'cute', 'sandbox', 'sexual_content', 'classic', 'exploration', 'space',
       'turn_based', 'massively_multiplayer', 'psychological_horror',
       'relaxing', 'third_person', 'replay_value', 'local_multiplayer',
       'shoot_em_up', 'rpgmaker', 'controller', 'zombies', 'colorful',
       'fast_paced', 'rogue_like', 'local_co_op', 'mystery', 'hidden_object',
       'tactical', 'side_scroller', 'meme

In [None]:
# there are a lot of tags that are not sort of overlap
# I took 32 manually from top 80
# some of these have low amount of samples
# SO maybe we can use only top 32 tags if that boosts perf
picked_tags = [
    "indie",
    "action",
    'adventure',
    'casual',
    'simulation',
    "rpg",
    'puzzle',
    '2d',
    'anime',
    'horror',
    'platformer',
    'pixel_graphics',
    'violent',
    'open_world',
    'retro',
    'point_&_click',
    'turn_based',
    'massively_multiplayer',
    'puzzle_platformer',
    'turn_based_strategy',
    'rogue_like',
    'hack_and_slash',
    'fps',
    'survival',
    'local_co_op',
    'tactical',
    'building',
    'survival_horror',
    'stealth',
    'pvp',
    'co_op',
    'sports'
]
len(picked_tags)

32

In [None]:
picked_tags = [tag for tag in steamspy_tag_data_df_count_sorted.columns[:20]]
len(picked_tags), picked_tags

(20,
 ['appid',
  'indie',
  'action',
  'adventure',
  'casual',
  'singleplayer',
  'strategy',
  'simulation',
  'rpg',
  'early_access',
  'puzzle',
  '2d',
  'great_soundtrack',
  'multiplayer',
  'atmospheric',
  'vr',
  'difficult',
  'story_rich',
  'free_to_play',
  'anime'])

In [None]:
steamspy_picked_tags = steamspy_tag_data_df[["appid",*picked_tags]]
steamspy_picked_tags

Unnamed: 0,appid,appid.1,indie,action,adventure,casual,singleplayer,strategy,simulation,rpg,...,puzzle,2d,great_soundtrack,multiplayer,atmospheric,vr,difficult,story_rich,free_to_play,anime
0,10,10,0,2681,0,0,0,329,0,0,...,0,0,0,1659,0,0,0,0,0,0
1,20,20,0,208,15,0,0,0,0,0,...,0,0,0,172,0,0,0,0,0,0
2,30,30,0,99,0,0,16,0,0,0,...,0,0,0,115,0,0,0,0,0,0
3,40,40,0,85,0,0,0,0,0,0,...,0,0,0,58,0,0,0,0,0,0
4,50,50,0,211,87,0,148,0,0,0,...,18,0,25,0,73,0,0,40,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29017,1065230,1065230,21,0,21,21,0,0,0,0,...,12,0,0,0,0,0,0,0,0,0
29018,1065570,1065570,21,21,20,0,0,0,0,0,...,0,0,0,0,0,0,10,0,0,0
29019,1065650,1065650,21,21,0,21,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
29020,1066700,1066700,21,0,20,21,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [None]:
len(steam_description_data_df), len(steamspy_picked_tags)

(27334, 29022)

In [None]:
steam_description_data_df.head()

Unnamed: 0,steam_appid,detailed_description,about_the_game,short_description
0,10,Play the world's number 1 online action game. ...,Play the world's number 1 online action game. ...,Play the world's number 1 online action game. ...
1,20,One of the most popular online action games of...,One of the most popular online action games of...,One of the most popular online action games of...
2,30,Enlist in an intense brand of Axis vs. Allied ...,Enlist in an intense brand of Axis vs. Allied ...,Enlist in an intense brand of Axis vs. Allied ...
3,40,Enjoy fast-paced multiplayer gaming with Death...,Enjoy fast-paced multiplayer gaming with Death...,Enjoy fast-paced multiplayer gaming with Death...
4,50,Return to the Black Mesa Research Facility as ...,Return to the Black Mesa Research Facility as ...,Return to the Black Mesa Research Facility as ...


In [None]:
steamspy_picked_tags = steamspy_picked_tags.rename(columns={"appid":"steam_appid"})
detailed_description_df = pd.merge(steam_description_data_df, steamspy_picked_tags,on="steam_appid")
detailed_description_df

Unnamed: 0,steam_appid,detailed_description,about_the_game,short_description,indie,action,adventure,casual,simulation,rpg,...,fps,survival,local_co_op,tactical,building,survival_horror,stealth,pvp,co_op,sports
0,10,Play the world's number 1 online action game. ...,Play the world's number 1 online action game. ...,Play the world's number 1 online action game. ...,0,2681,0,0,0,0,...,2048,192,0,734,0,0,0,480,0,0
1,20,One of the most popular online action games of...,One of the most popular online action games of...,One of the most popular online action games of...,0,208,15,0,0,0,...,188,0,0,0,0,0,0,0,62,0
2,30,Enlist in an intense brand of Axis vs. Allied ...,Enlist in an intense brand of Axis vs. Allied ...,Enlist in an intense brand of Axis vs. Allied ...,0,99,0,0,0,0,...,138,0,0,14,0,0,0,0,12,0
3,40,Enjoy fast-paced multiplayer gaming with Death...,Enjoy fast-paced multiplayer gaming with Death...,Enjoy fast-paced multiplayer gaming with Death...,0,85,0,0,0,0,...,71,0,0,0,0,0,0,0,0,0
4,50,Return to the Black Mesa Research Facility as ...,Return to the Black Mesa Research Facility as ...,Return to the Black Mesa Research Facility as ...,0,211,87,0,0,0,...,235,0,0,0,0,0,0,0,27,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
27324,1065230,"<img src=""https://steamcdn-a.akamaihd.net/stea...","<img src=""https://steamcdn-a.akamaihd.net/stea...",The Room of Pandora is a third-person interact...,21,0,21,21,0,0,...,0,0,0,0,0,0,0,0,0,0
27325,1065570,Have you ever been so lonely that no one but y...,Have you ever been so lonely that no one but y...,Cyber Gun is a hardcore first-person shooter w...,21,21,20,0,0,0,...,12,0,0,0,0,0,0,0,0,0
27326,1065650,<strong>Super Star Blast </strong>is a space b...,<strong>Super Star Blast </strong>is a space b...,Super Star Blast is a space based game with ch...,21,21,0,21,0,0,...,0,0,0,0,0,0,0,0,0,0
27327,1066700,Pursue a snow-white deer through an enchanted ...,Pursue a snow-white deer through an enchanted ...,Pursue a snow-white deer through an enchanted ...,21,0,20,21,0,0,...,0,0,0,0,0,0,0,0,0,0


In [None]:
X = detailed_description_df
y = detailed_description_df[picked_tags].astype(bool).astype(int)
y.head()

Unnamed: 0,indie,action,adventure,casual,simulation,rpg,puzzle,2d,anime,horror,...,fps,survival,local_co_op,tactical,building,survival_horror,stealth,pvp,co_op,sports
0,0,1,0,0,0,0,0,0,0,0,...,1,1,0,1,0,0,0,1,0,0
1,0,1,1,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,1,0
2,0,1,0,0,0,0,0,0,0,0,...,1,0,0,1,0,0,0,0,1,0
3,0,1,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
4,0,1,1,0,0,0,1,0,0,0,...,1,0,0,0,0,0,0,0,1,0


In [None]:
y.sum(axis=0) # count labels

indie                    20289
action                   12681
adventure                11232
casual                   11145
simulation                5708
rpg                       4705
puzzle                    3225
2d                        3216
anime                     1992
horror                    1965
platformer                1952
pixel_graphics            1944
violent                   1790
open_world                1481
retro                     1525
point_&_click             1115
turn_based                 945
massively_multiplayer      868
puzzle_platformer          556
turn_based_strategy        547
rogue_like                 672
hack_and_slash             404
fps                       1271
survival                  1256
local_co_op                675
tactical                   634
building                   532
survival_horror            516
stealth                    385
pvp                        338
co_op                     1433
sports                    1407
dtype: i

## split

In [None]:
from sklearn.model_selection import train_test_split
test_split = .3
# stratify this
# apparrently sklearn can't stratify with multilabel
# https://stackoverflow.com/questions/53378970/how-to-perform-multilabel-stratified-sampling
# X_train, X_rest, y_train, y_rest = train_test_split(X, y,stratify =y, test_size=test_split, random_state=RANDOM_SEED, shuffle=True)
# X_val, X_test, y_val, y_test = train_test_split(X_rest, y_rest,stratify =y, test_size=0.5, random_state=RANDOM_SEED, shuffle=True)


X_train, X_rest, y_train, y_rest = train_test_split(X, y, test_size=test_split, random_state=RANDOM_SEED, shuffle=True)
X_val, X_test, y_val, y_test = train_test_split(X_rest, y_rest, test_size=0.5, random_state=RANDOM_SEED, shuffle=True)


In [None]:
y_train.sum(axis=0)

indie                    14208
action                    8880
adventure                 7871
casual                    7823
simulation                4017
rpg                       3281
puzzle                    2235
2d                        2255
anime                     1417
horror                    1351
platformer                1356
pixel_graphics            1382
violent                   1239
open_world                1013
retro                     1061
point_&_click              792
turn_based                 654
massively_multiplayer      612
puzzle_platformer          384
turn_based_strategy        395
rogue_like                 480
hack_and_slash             280
fps                        886
survival                   869
local_co_op                476
tactical                   449
building                   366
survival_horror            366
stealth                    270
pvp                        233
co_op                     1001
sports                     982
dtype: i

In [None]:
y_val.sum(axis=0)

indie                    3043
action                   1891
adventure                1691
casual                   1674
simulation                831
rpg                       720
puzzle                    487
2d                        479
anime                     271
horror                    308
platformer                302
pixel_graphics            296
violent                   289
open_world                243
retro                     250
point_&_click             165
turn_based                147
massively_multiplayer     112
puzzle_platformer          88
turn_based_strategy        78
rogue_like                 95
hack_and_slash             56
fps                       176
survival                  180
local_co_op                99
tactical                   83
building                   87
survival_horror            80
stealth                    60
pvp                        50
co_op                     201
sports                    209
dtype: int64

In [None]:
y_test.sum(axis=0)

indie                    3038
action                   1910
adventure                1670
casual                   1648
simulation                860
rpg                       704
puzzle                    503
2d                        482
anime                     304
horror                    306
platformer                294
pixel_graphics            266
violent                   262
open_world                225
retro                     214
point_&_click             158
turn_based                144
massively_multiplayer     144
puzzle_platformer          84
turn_based_strategy        74
rogue_like                 97
hack_and_slash             68
fps                       209
survival                  207
local_co_op               100
tactical                  102
building                   79
survival_horror            70
stealth                    55
pvp                        55
co_op                     231
sports                    216
dtype: int64

## Dataset class

In [None]:
max_tokens = 512
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
bert_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)


Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:

class SteamGameTagDataset(Dataset):

    def __init__(self, steam_game_desc, targets, tokenizer, max_len):
        self.steam_game_desc = steam_game_desc
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.steam_game_desc)
    def encode(self, text):
        return self.tokenizer.encode_plus(
            text,
            max_length=max_tokens,
            truncation=True,
            padding="max_length",
            add_special_tokens=True, # Add '[CLS]' and '[SEP]'
            return_token_type_ids=False,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt',  # Return PyTorch tensors
        )


    def __getitem__(self, item):
        desc = str(self.steam_game_desc[item])
        target = self.targets[item]

        encoding = self.encode(desc)
        return {
            'desc': desc,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'targets': torch.tensor(target)
        }

In [None]:
def create_data_loader(X, y, tokenizer, max_len, batch_size):
  ds = SteamGameTagDataset(
    X.to_numpy(),
    y.to_numpy(dtype=np.float),
    tokenizer=tokenizer,
    max_len=max_len
  )

  return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=    2
  )

## Dataloader

In [None]:
BATCH_SIZE = 16

train_data_loader = create_data_loader(X_train["detailed_description"], y_train, tokenizer, max_tokens, BATCH_SIZE)
val_data_loader = create_data_loader(X_val["detailed_description"], y_val, tokenizer, max_tokens, BATCH_SIZE)
test_data_loader = create_data_loader(X_test["detailed_description"], y_test, tokenizer, max_tokens, BATCH_SIZE)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  y.to_numpy(dtype=np.float),


# Model

In [None]:
def process_bert_output(output):
  last_hidden_state, pooled_output = output['last_hidden_state'], output['pooler_output']
  return last_hidden_state, pooled_output

class SteamGameTagClassifier(nn.Module):

  def __init__(self, n_classes):
    super(SteamGameTagClassifier, self).__init__()
    self.bert = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
    self.drop = nn.Dropout(p=0.3)
    self.linear = nn.Linear(self.bert.config.hidden_size, n_classes)

  def forward(self, input_ids, attention_mask):
    # same as above
    output = self.bert(
      input_ids=input_ids,
      attention_mask=attention_mask
    )
    _, pooled_output = process_bert_output(output)

    output = self.drop(pooled_output)
    output = self.linear(pooled_output)
    output = torch.sigmoid(output)

    return output

In [None]:
model = SteamGameTagClassifier(len(picked_tags))
model = model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# preds_all, targets_all = predict_model(model, test_data_loader, y_test, device)

# print(classification_report(preds_all, targets_all, target_names =picked_tags))

                       precision    recall  f1-score   support

                indie       1.00      0.74      0.85      4099
               action       1.00      0.47      0.64      4100
            adventure       0.17      0.43      0.24       654
               casual       0.99      0.40      0.57      4075
           simulation       1.00      0.21      0.35      4100
                  rpg       0.00      0.00      0.00         0
               puzzle       1.00      0.12      0.22      4096
                   2d       0.97      0.12      0.21      3991
                anime       0.12      0.06      0.08       637
               horror       0.37      0.08      0.13      1467
           platformer       0.01      0.07      0.01        27
       pixel_graphics       0.56      0.07      0.12      2249
              violent       0.29      0.06      0.11      1201
           open_world       1.00      0.05      0.10      4100
                retro       1.00      0.05      0.10  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Train

In [None]:
def train_epoch(
  model,
  data_loader,
  loss_fn,
  optimizer,
  device,
  scheduler,
  n_examples
):
    model = model.train()

    losses = []
    correct_predictions = 0
    with tqdm(data_loader, unit="batch") as tepoch:

        for d in tepoch:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["targets"].to(device)
            # print(targets.dtype)

            outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
            )
            # print(preds.shape)
            # print(targets.shape)
            # _, preds = torch.max(outputs, dim=1)
            loss = loss_fn(outputs, targets)
            preds = torch.round(outputs)

            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

    return correct_predictions.double() / n_examples, np.mean(losses)

In [None]:

last_preds =None
last_targets =None
def eval_model(model, data_loader, loss_fn, device, n_examples):
  model = model.eval()

  losses = []
  correct_predictions = 0

  with torch.no_grad():
    with tqdm(data_loader, unit="batch") as tepoch:
        for d in tepoch:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["targets"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            #   _, preds = torch.max(outputs, dim=1)
            #   print(preds, targets)
            loss = loss_fn(outputs, targets)
            preds = torch.round(outputs)

            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())

  return correct_predictions.double() / n_examples, np.mean(losses)

## Train loop

In [None]:

EPOCHS = 8

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
total_steps = len(train_data_loader) * EPOCHS

scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)

loss_fn = nn.CrossEntropyLoss().to(device)

history = {}
history['train_acc']= []
history['train_loss']= []
history['val_acc']= []
history['val_loss'] = []

best_accuracy = 0

for epoch in range(EPOCHS):

  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)

  train_acc, train_loss = train_epoch(
    model,
    train_data_loader,
    loss_fn,
    optimizer,
    device,
    scheduler,
    len(X_train)
  )


  print(f'Train loss {train_loss} accuracy {train_acc}')

  val_acc, val_loss = eval_model(
    model,
    val_data_loader,
    loss_fn,
    device,
    len(X_val)
  )

  print(f'Val   loss {val_loss} accuracy {val_acc}')
  print()

  history['train_acc'].append(train_acc)
  history['train_loss'].append(train_loss)
  history['val_acc'].append(val_acc)
  history['val_loss'].append(val_loss)

  if val_acc > best_accuracy:
    # torch.save(model.state_dict(), 'best_model_state.bin')
    best_accuracy = val_acc

Epoch 1/8
----------


100%|██████████| 1196/1196 [31:02<00:00,  1.56s/batch]

Train loss 10.977613014959571 accuracy 27.591688447464715





Val   loss 10.770991353112747 accuracy 28.015369602342034

Epoch 2/8
----------


100%|██████████| 1196/1196 [31:09<00:00,  1.56s/batch]

Train loss 10.724148678405827 accuracy 28.158180867746996





Val   loss 10.665694602809838 accuracy 28.100024396194193

Epoch 3/8
----------


100%|██████████| 1196/1196 [31:08<00:00,  1.56s/batch]

Train loss 10.609500484430171 accuracy 28.311918452692108





Val   loss 10.620475426838091 accuracy 28.239326665040252

Epoch 4/8
----------


100%|██████████| 1196/1196 [31:11<00:00,  1.56s/batch]

Train loss 10.52481991372579 accuracy 28.56957658128594





Val   loss 10.600674212732667 accuracy 28.303976579653575

Epoch 5/8
----------


100%|██████████| 1196/1196 [31:11<00:00,  1.57s/batch]

Train loss 10.463779645106666 accuracy 28.774803972817566





Val   loss 10.580447592971854 accuracy 28.45084166869968

Epoch 6/8
----------


100%|██████████| 1196/1196 [31:11<00:00,  1.56s/batch]

Train loss 10.417040462602541 accuracy 28.953894406691063





Val   loss 10.570593724145654 accuracy 28.488655769699925

Epoch 7/8
----------


100%|██████████| 1196/1196 [31:10<00:00,  1.56s/batch]

Train loss 10.383244518320199 accuracy 29.096184004181914





Val   loss 10.563016031386788 accuracy 28.58453281288119

Epoch 8/8
----------


100%|██████████| 1196/1196 [31:10<00:00,  1.56s/batch]

Train loss 10.361717556962699 accuracy 29.193465760585468





Val   loss 10.561169142949597 accuracy 28.639180287875092

CPU times: user 4h 24min 14s, sys: 55.8 s, total: 4h 25min 10s
Wall time: 4h 28min 40s


In [None]:
from sklearn.metrics import classification_report

def predict_model(model, data_loader, y,  device):
    model = model.eval()

    preds_all = []
    targets_all = []

    with torch.no_grad():
        with tqdm(data_loader, unit="batch") as tepoch:
            for d in tepoch:
                input_ids = d["input_ids"].to(device)
                attention_mask = d["attention_mask"].to(device)
                targets = d["targets"].to(device)

                outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
                )
                preds = torch.round(outputs)
                preds_all.append(preds)
                targets_all.append(targets)
    preds_all, targets_all = torch.cat(preds_all), torch.cat(targets_all)
    preds_all, targets_all = preds_all.cpu().detach().numpy(), targets_all.cpu().detach().numpy()
    return preds_all, targets_all



In [None]:
preds_all, targets_all = predict_model(model, test_data_loader, y_test, device)

print(classification_report(preds_all, targets_all, target_names = picked_tags))

              precision    recall  f1-score   support

           0       1.00      0.74      0.85      4100
           1       0.96      0.59      0.73      3111
           2       0.98      0.46      0.63      3511
           3       0.95      0.47      0.63      3318
           4       0.81      0.44      0.57      1603
           5       0.74      0.56      0.64       926
           6       0.76      0.46      0.57       838
           7       0.52      0.33      0.41       767
           8       0.71      0.57      0.63       377
           9       0.60      0.53      0.56       350
          10       0.69      0.49      0.57       413
          11       0.42      0.33      0.37       334
          12       0.47      0.25      0.32       495
          13       0.43      0.37      0.40       260
          14       0.37      0.36      0.37       220
          15       0.70      0.59      0.64       186
          16       0.59      0.42      0.49       203
          17       0.12    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
from sklearn.metrics import  multilabel_confusion_matrix

cm = multilabel_confusion_matrix(targets_all, preds_all )
print(cm)

[[[   0 1062]
  [   0 3038]]

 [[ 922 1268]
  [  67 1843]]

 [[ 549 1881]
  [  40 1630]]

 [[ 696 1756]
  [  86 1562]]

 [[2335  905]
  [ 162  698]]

 [[2992  404]
  [ 182  522]]

 [[3142  455]
  [ 120  383]]

 [[3104  514]
  [ 229  253]]

 [[3635  161]
  [  88  216]]

 [[3629  165]
  [ 121  185]]

 [[3595  211]
  [  92  202]]

 [[3611  223]
  [ 155  111]]

 [[3466  372]
  [ 139  123]]

 [[3712  163]
  [ 128   97]]

 [[3746  140]
  [ 134   80]]

 [[3866   76]
  [  48  110]]

 [[3838  118]
  [  59   85]]

 [[3940   16]
  [ 127   17]]

 [[3978   38]
  [  48   36]]

 [[3926  100]
  [  26   48]]

 [[3959   44]
  [  57   40]]

 [[4032    0]
  [  68    0]]

 [[3713  178]
  [  88  121]]

 [[3775  118]
  [ 132   75]]

 [[4000    0]
  [ 100    0]]

 [[3875  123]
  [  43   59]]

 [[3991   30]
  [  67   12]]

 [[3959   71]
  [  41   29]]

 [[4045    0]
  [  55    0]]

 [[4045    0]
  [  55    0]]

 [[3753  116]
  [ 163   68]]

 [[3770  114]
  [  56  160]]]
