In [1]:
# Adapted from
# https://github.com/curiousily/Getting-Things-Done-with-Pytorch/blob/master/11.multi-label-text-classification-with-bert.ipynb

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch import nn
from collections import defaultdict
import lightning.pytorch as pl
import torch
import numpy as np

In [3]:
cards = pd.read_json("data/cards.json")
cards

Unnamed: 0,name,rules_text,colors,color_identity,flavour_text,type_line,power,toughness,set
0,Static Orb,"As long as CARDNAME is untapped, players can't...","[0, 0, 0, 0, 0, 1]","[0, 0, 0, 0, 0, 1]",,Artifact,,,7ed
1,Sensory Deprivation,Enchant creature\nEnchanted creature gets -3/-0.,"[0, 1, 0, 0, 0, 0]","[0, 1, 0, 0, 0, 0]",,Enchantment — Aura,,,m14
2,Road of Return,Choose one —\n• Return target permanent card f...,"[0, 0, 0, 0, 1, 0]","[0, 0, 0, 0, 1, 0]",,Sorcery,,,c19
3,Storm Crow,Flying (This creature can't be blocked except ...,"[0, 1, 0, 0, 0, 0]","[0, 1, 0, 0, 0, 0]",,Creature — Bird,1,2,9ed
4,Walking Sponge,tap: Target creature loses your choice of fly...,"[0, 1, 0, 0, 0, 0]","[0, 1, 0, 0, 0, 0]",,Creature — Sponge,1,1,ulg
...,...,...,...,...,...,...,...,...,...
24444,Quarry Beetle,"When CARDNAME enters the battlefield, you may ...","[0, 0, 0, 0, 1, 0]","[0, 0, 0, 0, 1, 0]",,Creature — Insect,4,5,hou
24445,Devoted Hero,,"[1, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0]",,Creature — Elf Soldier,1,2,s99
24446,Without Weakness,Target creature you control gains indestructib...,"[0, 0, 1, 0, 0, 0]","[0, 0, 1, 0, 0, 0]",,Instant,,,hou
24447,Firesong and Sunspeaker,Red instant and sorcery spells you control hav...,"[1, 0, 0, 1, 0, 0]","[1, 0, 0, 1, 0, 0]",,Legendary Creature — Minotaur Cleric,4,6,2x2


In [4]:
import spacy

en = spacy.load('en_core_web_sm')
stopwords = en.Defaults.stop_words

def text_preprocess(input_text):
    # remove all stop words
    input_text = ' '.join([word for word in input_text.split() if word not in stopwords])

    input_text = ''.join([char for char in input_text if char.isalnum() or char == '/' or char == ' '])

    return input_text

X = []
Y = []

# input_text = type_line + rules_text (if not None) + power / toughness (if not None)
for index, card in cards.iterrows():

    input_text = card['type_line']
    if card['rules_text'] is not None:
        input_text += '\n' + card['rules_text']
    if card['power'] is not None:
        input_text += '\n' + card['power'] + '/' + card['toughness']

    input_text = text_preprocess(input_text)
    
    X.append(input_text)
    Y.append(card["color_identity"])

In [5]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.20)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5)

nr_of_targets= 5

# Remove colorless from array and convert from int to float32
y_train = [np.asarray(y)[0:nr_of_targets].astype('float32').ravel() for y in y_train] 
y_test = [np.asarray(y)[0:nr_of_targets].astype('float32').ravel() for y in y_test]
y_val = [np.asarray(y)[0:nr_of_targets].astype('float32').ravel() for y in y_val]

In [6]:
maxlen = 50

# Testing Bert tokenizer on the dataset
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

inputs = tokenizer.batch_encode_plus(
            x_train,
            None,
            max_length= maxlen,
            padding = 'max_length',
            return_token_type_ids= False, 
            return_attention_mask= True,#diff normal/pad tokens
            truncation= True,# Truncate data beyond max length
            return_tensors = 'pt' # PyTorch Tensor format
          )

