In [1]:
!pip install datasets
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
Collec

In [2]:
import copy
def preprocess(example):

  def _add_adjusted_col_offsets(table):
    """Add adjusted column offsets to take into account multi-column cells."""
    adjusted_table = []
    for row in table:
      real_col_index = 0
      adjusted_row = []
      for cell in row:
        adjusted_cell = copy.deepcopy(cell)
        adjusted_cell["adjusted_col_start"] = real_col_index
        adjusted_cell["adjusted_col_end"] = (
            adjusted_cell["adjusted_col_start"] + adjusted_cell["column_span"])
        real_col_index += adjusted_cell["column_span"]
        adjusted_row.append(adjusted_cell)
      adjusted_table.append(adjusted_row)
    return adjusted_table


  def _get_heuristic_row_headers(adjusted_table, row_index, col_index):
    """Heuristic to find row headers."""
    row_headers = []
    row = adjusted_table[row_index]
    for i in range(0, col_index):
      if row[i]["is_header"]:
        row_headers.append(row[i])
    return row_headers


  def _get_heuristic_col_headers(adjusted_table, row_index, col_index):
    """Heuristic to find column headers."""
    adjusted_cell = adjusted_table[row_index][col_index]
    adjusted_col_start = adjusted_cell["adjusted_col_start"]
    adjusted_col_end = adjusted_cell["adjusted_col_end"]
    col_headers = []
    for r in range(0, row_index):
      row = adjusted_table[r]
      for cell in row:
        if (cell["adjusted_col_start"] < adjusted_col_end and
            cell["adjusted_col_end"] > adjusted_col_start):
          if cell["is_header"]:
            col_headers.append(cell)

    return col_headers


  
  table = example['table']
  cell_indices = example["highlighted_cells"]
  table_str = ""
  if example['table_page_title']:
    table_str += "<page_title> " + example['table_page_title'] + " </page_title> "
  if example['table_section_title']:
    table_str += "<section_title> " + example['table_section_title'] + " </section_title> "

  table_str += "<table> "
  adjusted_table = _add_adjusted_col_offsets(table)
  for r_index, row in enumerate(table):
    row_str = "<row> "
    for c_index, col in enumerate(row):

      row_headers = _get_heuristic_row_headers(adjusted_table, r_index, c_index)
      col_headers = _get_heuristic_col_headers(adjusted_table, r_index, c_index)

      # Distinguish between highlighted and non-highlighted cells.
      if [r_index, c_index] in cell_indices:
        start_cell_marker = "<highlighted_cell> "
        end_cell_marker = "</highlighted_cell> "
      else:
        start_cell_marker = "<c> "
        end_cell_marker = "</c> "

      # The value of the cell.
      item_str = start_cell_marker + col["value"] + " "

      # All the column headers associated with this cell.
      for col_header in col_headers:
        item_str += "<col_header> " + col_header["value"] + " </col_header> "

      # All the row headers associated with this cell.
      for row_header in row_headers:
        item_str += "<row_header> " + row_header["value"] + " </row_header> "

      item_str += end_cell_marker
      row_str += item_str

    row_str += "</row> "
    table_str += row_str

  table_str += "</table>"

  example['linearized_table'] = '<s>' + table_str + '\n' + '\n'
  return example

In [3]:
from datasets import load_dataset
from transformers import BloomTokenizerFast, BloomForCausalLM

In [4]:
ckpt = "mrm8488/bloom-560m-finetuned-totto-table-to-text"

tokenizer = BloomTokenizerFast.from_pretrained(ckpt)
model = BloomForCausalLM.from_pretrained(ckpt).to("cpu")

Downloading (…)okenizer_config.json:   0%|          | 0.00/268 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/96.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/873 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

In [5]:
def explain_hl_cells(text):
    inputs = tokenizer(text, return_tensors='pt')
    input_ids = inputs.input_ids.to("cpu")
    attention_mask = inputs.attention_mask.to("cpu")
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=2048, eos_token_id=tokenizer.eos_token_id)

    return tokenizer.decode(output[0], skip_special_tokens=False)


In [6]:
def table_to_totto(table_dict: dict):
    output = {}
    output.update(table_page_title = "Invitation to Company Event.")
    output.update(table_webpage_url = "")
    output.update(table_section_title = "Invitation to Company Event.")
    output.update(table_section_text = "")

    input_table = []
    keys = []
    for key in table_dict:
        tmp  = {
            "column_span": 1,
            "is_header": True,
            "row_span": 1,
            "value": key
        }
        keys.append(tmp)

    values = []
    for value in table_dict.values():
        tmp = {
            "column_span": 1,
            "is_header": False,
            "row_span": 1,
            "value": value
        }
        values.append(tmp)

    input_table.append(keys)
    input_table.append(values)

    output.update(table = input_table)

    highlighted_cells = []
    for i in range(len(keys)):
        highlighted_cells.append([0, i])     
    output.update(highlighted_cells = highlighted_cells)
    output.update(example_id = 0)
    output.update(sentence_annotations = [])

    return output


example = {
 "Recipient Name": "Receiver",
 "Greetings": "Dear Receiver,",
 "Opening": "I hope this email finds you well. ",
 "Reason for writing": "I am writing to extend a cordial invitation to you for our upcoming company event",
 "Date and Time": "tomorrow at 12.30 in the meeting room",
 "Event Purpose": "to celebrate the remarkable achievements of our organization over the past year and to express our gratitude to all the individuals who have contributed to our success",
 "Activities": "keynote speeches from renowned industry experts, interactive workshops, and a networking reception",
 "Importance of Attendance": "We highly value your presence and the unique perspective you bring to our organization. Your attendance would greatly contribute to the overall success of the event. We believe that your participation would not only strengthen our professional network but also foster valuable collaborations in the future.",
 "Confirmation Request": "Kindly confirm your availability by 16.30 today so that we can make the necessary arrangements for seating and catering. If you have any dietary restrictions or special requirements, please let us know in advance so that we can accommodate them accordingly.",
 "Closing": "Thank you for your time, and we anticipate a memorable evening celebrating our collective achievements.",
 "Closing Greetings": "Best regards,"
}

In [7]:
from datasets import Dataset
my_dict = table_to_totto(example)
dataset = Dataset.from_list([my_dict])


In [8]:
dataset = dataset.map(preprocess)

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [None]:
print(explain_hl_cells(dataset[0]['linearized_table']))