# A decoder transformer model for wine reviews
https://www.kaggle.com/datasets/zynicide/wine-reviews

In [1]:
# imports
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import json
import re
import string
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


## Load the wine reviews data

In [3]:
with open('data/wine-reviews/winemag-data-130k-v2.json') as json_data:
    wine_data = json.load(json_data)

# Filter the dataset
filtered_data = [
    "wine review | "
    + x["variety"]
    + " from "
    + x["province"]
    + ", "
    + x["country"]
    + ": "
    + x["description"]
    for x in wine_data
    if x["country"] is not None
    and x["province"] is not None
    and x["variety"] is not None
    and x["description"] is not None
]

In [4]:
n_recipes = len(filtered_data)
print(f"{n_recipes} recipes loaded")
example = filtered_data[25]
print(example)

129907 recipes loaded
wine review | Pinot Noir from California, US: Oak and earth intermingle around robust aromas of wet forest floor in this vineyard-designated Pinot that hails from a high-elevation site. Small in production, it offers intense, full-bodied raspberry and blackberry steeped in smoky spice and smooth texture.


## Process and Tokenize data

In [5]:
def pad_punctuation(str):
    # add spaces before and after punctuation
    str = re.sub(f"([{string.punctuation}])", r" \1 ", str)
    # replace one or more spaces with one space
    str = re.sub(' +', ' ', str)
    return str

text_data = [pad_punctuation(x) for x in filtered_data]
text_data = [x.lower() for x in text_data]

In [6]:
# default tokenizer just splits by spaces
tokenizer = get_tokenizer(None)
def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(
    yield_tokens(iter(text_data)),
    specials=["<stop>", "<unk>"],
    special_first=True,
    max_tokens=10000
)
vocab.set_default_index(vocab['<unk>']) # set default unknown to <unk>

In [7]:
print(len(vocab))
text_to_vec = lambda x: vocab(tokenizer(x))
print(vocab.lookup_tokens(range(0, 10)))

10000
['<stop>', '<unk>', ',', '.', 'and', 'the', 'wine', 'a', 'of', 'from']


## Create dataset

In [8]:
MAX_SEQ_LEN = 80 + 1
STOP = 0
def pad_tokens(tok_list):
    if len(tok_list) < MAX_SEQ_LEN:
        return tok_list + [STOP for x in range(MAX_SEQ_LEN - len(tok_list))]
    else:
        return tok_list[:MAX_SEQ_LEN]

def prepare_inputs(text):
    tokenized_sentences = list(map(text_to_vec, text))
    tokenized_sentences = list(map(pad_tokens, tokenized_sentences))
    tokenized_sentences = torch.tensor(tokenized_sentences).to(device)
    x = tokenized_sentences[:, :-1]
    y = tokenized_sentences[:, 1:]
    return x, y

train_dataloader = torch.utils.data.DataLoader(
    text_data,
    batch_size=32,
    shuffle=True,
    collate_fn=prepare_inputs)

In [9]:
x, y  = next(iter(train_dataloader))
# notice that y is just every token shifted over by one
print(x[0], "\n", y[0])

