In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import spacy
from torch.utils.data import Dataset, DataLoader


# Set up device
device = torch.device("mps" if torch.backends.mps.is_built() else "cpu")
# device = "cpu"

In [None]:
# read data
# data = pd.read_csv('PoetryFoundationData.csv')

# # count number of non empty cells in column tags
# data['Tags'].count()
# # drop rows with empty cells in column tags
# data = data.dropna(subset=['Tags'])

# # make new df with poem and tags
# data = data[['Poem', 'Tags']]

In [None]:
# # find number of unique words in the vocabulary in train_data.csv
# words = []
# for i in range(len(data)):
#     words.extend(data['Poem'].iloc[i].split())
# words = list(set(words))
# vocab_size = len(words)
# vocab_size
# words


## Auxiliary classes & functions

In [None]:
class DataProcessor(object):
    def __init__(self, ):
        super().__init__()
        nlp = spacy.load("en_core_web_sm")
        nltk.download('omw-1.4')
        nltk.download("punkt")
        nltk.download("wordnet")
        nltk.download("stopwords")

    @staticmethod
    def preprocess_text(text):
        # Tokenize, remove punctuation and lowercase
        try:
            tokens = nltk.word_tokenize(text)
        except TypeError as e:
            print("Error in tokenizing text \"%s\": %s", text, str(e))
            return ""

        tokens = [word.lower() for word in tokens if word.isalpha()]

        # Remove stopwords and lemmatize
        stop_words = set(stopwords.words("english"))
        lemmatizer = WordNetLemmatizer()
        processed_text = [
            lemmatizer.lemmatize(word) for word in tokens if word not in stop_words
        ]

        return " ".join(processed_text)

    def process_batch(self, texts):
        return [self.preprocess_text(d) for d in texts]

In [None]:
class Tokenizer(object):
    def __init__(self, max_length=0, special_characters=[]):
        super().__init__()

        self.max_length = max_length
        self.special_characters = special_characters
        self.alphabet_letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

        self.alphabet = self.prepare_alphabet()
        self.decoded_alphabet = self.prepare_decoded_alphabet()
        self.char_map = self.prepare_char_map()

    def prepare_alphabet(self):
        # PREPARE THE ALPHABET (CHAR->INT)
        # as a dictionary
        alphabet = {}
        alphabet['pad'] = 0  # add 'pad'
        count = 1

        for letter in self.alphabet_letters:
            alphabet[letter] = count
            count += 1

        # add ' ', 'cls' tokens
        alphabet[' '] = count
        alphabet['cls'] = count + 1

        # add puncuation tokens
        alphabet['!'] = count + 2
        alphabet['?'] = count + 3
        alphabet['.'] = count + 4
        alphabet[','] = count + 5
        alphabet[';'] = count + 6
        alphabet[':'] = count + 7
        alphabet['<BREAK>'] = count + 8

        return alphabet

    def prepare_decoded_alphabet(self):
        # PREPARE DECODED ALPHABET (INT->CHAR)
        decoded_alphabet_ints = [i for i in range(len(self.alphabet_letters))]

        decoded_alphabet = {}
        decoded_alphabet[0] = 'pad'

        for i in decoded_alphabet_ints:
            decoded_alphabet[i+1] = self.alphabet_letters[i]

            decoded_alphabet[i+2] = ' '
        decoded_alphabet[i+3] = 'cls'

        decoded_alphabet[i+4] = '!'
        decoded_alphabet[i+5] = '?'
        decoded_alphabet[i+6] = '.'
        decoded_alphabet[i+7] = ','
        decoded_alphabet[i+8] = ';'
        decoded_alphabet[i+9] = ':'
        decoded_alphabet[i+10] = '<BREAK>' # for line breaks
        return decoded_alphabet

    def prepare_char_map(self):
        # Mapping of special characters to corresponding alphabet characters
        return {
            'é': 'e', 'í': 'i', 'á': 'a', 'ó': 'o', 'æ': 'a', 'ä': 'a', 'ū': 'u',
            'à': 'a', 'ç': 'c', 'ë': 'e', 'ñ': 'n', 'ö': 'o', 'ü': 'u', 'ú': 'u',
            'û': 'u', 'å': 'a', 'œ': 'o', 'ß': 's', 'ø': 'o', 'è': 'e', 'ï': 'i',
            'â': 'a', 'ê': 'e', 'î': 'i', 'ô': 'o', 'ō': 'o', 'ā': 'a', 'ī': 'i',
            'ē': 'e', 'ồ': 'o', 'ế': 'e', 'π': 'p', '∞': 'i', '∑': 's', '√': 'r',
            '∫': 'i', '≈': 'a', 'ﬂ': 'f', 'ﬁ': 'f', 'ﬀ': 'f', 'ﬃ': 'f', 'α': 'a',
            'β': 'b', 'γ': 'g', 'δ': 'd', 'ε': 'e', 'ζ': 'z', 'η': 'e', 'θ': 't',
            'ι': 'i', 'κ': 'k', 'λ': 'l', 'μ': 'm', 'ν': 'n', 'ξ': 'x', 'ο': 'o',
            'ρ': 'r', 'σ': 's', 'τ': 't', 'υ': 'u', 'φ': 'f', 'χ': 'c', 'ψ': 'p',
            'ω': 'w'
        }

    def encode(self, texts):
        N = len(texts)

        if self.max_length == 0:
            max_length = max(len(text) for text in texts)
        else:
            max_length = self.max_length

        tokens = np.zeros((N, max_length + 1))

        for i, text in enumerate(texts):
            len_i = len(text)
            for j in range(-1, max_length):
                if j == -1:
                    tokens[i, j + 1] = self.alphabet['cls']
                elif j >= len_i:
                    tokens[i, j + 1] = self.alphabet['pad']
                else:
                    char = text[j]
                    if char in self.char_map:
                        tokens[i, j + 1] = self.alphabet[self.char_map[char]]
                    elif char in self.special_characters:
                        tokens[i, j + 1] = self.alphabet['q']
                    else:
                        tokens[i,j+1] = self.alphabet[texts[i][j]]

        return tokens

    def decode(self, tokens):
        texts = []

        for i in range(len(tokens)):
            tokens_i = tokens[i,:]
            text_i = ''
            for j in range(len(tokens_i)):
                if tokens_i[j] == 0:
                    break
                else:
                    if self.decoded_alphabet[tokens_i[j]] != 'cls':
                        text_i += self.decoded_alphabet[tokens_i[j]]
            texts.append(text_i)

        return texts

