# Solving Cryptic Crosswords with LLMs: Part 1
date created: 04.08.2023

## Installing modules

In [None]:
! pip install git+https://github.com/huggingface/transformers

In [None]:
! pip install torch datasets evaluate accelerate sentencepiece

In [None]:
# import modules
import pandas as pd
import numpy as np
import ast
import json
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# import parameters
from utils.parameters import *

## Data Extraction and Transformation

In [None]:
# Import data
clues_raw = pd.read_csv(clues_path_raw).dropna().sample(frac=1)
clues = clues_raw.copy()[['rowid', 'clue', 'answer', 'definition']]

# Transform columns into the format required by the trainer module
clues['rowid'] = clues['rowid'].astype(str)
clues['question'] = clues['clue']
clues['context'] = clues['definition'] 
clues['answers'] = clues['answer'].map(lambda x : {"text" : [x], "answer_start" : [0]})
clues['answers'].apply(lambda x : ast.literal_eval(str(x)))
clues = clues.rename(columns={'rowid' : 'id'})
clues = clues[['id', 'question', 'context', 'answers']].dropna()

# Print examples
clues.head()

In [None]:
# Split data into train, validation, test

train_val = clues.sample(frac=0.9,random_state=200)
test = clues.drop(train_val.index)
train = train_val.sample(frac=0.9,random_state=200)
validation = train_val.drop(train.index)

# Save data
clues.to_csv(clues_path_processed, index=False)
train.to_csv(clues_path_train, index=False)
validation.to_csv(clues_path_validation, index=False)
test.to_csv(clues_path_test, index=False)

## Fine Tuning

N.B. the modelling parameters are in the parameter.py file.

In [None]:
! python ./utils/run_seq2seq_qa.py \
  --model_name_or_path {model_name_t5} \
  --train_file {clues_path_train} \
  --validation_file {clues_path_validation} \
  --test_file {clues_path_test} \
  --question_column question \
  --context_column context \
  --answer_column answers \
  --do_train \
  --do_eval \
  --do_pred \
  --predict_with_generate \
  --version_2_with_negative \
  --per_device_train_batch_size {batch_size} \
  --learning_rate {lr} \
  --num_train_epochs {num_epochs} \
  --max_seq_length {max_seq_length} \
  --overwrite_output_dir {overwrite_dir} \
  --output_dir {output_dir}

## Predictions

In [None]:
# Read prediction output file
predictions = pd.read_json(predictions_path)
predictions = predictions.rename(columns={'id' : 'rowid'})

# Show some examples
predictions.head()

In [None]:
# Join and compare with clues dataset to see correct / incorrect answers

compare = clues_raw.merge(predictions, on='rowid')[['clue', 'definition', 'answer', 'prediction_text']]
compare['correct_len'] = np.where(compare['prediction_text'].str.len() == compare['answer'].str.len(), 1, 0)
compare['correct'] = np.where(compare['prediction_text'] == compare['answer'], 1, 0)
compare['correct_len_1'] = np.where(abs(compare['prediction_text'].str.len() - compare['answer'].str.len()) <=1, 1, 0)

# Get stats of correct vs incorrect cols
compare.groupby(['correct', 'correct_len', 'correct_len_1']).count()

## Plot loss vs epochs

In [None]:
with open(f'{output_dir}/trainer_state.json', 'rb') as f:
    tr = json.load(f)

epoch_list = [0]
loss_list = [None]
learning_rate_list = [lr]

# Collect the list of each metric
for x in tr['log_history'][:-1]:
    epoch_list.append(x['epoch'])
    loss_list.append(x['loss'])
    learning_rate_list.append(x['learning_rate'])

In [None]:
df = pd.DataFrame(dict(
    epoch = epoch_list,
    loss = loss_list,
    learning_rate = learning_rate_list
))


# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])

# Create traces
fig.add_trace(go.Scatter(x=epoch_list, y=loss_list,
                    mode='lines',
                    name='loss'))
fig.add_trace(go.Scatter(x=epoch_list, y=learning_rate_list, 
                    mode='lines+markers',
                    name='learning_rate'), secondary_y=True)

# Set x-axis title
fig.update_xaxes(title_text="Epoch")

# Set y-axes titles
fig.update_yaxes(title_text="Loss", secondary_y=False)
fig.update_yaxes(title_text="Learning Rate", secondary_y=True)