tensor([   6,   11,   12,   26,   29,    9,  104,    2,   22,   10,   53,  798,
          29,    8,   93,    2,  101,    2,  287,    4, 1453,    2,   14,    6,
          76,   24,   16,  342,   28,    8,   98,    4,   68,   27,  173,   21,
         524,   24,   18,    2,   13,  259,  558,   35, 1227,  653,    3,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0], device='cuda:0') 
 tensor([  11,   12,   26,   29,    9,  104,    2,   22,   10,   53,  798,   29,
           8,   93,    2,  101,    2,  287,    4, 1453,    2,   14,    6,   76,
          24,   16,  342,   28,    8,   98,    4,   68,   27,  173,   21,  524,
          24,   18,    2,   13,  259,  558,   35, 1227,  653,    3,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0, 

## Create the causal attention mask function

In [10]:
def causal_attention_mask(n_dest, n_src, dtype):
    i = torch.arange(0, n_dest)[:, None] # num rows
    j = torch.arange(0, n_src)           # num cols
    # note we flip the inequality that's in the book because
    # torch's MHA expects `true` to be masked
    m = i < j - n_src + n_dest         
    mask = m.to(dtype)
    # note torch broadcasts automatically, so no need to add an extra batch dim
    # mask = torch.reshape(mask, (1, n_dest, n_src)) # add extra dim for batch
    # mult = [batch_size, 1, 1]
    return mask

np.transpose(causal_attention_mask(10, 10, dtype=torch.bool))

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.uint8)

## Create a transformer block
Note that pytorch actually has a [transformer layer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html) for plug-and-play use.

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self, num_heads, key_dim, embed_dim, ff_dim, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate
        
        # returns output of shape (batch, seq_len, emb_dim)
        # does value_dim*num_heads = embed dim?
        self.attn = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            kdim=key_dim,
            batch_first=True
        )
        self.dropout1 = nn.Dropout(self.dropout_rate)
        # normalize only over the embedding dimention
        self.ln1 = nn.LayerNorm(embed_dim, eps=1e-6)
        self.ffn1 = nn.Linear(
            in_features=embed_dim,
            out_features=ff_dim,
        )
        self.ffn2 = nn.Linear(
            in_features=ff_dim,
            out_features=embed_dim,
        )
        self.dropout2 = nn.Dropout(self.dropout_rate)
        self.ln2 = nn.LayerNorm(embed_dim, eps=1e-6)

        
    def forward(self, x):
        batch_size, seq_len = x.shape[0], x.shape[1]
        causal_mask = causal_attention_mask(
            seq_len, seq_len, dtype=torch.bool
        ).to(device)
        attn_out, attn_out_weights = self.attn(
            query=x,
            key=x,
            value=x,
            attn_mask=causal_mask
        )
        attn_out = self.dropout1(attn_out)
        out1 = self.ln1(x + attn_out)
        ffn_1 = F.relu(self.ffn1(out1))
        ffn_2 = self.ffn2(ffn_1)
        ffn_out = self.dropout2(ffn_2)
        return self.ln2(out1 + ffn_out), attn_out_weights

## Build the Token and Position Embedding

In [12]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, max_len, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.max_len = max_len
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        
        self.token_emb = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=embed_dim
        )
        self.pos_emb = nn.Embedding(num_embeddings=max_len, embedding_dim=embed_dim)
           
    def forward(self, x):
        max_len = x.shape[-1]
        positions = torch.arange(0, max_len).to(device)
        pos_emb = self.pos_emb(positions)
        tok_emb = self.token_emb(x)
        return tok_emb + pos_emb

## Build the GPT model

In [13]:
class GPT(nn.Module):
    def __init__(self, max_len, vocab_size, embed_dim, num_heads, key_dim, ff_dim):
        super(GPT, self).__init__()
        self.tok_pos_emb = TokenAndPositionEmbedding(max_len, vocab_size, embed_dim)
        # we're only using one transformer block
        self.transformer = TransformerBlock(num_heads, key_dim, embed_dim, ff_dim)
        self.dense = nn.Linear(embed_dim, vocab_size)
        
    def forward(self, x):
        x = self.tok_pos_emb(x)
        x, attn_out_weights = self.transformer(x)
        # no softmax because CELoss does it in there
        x = self.dense(x)
        return x, attn_out_weights

## Train

In [14]:
model = GPT(
    max_len = 80, 
    vocab_size = 10000, 
    embed_dim = 256, 
    num_heads = 2, 
    key_dim = 256, 
    ff_dim = 256).to(device)
optim = torch.optim.Adam(model.parameters()) 
loss_fn = torch.nn.CrossEntropyLoss()
epochs = 20

In [15]:
def sample_from(probs, temperature):
        probs = probs.to('cpu').numpy()
        probs = probs ** (1 / temperature)
        probs = probs / np.sum(probs)
        return np.random.choice(len(probs), p=probs), probs

def generate(model, start_prompt, max_tokens, temperature):
    model.eval()

    start_tokens = vocab.lookup_indices(tokenizer(pad_punctuation(start_prompt)))
    start_tokens = torch.IntTensor(start_tokens).to(device)

    next_token = None
    info = []
    while len(start_tokens) < max_tokens and next_token != 0:
        x = start_tokens
        y, attn_out_weights = model(x.unsqueeze(0))
        y = F.softmax(y, dim=-1)
        with torch.no_grad():
          next_token, probs = sample_from(y[0][-1], temperature)
        info.append(
                {
                    "prompt": start_prompt,
                    "word_probs": probs,
                    "atts": attn_out_weights,
                }
            )
        start_tokens = torch.cat(
            (start_tokens, 
             torch.IntTensor([next_token]).to(device)))
        start_prompt = start_prompt + " " + vocab.lookup_token(next_token)
    print(f"\ngenerated text:\n{start_prompt}\n")
    return info