## Create the dataset

In [None]:
class Poems(Dataset):

    def __init__(self, dataprocessor, tokenizer, dataset, dataset_type, num_training_data=None, transforms=None):

        # PREPARE DATA
        if dataset_type == 'train':
            train_texts = dataprocessor.process_batch(dataset['Poem'] + dataset['Tags']) # list
            if num_training_data is None:
                self.data = torch.tensor(tokenizer.encode(train_texts)).long()
                # self.data.to(device)
            else:
                self.data = torch.tensor(tokenizer.encode(train_texts)[:num_training_data]).long()
                # self.data.to(device)
        elif dataset_type == 'val':
            validation_texts = dataprocessor.process_batch(dataset['Poem'] + dataset['Tags']) # list
            self.data = torch.tensor(tokenizer.encode(validation_texts)).long()
            # self.data.to(device)
        else:  # 'test'
            test_texts = dataprocessor.process_batch(dataset['Poem'] + dataset['Tags']) # list
            self.data = torch.tensor(tokenizer.encode(test_texts)).long()
            # self.data.to(device)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

In [None]:
class LossFun(nn.Module):
    def __init__(self,):
        super().__init__()

        self.loss = nn.NLLLoss(reduction='none')

    def forward(self, y_model, y_true, reduction='sum'):
        B, T, V = y_model.size()

        y_model = y_model.view(B * T, V)
        y_true = y_true.view(B * T,)

        loss_matrix = self.loss(y_model, y_true) # B*T

        if reduction == 'sum':
            return torch.sum(loss_matrix)
        elif reduction == 'mean':
            loss_matrix = loss_matrix.view(B, T)
            return torch.mean(torch.sum(loss_matrix, 1))
        else:
            raise ValueError('Reduction could be either `sum` or `mean`.')

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, num_emb, num_heads=8):
        super().__init__()

        # hyperparams
        self.D = num_emb
        self.H = num_heads

        # weights for self-attention
        self.w_k = nn.Linear(self.D, self.D * self.H)
        self.w_q = nn.Linear(self.D, self.D * self.H)
        self.w_v = nn.Linear(self.D, self.D * self.H)

        # weights for a combination of multiple heads
        self.w_c = nn.Linear(self.D * self.H, self.D)

    def forward(self, x, causal=True):
        # x: B(atch) x T(okens) x D(imensionality)
        B, T, D = x.size()
        # print('B: ', B)
        # print('T: ', T)
        # print('D: ', D)

        # keys, queries, values
        k = self.w_k(x).view(B, T, self.H, D) # B x T x H x D
        q = self.w_q(x).view(B, T, self.H, D) # B x T x H x D
        v = self.w_v(x).view(B, T, self.H, D) # B x T x H x D

        k = k.transpose(1, 2).contiguous().view(B * self.H, T, D) # B*H x T x D
        q = q.transpose(1, 2).contiguous().view(B * self.H, T, D) # B*H x T x D
        v = v.transpose(1, 2).contiguous().view(B * self.H, T, D) # B*H x T x D

        k = k / (D**0.25) # scaling
        q = q / (D**0.25) # scaling

        # kq
        kq = torch.bmm(q, k.transpose(1, 2)) # B*H x T x T

        # if causal
        if causal:
            mask = torch.triu_indices(T, T, offset=1)
            kq[..., mask[0], mask[1]] = float('-inf')

        # softmax
        skq = F.softmax(kq, dim=2)

        # self-attention
        sa = torch.bmm(skq, v) # B*H x T x D
        sa = sa.view(B, self.H, T, D) # B x H x T x D
        sa = sa.transpose(1, 2) # B x T x H x D
        sa = sa.contiguous().view(B, T, D * self.H) # B x T x D*H

        out = self.w_c(sa) # B x T x D

        return out


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, num_emb, num_neurons, num_heads=4):
        super().__init__()

        # hyperparams
        self.D = num_emb
        self.H = num_heads
        self.neurons = num_neurons

        # components
        self.msha = MultiHeadSelfAttention(num_emb=self.D, num_heads=self.H)
        self.layer_norm1 = nn.LayerNorm(self.D)
        self.layer_norm2 = nn.LayerNorm(self.D)

        self.mlp = nn.Sequential(nn.Linear(self.D, self.neurons * self.D),
                                nn.ReLU(),
                                nn.Linear(self.neurons * self.D, self.D))

    def forward(self, x, causal=True):
        # Multi-Head Self-Attention
        x_attn = self.msha(x, causal)
        # LayerNorm
        x = self.layer_norm1(x_attn + x)
        # MLP
        x_mlp = self.mlp(x)
        # LayerNorm
        x = self.layer_norm2(x_mlp + x)
        return x

