<a href="https://colab.research.google.com/github/agemagician/CodeTrans/blob/main/prediction/single%20task/source%20code%20summarization/sql/small_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**<h3>Summarize the sql source code using codeTrans single task training model</h3>**
<h4>You can make free prediction online through this 
<a href="https://huggingface.co/SEBIS/code_trans_t5_small_source_code_summarization_sql">Link</a></h4> (When using the prediction online, you need to parse and tokenize the code first.)

**1. Load necessry libraries including huggingface transformers**

In [1]:
!pip install -q transformers sentencepiece

[K     |████████████████████████████████| 1.4MB 5.7MB/s 
[K     |████████████████████████████████| 1.1MB 46.6MB/s 
[K     |████████████████████████████████| 2.9MB 42.4MB/s 
[K     |████████████████████████████████| 890kB 52.7MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [2]:
from transformers import AutoTokenizer, AutoModelWithLMHead, SummarizationPipeline

**2. Load the token classification pipeline and load it into the GPU if avilabile**

In [3]:
pipeline = SummarizationPipeline(
    model=AutoModelWithLMHead.from_pretrained("SEBIS/code_trans_t5_small_source_code_summarization_sql"),
    tokenizer=AutoTokenizer.from_pretrained("SEBIS/code_trans_t5_small_source_code_summarization_sql", skip_special_tokens=True),
    device=0
)



HBox(children=(FloatProgress(value=0.0, description='Downloading', max=630.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242087629.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=797030.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1786.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=24.0, style=ProgressStyle(description_w…




**3 Give the code for summarization, parse and tokenize it**

In [4]:
code = "select time (fieldname) from tablename" #@param {type:"raw"}

In [5]:
import re
import sqlparse

scanner=re.Scanner([
  (r"\[[^\]]*\]",       lambda scanner,token: token),
  (r"\+",      lambda scanner,token:"R_PLUS"),
  (r"\*",        lambda scanner,token:"R_KLEENE"),
  (r"%",        lambda scanner,token:"R_WILD"),
  (r"\^",        lambda scanner,token:"R_START"),
  (r"\$",        lambda scanner,token:"R_END"),
  (r"\?",        lambda scanner,token:"R_QUESTION"),
  (r"[\.~``;_a-zA-Z0-9\s=:\{\}\-\\]+", lambda scanner,token:"R_FREE"),
  (r'.', lambda scanner, token: None),
])

def tokenizeRegex(s):
  results, remainder=scanner.scan(s)
  return results

def my_traverse(token_list, statement_list, result_list):
  for t in token_list:
    if t.ttype == None:
      my_traverse(t, statement_list, result_list)
    elif t.ttype != sqlparse.tokens.Whitespace:
      statement_list.append(t.ttype)
      result_list.append(str(t))
  return statement_list, result_list

def sanitizeSql(sql):
  s = sql.strip().lower()
  if not s[-1] == ";":
    s += ';'
  s = re.sub(r'\(', r' ( ', s)
  s = re.sub(r'\)', r' ) ', s)
  s = s.replace('#', '')
  return s

In [6]:

statement_list = []
result_list = []
code = sanitizeSql(code)
tokens = sqlparse.parse(code)
statements, result = my_traverse(tokens, statement_list, result_list)

table_map = {}
column_map = {}
for i in range(len(statements)):
  if statements[i] in [sqlparse.tokens.Number.Integer, sqlparse.tokens.Literal.Number.Integer]:
    result[i] = "CODE_INTEGER"
  elif statements[i] in [sqlparse.tokens.Number.Float, sqlparse.tokens.Literal.Number.Float]:
    result[i] = "CODE_FLOAT"
  elif statements[i] in [sqlparse.tokens.Number.Hexadecimal, sqlparse.tokens.Literal.Number.Hexadecimal]:
    result[i] = "CODE_HEX"
  elif statements[i] in [sqlparse.tokens.String.Symbol, sqlparse.tokens.String.Single, sqlparse.tokens.Literal.String.Single, sqlparse.tokens.Literal.String.Symbol]:
    result[i] = tokenizeRegex(result[i])
  elif statements[i] in[sqlparse.tokens.Name, sqlparse.tokens.Name.Placeholder, sqlparse.sql.Identifier]:
    old_value = result[i]
    if old_value in column_map:
      result[i] = column_map[old_value]
    else:
      result[i] = 'col'+ str(len(column_map))
      column_map[old_value] = result[i]
  elif (result[i] == "." and statements[i] == sqlparse.tokens.Punctuation and i > 0 and result[i-1].startswith('col')):
    old_value = result[i-1]
    if old_value in table_map:
      result[i-1] = table_map[old_value]
    else:
      result[i-1] = 'tab'+ str(len(table_map))
      table_map[old_value] = result[i-1]
  if (result[i].startswith('col') and i > 0 and (result[i-1] in ["from"])):
    old_value = result[i]
    if old_value in table_map:
      result[i] = table_map[old_value]
    else:
      result[i] = 'tab'+ str(len(table_map))
      table_map[old_value] = result[i]

tokenized_code = ' '.join(result)
print("SQL after tokenized: " + tokenized_code)

SQL after tokenized: select time ( col0 ) from tab0 ;


**4. Make Prediction**

In [7]:
pipeline([tokenized_code])

Your max_length is set to 512, but you input_length is only 18. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)


[{'summary_text': 'mysql : how to get the difference of a column in a table ?'}]