In [7]:
# Example of tokens
inputs['input_ids'][0]

tensor([  140, 11811,  5332,  2896,  2087,  9240,   122,  2448, 24930,  1181,
         1299,  1161,  2942,   124,   120,   123,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])

In [8]:
# Converting tokens back to text
print(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].squeeze()))

['C', '##rea', '##ture', 'El', '##f', 'Scout', '1', 'green', 'Ad', '##d', 'man', '##a', 'color', '3', '/', '2', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']


In [9]:
# Wrapping the data into a Torch Dataset. This can be used for training in torch
class CardDataset (Dataset):
    def __init__(self, cards, colors, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.cards = cards
        self.labels = colors
        self.max_len = max_len
        
    def __len__(self):
        return len(self.cards)
    
    def __getitem__(self, item_idx):
        text = self.cards[item_idx]
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length= self.max_len,
            padding = 'max_length',
            return_token_type_ids= False,
            return_attention_mask= True,
            truncation=True,
            return_tensors = 'pt'
          )
        
        input_ids = inputs['input_ids'].flatten()
        attn_mask = inputs['attention_mask'].flatten()
               
        return {
          'text': text,
          'input_ids': input_ids ,
          'attention_mask': attn_mask,
          'labels':torch.tensor(self.labels[item_idx],dtype= torch.float)
        }

In [10]:
# Creating dataset using training data
train_dataset = CardDataset(
  x_train, y_train,
  tokenizer,
  max_len=50
)

sample_item = train_dataset[0]
sample_item.keys()

dict_keys(['text', 'input_ids', 'attention_mask', 'labels'])

In [11]:
sample_item["text"]

'Creature  Elf Scout 1 green Add mana color 3/2'

In [12]:
sample_item["labels"]

tensor([0., 0., 0., 0., 1.])

In [13]:
# The training, validation and test datasets are warped into a lightning.pytorch.LightningDataModule
class CardDataModule (pl.LightningDataModule):
    
    def __init__(self,x_train,y_train,x_val,y_val,x_test,y_test,tokenizer, batch_size=32,max_token_len=50):
        super().__init__()
        self.train_text = x_train
        self.train_label = y_train
        self.val_text = x_val
        self.val_label = y_val
        self.test_text = x_test
        self.test_label = y_test
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_token_len = max_token_len

    # Setup the datasets used during training/testing
    def setup(self, stage):
        self.train_dataset = CardDataset(cards=self.train_text, colors=self.train_label, tokenizer=self.tokenizer,max_len= self.max_token_len)
        self.val_dataset= CardDataset(cards=self.val_text, colors=self.val_label,tokenizer=self.tokenizer,max_len = self.max_token_len)
        self.test_dataset = CardDataset(cards=self.test_text, colors=self.test_label,tokenizer=self.tokenizer,max_len = self.max_token_len)

    # Create Torch DataLoaders for the 3 datasets 
    def train_dataloader(self):
         return DataLoader(self.train_dataset,batch_size= self.batch_size, shuffle = True)
    def val_dataloader(self):
         return DataLoader (self.val_dataset,batch_size= self.batch_size)
    def test_dataloader(self):
         return DataLoader (self.test_dataset,batch_size= self.batch_size)

In [14]:
# Define parameters
N_EPOCHS = 4
BATCH_SIZE = 12
MAX_LEN = 50

# Create the DataModule
card_data_module = CardDataModule(x_train, y_train, x_val, y_val, x_test, y_test, tokenizer, BATCH_SIZE, MAX_LEN)

In [15]:
# The lightning.pytorch.LightningModule is the model that will be trained
# Includes the pretrained Bert module as well as an additional Linear layer for our classification
class CardClassifier(pl.LightningModule):
    # Set up the classifier        
    def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased', return_dict=True)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
        self.n_training_steps = n_training_steps
        self.n_warmup_steps = n_warmup_steps
        self.criterion = nn.BCELoss()
            
    def forward(self, input_ids, attention_mask, labels=None):
        output = self.bert(input_ids, attention_mask=attention_mask)
        output = self.classifier(output.pooler_output)
        output = torch.sigmoid(output)    
        loss = 0
        if labels is not None:
            loss = self.criterion(output, labels)
        return loss, output
    
    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": outputs, "labels": labels}

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):

        optimizer = AdamW(self.parameters(), lr=2e-5)

        scheduler = get_linear_schedule_with_warmup(
          optimizer,
          num_warmup_steps=self.n_warmup_steps,
          num_training_steps=self.n_training_steps
        )

        return dict(
          optimizer=optimizer,
          lr_scheduler=dict(
            scheduler=scheduler,
            interval='step'
          )
        )