In [None]:
class DecoderTransformer(nn.Module):
    def __init__(self, num_tokens, num_token_vals, num_emb, num_neurons, num_heads=2, dropout_prob=0.1, num_blocks=10, device='cpu'):
        super().__init__()

        # hyperparams
        self.device = device
        self.num_tokens = num_tokens
        self.num_token_vals = num_token_vals
        self.num_emb = num_emb
        self.num_blocks = num_blocks

        # embedding layer
        self.embedding = torch.nn.Embedding(num_token_vals, num_emb)

        # positional embedding
        self.positional_embedding = nn.Embedding(num_tokens, num_emb)

        # transformer blocks
        self.transformer_blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.transformer_blocks.append(TransformerBlock(num_emb=num_emb, num_neurons=num_neurons, num_heads=num_heads))

        # output layer (logits + softmax)
        self.logits = nn.Sequential(nn.Linear(num_emb, num_token_vals))

        # dropout layer
        self.dropout = nn.Dropout(dropout_prob)

        # loss function
        self.loss_fun = LossFun()

    def transformer_forward(self, x, causal=True, temperature=1.0):
        # x: B(atch) x T(okens)
        # embedding of tokens
        x = self.embedding(x) # B x T x D
        # print(x)
        # embedding of positions
        pos = torch.arange(0, x.shape[1], dtype=torch.long).unsqueeze(0).to(self.device)
        print('pos: ', pos)
        pos_emb = self.positional_embedding(pos)
        # dropout of embedding of inputs
        x = self.dropout(x + pos_emb)

        # transformer blocks
        for i in range(self.num_blocks):
            x = self.transformer_blocks[i](x)

        # output logits
        out = self.logits(x)

        return F.log_softmax(out/temperature, 2)

    @torch.no_grad()
    def sample(self, batch_size=4, temperature=1.0):
        x_seq = np.asarray([[self.num_token_vals - 1] for i in range(batch_size)])

        # sample next tokens
        for i in range(self.num_tokens-1):
            xx = torch.tensor(x_seq, dtype=torch.long, device=self.device)
            # process x and calculate log_softmax
            x_log_probs = self.transformer_forward(xx, temperature=temperature)
            # sample i-th tokens
            x_i_sample = torch.multinomial(torch.exp(x_log_probs[:,i]), 1).to(self.device)
            # update the batch with new samples
            x_seq = np.concatenate((x_seq, x_i_sample.to('cpu').detach().numpy()), 1)

        return x_seq

    @torch.no_grad()
    def top1_rec(self, x, causal=True):
        x_prob = torch.exp(self.transformer_forward(x, causal=True))[:,:-1,:].contiguous()
        _, x_rec_max = torch.max(x_prob, dim=2)
        return torch.sum(torch.mean((x_rec_max.float() == x[:,1:].float().to(device)).float(), 1).float())

    def forward(self, x, causal=True, temperature=1.0, reduction='mean'):
        # get log-probabilities
        log_prob = self.transformer_forward(x, causal=causal, temperature=temperature)

        return self.loss_fun(log_prob[:,:-1].contiguous(), x[:,1:].contiguous(), reduction=reduction)

