# 6.864 ShaderGenerator Final Project

This notebook contains code that trains the Hearthstone baseline model for the ShaderGenerator final project. We first download our Github repo that contains the datasets we use, install some dependencies, and set up our environment and imports.

In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
rm -rf ShaderGenerator
git clone \
    --depth 1 \
    --filter=blob:none \
    --no-checkout \
    https://github.com/RichardMuri/ShaderGenerator.git
cd ShaderGenerator
git checkout main -- torchASN/ card2code/third_party/hearthstone

In [None]:
!pip install torch

In [None]:
import sys
sys.path.append("/content/ShaderGenerator/")
sys.path.append("/content/ShaderGenerator/torchASN/")
import numpy as np
import random
import torch
from torch import cuda
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from dataclasses import dataclass

from common.config import *
from components.dataset import *

from grammar.grammar import Grammar

from grammar.python3.python3_transition_system import Python3TransitionSystem
from models.ASN import ASNParser
from models import nn_utils

from torch import optim
import os
import time

seed = 0
random.seed(seed)
np.random.seed(seed)

device = None
if cuda.is_available():
  device = 'cuda'
  torch.cuda.manual_seed_all(seed)
else:
  print('WARNING: you are running this project on a cpu!')
  device = 'cpu'

# Pre-processing Hearthstone Data

In [None]:
%%bash 
cd ShaderGenerator/
export PYTHONPATH="${PYTHONPATH}:torchASN/"
python torchASN/datasets/hearthstone/make_dataset.py

# Training the Model

We now train the model.

In [None]:
asdl_file = "/content/ShaderGenerator/torchASN/data/hearthstone/python_3_7_12_asdl.txt"
vocab = "/content/ShaderGenerator/torchASN/data/hearthstone/vocab.bin"
train_file = "/content/ShaderGenerator/torchASN/data/hearthstone/train.bin"
dev_file = "/content/ShaderGenerator/torchASN/data/hearthstone/dev.bin"
dropout = 0.3
enc_hid_size = 100
src_emb_size = 100
field_emb_size = 100
max_epoch = 100
clip_grad = 5.0
batch_size = 32
lr = 0.003
model_file = f"model.hearthstone.enc{enc_hid_size}.src{src_emb_size}.field{field_emb_size}.drop{dropout}.max_ep{max_epoch}.batch{batch_size}.lr{lr}.clip_grad{clip_grad}.bin"
save_to = f"/content/ShaderGenerator/torchASN/checkpoints/hearthstone/{model_file}"
log_every = 50
run_val_after = 5
max_decode_step = 70
max_naive_parse_depth = 18

@dataclass
class Arguments:
    asdl_file: str
    vocab: str
    train_file: str
    dev_file: str
    dropout: float
    enc_hid_size: int
    src_emb_size: int
    field_emb_size: int
    max_epoch: int
    clip_grad: float
    batch_size: int
    lr: float
    model_file: str
    save_to: str
    log_every: int
    run_val_after: int
    max_decode_step: int
    max_naive_parse_depth: int

args = Arguments(asdl_file=asdl_file, vocab=vocab, train_file=train_file, dev_file=dev_file, dropout=dropout, enc_hid_size=enc_hid_size, src_emb_size=src_emb_size, 
                 field_emb_size=field_emb_size, max_epoch=max_epoch, clip_grad=clip_grad, batch_size=batch_size, lr=lr, model_file=model_file, save_to=save_to, log_every=log_every, 
                 run_val_after=run_val_after, max_decode_step=max_decode_step, max_naive_parse_depth=max_naive_parse_depth)


train_set = Dataset.from_bin_file(args.train_file)
if args.dev_file:
    dev_set = Dataset.from_bin_file(args.dev_file)
else:
    dev_set = Dataset(examples=[])

vocab = pickle.load(open(args.vocab, 'rb'))
grammar = Grammar.from_text(open(args.asdl_file).read())
# transition_system = Registrable.by_name(args.transition_system)(grammar)
transition_system = Python3TransitionSystem(grammar)

parser = ASNParser(args, transition_system, vocab)
nn_utils.glorot_init(parser.parameters())

optimizer = optim.Adam(parser.parameters(), lr=args.lr)
best_acc = 0.0
log_every = args.log_every

train_begin = time.time()

for epoch in tqdm(range(1, args.max_epoch + 1)):
    train_iter = 0
    loss_val = 0.
    epoch_loss = 0.

    parser.train()

    epoch_begin = time.time()
    for batch_example in train_set.batch_iter(batch_size=args.batch_size, shuffle=False):
        optimizer.zero_grad()
        loss = parser.score(batch_example)
        loss_val += torch.sum(loss).data.item()
        epoch_loss += torch.sum(loss).data.item()
        loss = torch.mean(loss)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(parser.parameters(), args.clip_grad)

        optimizer.step()
        train_iter += 1
        if train_iter % log_every == 0:
            print("[epoch {}, step {}] loss: {:.3f}".format(
                epoch, train_iter, loss_val / (log_every * args.batch_size)))
            loss_val = 0.

    # print(epoch, 'Train loss', '{:.3f}'.format(epoch_loss / len(train_set)), 'time elapsed %d' % (time.time() - epoch_begin))
    print('[epoch {}] train loss {:.3f}, epoch time {:.0f}, total time {:.0f}'.format(
        epoch, epoch_loss / len(train_set), time.time() - epoch_begin, time.time() - train_begin))
    if epoch > args.run_val_after:
        eval_begin = time.time()
        parser.eval()
        with torch.no_grad():
            parse_results = [parser.naive_parse(ex) for ex in dev_set]
        match_results = [transition_system.compare_ast(r, e.tgt_ast) for r, e in zip(parse_results, dev_set)]
        match_acc = sum(match_results) * 1. / len(match_results)
        # print('Eval Acc', match_acc)
        print('[epoch {}] eval acc {:.3f}, eval time {:.0f}'.format(
            epoch, match_acc, time.time() - eval_begin))

        if match_acc >= best_acc:
            best_acc = match_acc
            parser.save(args.save_to)