In [1]:
import tensorflow as tf
tf.config.list_physical_devices()

2022-12-06 10:23:54.842496: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [2]:
import pandas as pd
import json
import requests
import matplotlib.pyplot as plt
import re
import numpy as np
import datetime as dt
from sklearn.linear_model import LinearRegression


In [3]:
full_scryfall_df = pd.DataFrame(json.load(open("scryfall.json")))

In [4]:
df = full_scryfall_df[['name', 'mana_cost', 'cmc', 'type_line', 'oracle_text', 'power', 'toughness', 'colors', 'color_identity', 'keywords', 'set', 'released_at', 'rarity', 'games', 'legalities']]

In [5]:
# remove cards with the same name
df = df.sort_values(by=['released_at', 'name'])
df = df.drop_duplicates(subset=['name'])
df = df[df['games'].apply(lambda i: 'paper' in i)]

def legal(legalities):
    v = legalities.values()
    if len(set(v)) == 1 and "not_legal" in v:
        return False
    return True

df = df[df['legalities'].apply(legal)]

In [6]:
unsets = ['unglued', 'unhinged', 'unstable', 'unsanctioned', 'unfinity']
sets = json.loads(requests.get("https://api.scryfall.com/sets").text)
for s in sets["data"]:
    if s['name'].lower() in unsets:
        df = df[~df["set"].str.contains(s['code'])]

In [7]:
df.head()


Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,set,released_at,rarity,games,legalities
31965,Air Elemental,{3}{U}{U},5.0,Creature — Elemental,Flying,4.0,4.0,[U],[U],[Flying],lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."
34069,Ancestral Recall,{U},1.0,Instant,Target player draws three cards.,,,[U],[U],[],lea,1993-08-05,rare,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal..."
30915,Animate Artifact,{3}{U},4.0,Enchantment — Aura,Enchant artifact\nAs long as enchanted artifac...,,,[U],[U],[Enchant],lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."
43342,Animate Dead,{1}{B},2.0,Enchantment — Aura,Enchant creature card in a graveyard\nWhen Ani...,,,[B],[B],[Enchant],lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."
64661,Animate Wall,{W},1.0,Enchantment — Aura,Enchant Wall\nEnchanted Wall can attack as tho...,,,[W],[W],[Enchant],lea,1993-08-05,rare,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."


In [8]:
df = df[~df["type_line"].str.contains("Token", na=False)] # remove tokens