generate(model, "wine review", max_tokens=80, temperature=1.0);


generated text:
wine review crown strikes martin favorite pick brawny neighborhood shores supportive parmigiano helps exemplified overwhelmingly clarksburg worked maintains whistle opposite magliocco beef 87 haven sleekness 50 becomes denying reims allowing aromatically ripens magnum prevails rush texture bunch flush tom sensuous devillard rolland acids garda overrides dissipates demonstrate wide basque dill 1987 aniseed grapefruits similar loureiro cayenne 2013 mesquite saffron formula provided figs catalonian durbanville outline demeter capers novelty mode layon flush mentholated woody vera pungently delineation garys kitchen crust carpet



In [16]:
for epoch in range(epochs):
    train_loss = 0
    model.train()
    for batch_num, (curr, target) in enumerate(train_dataloader, start=0):
        batch_size = curr.shape[0]
        seq_len = curr.shape[1]

        # curr [32, 200], pred [32, 200, 10000]
        pred, weights = model(curr)
        cache_p = pred
        # CE loss expects n x c for predicted values, where C is a list of probabilities
        # so we need to reshape our [batch, seq, vocab] into [batch * seq, vocab]
        pred = pred.reshape(batch_size * seq_len, -1)
        target = target.reshape(-1)

        loss = loss_fn(pred, target) 
        optim.zero_grad() # reset gradients
        loss.backward()
        optim.step()

        if batch_num % 10 == 0:
            print(f"\rBatch {batch_num}/{len(train_dataloader)}, loss {loss:.4f}", end='')
            print(" ", [round(x, 3) for x in torch.topk(cache_p[8][-1], 5).values.tolist()],
                  " ",torch.topk(cache_p[8][-1], 5).indices.tolist(), end='           ')
        train_loss += loss 
    train_loss /= len(train_dataloader)

    print(f'\nEpoch:{epoch}, Train Loss:{train_loss:.4f}')
    generate(model, "wine review", max_tokens=80, temperature=1.0);

Batch 4050/4060, loss 2.5119  [18.415, 6.191, 5.94, 5.907, 5.817]   [0, 3, 36, 4, 805]           
Epoch:0, Train Loss:2.4556

generated text:
wine review | gewürztraminer from washington , us : the aromatic lemon notes dominate the nose of this wine from the oldest grape varieties and cool tar that are followed by the expected crispness of roussanne that charms off the fresh , brimming with generous aromatics . the wine should pair nicely with shrimp , it gains almost no . <stop>

Batch 4050/4060, loss 1.9432  [20.329, 7.625, 7.434, 6.607, 6.334]   [0, 805, 36, 1713, 428]           
Epoch:1, Train Loss:2.0833

generated text:
wine review | arinto from tejo , portugal : this is a flowery wine with juicy berry character . often goes grape layers of fruit and light acidity , it ' s bright in style of obvious vintage , it is ready to drink . <stop>

Batch 4050/4060, loss 1.6425  [20.387, 6.923, 6.901, 6.48, 5.908]   [0, 805, 36, 428, 1713]           
Epoch:2, Train Loss:1.9941

generated t

## Save model

In [17]:
torch.save(model.state_dict(), "wineGPT")
checkpoint = torch.load("wineGPT", map_location=device)
model.load_state_dict(checkpoint)

<All keys matched successfully>

## Generate some reviews!

In [18]:
generate(model, "wine review | ruby port from napa, california ", max_tokens=80, temperature=.5);


generated text:
wine review | ruby port from napa, california  , us : this is a simple , dry and crisp wine , with a clean , fresh acidity . it has flavors of cherry and raspberry , cola and spice . <stop>



In [19]:
generate(model, "wine review | sweet red blend from ", max_tokens=80, temperature=.5);


generated text:
wine review | sweet red blend from  central italy , italy : this dessert wine is made from syrah . it offers a lot of appeal , but the wine ' s also <unk> with an attractive floral fragrance of red berries and black pepper . the wine has a touch of sweetness that would pair well with pasta or lasagna . <stop>



