# Fine Tuning a LLM for Visualization Funciton Calling

## Generating Data

In [4]:
from matplotlib.colors import CSS4_COLORS
import random

In [5]:
def read_variables_from_file(filename='functioncalling/variables.txt'):
    with open(filename, 'r') as file:
        content = file.read()
        variables = content.strip().split(',')
        variables = [var.strip() for var in variables]
    return variables

In [6]:
variables = read_variables_from_file(filename='functioncalling/variables.txt')
hues      = read_variables_from_file(filename='functioncalling/hues.txt')
plot_types = ['scatter', 'bar', 'violin', 'box', 'pair', 'line', 'histogram', 'heatmap', 'pie', 'area', 'hexbin', 'kde', 'facet', 'joint', 'strip', 'swarm', 'count', 'cat', 'reg', 'dist']
markers = [".", ",", "o", "v", "^", "<", ">", "1", "2", "3", "4", "8", "s", "p", "P", "*", "h", "H", "+", "x", "X", "D", "d", "|", "_"]
linestyles = ['-', '--', '-.', ':']
colors = list(CSS4_COLORS.keys())
bools = [True, False]

In [16]:
def scatter_prompts(xvar, yvar, title, xlabel, ylabel, dpi, figsize, hue, size, style, palette, markers):
    templates = [
        f"Create a scatter with '{xvar}' on the x-axis and '{yvar}' on the y-axis. Customize markers with style '{style}', color by '{hue}', and set the figure size to {figsize}.",
        f"Generate a scatter of '{yvar}' against '{xvar}' with markers styled as '{style}' and color by '{hue}'. Set the xlabel as '{xlabel}' and ylabel as '{ylabel}'.",
        f"Visualize a scatter with '{xvar}' on the x-axis and '{yvar}' on the y-axis. Use '{palette}' for color differentiation and set markers size to '{size}'. Adjust figure dpi to {dpi}.",
        f"Make a scatter using '{xvar}' on the x-axis and '{yvar}' on the y-axis, with markers of size '{size}' and style '{style}'. Include xlabel '{xlabel}' and ylabel '{ylabel}'.",
        f"Plot a scatter plot of '{xvar}' against '{yvar}' with markers color-coded by '{hue}' and styled as '{style}'. Set title as '{title}' and figsize to {figsize}. Also change the dpi to {dpi}."
    ]
    return random.choice(templates)

In [38]:
xvar = "Age"
yvar = "Weight"
title = "Scatter Plot"
xlabel = "Age"
ylabel = "Weight"
dpi = 100
figsize = (10, 8)
hue = "City"
size = 10
style = "o"
palette = "viridis"
markers = True

prompt = scatter_prompts(xvar, yvar, title, xlabel, ylabel, dpi, figsize, hue, size, style, palette, markers)
print(prompt)


Create a scatter plot with 'Age' on the x-axis and 'Weight' on the y-axis. Customize markers with style 'o', color by 'City', and set the figure size to (10, 8).


In [None]:
plot_types = ['scatter', 'bar', 'violin', 'box', 'pair', 'line', 'histogram', 'heatmap', 'pie', 'area', 'hexbin', 'kde', 'facet', 'joint', 'strip', 'swarm', 'count', 'cat', 'reg', 'dist']

linestyles = ['-', '--', '-.', ':']
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'black', 'orange', 'pink']
bool_options = [True, False]

        "line"         : line_plot   ,
        "bar"          : bar_plot    ,
        "histogram"    : histogram   ,
        "box"          : box_plot    ,
        "boxplot"      : box_plot    ,
        "violin"       : violin_plot ,
        "scatter"      : scatter_plot,
        "pair"         : pair_plot   ,
        "heatmap"      : heatmap     ,
        "pie"          : pie_chart   ,
        "area"         : area_plot   ,
        "hexbin"       : hexbin_plot ,
        "kde"          : kde_plot    ,
        "facet grid"   : facet_grid  ,
        "joint"        : joint_plot  ,
        "strip"        : strip_plot  ,
        "swarm"        : swarm_plot  ,
        "count"        : count_plot  ,
        "cat"          : cat_plot    ,
        "reg"          : reg_plot    ,
        "dist"         : dist_plot   ,
        "distribution" : dist_plot


