Skip to content

Commit

Permalink
Merge pull request #623 from donggrant/fix-extra-quotations
Browse files Browse the repository at this point in the history
Fixed quotation issues
  • Loading branch information
qiyanjun committed Mar 20, 2022
2 parents 14eb686 + fb5a35c commit 11782c4
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions textattack/commands/augment_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,36 @@ def run(self, args):
# Read in CSV file as a list of dictionaries. Use the CSV sniffer to
# try and automatically infer the correct CSV format.
csv_file = open(args.input_csv, "r")

# mark where commas and quotes occur within the text value
def markQuotes(lines):
for row in lines:
row = row.replace('"', '"/')
yield row

dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,")
csv_file.seek(0)
rows = [
row
for row in csv.DictReader(
csv_file, dialect=dialect, skipinitialspace=True
markQuotes(csv_file),
dialect=dialect,
skipinitialspace=True,
)
]

# replace markings with quotations and commas
for row in rows:
for item in row:
i = 0
while i < len(row[item]):
if row[item][i] == "/":
if row[item][i - 1] == '"':
row[item] = row[item][:i] + row[item][i + 1 :]
else:
row[item] = row[item][:i] + '"' + row[item][i + 1 :]
i += 1

# Validate input column.
row_keys = set(rows[0].keys())
if args.input_column not in row_keys:
Expand All @@ -174,20 +196,30 @@ def run(self, args):
augmented_row = row.copy()
augmented_row[args.input_column] = augmentation
output_rows.append(augmented_row)

# Print to file.
with open(args.output_csv, "w") as outfile:
csv_writer = csv.writer(
outfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
outfile, delimiter=",", quotechar="/", quoting=csv.QUOTE_MINIMAL
)
# Write header.
csv_writer.writerow(output_rows[0].keys())
# Write rows.
for row in output_rows:
csv_writer.writerow(row.values())

textattack.shared.logger.info(
f"Wrote {len(output_rows)} augmentations to {args.output_csv} in {time.time() - start_time}s."
)

# Remove extra markings in output file
with open(args.output_csv, "r") as file:
data = file.readlines()
for i in range(len(data)):
data[i] = data[i].replace("/", "")
with open(args.output_csv, "w") as file:
file.writelines(data)

@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
Expand Down

0 comments on commit 11782c4

Please sign in to comment.