In [9]:
# just vanilla creatures
vanilla_df = df[df["oracle_text"] == ""]
vanilla_df = vanilla_df[vanilla_df["type_line"].str.contains("Creature")]
vanilla_df


Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,set,released_at,rarity,games,legalities
58015,Craw Wurm,{4}{G}{G},6.0,Creature — Wurm,,6,4,[G],[G],[],lea,1993-08-05,common,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."
53877,Earth Elemental,{3}{R}{R},5.0,Creature — Elemental,,4,5,[R],[R],[],lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."
65999,Fire Elemental,{3}{R}{R},5.0,Creature — Elemental,,5,4,[R],[R],[],lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."
34904,Gray Ogre,{2}{R},3.0,Creature — Ogre,,2,2,[R],[R],[],lea,1993-08-05,common,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."
62372,Grizzly Bears,{1}{G},2.0,Creature — Bear,,2,2,[G],[G],[],lea,1993-08-05,common,[paper],"{'standard': 'not_legal', 'future': 'not_legal..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11135,Highborn Vampire,{3}{B},4.0,Creature — Vampire Warrior,,4,3,[B],[B],[],znr,2020-09-25,common,"[arena, paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal..."
72508,Murasa Brute,{2}{G},3.0,Creature — Troll Warrior,,3,3,[G],[G],[],znr,2020-09-25,common,"[arena, paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal..."
22337,Grizzled Outrider,{4}{G},5.0,Creature — Elf Warrior,,5,5,[G],[G],[],khm,2021-02-05,common,"[arena, paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal..."
22403,Ageless Guardian,{1}{W},2.0,Creature — Spirit Soldier,,1,4,[W],[W],[],stx,2021-04-23,common,"[arena, paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal..."


In [10]:
vanilla_df['cmc'].value_counts()


2.0    68
3.0    65
4.0    63
5.0    59
1.0    26
6.0    23
7.0    13
0.0     5
8.0     2
Name: cmc, dtype: int64

In [11]:
x = set()
for i, row in df.iterrows():
    if row['set'] not in ['afr', '40k', 'clb', 'sld']:
        x = x.union(row['keywords'])
    else:
        # ignore multimodal keywords
        x = x.union([i for i in row['keywords'] if ' ' not in i])
keyword_soup = ' '.join(x).lower()
print(keyword_soup)

plainswalk wizardcycling jump-start cohort dredge martyrdom transfigure radiance council's dilemma double agenda lieutenant nightbound swampwalk spectacle reach polymorphine undying deathtouch hadoken rebound splice swampcycling tempting offer bushido soulshift grav-cannon intimidate champion flanking strive mountainwalk shroud annihilator phaeron coven bestow phasing changeling forecast forestcycling adamant aura swap assist undaunted menace outlast bloodthirst rally haunt gravestorm improvise crew legendary landwalk skulk fear kicker sweep ingest retrace shieldwall embalm convoke metalcraft partner unearth ripple vigilance bolster squad protection desertwalk threshold grandeur converge chroma battle cry ravenous amass reconfigure banding more than meets the eye ninjutsu living metal alliance islandwalk emerge clash raid recover first strike encore sunburst demonstrate graft epic meld companion devour hexproof from hellbent cleave cumulative upkeep multikicker landcycling amplify scav

In [12]:
removes = [r'\(.*?\)', r'\{.*?\}', r'—[^ ][^\n]*', r'(P|p)rotection(?! F)[^\n]*', r'\d*', r'Prototype[^\n]*']
#r'—[^{][^T][^\n]*', 
# todo protection, ward, a lot of stuff actually
def is_french_vanilla(row):
    text = row['oracle_text']
    
    if text == '': 
        return True # is just vanilla
    
    if text is np.nan:
        return False # is not valid
    
    for r in removes:
        text = re.sub(r, '', text)
    text = text.replace(',', '').replace(';', '')
    text = text.lower()
    text = text.strip()

    for i in text.split():
        if i not in keyword_soup:
            return False

    return True

french_vanilla_df = df[df["type_line"].str.contains("Creature", na=False)] # for now just creatures
french_vanilla_df = french_vanilla_df.sort_values(by=['name'])
french_vanilla_df['is_french_vanilla'] = french_vanilla_df.apply(is_french_vanilla, axis=1)
french_vanilla_df = french_vanilla_df[french_vanilla_df['is_french_vanilla']]

french_vanilla_df = french_vanilla_df.sort_values(by=['name'])

french_vanilla_df


Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,set,released_at,rarity,games,legalities,is_french_vanilla
34300,Abbey Gargoyles,{2}{W}{W}{W},5.0,Creature — Gargoyle,"Flying, protection from red",3,4,[W],[W],"[Flying, Protection]",hml,1995-10-01,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",True
57889,Abbey Griffin,{3}{W},4.0,Creature — Griffin,"Flying, vigilance",2,2,[W],[W],"[Flying, Vigilance]",isd,2011-09-30,common,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
42340,Aboroth,{4}{G}{G},6.0,Creature — Elemental,Cumulative upkeep—Put a -1/-1 counter on Aboro...,9,9,[G],[G],[Cumulative upkeep],wth,1997-06-09,rare,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
8391,Abzan Guide,{3}{W}{B}{G},6.0,Creature — Human Warrior,Lifelink (Damage dealt by this creature also c...,4,4,"[B, G, W]","[B, G, W]","[Lifelink, Morph]",ktk,2014-09-26,common,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
29596,Accomplished Automaton,{7},7.0,Artifact Creature — Construct,Fabricate 1 (When this creature enters the bat...,5,7,[],[],[Fabricate],kld,2016-09-30,common,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51689,Zodiac Tiger,{2}{G}{G},4.0,Creature — Cat,Forestwalk (This creature can't be blocked as ...,3,4,[G],[G],"[Landwalk, Forestwalk]",ptk,1999-07-06,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",True
54222,Zombie Brute,{6}{B},7.0,Creature — Zombie,Amplify 1 (As this creature enters the battlef...,5,4,[B],[B],"[Amplify, Trample]",lgn,2003-02-03,uncommon,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
76754,Zombie Cutthroat,{3}{B}{B},5.0,Creature — Zombie,Morph—Pay 5 life. (You may cast this card face...,3,4,[B],[B],[Morph],scg,2003-05-26,common,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
61737,Zombie Goliath,{4}{B},5.0,Creature — Zombie Giant,,4,3,[B],[B],[],m10,2009-07-17,common,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True


In [13]:
french_vanilla_df[french_vanilla_df['rarity'] == 'mythic']

Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,set,released_at,rarity,games,legalities,is_french_vanilla
43282,Apex Devastator,{8}{G}{G},10.0,Creature — Chimera Hydra,"Cascade, cascade, cascade, cascade (When you c...",10,10,[G],[G],[Cascade],cmr,2020-11-20,mythic,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
31256,Baneslayer Angel,{3}{W}{W},5.0,Creature — Angel,"Flying, first strike, lifelink, protection fro...",5,5,[W],[W],"[Flying, Lifelink, First strike, Protection]",m10,2009-07-17,mythic,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
54901,Furyborn Hellkite,{4}{R}{R}{R},7.0,Creature — Dragon,Bloodthirst 6 (If an opponent was dealt damage...,6,6,[R],[R],"[Flying, Bloodthirst]",m12,2011-07-15,mythic,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
74281,Impervious Greatwurm,{7}{G}{G}{G},10.0,Creature — Wurm,Convoke (Your creatures can help cast this spe...,16,16,[G],[G],"[Indestructible, Convoke]",grn,2018-10-05,mythic,"[arena, paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
16939,Phyrexian Fleshgorger,{7},7.0,Artifact Creature — Phyrexian Wurm,Prototype {1}{B}{B} — 3/3 (You may cast this s...,7,5,[],[B],"[Lifelink, Menace, Ward, Prototype]",bro,2022-11-18,mythic,"[paper, arena, mtgo]","{'standard': 'legal', 'future': 'legal', 'hist...",True
60540,Sphinx of the Steel Wind,{5}{W}{U}{B},8.0,Artifact Creature — Sphinx,"Flying, first strike, vigilance, lifelink, pro...",6,6,"[B, U, W]","[B, U, W]","[Flying, Lifelink, Vigilance, First strike, Pr...",arb,2009-04-30,mythic,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True
5913,Vorapede,{2}{G}{G}{G},5.0,Creature — Insect,"Vigilance, trample\nUndying (When this creatur...",5,4,[G],[G],"[Undying, Vigilance, Trample]",dka,2012-02-03,mythic,"[paper, mtgo]","{'standard': 'not_legal', 'future': 'not_legal...",True


In [14]:
# todo organize code
alpha_release_date = dt.datetime(1993, 8, 5)
# ignore type line for now
data_df = french_vanilla_df.drop(columns=['name', 'mana_cost', 'type_line', 'oracle_text', 'color_identity', 'set', 'is_french_vanilla', 'games', 'legalities'])
data_df['cmc'] = data_df['cmc'].apply(int)
data_df['power'] = data_df['power'].apply(int)
data_df['toughness'] = data_df['toughness'].apply(int)
data_df['released_at'] = data_df['released_at'].apply(lambda i: (dt.datetime.strptime(i, '%Y-%m-%d').year - alpha_release_date.year))
data_df


Unnamed: 0,cmc,power,toughness,colors,keywords,released_at,rarity
34300,5,3,4,[W],"[Flying, Protection]",2,uncommon
57889,4,2,2,[W],"[Flying, Vigilance]",18,common
42340,6,9,9,[G],[Cumulative upkeep],4,rare
8391,6,4,4,"[B, G, W]","[Lifelink, Morph]",21,common
29596,7,5,7,[],[Fabricate],23,common
...,...,...,...,...,...,...,...
51689,4,3,4,[G],"[Landwalk, Forestwalk]",6,uncommon
54222,7,5,4,[B],"[Amplify, Trample]",10,uncommon
76754,5,3,4,[B],[Morph],10,common
61737,5,4,3,[B],[],16,common


In [15]:
def dummy_list(data_df, one_hot_df, column):
    x = set(data_df.explode(column)[column].values)
    x.remove(np.nan)
    
    for i in x:
        one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
    

one_hot_df = pd.get_dummies(data_df.drop(columns=['colors', 'keywords']))
dummy_list(data_df, one_hot_df, 'colors')
dummy_list(data_df, one_hot_df, 'keywords')
one_hot_df

  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))
  one_hot_df[f'{column}_{i}'] = data_df[column].apply(lambda j: int(i in j))

Unnamed: 0,cmc,power,toughness,released_at,rarity_common,rarity_mythic,rarity_rare,rarity_uncommon,colors_G,colors_R,...,keywords_Delve,keywords_Hexproof,keywords_Surge,keywords_Prowess,keywords_Transmute,keywords_Mill,keywords_Trample,keywords_Suspend,keywords_Dash,keywords_Provoke
34300,5,3,4,2,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
57889,4,2,2,18,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
42340,6,9,9,4,0,0,1,0,1,0,...,0,0,0,0,0,0,0,0,0,0
8391,6,4,4,21,1,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
29596,7,5,7,23,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51689,4,3,4,6,0,0,0,1,1,0,...,0,0,0,0,0,0,0,0,0,0
54222,7,5,4,10,0,0,0,1,0,0,...,0,0,0,0,0,0,1,0,0,0
76754,5,3,4,10,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
61737,5,4,3,16,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [16]:
# make new dummy matrices for interaction terms
X = one_hot_df.drop(columns=['cmc', 'rarity_common'])
y = one_hot_df['cmc']

# make new linear regression
reg2 = LinearRegression().fit(X, y)

# print all coefficients
for name, coef in zip(X.columns, reg2.coef_):
    print(f"{name} Coefficient: {coef}")
print("Intercept:", reg2.intercept_)


power Coefficient: 0.6300050885468829
toughness Coefficient: 0.4826803038488169
released_at Coefficient: -0.025297956183267523
rarity_mythic Coefficient: -2.7038942725159134
rarity_rare Coefficient: -0.6447795562543783
rarity_uncommon Coefficient: -0.11527265003920362
colors_G Coefficient: -0.6139517747558176
colors_R Coefficient: -0.22653991134295204
colors_U Coefficient: -0.20107422390029384
colors_W Coefficient: -0.346862576552905
colors_B Coefficient: -0.1597405681651599
keywords_Plainswalk Coefficient: 1.1391086896843137
keywords_Plainscycling Coefficient: 0.9447987841733585
keywords_Dredge Coefficient: 0.43058394226749147
keywords_Riot Coefficient: 1.2672086005684373
keywords_Transfigure Coefficient: 0.5728252403490872
keywords_Defender Coefficient: -0.8748716082629187
keywords_Swampwalk Coefficient: 0.25013683982145096
keywords_Spectacle Coefficient: 0.05059753444182724
keywords_Reach Coefficient: 0.5254264214126412
keywords_Typecycling Coefficient: -0.3980275813491587
keywords_

In [17]:
# import spacy
# nlp = spacy.load("en_core_web_sm")
creature_df = df[df["type_line"].str.contains("Creature", na=False)]
creature_df["mana_cost_spaced"] = creature_df["mana_cost"].apply(lambda i: i.replace("}{", "} {") if isinstance(i, str) else i)

tokens = {'': 0}
num_tokens = 1

replaces = [
      [",", " {comma} "],
      [":", " {colon} "],
      [".", " {period} "],
      [";", " {semicolon} "],
    ["\n", " {newline} "],
      ["}{", "} {"]]

def get_data(row):
    t = row["oracle_text"]
    if t is not np.nan:
        t = t.replace(row["name"], "{CARDNAME}")
        for i, j in replaces:
            t = t.replace(i, j)
    return [t]
#     return [row["name"], row["mana_cost_spaced"], row["type_line"], row["power"], row["toughness"], row["oracle_text"]]

def process(text):
    # remove reminder text & ,.
    return re.sub(removes[0], '', text).lower().replace(',', '').replace('.', '')

for i, row in creature_df.iterrows():
    for data in get_data(row):
#         doc = row["oracle_text"] # nlp().tokens took too long but im still considering it
        if not isinstance(data, str):
            continue # split cards have no orcale text
        data = process(data)

        for token in data.split():
            if token not in tokens:
                tokens[token] = num_tokens
                num_tokens += 1

creature_df.head()


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  creature_df["mana_cost_spaced"] = creature_df["mana_cost"].apply(lambda i: i.replace("}{", "} {") if isinstance(i, str) else i)


Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,set,released_at,rarity,games,legalities,mana_cost_spaced
31965,Air Elemental,{3}{U}{U},5.0,Creature — Elemental,Flying,4,4,[U],[U],[Flying],lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{3} {U} {U}
5357,Benalish Hero,{W},1.0,Creature — Human Soldier,"Banding (Any creatures with banding, and up to...",1,1,[W],[W],[Banding],lea,1993-08-05,common,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{W}
25933,Birds of Paradise,{G},1.0,Creature — Bird,Flying\n{T}: Add one mana of any color.,0,1,[G],[G],[Flying],lea,1993-08-05,rare,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{G}
58469,Black Knight,{B}{B},2.0,Creature — Human Knight,First strike (This creature deals combat damag...,2,2,[B],[B],"[First strike, Protection]",lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{B} {B}
31122,Bog Wraith,{3}{B},4.0,Creature — Wraith,Swampwalk (This creature can't be blocked as l...,3,3,[B],[B],"[Landwalk, Swampwalk]",lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{3} {B}


In [18]:
print(num_tokens)
tokens_list = list(tokens.keys())

2376


In [19]:
def encode(text):
    if text is np.nan:
        return []
    t = process(text).split()
    return [tokens[i] for i in t]

def encode_row(row):
    data = get_data(row)
    # todo pad mana cost & type line
    out = []
    for i in data:
        out.extend(encode(i))

    return out

def decode(encoded):
    s = " ".join(tokens_list[i] for i in encoded)
    for i, j in replaces[:-1]:
        s = s.replace(j, i + " ")
        s = s.replace(j.strip(), i + " ")
    return s

name = "Pixie Illusionist"
t = creature_df[creature_df["name"] == name].iloc[0]
print(t)
print(encode_row(t))
print(decode(encode_row(t)))

name                                                Pixie Illusionist
mana_cost                                                         {U}
cmc                                                               1.0
type_line                                    Creature — Faerie Wizard
oracle_text         Kicker {3}{G} (You may pay an additional {3}{G...
power                                                               1
toughness                                                           1
colors                                                            [U]
color_identity                                                 [G, U]
keywords                                             [Kicker, Flying]
set                                                               dmu
released_at                                                2022-09-09
rarity                                                         common
games                                            [paper, arena, mtgo]
legalities          

In [20]:
creature_df["encoded"] = creature_df.apply(encode_row, axis=1)
creature_df["encoded_length"] = creature_df["encoded"].apply(len)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  creature_df["encoded"] = creature_df.apply(encode_row, axis=1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  creature_df["encoded_length"] = creature_df["encoded"].apply(len)


In [38]:
creature_df = creature_df[creature_df["encoded_length"] < 20]

In [39]:
total_words = max(creature_df["encoded_length"])
largest_encode = max(creature_df["encoded_length"])
creature_df[creature_df["encoded_length"] == largest_encode]

Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,set,released_at,rarity,games,legalities,mana_cost_spaced,encoded,encoded_length
70512,Thicket Basilisk,{3}{G}{G},5.0,Creature — Basilisk,Whenever Thicket Basilisk blocks or becomes bl...,2,4,[G],[G],[],lea,1993-08-05,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{3} {G} {G},"[66, 19, 67, 35, 68, 36, 69, 39, 70, 65, 32, 7...",19
66352,Aladdin,{2}{R}{R},4.0,Creature — Human Rogue,"{1}{R}{R}, {T}: Gain control of target artifac...",1,1,[R],[R],[],arn,1993-12-17,rare,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{2} {R} {R},"[144, 85, 85, 32, 4, 5, 217, 119, 9, 73, 218, ...",19
37066,Sindbad,{1}{U},2.0,Creature — Human,{T}: Draw a card and reveal it. If it isn't a ...,1,1,[U],[U],[],arn,1993-12-17,uncommon,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{1} {U},"[4, 5, 201, 39, 134, 80, 259, 28, 12, 33, 28, ...",19
61051,Priest of Yawgmoth,{1}{B},2.0,Creature — Phyrexian Human Cleric,"{T}, Sacrifice an artifact: Add an amount of {...",1,2,[B],[B],[],atq,1994-03-04,common,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{1} {B},"[4, 32, 81, 82, 218, 5, 6, 82, 277, 9, 78, 117...",19
1151,Ayesha Tanaka,{W}{W}{U}{U},4.0,Legendary Creature — Human Artificer,"Banding (Any creatures with banding, and up to...",2,2,"[U, W]","[U, W]",[Banding],leg,1994-06-01,rare,[paper],"{'standard': 'not_legal', 'future': 'not_legal...",{W} {W} {U} {U},"[2, 3, 4, 5, 40, 73, 91, 46, 16, 82, 218, 306,...",19
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
76137,Clay Revenant,{1},1.0,Artifact Creature — Golem,Clay Revenant enters the battlefield tapped.\n...,1,2,[],[B],[],bro,2022-11-18,common,"[paper, arena, mtgo]","{'standard': 'legal', 'future': 'legal', 'hist...",{1},"[19, 20, 21, 22, 190, 12, 3, 243, 78, 5, 266, ...",19
35267,Combat Thresher,{7},7.0,Artifact Creature — Construct,Prototype {2}{W} — 1/1 (You may cast this spel...,3,3,[],[W],"[Double strike, Prototype]",bro,2022-11-18,uncommon,"[paper, arena, mtgo]","{'standard': 'legal', 'future': 'legal', 'hist...",{7},"[2351, 243, 167, 657, 292, 3, 759, 14, 3, 176,...",19
43927,Scrapwork Cohort,{4},4.0,Artifact Creature — Soldier,"When Scrapwork Cohort enters the battlefield, ...",3,1,[],[W],[Unearth],bro,2022-11-18,common,"[paper, mtgo, arena]","{'standard': 'legal', 'future': 'legal', 'hist...",{4},"[176, 19, 20, 21, 22, 32, 253, 39, 292, 293, 4...",19
29315,Thopter Architect,{3}{W},4.0,Creature — Human Artificer,Whenever an artifact enters the battlefield un...,2,3,[W],[W],[],bro,2022-11-18,uncommon,"[paper, arena, mtgo]","{'standard': 'legal', 'future': 'legal', 'hist...",{3} {W},"[66, 82, 218, 20, 21, 22, 228, 57, 119, 32, 73...",19


In [41]:
input_sequences = []
for seq in creature_df["encoded"]:
#     input_sequences.append(seq)
    for i in range(1, len(seq)):
        input_sequences.append(seq[:i + 1])

In [43]:
import tensorflow.keras.utils as ku
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np

# pad sequences
max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))

# create predictors and label
total_words = num_tokens
predictors, label = input_sequences[:,:-1],input_sequences[:,-1]
label = ku.to_categorical(label, num_classes=total_words)

In [44]:
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Bidirectional
from tensorflow.keras.models import Sequential
from tensorflow.keras import regularizers
import tensorflow as tf

# https://github.com/nicknochnack/GANBasics/blob/main/FashionGAN-Tutorial.ipynb
# https://towardsdatascience.com/training-neural-networks-to-create-text-like-a-human-23bfdc23c28

model = Sequential()
model.add(Embedding(total_words, 240, input_length=max_sequence_len-1))
model.add(Bidirectional(LSTM(150, return_sequences = True)))
model.add(Dropout(0.2))
model.add(LSTM(100))
model.add(Dense(total_words/2, activation='relu', kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(total_words, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])


In [45]:
max(input_sequences[:,-1])

2352

In [46]:
len(label[0])
total_words

2998

In [49]:
import os
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_3/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=1963)

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('accuracy')>0.93):
            print("\nReached 93% accuracy so cancelling training!")
            self.model.stop_training = True

callbacks = myCallback()
history = model.fit(predictors, label, epochs=300, verbose=1, callbacks=[cp_callback, callbacks])
 

Epoch 1/300
Epoch 1: saving model to training_3/cp-0001.ckpt
Epoch 2/300
Epoch 2: saving model to training_3/cp-0002.ckpt
Epoch 3/300
Epoch 3: saving model to training_3/cp-0003.ckpt
Epoch 4/300
 243/1963 [==>...........................] - ETA: 3:30 - loss: 2.1406 - accuracy: 0.5624

KeyboardInterrupt: 

In [50]:
model.load_weights("training_3/cp-0003.ckpt")

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x1e8655930>

In [33]:
len(input_sequences)

13076

In [55]:
thingy = [tokens["{g}"]]
for i in range(40):
    X = pad_sequences([thingy], maxlen=max_sequence_len - 1, padding='pre')
    thingy.append(np.argmax(model.predict(X)))

print(thingy)
print(decode(thingy))

[106, 5, 19, 86, 107, 87, 30, 9, 88, 12, 54, 55, 16, 39, 65, 12, 3, 1, 32, 81, 19, 12, 3, 19, 86, 107, 63, 110, 63, 59, 119, 39, 65, 12, 54, 55, 101, 44, 198, 12, 54]
{g}: {cardname} gets +1/+1 until end of turn. activate only from a creature.
 flying, sacrifice {cardname}.
 {cardname} gets +1/+1 as long as you control a creature. activate only less to cast. activate


In [None]:
import torch
from torch import nn
import math

In [None]:
eighteen_df = creature_df[creature_df["encoded_length"] == 18]
largest_encode = 18

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(largest_encode, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.model(x)
        return output

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(largest_encode, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, largest_encode),
        )

    def forward(self, x):
        output = self.model(x)
        return output

In [None]:
train_data_length = len(eighteen_df.index) - (len(eighteen_df.index) % 32)
train_data = torch.tensor(pad_sequences(eighteen_df["encoded"].values[:train_data_length], maxlen=largest_encode, padding="post"))
train_labels = torch.zeros(train_data_length)
train_set = [
    (train_data[i], train_labels[i]) for i in range(train_data_length)
]
print(train_set[300])


In [None]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

In [None]:
lr = 0.001
num_epochs = 300
loss_function = nn.BCELoss()

In [None]:
discriminator = Discriminator()
generator = Generator()
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

In [None]:
import tqdm
for epoch in tqdm.trange(num_epochs):
    for n, (real_samples, _) in list(enumerate(train_loader)):
        # Data for training the discriminator
        real_samples_labels = torch.ones((batch_size, 1))
        latent_space_samples = torch.randn((batch_size, largest_encode))
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1))
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, largest_encode))

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()
    
        # Show loss
#         if epoch % 10 == 0 and n == batch_size - 1:
#         print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
#         print(f"Epoch: {epoch} Loss G.: {loss_generator}")
        
#             latent_space_samples = torch.randn(100, 2)
#             generated_samples = generator(latent_space_samples)
#             generated_samples = generated_samples.detach()
#             plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")
#             plt.show()
#             input()

In [None]:
latent_space_samples = torch.randn(1, largest_encode)
generated_samples = generator(latent_space_samples)
generated_samples = generated_samples.detach()

print(decode(list(map(int, list(generated_samples[0])))))
generated_samples

In [None]:
plt.hist(creature_df["encoded_length"])