## Minimal Example of Training

In [13]:
pairedData = [
    {
        'instruction': "Create a scatter plot with Age on the x-axis and Weight on the y-axis.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Weight', 'xlabel': 'Age', 'ylabel': 'Weight', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot of Height versus Weight.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Make a scatter plot of Sales against Profit.",
        'output': "{'plot_type': 'scatter', 'x': 'Sales', 'y': 'Profit', 'xlabel': 'Sales', 'ylabel': 'Profit', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Plot a scatter plot with Age on the x-axis and Blood Pressure on the y-axis. Color points by Gender.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Blood Pressure', 'xlabel': 'Age', 'ylabel': 'Blood Pressure', 'marker_size': None, 'hue': 'Gender'}"
    },
    {
        'instruction': "Create a scatter plot showing Income versus Savings. Label the x-axis as 'Income' and y-axis as 'Savings'.",
        'output': "{'plot_type': 'scatter', 'x': 'Income', 'y': 'Savings', 'xlabel': 'Income', 'ylabel': 'Savings', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot of Height versus Weight, color-coded by Age group.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': 'Age group'}"
    },
    {
        'instruction': "Make a scatter plot with Age on the x-axis and Income on the y-axis, differentiating points by Education level.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Income', 'xlabel': 'Age', 'ylabel': 'Income', 'marker_size': None, 'hue': 'Education level'}"
    },
    {
        'instruction': "Create a scatter plot of Temperature versus Humidity.",
        'output': "{'plot_type': 'scatter', 'x': 'Temperature', 'y': 'Humidity', 'xlabel': 'Temperature', 'ylabel': 'Humidity', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot with Distance on the x-axis and Speed on the y-axis.",
        'output': "{'plot_type': 'scatter', 'x': 'Distance', 'y': 'Speed', 'xlabel': 'Distance', 'ylabel': 'Speed', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Plot a scatter plot of Age versus Salary. Color points by Industry.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Salary', 'xlabel': 'Age', 'ylabel': 'Salary', 'marker_size': None, 'hue': 'Industry'}"
    },
    {
        'instruction': "Make a scatter plot of Height versus Weight. Label the x-axis as 'Height' and y-axis as 'Weight'.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Create a scatter plot of Sales versus Expenses.",
        'output': "{'plot_type': 'scatter', 'x': 'Sales', 'y': 'Expenses', 'xlabel': 'Sales', 'ylabel': 'Expenses', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot with Age on the x-axis and Blood Pressure on the y-axis. Use different colors for each Region.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Blood Pressure', 'xlabel': 'Age', 'ylabel': 'Blood Pressure', 'marker_size': None, 'hue': 'Region'}"
    },
    {
        'instruction': "Make a scatter plot of Temperature versus Pressure. Color points by Altitude.",
        'output': "{'plot_type': 'scatter', 'x': 'Temperature', 'y': 'Pressure', 'xlabel': 'Temperature', 'ylabel': 'Pressure', 'marker_size': None, 'hue': 'Altitude'}"
    },
    {
        'instruction': "Create a scatter plot with Age on the x-axis and Weight on the y-axis. Color points by Gender.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Weight', 'xlabel': 'Age', 'ylabel': 'Weight', 'marker_size': None, 'hue': 'Gender'}"
    },
    {
        'instruction': "Generate a scatter plot of Height versus Weight, labeling the x-axis as 'Height' and y-axis as 'Weight'.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Make a scatter plot of Sales versus Profit. Color points by Quarter.",
        'output': "{'plot_type': 'scatter', 'x': 'Sales', 'y': 'Profit', 'xlabel': 'Sales', 'ylabel': 'Profit', 'marker_size': None, 'hue': 'Quarter'}"
    },
    {
        'instruction': "Create a scatter plot of Age versus Income.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Income', 'xlabel': 'Age', 'ylabel': 'Income', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot of Height versus Weight, color-coded by BMI group.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': 'BMI group'}"
    },
    {
        'instruction': "Make a scatter plot with Age on the x-axis and Blood Pressure on the y-axis. Differentiate points by Medication status.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Blood Pressure', 'xlabel': 'Age', 'ylabel': 'Blood Pressure', 'marker_size': None, 'hue': 'Medication status'}"
    },
    {
        'instruction': "Create a scatter plot showing Temperature versus Humidity. Label the x-axis as 'Temperature' and y-axis as 'Humidity'.",
        'output': "{'plot_type': 'scatter', 'x': 'Temperature', 'y': 'Humidity', 'xlabel': 'Temperature', 'ylabel': 'Humidity', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot of Distance versus Speed.",
        'output': "{'plot_type': 'scatter', 'x': 'Distance', 'y': 'Speed', 'xlabel': 'Distance', 'ylabel': 'Speed', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Make a scatter plot of Age versus Salary. Color points by Department.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Salary', 'xlabel': 'Age', 'ylabel': 'Salary', 'marker_size': None, 'hue': 'Department'}"
    },
    {
        'instruction': "Create a scatter plot of Height versus Weight. Use different colors for each Gender.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': 'Gender'}"
    },
    {
        'instruction': "Generate a scatter plot of Sales versus Expenses. Color points by Region.",
        'output': "{'plot_type': 'scatter', 'x': 'Sales', 'y': 'Expenses', 'xlabel': 'Sales', 'ylabel': 'Expenses', 'marker_size': None, 'hue': 'Region'}"
    },
    {
        'instruction': "Make a scatter plot of Age versus Blood Pressure.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Blood Pressure', 'xlabel': 'Age', 'ylabel': 'Blood Pressure', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Create a scatter plot with Age on the x-axis and Weight on the y-axis.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Weight', 'xlabel': 'Age', 'ylabel': 'Weight', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot of Height versus Weight.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Make a scatter plot of Sales against Profit.",
        'output': "{'plot_type': 'scatter', 'x': 'Sales', 'y': 'Profit', 'xlabel': 'Sales', 'ylabel': 'Profit', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Plot a scatter plot with Age on the x-axis and Blood Pressure on the y-axis. Color points by Gender.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Blood Pressure', 'xlabel': 'Age', 'ylabel': 'Blood Pressure', 'marker_size': None, 'hue': 'Gender'}"
    },
    {
        'instruction': "Create a scatter plot showing Income versus Savings. Label the x-axis as 'Income' and y-axis as 'Savings'.",
        'output': "{'plot_type': 'scatter', 'x': 'Income', 'y': 'Savings', 'xlabel': 'Income', 'ylabel': 'Savings', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot of Height versus Weight, color-coded by Age group.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': 'Age group'}"
    },
    {
        'instruction': "Make a scatter plot with Age on the x-axis and Income on the y-axis, differentiating points by Education level.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Income', 'xlabel': 'Age', 'ylabel': 'Income', 'marker_size': None, 'hue': 'Education level'}"
    },
    {
        'instruction': "Create a scatter plot of Temperature versus Humidity.",
        'output': "{'plot_type': 'scatter', 'x': 'Temperature', 'y': 'Humidity', 'xlabel': 'Temperature', 'ylabel': 'Humidity', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot with Distance on the x-axis and Speed on the y-axis.",
        'output': "{'plot_type': 'scatter', 'x': 'Distance', 'y': 'Speed', 'xlabel': 'Distance', 'ylabel': 'Speed', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Plot a scatter plot of Age versus Salary. Color points by Industry.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Salary', 'xlabel': 'Age', 'ylabel': 'Salary', 'marker_size': None, 'hue': 'Industry'}"
    },
    {
        'instruction': "Make a scatter plot of Height versus Weight. Label the x-axis as 'Height' and y-axis as 'Weight'.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Create a scatter plot of Sales versus Expenses.",
        'output': "{'plot_type': 'scatter', 'x': 'Sales', 'y': 'Expenses', 'xlabel': 'Sales', 'ylabel': 'Expenses', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot with Age on the x-axis and Blood Pressure on the y-axis. Use different colors for each Region.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Blood Pressure', 'xlabel': 'Age', 'ylabel': 'Blood Pressure', 'marker_size': None, 'hue': 'Region'}"
    },
    {
        'instruction': "Make a scatter plot of Temperature versus Pressure. Color points by Altitude.",
        'output': "{'plot_type': 'scatter', 'x': 'Temperature', 'y': 'Pressure', 'xlabel': 'Temperature', 'ylabel': 'Pressure', 'marker_size': None, 'hue': 'Altitude'}"
    },
    {
        'instruction': "Create a scatter plot with Age on the x-axis and Weight on the y-axis. Color points by Gender.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Weight', 'xlabel': 'Age', 'ylabel': 'Weight', 'marker_size': None, 'hue': 'Gender'}"
    },
    {
        'instruction': "Generate a scatter plot of Height versus Weight, labeling the x-axis as 'Height' and y-axis as 'Weight'.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Make a scatter plot of Sales versus Profit. Color points by Quarter.",
        'output': "{'plot_type': 'scatter', 'x': 'Sales', 'y': 'Profit', 'xlabel': 'Sales', 'ylabel': 'Profit', 'marker_size': None, 'hue': 'Quarter'}"
    },
    {
        'instruction': "Create a scatter plot of Age versus Income.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Income', 'xlabel': 'Age', 'ylabel': 'Income', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot of Height versus Weight, color-coded by BMI group.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': 'BMI group'}"
    },
    {
        'instruction': "Make a scatter plot with Age on the x-axis and Blood Pressure on the y-axis. Differentiate points by Medication status.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Blood Pressure', 'xlabel': 'Age', 'ylabel': 'Blood Pressure', 'marker_size': None, 'hue': 'Medication status'}"
    },
    {
        'instruction': "Create a scatter plot showing Temperature versus Humidity. Label the x-axis as 'Temperature' and y-axis as 'Humidity'.",
        'output': "{'plot_type': 'scatter', 'x': 'Temperature', 'y': 'Humidity', 'xlabel': 'Temperature', 'ylabel': 'Humidity', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Generate a scatter plot of Distance versus Speed.",
        'output': "{'plot_type': 'scatter', 'x': 'Distance', 'y': 'Speed', 'xlabel': 'Distance', 'ylabel': 'Speed', 'marker_size': None, 'hue': None}"
    },
    {
        'instruction': "Make a scatter plot of Age versus Salary. Color points by Department.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Salary', 'xlabel': 'Age', 'ylabel': 'Salary', 'marker_size': None, 'hue': 'Department'}"
    },
    {
        'instruction': "Create a scatter plot of Height versus Weight. Use different colors for each Gender.",
        'output': "{'plot_type': 'scatter', 'x': 'Height', 'y': 'Weight', 'xlabel': 'Height', 'ylabel': 'Weight', 'marker_size': None, 'hue': 'Gender'}"
    },
    {
        'instruction': "Generate a scatter plot of Sales versus Expenses. Color points by Region.",
        'output': "{'plot_type': 'scatter', 'x': 'Sales', 'y': 'Expenses', 'xlabel': 'Sales', 'ylabel': 'Expenses', 'marker_size': None, 'hue': 'Region'}"
    },
    {
        'instruction': "Make a scatter plot of Age versus Blood Pressure.",
        'output': "{'plot_type': 'scatter', 'x': 'Age', 'y': 'Blood Pressure', 'xlabel': 'Age', 'ylabel': 'Blood Pressure', 'marker_size': None, 'hue': None}"
    }
]

