In [1]:
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

In [2]:
import os

cache_dir = "/root/autodl-fs"
os.environ["TRANSFORMERS_CACHE"] = cache_dir

In [3]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import logging
logging.basicConfig(level=logging.INFO)# OPTIONAL



tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
model.to('cuda')  # if you have gpu


def predict_masked_sent(text, top_k=5):
    # Tokenize input
    text = "[CLS] %s [SEP]"%text
    tokenized_text = tokenizer.tokenize(text)
    masked_index = tokenized_text.index("[MASK]")
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    tokens_tensor = tokens_tensor.to('cuda')    # if you have gpu
    predicted_token = ""

    # Predict all tokens
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
    top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)

    for i, pred_idx in enumerate(top_k_indices):
        predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
        token_weight = top_k_weights[i]
        print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))
    return predicted_token

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
import os

# Define the directory containing the files
directory = 'raw_data'

# Iterate over each file in the directory
for filename in os.listdir(directory):
    if filename.endswith('.txt'):
        # Open the input file
        with open(os.path.join(directory, filename), 'r', encoding='utf-8') as infile, open(os.path.join('cleared_data', filename), 'w', encoding='utf-8') as outfile:
            for line in infile:
                max_length = 512  

                if len(line) > max_length:
                    line = line[:max_length]
                modified_line = line
                if "[inaudible]" in line or "[redacted]" in line or "[crosstalk]" in line:
                    # Replace and tag occurrences of "[inaudible]", "[redacted]", and "[crosstalk]" with "[MASK]"
                    modified_line = line.replace("[inaudible]", "[MASK]").replace("[redacted]", "[MASK]").replace("[crosstalk]", "[MASK]")
                    
                    mask_count = modified_line.count("[MASK]")
                    
                    # Replace each occurrence of [MASK] with the predicted word
                    for _ in range(mask_count):
                        max_length = 512 
                        
                        if len(modified_line) > max_length:
                            modified_line1 = modified_line[:max_length]
                            modified_line2 = modified_line[max_length]
                            replaced = predict_masked_sent(modified_line1, top_k=1)
                            modified_line1 = modified_line1.replace("[MASK]", replaced, 1)
                            
                            replaced = predict_masked_sent(modified_line2, top_k=1)
                            modified_line2 = modified_line2.replace("[MASK]", replaced, 1) 
                            modified_line = modified_line1 + modified_line2
                            print("truncate")
                        else:
                            replaced = predict_masked_sent(modified_line, top_k=1)
                            modified_line = modified_line.replace("[MASK]", replaced, 1)
        
        
                term_count = modified_line.count("[")
                for _ in range(term_count):
                    index_start = modified_line.find("[")
                    if index_start != -1:
                        index_end = modified_line.find("]", index_start)
                        if index_end != -1:
                            if modified_line[index_end-1] == "?":
                                # Remove '[' and '?' before ']'
                                modified_line = modified_line[:index_start]+modified_line[index_start+1: index_end-1]+modified_line[index_end+1:]
                            else: modified_line = modified_line[:index_start]+modified_line[index_end+1:]
                
                # Write the modified line to the output file
                outfile.write(modified_line)            


[MASK]: 'maybe'  | weights: 0.2179914116859436
[MASK]: 'looking'  | weights: 0.18057651817798615
[MASK]: 'look'  | weights: 0.05956616625189781
[MASK]: 'hope'  | weights: 0.041712548583745956
[MASK]: 'try'  | weights: 0.031869422644376755
[MASK]: 'else'  | weights: 0.15114066004753113
[MASK]: 'different'  | weights: 0.10842850804328918
[MASK]: 'better'  | weights: 0.09416163712739944
[MASK]: 'more'  | weights: 0.04375604912638664
[MASK]: 'new'  | weights: 0.03972452133893967
[MASK]: 'on'  | weights: 0.949520468711853
[MASK]: 'in'  | weights: 0.02288668230175972
[MASK]: 'around'  | weights: 0.00320236012339592
[MASK]: ','  | weights: 0.002488137222826481
[MASK]: 'there'  | weights: 0.002456692047417164
[MASK]: 'me'  | weights: 0.670590341091156
[MASK]: 'him'  | weights: 0.09520667791366577
[MASK]: 'her'  | weights: 0.07619468867778778
[MASK]: 'us'  | weights: 0.02384626865386963
[MASK]: 'them'  | weights: 0.009263100102543831