In [16]:
# Define number of training steps
steps_per_epoch=len(x_train) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS

warmup_steps = total_training_steps // 5
warmup_steps, total_training_steps

(1303, 6516)

In [17]:
# Verify that CUDA is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [18]:
# Create model
model = CardClassifier(
  n_classes=nr_of_targets,
  n_warmup_steps=warmup_steps,
  n_training_steps=total_training_steps 
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [19]:
# Callback to save the best model to checkpoints/best-checkpoint.ckpt
# Checkpoint is overwritten if the val_loss of the new Epoch is better than the current best
checkpoint_callback = pl.callbacks.ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)

In [20]:
# lightning.pytorch.Trainer is a high-level abstraction that manages most of the training
trainer = pl.Trainer(
  callbacks=[checkpoint_callback],
  max_epochs=N_EPOCHS,
  enable_progress_bar=True
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [21]:
# Fit the model using the trainer
# The warning that num_workers is set too low and might be a bottleneck is displayed
# Changeing the value of num_workers does however lead to different errors
# This seems to be a problem with jupyter notebook:
# https://stackoverflow.com/questions/32489352/multiprocessing-program-has-attributeerror-in-anaconda-notebook
trainer.fit(model, card_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type      | Params
-----------------------------------------
0 | bert       | BertModel | 108 M 
1 | classifier | Linear    | 3.8 K 
2 | criterion  | BCELoss   | 0     
-----------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
433.256   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 1630: 'val_loss' reached 0.27665 (best 0.27665), saving model to 'C:\\Users\\Peter Helf\\Desktop\\IT-SEC\\Master\\AI\\semester 2\\NLP\\nlp-magic-color-identifier\\checkpoints\\best-checkpoint.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 3260: 'val_loss' reached 0.24652 (best 0.24652), saving model to 'C:\\Users\\Peter Helf\\Desktop\\IT-SEC\\Master\\AI\\semester 2\\NLP\\nlp-magic-color-identifier\\checkpoints\\best-checkpoint.ckpt' as top 1


In [22]:
# Display val_loss of final epoch
trainer.test()

In [21]:
# Load the final checkpoint from file
trained_model = CardClassifier.load_from_checkpoint(
  'checkpoints/best-checkpoint.ckpt',
  n_classes=nr_of_targets
)
trained_model = trained_model.to(device)
trained_model.eval()
trained_model.freeze()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
predictions = []
labels = []

# Create predictions for our test data. This might take 1-2 minutes
for i in range(len(x_test)):
    encoding = tokenizer.encode_plus(
      x_test[i],
      add_special_tokens=True,
      max_length=50,
      return_token_type_ids=False,
      padding="max_length",
      return_attention_mask=True,
      return_tensors='pt',
    )
    
    _, prediction = trained_model(encoding["input_ids"].to(device), encoding["attention_mask"].to(device))
    
    predictions.append(prediction.flatten())
    labels.append(torch.from_numpy(y_test[i]))
    
predictions = torch.stack(predictions).detach().cpu()
labels = torch.stack(labels).detach().cpu()

In [23]:
labels

tensor([[1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0.],
        ...,
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 1.]])

In [24]:
predictions

tensor([[0.9420, 0.1281, 0.2823, 0.0226, 0.1364],
        [0.0051, 0.0369, 0.0119, 0.0105, 0.0054],
        [0.0468, 0.0018, 0.0152, 0.8669, 0.0812],
        ...,
        [0.0107, 0.4624, 0.9094, 0.0765, 0.0080],
        [0.0175, 0.0060, 0.1797, 0.9400, 0.0063],
        [0.0059, 0.0049, 0.7363, 0.0204, 0.7023]])

In [25]:
from torchmetrics import Accuracy, F1Score, AUROC

# Calculate Accuracy
accuracy = Accuracy(task='multilabel', num_labels=nr_of_targets, threshold=0.5)
print("Accuracy: ", accuracy(predictions, labels))

Accuracy:  tensor(0.9442)


In [26]:
from sklearn.metrics import classification_report

LABELS = ["w", "u", "b", "r", "g"]

y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > 0.5, upper, lower)