In [None]:
import matplotlib.pyplot as plt
def evaluation(test_loader, name=None, model_best=None, epoch=None, device='cuda'):
    # EVALUATION
    if model_best is None:
        # load best performing model
        model_best = torch.load(name + '.model').to(device)

    model_best.eval()
    loss = 0.
    rec = 1.
    N = 0.
    for indx_batch, test_batch in enumerate(test_loader):
        loss_t = model_best.forward(test_batch.to(device), reduction='sum')
        loss = loss + loss_t.item()

        rec_t = model_best.top1_rec(test_batch.to(device))
        rec = rec + rec_t.item()

        N = N + test_batch.shape[0]
    loss = loss / N
    rec = rec / N

    if epoch is None:
        print(f'FINAL LOSS: nll={loss}, rec={rec}')
    else:
        print(f'Epoch: {epoch}, val nll={loss}, val rec={rec}')

    return loss, rec

def plot_curve(name, nll_val, ylabel='nll'):
    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')
    plt.xlabel('epochs')
    plt.ylabel(ylabel)
    plt.savefig(name + '_' + ylabel + '_val_curve.pdf', bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
import os
cwd = os.getcwd()
def save_texts(sampled_texts, name=''):
    # open file in write mode
    with open(cwd + '/samples_' + name + '.txt', 'w') as fp:
        for item in sampled_texts:
            # write each item in a new line
            fp.write("%s\n" % item)

In [None]:
def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader, device='cuda'):
    nll_val = []
    rec_val = []
    best_nll = 1000.
    patience = 0

    # Main loop
    for e in range(num_epochs):
        # TRAINING
        model.train()
        for indx_batch, batch in enumerate(training_loader):
            # batch = batch[0].to(device, dtype=torch.long)
            print(f"Batch index: {indx_batch}, Batch data shape: {batch.shape}")
            loss = model.forward(batch.to(device))

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        # Validation
        loss_val, r_val = evaluation(val_loader, model_best=model, epoch=e, device=device)
        nll_val.append(loss_val)  # save for plotting
        rec_val.append(r_val)

        if e == 0:
            print('saved!')
            torch.save(model, name + '.model')
            best_nll = loss_val

            sampled_tokens = model.sample(batch_size=64, temperature=1.0)
            sampled_texts = Tokenizer.decode(sampled_tokens)
            save_texts(sampled_texts, name='epoch_' + str(e))
        elif loss_val < best_nll:
            print('saved!')
            torch.save(model, name + '.model')
            best_nll = loss_val
            patience = 0

            sampled_tokens = model.sample(batch_size=64, temperature=1.0)
            sampled_texts = Tokenizer.decode(sampled_tokens)
            save_texts(sampled_texts, name='epoch_' + str(e))
        else:
            patience += 1

        if patience > max_patience:
            break

    nll_val = np.asarray(nll_val)
    rec_val = np.asarray(rec_val)

    np.save(name + '_nll_val.npy', nll_val)
    np.save(name + '_rec_val.npy', rec_val)

    return nll_val, rec_val

In [None]:
tr = pd.read_csv('train_data.csv')
tr[:]['Tags']

## Preprocessing

In [None]:
# import data 
data = pd.read_csv('PoetryFoundationData.csv')
data = data.dropna()

# make new df with poem and tags
data = data[['Poem', 'Tags']]

# replace all \n with a '<LINE>' character
data['Poem'] = data['Poem'].apply(lambda x: x.replace('\n', '<LINE>'))
# replace all double <LINE><LINE> with a single <LINE>
data['Poem'] = data['Poem'].apply(lambda x: x.replace('<LINE><LINE>', '<LINE>'))
# remove all leading and trailing <LINE> characters
data['Poem'] = data['Poem'].apply(lambda x: x.strip('<LINE>'))

