@@ -0,0 +1,1147 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMbj+7m976fmpJkGXHf5tPL",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PradipNichite/Youtube-Tutorials/blob/main/Youtube_GPT_3_Finetuning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"DataSet: ****ATIS Airline Travel Information System****\n",
"\n",
"[https://www.kaggle.com/datasets/hassanamin/atis-airlinetravelinformationsystem](https://www.kaggle.com/datasets/hassanamin/atis-airlinetravelinformationsystem)"
],
"metadata": {
"id": "p7kMGK5fVrIg"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qlJ_HpV1TIMz"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"np.random.seed(0)"
]
},
{
"cell_type": "code",
"source": [
"data = pd.read_csv(\"/content/atis_intents.csv\",header=None)\n",
"data.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "g3CienaxTTQN",
"outputId": "2eddbc79-e9b8-46f8-d15f-9e741da2c9d4"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" 0 1\n",
"0 atis_flight i want to fly from boston at 838 am and arriv...\n",
"1 atis_flight what flights are available from pittsburgh to...\n",
"2 atis_flight_time what is the arrival time in san francisco for...\n",
"3 atis_airfare cheapest airfare from tacoma to orlando\n",
"4 atis_airfare round trip fares from pittsburgh to philadelp..."
],
"text/html": [
"\n",
" <div id=\"df-a39d7310-3154-4432-b172-d175cf57dcc4\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>atis_flight</td>\n",
" <td>i want to fly from boston at 838 am and arriv...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>atis_flight</td>\n",
" <td>what flights are available from pittsburgh to...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>atis_flight_time</td>\n",
" <td>what is the arrival time in san francisco for...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>atis_airfare</td>\n",
" <td>cheapest airfare from tacoma to orlando</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>atis_airfare</td>\n",
" <td>round trip fares from pittsburgh to philadelp...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a39d7310-3154-4432-b172-d175cf57dcc4')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-a39d7310-3154-4432-b172-d175cf57dcc4 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-a39d7310-3154-4432-b172-d175cf57dcc4');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"source": [
"data.columns = ['intent','text']"
],
"metadata": {
"id": "09_iuCZ_WrHI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"data['intent'].unique()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hlUkeTobTX4U",
"outputId": "56ac478b-4824-45cb-e35c-30787d288915"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array(['atis_flight', 'atis_flight_time', 'atis_airfare', 'atis_aircraft',\n",
" 'atis_ground_service', 'atis_airport', 'atis_airline',\n",
" 'atis_distance', 'atis_abbreviation', 'atis_ground_fare',\n",
" 'atis_quantity', 'atis_city', 'atis_flight_no', 'atis_capacity',\n",
" 'atis_flight#atis_airfare', 'atis_meal', 'atis_restriction',\n",
" 'atis_airline#atis_flight_no',\n",
" 'atis_ground_service#atis_ground_fare',\n",
" 'atis_airfare#atis_flight_time', 'atis_cheapest',\n",
" 'atis_aircraft#atis_flight#atis_flight_no'], dtype=object)"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"data['intent'].nunique()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uNhdMgTXTjFl",
"outputId": "8a0dda10-2017-4bfc-bd4e-69f88061caf5"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"22"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"data['intent'] = data['intent'].str.replace('#','_')\n",
"data['intent'].unique()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uujzu0U_T-Xy",
"outputId": "ac39b033-f372-46a4-f6e6-4c2b0a9a4379"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array(['atis_flight', 'atis_flight_time', 'atis_airfare', 'atis_aircraft',\n",
" 'atis_ground_service', 'atis_airport', 'atis_airline',\n",
" 'atis_distance', 'atis_abbreviation', 'atis_ground_fare',\n",
" 'atis_quantity', 'atis_city', 'atis_flight_no', 'atis_capacity',\n",
" 'atis_flight_atis_airfare', 'atis_meal', 'atis_restriction',\n",
" 'atis_airline_atis_flight_no',\n",
" 'atis_ground_service_atis_ground_fare',\n",
" 'atis_airfare_atis_flight_time', 'atis_cheapest',\n",
" 'atis_aircraft_atis_flight_atis_flight_no'], dtype=object)"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"data['intent'] = data['intent'].str.replace('atis_','')\n",
"data['intent'].unique()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BdqUTN7LVMb3",
"outputId": "d550fef3-c487-4c02-d6cc-d4821812fd8d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array(['flight', 'flight_time', 'airfare', 'aircraft', 'ground_service',\n",
" 'airport', 'airline', 'distance', 'abbreviation', 'ground_fare',\n",
" 'quantity', 'city', 'flight_no', 'capacity', 'flight_airfare',\n",
" 'meal', 'restriction', 'airline_flight_no',\n",
" 'ground_service_ground_fare', 'airfare_flight_time', 'cheapest',\n",
" 'aircraft_flight_flight_no'], dtype=object)"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"data['intent'].value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wN_Mnz5XVDP6",
"outputId": "20db46e3-2b4a-4791-de65-fbae8e9b9d55"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"flight 3666\n",
"airfare 423\n",
"ground_service 255\n",
"airline 157\n",
"abbreviation 147\n",
"aircraft 81\n",
"flight_time 54\n",
"quantity 51\n",
"flight_airfare 21\n",
"airport 20\n",
"distance 20\n",
"city 19\n",
"ground_fare 18\n",
"capacity 16\n",
"flight_no 12\n",
"meal 6\n",
"restriction 6\n",
"airline_flight_no 2\n",
"ground_service_ground_fare 1\n",
"airfare_flight_time 1\n",
"cheapest 1\n",
"aircraft_flight_flight_no 1\n",
"Name: intent, dtype: int64"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"labels = ['flight','ground_service','airfare','abbreviation','flight_time']"
],
"metadata": {
"id": "HYwbkYVwWLxK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"data = data[data[\"intent\"].isin(labels)]\n",
"data['intent'].value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oj86WhtpXTYB",
"outputId": "f17eaea1-bda3-4913-945e-5572c97e9b83"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"flight 3666\n",
"airfare 423\n",
"ground_service 255\n",
"abbreviation 147\n",
"flight_time 54\n",
"Name: intent, dtype: int64"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"sample_data = data.groupby('intent').apply(lambda x: x.sample(n=40)).reset_index(drop = True)\n",
"sample_data.intent.value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9kmxqUy3X71W",
"outputId": "905decb5-9da4-486d-f68d-ff710a6c4a93"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"abbreviation 40\n",
"airfare 40\n",
"flight 40\n",
"flight_time 40\n",
"ground_service 40\n",
"Name: intent, dtype: int64"
]
},
"metadata": {},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"source": [
"sample_data.to_csv(\"sample_data.csv\",index=False)"
],
"metadata": {
"id": "OmFeffyMuPN_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sample_data = sample_data[['text','intent']]"
],
"metadata": {
"id": "L9m4n4m3aMXg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"sample_data.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "E5dCsBQoYt5L",
"outputId": "41d1f05f-7f94-4d3d-c3c6-0578854f8c07"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text intent\n",
"0 fare code y what does that mean abbreviation\n",
"1 does dl stand for delta abbreviation\n",
"2 what does code qw mean abbreviation\n",
"3 please explain fare code f abbreviation\n",
"4 what is ewr abbreviation"
],
"text/html": [
"\n",
" <div id=\"df-6d543367-4320-4c51-9324-47bc5af9b651\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>intent</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>fare code y what does that mean</td>\n",
" <td>abbreviation</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>does dl stand for delta</td>\n",
" <td>abbreviation</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>what does code qw mean</td>\n",
" <td>abbreviation</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>please explain fare code f</td>\n",
" <td>abbreviation</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>what is ewr</td>\n",
" <td>abbreviation</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-6d543367-4320-4c51-9324-47bc5af9b651')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-6d543367-4320-4c51-9324-47bc5af9b651 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-6d543367-4320-4c51-9324-47bc5af9b651');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"source": [
"sample_data['text'] = sample_data['text'].str.strip()\n",
"sample_data['intent'] = sample_data['intent'].str.strip()"
],
"metadata": {
"id": "pLTOXkLcjI7Q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sample_data['text'] = sample_data['text'] + \"\\n\\nIntent:\\n\\n\"\n",
"# sample_data['text'] = \"Classify text into on the intent: flight, ground_service, airline, aircraft, flight_time. Text: \"+sample_data['text'] + \"\\n\\nIntent:\\n\\n\"\n",
"sample_data['intent'] = \" \"+sample_data['intent'] + \" END\"\n",
"sample_data.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "OR1BJULNh29M",
"outputId": "49998e49-c377-4247-c0ea-e2c9e4c8d921"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text intent\n",
"0 fare code y what does that mean\\n\\nIntent:\\n\\n abbreviation END\n",
"1 does dl stand for delta\\n\\nIntent:\\n\\n abbreviation END\n",
"2 what does code qw mean\\n\\nIntent:\\n\\n abbreviation END\n",
"3 please explain fare code f\\n\\nIntent:\\n\\n abbreviation END\n",
"4 what is ewr\\n\\nIntent:\\n\\n abbreviation END"
],
"text/html": [
"\n",
" <div id=\"df-ed53e3ba-f468-4a26-98f2-903f35fa4b52\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>intent</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>fare code y what does that mean\\n\\nIntent:\\n\\n</td>\n",
" <td>abbreviation END</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>does dl stand for delta\\n\\nIntent:\\n\\n</td>\n",
" <td>abbreviation END</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>what does code qw mean\\n\\nIntent:\\n\\n</td>\n",
" <td>abbreviation END</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>please explain fare code f\\n\\nIntent:\\n\\n</td>\n",
" <td>abbreviation END</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>what is ewr\\n\\nIntent:\\n\\n</td>\n",
" <td>abbreviation END</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ed53e3ba-f468-4a26-98f2-903f35fa4b52')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-ed53e3ba-f468-4a26-98f2-903f35fa4b52 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-ed53e3ba-f468-4a26-98f2-903f35fa4b52');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"source": [
"print(sample_data['text'][0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PHfA-51Vi56x",
"outputId": "520abf81-ded2-4839-9d3c-274b3fc1dbc3"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"fare code y what does that mean\n",
"\n",
"Intent:\n",
"\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(sample_data['intent'][0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cYo95VjijZK1",
"outputId": "46a9e64f-b0cf-4ee9-962e-0268004fab55"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" abbreviation END\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"sample_data.columns = ['prompt','completion']"
],
"metadata": {
"id": "4twzXMkKZwUX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sample_data.to_json(\"intent_sample.jsonl\", orient='records', lines=True)"
],
"metadata": {
"id": "0JW51Cb7YcZO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# sample_data.to_json(\"intent_.json\", orient='records')"
],
"metadata": {
"id": "v0Pg58AHYzXF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install --upgrade openai"
],
"metadata": {
"id": "1jc26bsGY7V8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!openai tools fine_tunes.prepare_data -f intent_sample.jsonl"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0bhyKnRuZYG9",
"outputId": "7e287cf7-1149-4273-a4ea-1be0b71427d4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Analyzing...\n",
"\n",
"- Your file contains 200 prompt-completion pairs\n",
"- Based on your data it seems like you're trying to fine-tune a model for classification\n",
"- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n",
"- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training\n",
"- There are 10 duplicated prompt-completion sets. These are rows: [11, 15, 19, 24, 27, 34, 39, 129, 152, 179]\n",
"- All prompts end with suffix `\\n\\nIntent:\\n\\n`. This suffix seems very long. Consider replacing with a shorter suffix, such as `\\n\\n###\\n\\n`\n",
"\n",
"Based on the analysis we will perform the following actions:\n",
"- [Recommended] Remove 10 duplicate rows [Y/n]: Y\n",
"- [Recommended] Would you like to split into training and validation set? [Y/n]: Y\n",
"\n",
"\n",
"Your data will be written to a new JSONL file. Proceed [Y/n]: Y\n",
"\n",
"Wrote modified files to `intent_sample_prepared_train.jsonl` and `intent_sample_prepared_valid.jsonl`\n",
"Feel free to take a look!\n",
"\n",
"Now use that file when fine-tuning:\n",
"> openai api fine_tunes.create -t \"intent_sample_prepared_train.jsonl\" -v \"intent_sample_prepared_valid.jsonl\" --compute_classification_metrics --classification_n_classes 5\n",
"\n",
"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `\\n\\nIntent:\\n\\n` for the model to start generating completions, rather than continuing with the prompt. Make sure to include `stop=[\" END\"]` so that the generated texts ends at the expected place.\n",
"Once your model starts training, it'll approximately take 6.89 minutes to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"os.environ['OPENAI_API_KEY'] = \"key\""
],
"metadata": {
"id": "HXH7FcA74j6q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!openai api fine_tunes.create -t \"intent_sample_prepared_train.jsonl\" -v \"intent_sample_prepared_valid.jsonl\" -m 'davinci'"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cR27GSpl4lQ0",
"outputId": "b48ec37b-a4d1-414e-edca-a7444de6451c"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found potentially duplicated files with name 'intent_sample_prepared_train.jsonl', purpose 'fine-tune' and size 17810 bytes\n",
"file-ujyQ2W8fzEBP5nKviYzjBE6d\n",
"Enter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: \n",
"Upload progress: 100% 17.8k/17.8k [00:00<00:00, 29.4Mit/s]\n",
"Uploaded file from intent_sample_prepared_train.jsonl: file-qFZO4gBx4D1AfG6nXLZRPYRO\n",
"Found potentially duplicated files with name 'intent_sample_prepared_valid.jsonl', purpose 'fine-tune' and size 4643 bytes\n",
"file-KMcFoGeWhZwDVKd2rij0fW2f\n",
"Enter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: \n",
"Upload progress: 100% 4.64k/4.64k [00:00<00:00, 6.03Mit/s]\n",
"Uploaded file from intent_sample_prepared_valid.jsonl: file-PoccQPUMQqex4ctcDqC95gF7\n",
"Created fine-tune: ft-L68gCxl9xH1Cf6JZ7HYGevDV\n",
"Streaming events until fine-tuning is complete...\n",
"\n",
"(Ctrl-C will interrupt the stream, but not cancel the fine-tune)\n",
"[2022-09-11 18:52:30] Created fine-tune: ft-L68gCxl9xH1Cf6JZ7HYGevDV\n",
"[2022-09-11 18:52:34] Fine-tune costs $0.40\n",
"[2022-09-11 18:52:35] Fine-tune enqueued. Queue number: 0\n",
"[2022-09-11 18:52:36] Fine-tune started\n",
"\n",
"Stream interrupted (client disconnected).\n",
"To resume the stream, run:\n",
"\n",
" openai api fine_tunes.follow -i ft-L68gCxl9xH1Cf6JZ7HYGevDV\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!openai api fine_tunes.follow -i ft-L68gCxl9xH1Cf6JZ7HYGevDV"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "r3Qa8OXw759G",
"outputId": "ec455502-201b-4f79-f1ca-bc1a079d8f4b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[2022-09-11 18:52:30] Created fine-tune: ft-L68gCxl9xH1Cf6JZ7HYGevDV\n",
"[2022-09-11 18:52:34] Fine-tune costs $0.40\n",
"[2022-09-11 18:52:35] Fine-tune enqueued. Queue number: 0\n",
"[2022-09-11 18:52:36] Fine-tune started\n",
"[2022-09-11 18:59:59] Completed epoch 1/4\n",
"[2022-09-11 19:00:46] Completed epoch 2/4\n",
"[2022-09-11 19:01:36] Completed epoch 3/4\n",
"[2022-09-11 19:02:23] Completed epoch 4/4\n",
"[2022-09-11 19:05:11] Uploaded model: davinci:ft-personal-2022-09-11-19-05-11\n",
"[2022-09-11 19:05:13] Uploaded result file: file-VWYf7Mgq4dL8lXEDP0u85f17\n",
"[2022-09-11 19:05:13] Fine-tune succeeded\n",
"\n",
"Job complete! Status: succeeded 🎉\n",
"Try out your fine-tuned model:\n",
"\n",
"openai api completions.create -m davinci:ft-personal-2022-09-11-19-05-11 -p <YOUR_PROMPT>\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# prompt = \"Do we have london flight on Monday\\n\\nIntent:\\n\\n\"\n",
"# prompt = \"what is the ap57 restriction\\n\\nIntent:\\n\\n\"\n",
"prompt = \"show me ground transportation in baltimore\\n\\nIntent:\\n\\n\""
],
"metadata": {
"id": "ev9FUSJEgGQA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "lC1qLV9Vwdvi"
}
},
{
"cell_type": "code",
"source": [
"import openai\n",
"openai.api_key ='key'\n",
"response = openai.Completion.create(\n",
" model=\"davinci:ft-personal-2022-09-11-19-05-11\",\n",
" prompt=prompt,\n",
" max_tokens=5,\n",
" temperature=0,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" stop=[\" END\"]\n",
")\n",
"print(response['choices'][0]['text'])"
],
"metadata": {
"id": "nKlaspj-fFbm",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d32648a5-5e10-4472-cebe-d5288107bbe7"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" ground_service\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Jt22aIK29d_Z"
},
"execution_count": null,
"outputs": []
}
]
}