print(classification_report(
  y_true, 
  y_pred, 
  target_names=LABELS, 
  zero_division=0
))

              precision    recall  f1-score   support

           w       0.88      0.81      0.84       516
           u       0.89      0.83      0.86       498
           b       0.91      0.83      0.87       534
           r       0.93      0.83      0.88       538
           g       0.89      0.84      0.86       513

   micro avg       0.90      0.83      0.86      2599
   macro avg       0.90      0.83      0.86      2599
weighted avg       0.90      0.83      0.86      2599
 samples avg       0.78      0.76      0.76      2599



In [27]:
from sklearn.metrics import multilabel_confusion_matrix
confusion_matrix = multilabel_confusion_matrix(y_true, y_pred)
print(confusion_matrix)

[[[1870   59]
  [  96  420]]

 [[1898   49]
  [  83  415]]

 [[1869   42]
  [  92  442]]

 [[1875   32]
  [  93  445]]

 [[1880   52]
  [  84  429]]]


In [28]:
# Show some prediction samples
import random

for i in random.sample(range(len(x_test)), 10):
    test_text = x_test[i]
    
    encoding = tokenizer.encode_plus(
      test_text,
      add_special_tokens=True,
      max_length=50,
      return_token_type_ids=False,
      padding="max_length",
      return_attention_mask=True,
      return_tensors='pt',
    )

    _, test_prediction = trained_model(encoding["input_ids"].to(device), encoding["attention_mask"].to(device))
    test_prediction = test_prediction.cpu().flatten().numpy()
    
    
    formatted_y_test = ["%.3f"%item for item in y_test[i]]
    formatted_prediction = ["%.3f"%item for item in test_prediction]
    
    print("Card text: ", test_text)
    print("Colors:           [   w   ,    u   ,    b   ,    r   ,    g   ]")
    print("Actual colors:   ", formatted_y_test)
    print("Predicted colors:", formatted_prediction)
    print()

Card text:  Instant You exile red card hand pay spells mana cost CARDNAME deals 4 damage divided choose number target creatures
Colors:           [   w   ,    u   ,    b   ,    r   ,    g   ]
Actual colors:    ['0.000', '0.000', '0.000', '1.000', '0.000']
Predicted colors: ['0.015', '0.014', '0.014', '0.988', '0.016']

Card text:  Creature  Human Warrior 3/1
Colors:           [   w   ,    u   ,    b   ,    r   ,    g   ]
Actual colors:    ['1.000', '0.000', '0.000', '0.000', '0.000']
Predicted colors: ['0.071', '0.004', '0.207', '0.428', '0.110']

Card text:  Sorcery Destroy target creature You gain life equal toughness You search library and/or graveyard card named Vraska Regal Gorgon reveal it hand If search library way shuffle
Colors:           [   w   ,    u   ,    b   ,    r   ,    g   ]
Actual colors:    ['0.000', '0.000', '1.000', '0.000', '1.000']
Predicted colors: ['0.098', '0.012', '0.936', '0.016', '0.119']

Card text:  Sorcery Return target permanent card mana value 3 grave