In [14]:
data = {
    'instruction' : [],
    'output'      : []
}
for i in range(len(pairedData)):
    data['instruction'].append(pairedData[i]['instruction'])
    data['output'].append(     pairedData[i]['output'])
data

{'instruction': ['Create a scatter plot with Age on the x-axis and Weight on the y-axis.',
  'Generate a scatter plot of Height versus Weight.',
  'Make a scatter plot of Sales against Profit.',
  'Plot a scatter plot with Age on the x-axis and Blood Pressure on the y-axis. Color points by Gender.',
  "Create a scatter plot showing Income versus Savings. Label the x-axis as 'Income' and y-axis as 'Savings'.",
  'Generate a scatter plot of Height versus Weight, color-coded by Age group.',
  'Make a scatter plot with Age on the x-axis and Income on the y-axis, differentiating points by Education level.',
  'Create a scatter plot of Temperature versus Humidity.',
  'Generate a scatter plot with Distance on the x-axis and Speed on the y-axis.',
  'Plot a scatter plot of Age versus Salary. Color points by Industry.',
  "Make a scatter plot of Height versus Weight. Label the x-axis as 'Height' and y-axis as 'Weight'.",
  'Create a scatter plot of Sales versus Expenses.',
  'Generate a scatte

In [15]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
from transformers import DataCollatorForSeq2Seq, Trainer, TrainingArguments

In [16]:
# Convert to dataset
dataset = Dataset.from_dict(data)

# Load T5 model and tokenizer
model_name = 't5-small'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Preprocess data
def preprocess_function(examples):
    inputs = [f"translate English to JSON: {ex}" for ex in examples['instruction']]
    model_inputs = tokenizer(inputs, max_length=512, padding='max_length', truncation=True)
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['output'], max_length=512, padding='max_length', truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/52 [00:00<?, ? examples/s]

In [17]:
# Check lengths of tokenized inputs and labels
print("Tokenized input lengths:")
for input_ids in tokenized_dataset["input_ids"]:
    print(len(input_ids))

print("Tokenized label lengths:")
for label_ids in tokenized_dataset["labels"]:
    print(len(label_ids))
    
# Split the dataset
train_test_split = tokenized_dataset.train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
test_dataset = train_test_split['test']


Tokenized input lengths:
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
Tokenized label lengths:
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512
512


In [18]:
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Set up training arguments
training_args  = TrainingArguments(
    output_dir                  = './results',
    evaluation_strategy         = "epoch",
    learning_rate               = 2e-5,
    per_device_train_batch_size = 4,
    per_device_eval_batch_size  = 4,
    num_train_epochs            = 3,
    weight_decay                = 0.01,
)

# Initialize Trainer
trainer = Trainer(
    model = model,
    args  = training_args,
    train_dataset = train_dataset,
    eval_dataset  = test_dataset,
    data_collator = data_collator
)

# Train model
trainer.train()


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,8.301361
2,No log,6.546083
3,No log,6.012856


TrainOutput(global_step=33, training_loss=7.966746937144887, metrics={'train_runtime': 405.3684, 'train_samples_per_second': 0.303, 'train_steps_per_second': 0.081, 'total_flos': 16647041581056.0, 'train_loss': 7.966746937144887, 'epoch': 3.0})

In [41]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

new_instructions = [
    'Given this dataframe, create a bar chart of Sales and Month with x label Month, y label Sales, and color by Region.',
    'Generate a line plot with Date on the x-axis and Temperature on the y-axis. Label the x-axis as Date and the y-axis as Temperature. Use different colors for each City.'
]

# Encode the new inputs
# new_inputs = [f"Fill in the following JSON object based on the instruction: The JSON object is {'x': None, 'y': None, 'hue': None, 'xlabel': None, 'ylabel': None, 'title': None}. Based on the instruction, provide values for 'x' (variable for the x-axis), 'y' (variable for the y-axis), 'hue' (variable for color differentiation), 'xlabel' (label for the x-axis), 'ylabel' (label for the y-axis), and 'title' (title of the plot).: {instruction}" for instruction in new_instructions]
new_inputs = [f"Fill in the following JSON object based on the instruction: The JSON object is {{'x': None, 'y': None, 'hue': None, 'xlabel': None, 'ylabel': None, 'title': None}}. Based on the instruction, provide values for 'x' (variable for the x-axis), 'y' (variable for the y-axis), 'hue' (variable for color differentiation), 'xlabel' (label for the x-axis), 'ylabel' (label for the y-axis), and 'title' (title of the plot).: {instruction}" for instruction in new_instructions]
# new_inputs = [f"construct a python dictionary with plotting arguments from the instruction: {instruction}" for instruction in new_instructions]
tokenized_inputs = tokenizer(new_inputs, max_length=512, padding='max_length', truncation=True, return_tensors='pt')

# Move inputs to the correct device
tokenized_inputs = {key: value.to(device) for key, value in tokenized_inputs.items()}

# Generate outputs
outputs = model.generate(**tokenized_inputs, max_length=512, num_beams=4, early_stopping=True)

# Decode the outputs
decoded_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

for input_text, output_text in zip(new_instructions, decoded_outputs):
    print(f"Instruction: {input_text}")
    print(f"Generated JSON: {output_text}")
    print()


NameError: name 'tokenizer' is not defined