In [20]:
generate(model, "wine review | strawberry ", max_tokens=80, temperature=.5);


generated text:
wine review | strawberry  - style red blend from washington , us : this is a blend of syrah ( 33 % ) and syrah ( 33 % ) and petit verdot . the aromas are light and reserved in style , with a sense of restraint and balance . the tannins are soft and the flavors are light and elegant . <stop>



## Bonus: attention visualizer
Darker blue indicates more attention on the word.

In [21]:
from IPython.display import display, HTML

def print_probs(info, vocab, top_k=5):
    for i in info:
        highlighted_text = []
        attns = i["atts"][0][-1]
        max_attn = max(attns).item()
        for word, att_score in zip(
            i["prompt"].split(), attns
        ):
            highlighted_text.append(
                '<span style="background-color:rgba(135,206,250,'
                + str(att_score.item()/max_attn)
                + ');">'
                + word
                + "</span>"
            )
        highlighted_text = " ".join(highlighted_text)
        display(HTML(highlighted_text))

        word_probs = i["word_probs"]
        p_sorted = np.sort(word_probs)[::-1][:top_k]
        i_sorted = np.argsort(word_probs)[::-1][:top_k]
        for p, i in zip(p_sorted, i_sorted):
            print(f"{vocab.lookup_token(i)}:   \t{np.round(100*p,2)}%")
        print("--------\n")

In [22]:
info = generate(model, "wine review | dry riesling ", max_tokens=80, temperature=.5);
print_probs(info, vocab)


generated text:
wine review | dry riesling  from california , us : a nice everyday grigio , with pleasant citrus , pear and spice flavors . <stop>



from:   	100.0%
-:   	0.0%
,:   	0.0%
blend:   	0.0%
<unk>:   	0.0%
--------



northeastern:   	49.05%
california:   	33.48%
sicily:   	7.04%
<unk>:   	4.68%
new:   	2.07%
--------



,:   	100.0%
-:   	0.0%
and:   	0.0%
but:   	0.0%
with:   	0.0%
--------



us:   	100.0%
france:   	0.0%
italy:   	0.0%
chile:   	0.0%
germany:   	0.0%
--------



::   	100.0%
like:   	0.0%
,:   	0.0%
and:   	0.0%
that:   	0.0%
--------



this:   	54.28%
a:   	32.77%
made:   	4.56%
there:   	3.71%
with:   	1.6%
--------



nice:   	28.69%
little:   	23.57%
good:   	15.43%
very:   	11.12%
pretty:   	3.1%
--------



,:   	34.11%
everyday:   	14.33%
wine:   	13.91%
white:   	12.8%
cocktail:   	4.5%
--------



white:   	94.14%
wine:   	1.99%
pinot:   	1.25%
,:   	0.9%
sipper:   	0.52%
--------



,:   	83.35%
with:   	7.21%
that:   	6.49%
for:   	1.98%
.:   	0.91%
--------



with:   	99.41%
but:   	0.27%
sushi:   	0.07%
and:   	0.04%
if:   	0.04%
--------



a:   	27.48%
pleasant:   	13.73%
sushi:   	13.53%
crisp:   	12.47%
citrus:   	4.91%
--------



citrus:   	77.1%
flavors:   	7.7%
,:   	5.9%
peach:   	3.65%
lemon:   	1.11%
--------



,:   	63.43%
and:   	35.13%
flavors:   	1.01%
fruit:   	0.27%
-:   	0.14%
--------



pear:   	27.72%
green:   	19.08%
peach:   	14.35%
tropical:   	12.86%
lemongrass:   	5.2%
--------



and:   	56.83%
,:   	43.17%
fruit:   	0.0%
-:   	0.0%
flavors:   	0.0%
--------



spice:   	28.29%
white:   	22.09%
peach:   	15.71%
mineral:   	6.16%
vanilla:   	5.59%
--------



flavors:   	99.99%
,:   	0.0%
.:   	0.0%
notes:   	0.0%
aromas:   	0.0%
--------



.:   	87.77%
,:   	8.55%
that:   	3.53%
and:   	0.12%
wrapped:   	0.02%
--------



<stop>:   	96.29%
it:   	1.08%
the:   	0.45%
with:   	0.41%
a:   	0.31%
--------