# set all poems to lowercase
data['Poem'] = data['Poem'].apply(lambda x: x.lower())

# sometimes there are multiple spaces between words, replace them with a single space
data['Poem'] = data['Poem'].apply(lambda x: ' '.join(x.split()))

# set all tags to lowercase
data['Tags'] = data['Tags'].apply(lambda x: x.lower())

# remove all leading and trailing spaces
data['Tags'] = data['Tags'].apply(lambda x: x.strip())

data

In [None]:
import re
import pandas as pd

def find_special_characters(text):
    # Regular expression to find special characters excluding letters, digits, whitespace, punctuation, and <, >
    special_characters = re.findall(r'[^a-zA-Z0-9\s.,!?;:()\'\"-<>]', text)
    return special_characters

# # Load your data
# data = pd.read_csv('poems.csv')

#check for missing values
print(data.isnull().sum())

# Run function to find special characters for all instances in the data
special_lists = data['Poem'].apply(lambda x: find_special_characters(x))
data.dropna(inplace=True)

# Combine all lists into one list and ensure all values are unique
all_special_characters = set([char for sublist in special_lists for char in sublist])

# Convert the set back to a list if needed
unique_special_characters = list(all_special_characters)


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

num_training_data = None  # None to take all training data

dataprocessor = DataProcessor()
tokenizer = Tokenizer(max_length=1200, special_characters=unique_special_characters)

# Load your data
train_data = pd.read_csv('train_data.csv')
train_data.dropna(subset=['Poem'], inplace=True)
val_data = pd.read_csv('val_data.csv')
val_data.dropna(subset=['Poem'], inplace=True)
test_data = pd.read_csv('test_data.csv')
test_data.dropna(subset=['Poem'], inplace=True)

# Assuming dataprocessor and tokenizer are already defined and initialized
train_dataset = Poems(dataprocessor, tokenizer, dataset=train_data, dataset_type='train', num_training_data=num_training_data)
validation_dataset = Poems(dataprocessor, tokenizer, dataset=val_data, dataset_type='val')
test_dataset = Poems(dataprocessor, tokenizer, dataset=test_data, dataset_type='test')

# Set batch size
BATCH_SIZE = 32

# Create DataLoaders
training_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
for batch in training_loader:
    print(f"Batch shape: {batch.shape}")
    break
print(f"Dataset length: {training_loader.batch_size}")
print(f"Sample indices: {list(range(len(training_loader.dataset))[:10])}")

In [None]:
name = 'decoder_results'  
results_dir = cwd + name + '/'
if not(os.path.exists(results_dir)):
  os.mkdir(results_dir)

  
num_tokens = 1201
num_token_vals = 32  
num_neurons = 32 
num_heads = 8 
num_blocks = 4
num_emb = num_heads * 8  
causal=True 

lr = 1e-2 
num_epochs = 400
max_patience = 10 

In [None]:
from pytorch_model_summary import summary

# # device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = torch.device("mps")
model = DecoderTransformer(num_tokens=num_tokens, num_token_vals=num_token_vals, num_emb=num_emb, num_neurons=num_neurons, num_heads=num_heads, num_blocks=num_blocks, device=device)
model = model.to(device)

print(summary(model, torch.zeros(1, num_tokens, dtype=torch.long).to(device), show_input=False, show_hierarchical=False))

In [None]:
# Optimizer
optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad == True], lr=lr)
# Training procedure
nll_val, rec_val = training(name=cwd + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer, training_loader=training_loader, val_loader=val_loader, device=device)

In [None]:

# Final evaluation
test_loss, test_rec = evaluation(name=cwd + name, test_loader=test_loader, device=device)

with open(cwd + name + '_test_loss.txt', "w") as f:
    f.write('Test NLL: ' + str(test_loss)+'\n'+'Test REC: ' + str(test_rec))
    f.close()

plot_curve(cwd + name, nll_val, ylabel='nll')
plot_curve(cwd + name, rec_val, ylabel='rec')

In [None]:

# Sample texts: load best model
model_best = torch.load(cwd + name + '.model')
model_best = model_best.eval()

# sample
temperature = 1.0 
num_samples = 31 

sampled_tokens = model_best.sample(batch_size=num_samples, temperature=temperature) 
sampled_texts = tokenizer.decode(sampled_tokens)
print(sampled_texts)

save_texts(sampled_texts, name='FINAL_' + str(temperature))