In [1]:
!pip install transformers
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-2.0.0+cu118.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-2.0.0+cu118.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-2.0.0+cu118.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-2.0.0+cu118.html
# The same as Torch version and CUDA version (torch.__version__ is 2.0.0+cu118)
!pip install torch-geometric
!python3 -m spacy download en_core_web_sm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m46.8 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1
Looking in indexes: https://pypi.org/simple, https://

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
print(torch.__version__)
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader
from transformers import BertModel, BertTokenizer, BertConfig
import csv
import spacy
from spacy import displacy
import networkx as nx
#from gensim.models import Word2Vec
from tqdm import tqdm
import xml.etree.ElementTree as ET
from nltk.corpus import wordnet as wn
from nltk.corpus.reader.wordnet import Synset
from nltk.wsd import lesk
import nltk
import pickle
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from google.colab import drive
drive.mount('/content/drive')

nltk.download('wordnet')
nltk.download('punkt')

2.0.0+cu118
Mounted at /content/drive


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [3]:
nlp = spacy.load('en_core_web_sm') # DO NOT split a sentence for more 1 time
# If spliting a sentence more than 1 time, the result may different from the same split
model_path = "/content/drive/My Drive/my_GNN_MTL_PT_model_v4_epoch_12_val_loss_0.3258098229351971.pt"  # Choose your desired path and filename
# Model parameters
bert_model_name = 'bert-base-uncased'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def Wu_Palmer_similarity(sentence, word_list):
    """
    Wu-Palmer similarity (wup_similarity):
    Range: [0, 1]
    Minimum: 0 (no relationship)
    Maximum: 1 (identical synsets)
    """
    # No self-loop
    # Example sentence
    # sentence = "The cat chased the dog"

    # Calculate semantic relatedness using Wu-Palmer similarity
    # word1 = 'dog'
    # word2 = 'cat'
    # synset1 = lesk(nltk.word_tokenize(sentence), word1, 'n')  # Disambiguate word1
    # synset2 = lesk(nltk.word_tokenize(sentence), word2, 'n')  # Disambiguate word2

    edge_weight = torch.rand((len(word_list) * (len(word_list) - 1),))
    # Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1)
    nltk_word_tokenize_sentence = nltk.word_tokenize(sentence)
    synset_list = []

    for word in word_list:
        synset_list.append(lesk(nltk_word_tokenize_sentence, word, 'n'))

    for i in range(len(word_list)):
        for j in range(i + 1, len(word_list)):
            if isinstance(synset_list[0], Synset) and isinstance(synset_list[1], Synset):
                similarity = synset_list[0].wup_similarity(synset_list[1])
                # similarity = synset1.wup_similarity(synset2)
                # The same as synset2.wup_similarity(synset1)
                # print("Similarity between", word1, "and", word2, ":", similarity)
                edge_weight[j * (len(word_list) - 1) + i + (-1 if i > j else 0)] = similarity
                edge_weight[i * (len(word_list) - 1) + j + (-1 if j > i else 0)] = similarity
            #else:
                # print("Cannot compute similarity. One or both words are not in WordNet.")
                # random samples from a uniform distribution over [0, 1).

    return edge_weight

In [5]:
def create_dataset(texts, labels, tokenizer, max_length, total_location):
    # Each element of the texts is word list
    assert len(texts) == len(labels)
    dataset = []

    for data_index in tqdm(range(len(labels))):
        text, label, location = texts[data_index], labels[data_index], total_location[data_index]
        # Define the input text
        # text = ['We’ve', 'got', 'children’s', 'church', 'because', 'we', 'know', 'how', 'to', 'kid', 'around']
        corresponding_location = -10000
        id_range = [] # The word belong to which ids [start, end)

        # Process the text with the spaCy NLP pipeline
        words = [str(tok) for sent in (nlp(' '.join(text))).sents for tok in sent]
        #words = text
        #print(words)

        corresponding_index = []
        # consider which token (in tokens[], base on index) to be the corresponding word

        current_index = 0
        start = False
        idx = 0
        #print(words)
        for word_idx, x in enumerate(words):
            #print(x)
            subwords = tokenizer.tokenize((' ' if start else '') + x)
            start = True
            #print(subwords)
            for i in range(1, len(subwords)):
                subwords[i] = subwords[i][2:] # delete the '##'

            # subwords_len = np.zeros((len(subwords),), dtype=float)
            subwords_len = torch.zeros((len(subwords),))
            for count_index, each_token in enumerate(subwords):
                subwords_len[count_index] = len(each_token)
            # longest_index = np.argmax(subwords_len, axis=0)
            longest_index = int(torch.argmax(subwords_len, dim=0))

            #print(longest_index)
            if 1 + current_index + longest_index < max_length - 1:
                # Cannot longer than max_length
                # because the input_ids has a start ID and end ID
                # which needs to +1 and -1 to match the index
                corresponding_index.append(current_index + longest_index)
                id_range.append([current_index, current_index + len(subwords)])
                if word_idx == location:
                    corresponding_location = current_index + longest_index

                current_index += len(subwords)
                idx += 1
            else:
                break

        if label > 0.5 and corresponding_location < 0:
            print("Do not get the corresponding_location!")

        original_idx = idx
        if text[:idx] != words[:idx]:
            for each_word in text[:idx]:
                original_idx -= 1
                sub_sent = nlp(' '.join(text[:original_idx]))
                if [str(tok) for sent in sub_sent.sents for tok in sent] == words[:idx]:
                    break
            else:
                print("Something Wrong!")

        # TODO: Some bugs in here, if the cut word is the last word, it maybe wrong
        
        doc = nlp(' '.join(text[:original_idx]))
        assert [str(tok) for sent in doc.sents for tok in sent] == words[:idx]
        graph = nx.DiGraph()
        for token in doc:
            graph.add_node(token.i, word=token.text)
            graph.add_edge(token.head.i, token.i, relation=token.dep_)

        #for node in graph.nodes(data=True):
        #    print(f"Node: {node}")

        #print(corresponding_index)
        #print(words[:idx])
        #print(id_range)
        corresponding_index = torch.tensor(corresponding_index, dtype=torch.long)
        #print(corresponding_index)
        assert len(corresponding_index) == len(words[:idx])
        assert len(corresponding_index) == len(id_range)

        sentence = ' '.join(words) # be processed in tokenizer
        #tokens = tokenizer.tokenize(sentence)
        #print(tokens)
        
        tokenized = tokenizer(sentence, return_tensors="pt", max_length=max_length, return_token_type_ids=False, truncation=True, padding="max_length")
        #print(tokenized)

        num_nodes = len(corresponding_index)
        source_nodes = [i for j in range(num_nodes) for i in range(num_nodes) if i != j]
        target_nodes = [j for j in range(num_nodes) for i in range(num_nodes) if i != j]
        # Complete directed graph, no self-loop

        # Add the root, which is self-loop, let the temporary index is -1, -1, and it is the last element
        source_nodes.append(-1)
        target_nodes.append(-1)

        source_nodes = torch.tensor(source_nodes, dtype=torch.long)
        target_nodes = torch.tensor(target_nodes, dtype=torch.long)
        edge_index = torch.stack([source_nodes, target_nodes], dim=0)
        #print(edge_index)
        # print(edge_index.shape[1]) # the number of edges

        # edge_weight = torch.rand(edge_index.shape[1])  # Replace this with the actual edge weights of your adjacency matrix
        edge_weight = Wu_Palmer_similarity(sentence, words[:idx]) # numpy array
        # edge_weight = np.append(edge_weight, [10.0], axis=0) # add the ROOT value
        edge_weight = torch.cat((edge_weight, torch.tensor([10.0])), dim=-1)
        # edge_weight = torch.tensor(edge_weight, dtype=torch.float32) # the type must dtype=torch.float32
        assert edge_index.shape[1] == len(edge_weight)

        has_ROOT = False
        # Add dependency parsing tree
        for edge in graph.edges(data=True):
            #print(f"Edge: {edge}")
            source_index = edge[0]
            target_index = edge[1]
            if source_index == target_index:
                # ROOT
                source_nodes[-1] = source_index
                target_nodes[-1] = target_index
                edge_index[0, len(source_nodes) - 1] = source_index
                edge_index[1, len(target_nodes) - 1] = target_index
                edge_weight[len(edge_weight) - 1] = 20.0 # may change this number
                # print(edge_index[:, len(edge_weight) - 1])
                has_ROOT = True
            else:
                # not ROOT
                edge_weight[target_index * (num_nodes - 1) + source_index + (-1 if source_index > target_index else 0)] += 0.8
                edge_weight[target_index * (num_nodes - 1) + source_index + (-1 if source_index > target_index else 0)] *= 10
                # may change this number
                # print(edge_index[:, target_index * (num_nodes - 1) + source_index + (-1 if source_index > target_index else 0)])
        if not has_ROOT:
            print("No ROOT, something wrong!")

        y = torch.tensor(label, dtype=torch.long)

        #tokenized["input_ids"][0] == tokenized["input_ids"].flatten()
        #print(corresponding_index)
        if len(tokenized["input_ids"][0]) != max_length:
            print(len(tokenized["input_ids"][0]))
            print("The length is wrong!")
        for each_corre_idx, each_corre in enumerate(corresponding_index):
            assert id_range[each_corre_idx][0] <= each_corre < id_range[each_corre_idx][1]

        data = Data(data_identification=torch.tensor(data_index, dtype=torch.long), # The ID of the data
                    input_ids=tokenized["input_ids"][0],
                    attention_mask=tokenized["attention_mask"][0],
                    corresponding=corresponding_index,
                    sentence_len=torch.tensor(len(corresponding_index), dtype=torch.long), # original length
                    edge_index=edge_index,
                    edge_weight=edge_weight,
                    location=torch.tensor(location, dtype=torch.long), # pun word index in original sentence
                    corresponding_location=torch.tensor(corresponding_location, dtype=torch.long), # pun word index in ids (start with 0), need to +1
                    id_range=torch.tensor(id_range, dtype=torch.long),
                    y=y,
                    num_nodes=torch.tensor(num_nodes, dtype=torch.long)
                    )
        # DO NOT corresponding_index=corresponding_index!
        # If the end of the name of parameters of Data() has '_index',
        # the package will consider this is index list, it will be automatic update the value list:
        # Example: [1,2,5], [2,3,9], [3,6,1] -> [1,2,5,8,9,15,19,22,17]
        # Not: [1,2,5], [2,3,9], [3,6,1] -> [1,2,5,2,3,9,3,6,1]
        # If set num_nodes as parameters of Data(), it will be sum automatically (is a number finally, not a list)
        dataset.append(data)
        '''
        print('tokenized["input_ids"][0]')
        print(tokenized["input_ids"][0])
        print('tokenized["attention_mask"][0]')
        print(tokenized["attention_mask"][0])
        print('corresponding_index')
        print(corresponding_index)
        print('len(corresponding_index)')
        print(len(corresponding_index))
        print('location')
        print(location)
        print('corresponding_location')
        print(corresponding_location)
        print('id_range')
        print(id_range)
        print('y')
        print(y)
        '''

    return dataset

In [6]:
# Custom GNN and BiLSTM layers
class GNN(torch.nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        # input_dim is number of features of each node
        super(GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index, edge_weight, batch):
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)

        # Global Pooling (stack different aggregations)
        hidden = torch.cat([global_mean_pool(x, batch),
                            global_max_pool(x, batch)], dim=1)
        return x, hidden


class BERT_GNN_MTL_Classifier(nn.Module):
    def __init__(self, num_classes, hidden_dim, num_lstm_layers, gnn_hidden_dim, gnn_output_dim, dropout, max_length, batch_size):
        super().__init__()
        self.max_length = max_length
        self.batch_size = batch_size
        self.gnn_output_dim = gnn_output_dim
        self.bert_config = BertConfig.from_pretrained(bert_model_name)
        self.bert = BertModel.from_pretrained(bert_model_name, config=self.bert_config)
        # Embedding size of BERT is 768
        self.gnn = GNN(self.bert_config.hidden_size, gnn_hidden_dim, gnn_output_dim)
        #self.bilstm = BiLSTM(self.bert_config.hidden_size, hidden_dim, num_lstm_layers)
        #self.bilstm = BiLSTM(gnn_output_dim, hidden_dim, num_lstm_layers)
        self.bilstm = nn.LSTM(
            input_size=gnn_output_dim,
            hidden_size=hidden_dim,
            num_layers=num_lstm_layers,
            bidirectional=True,
            batch_first=True,
            #dropout=dropout if num_layers > 1 else 0
        )
        self.bilstm_2 = nn.LSTM(
            input_size=self.bert.config.hidden_size,
            hidden_size=hidden_dim,
            num_layers=num_lstm_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if num_lstm_layers > 1 else 0
        )
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(hidden_dim * 2, 1) # 2 is bidirectional LSTM, has 2, 1 is the output
        # self.fc = nn.Linear(hidden_dim * 2, num_classes) # multi-class classfication
        # self.sigmoid = nn.Sigmoid()
        #self.classifier = nn.Linear(hidden_dim * (2 if self.bilstm.bidirectional else 1), 1) # sigmoid
        self.fc2 = nn.Linear(gnn_output_dim * 2 + hidden_dim * 2, 256)
        self.classifier2 = nn.Linear(hidden_dim * (2 if self.bilstm.bidirectional else 1) + 256, 1) # sigmoid

    def forward(self, data, token_index):
        # token_index[i] should be start with 0 and step is 1, because the GNN use only one id per word
        input_ids, attention_mask = data.input_ids, data.attention_mask
        corresponding_index, sentence_len = data.corresponding, data.sentence_len
        edge_index, edge_weight, num_nodes, batch = data.edge_index, data.edge_weight, data.num_nodes, data.batch
        #print(sentence_len)
        assert sum(sentence_len) == len(corresponding_index)
        actual_batch_size = len(input_ids) // self.max_length

        first_cut = 0
        corres = [] # the corresponding index of each sentence
        # print(corresponding_index)
        # print(sentence_len)
        for each_sent_len_i in range(len(sentence_len) - 1):
            corres.append(corresponding_index[first_cut:first_cut+sentence_len[each_sent_len_i]])
            first_cut += sentence_len[each_sent_len_i]
        corres.append(corresponding_index[first_cut:first_cut+sentence_len[-1]])
        # print(corres)
        assert len(corres) == actual_batch_size

        input_ids = torch.reshape(input_ids, (actual_batch_size, self.max_length))
        attention_mask = torch.reshape(attention_mask, (actual_batch_size, self.max_length))
        # print("input_ids.shape:", input_ids.shape) # torch.Size([self.batch_size, self.max_length])

        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        hidden_states = bert_output.last_hidden_state
        # print("hidden_states.shape:", hidden_states.shape) # torch.Size([self.batch_size, self.max_length, 768])

        gnn_input = torch.zeros((len(corresponding_index), hidden_states.shape[-1])) # not include the start and end ID
        each_batch_index = 0
        start_index = 0
        each_counter = 0
        # print(corresponding_index)
        for each_corr_i in corresponding_index:
            # each_corr_i + 1 because the input_ids has a start ID in the front,
            # which needs to add 1 to match the index
            gnn_input[start_index] = hidden_states[each_batch_index, each_corr_i + 1]
            start_index += 1
            each_counter += 1
            if each_counter == sentence_len[each_batch_index]:
                each_batch_index += 1
                each_counter = 0
        each_counter = 0
        assert each_batch_index == actual_batch_size
        assert start_index == len(corresponding_index)
        
        gnn_output, gnn_hidden = self.gnn(x=gnn_input, edge_index=edge_index, edge_weight=edge_weight, batch=batch)
        #print(gnn_hidden.shape)

        bilstm_input = torch.zeros((actual_batch_size, self.max_length, self.gnn_output_dim))
        each_batch_index = 0
        start_index = 0
        # Similar to padding 0
        for each_sen_len in sentence_len:
            for each_word_i in range(each_sen_len):
                bilstm_input[each_batch_index, each_word_i] = gnn_output[start_index]
                start_index += 1
            each_batch_index += 1

        each_batch_index = 0
        start_index = 0

        lstm_output, (hidden, _) = self.bilstm(bilstm_input)
        '''
        Anthor way is using only last hidden state of the LSTM cell:

        Using lstm_output[:, -1] selects the last hidden state of the LSTM cell
        for each sequence in the batch. The reason we use this approach in the
        example provided is that the last hidden state is often a good
        representation of the entire sequence in many sequence-to-sequence
        models, especially for classification tasks.

        When we use lstm_output[:, -1], we're selecting the hidden states of the
        LSTM cells at the last time step (i.e., the last token in the input sequence)
        for each sequence in the batch. This can be a good representation of the
        entire sequence for classification tasks since it captures information
        from both the forward and backward passes of the sequence.
        '''
        pooled_output = torch.mean(lstm_output, 1)

        if self.bilstm.bidirectional:
            hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
            # hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=-1)
        else:
            hidden = hidden[-1, :, :]

        #print(hidden.shape)
        hidden_gnn_bilstm = torch.cat([gnn_hidden, hidden], dim=-1)
        #print(hidden_gnn_bilstm.shape)
        connected_layer = self.relu(self.fc2(self.dropout(hidden_gnn_bilstm)))
        #print("connected_layer.shape", connected_layer.shape)

        # Get the hidden state of the word at the specified index
        assert lstm_output.size(0) == len(input_ids)
        assert actual_batch_size == lstm_output.size(0)
        assert lstm_output.size(2) == self.bilstm.hidden_size * (2 if self.bilstm.bidirectional else 1)

        lstm_output_2, (hidden_2, _) = self.bilstm_2(hidden_states)

        assert lstm_output_2.size(2) == self.bilstm_2.hidden_size * (2 if self.bilstm_2.bidirectional else 1)

        focused_word_hidden = torch.zeros((lstm_output_2.size(0), lstm_output_2.size(2)))
        for sentence_index in range(lstm_output_2.size(0)):
            focused_word_hidden[sentence_index] = lstm_output_2[sentence_index, token_index[sentence_index], :]
        #classification_output = self.classifier(focused_word_hidden)
        classification_output = self.classifier2(torch.cat([focused_word_hidden, connected_layer], dim=-1))

        return self.fc(pooled_output), classification_output
        

In [7]:
def get_random_id_out_of_range(id_range, id_index):

    # check the id_range is correct or not:
    for i, x in enumerate(id_range[:-1]):
        assert x[0] < x[1]
        assert x[1] == id_range[i + 1][0]
    assert id_range[-1][0] < id_range[-1][1]

    the_list = [i for i in range(int(id_range[-1][1]))]
    for ele in id_range:
        if int(ele[0]) <= id_index < int(ele[1]):
            for j in range(int(ele[0]), int(ele[1])):
                the_list.remove(j)
            break
    else:
        print("Something wrong, cannot get the id range.")

    return torch.tensor(np.random.choice(the_list, 1)[0])

In [8]:
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
max_length = 80  # Adjust the maximum length based on your dataset

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [82]:
def check_sentence(sentence, gold=0, location=-10000):

    assert type(sentence) == str
    assert (gold == 0 and location == -10000) or (gold == 1 and location >= 0)
    # location is the location of pun word which after spaCy spliting

    print(sentence)
    doc = nlp(sentence)

    sentence_spaCy_split = [str(tok) for sent in doc.sents for tok in sent]
    print("After spaCy splitting:")
    print(sentence_spaCy_split)

    the_dict = {}
    for i, x in enumerate(sentence_spaCy_split):
        the_dict[i] = x
    print(the_dict)

    if location >= 0:
        print("The pun word is: " + sentence_spaCy_split[location])

    print("Dependency Parsing Tree (by spaCy):")
    %matplotlib inline
    displacy.render(doc, style="dep", jupyter=True)

    # Create data loaders
    dataset = create_dataset(texts=[sentence.split()], labels=[gold],
                            tokenizer=tokenizer, max_length=max_length,
                            total_location=[location])

    # Load the the saved file
    loaded_model = torch.load(model_path).to(device)

    # Set the model to evaluation mode if you plan to use it for inference
    loaded_model.eval()

    location_loader = DataLoader(dataset, batch_size=1, shuffle=False)

    with torch.no_grad():

        for ele in location_loader:
            attention_mask = ele.attention_mask.reshape(1, max_length).to(device)
            #label = ele.y.to(device)
            #token_index = torch.clone(ele.location).to(device) # After GNN, the same as original, start with 0
            id_range = ele.id_range.to(device)
            
            scores_list = torch.zeros((int(sum(attention_mask[0])) - 2,))
            for used_token_index in range(1, int(sum(attention_mask[0])) - 1):
                # After GNN, the index is same as original, start with 0
                logits_1, logits_2 = loaded_model(ele, torch.tensor(used_token_index).reshape(1,))
                scores_list[used_token_index - 1] = float(torch.sigmoid(logits_2.view(-1))) # Get each ID scores

            print()
            print("The probability of the sentence contains a pun is:", float(torch.sigmoid(logits_1.view(-1))))

            # Matched token ID is belonged to the pun word
            for i, x in enumerate(id_range):
                if int(x[0]) <= int(torch.argmax(scores_list, dim=0)) < int(x[1]):
                    print("The predicted pun word is: " + sentence_spaCy_split[i])
                    print("The score of the predicted pun word is:", float(torch.max(scores_list)))
                    if i == location:
                        print("The predict location is correct!")
                    break

            print()
            print("Each score of each word:")
            for i, x in enumerate(id_range):
                print((str(i) + " ").ljust(3) + sentence_spaCy_split[i].ljust(15) + "\t" +
                      str(float(torch.max(scores_list[int(x[0]):int(x[1])]))) + "\tScore list: " + str(scores_list[int(x[0]):int(x[1])]))
            print()

In [83]:
check_sentence("don t trust people that do acupuncture they re back stabbers") # Pun

don t trust people that do acupuncture they re back stabbers
After spaCy splitting:
['don', 't', 'trust', 'people', 'that', 'do', 'acupuncture', 'they', 're', 'back', 'stabbers']
{0: 'don', 1: 't', 2: 'trust', 3: 'people', 4: 'that', 5: 'do', 6: 'acupuncture', 7: 'they', 8: 're', 9: 'back', 10: 'stabbers'}
Dependency Parsing Tree (by spaCy):


100%|██████████| 1/1 [00:00<00:00, 33.92it/s]



The probability of the sentence contains a pun is: 0.014523979276418686
The predicted pun word is: stabbers
The score of the predicted pun word is: 0.2131805419921875

Each score of each word:
0  don            	0.001561130746267736	Score list: tensor([0.0016])
1  t              	0.0011132418876513839	Score list: tensor([0.0011])
2  trust          	0.0010025023948401213	Score list: tensor([0.0010])
3  people         	0.0008891725447028875	Score list: tensor([0.0009])
4  that           	0.0009093284606933594	Score list: tensor([0.0009])
5  do             	0.0010441169142723083	Score list: tensor([0.0010])
6  acupuncture    	0.004450162407010794	Score list: tensor([0.0022, 0.0017, 0.0018, 0.0045])
7  they           	0.004637404344975948	Score list: tensor([0.0046])
8  re             	0.010740486904978752	Score list: tensor([0.0107])
9  back           	0.060074519366025925	Score list: tensor([0.0601])
10 stabbers       	0.2131805419921875	Score list: tensor([0.2132, 0.0527])



In [84]:
check_sentence("When the church bought gas for their annual barbecue, proceeds went from the sacred to the propane.") # Pun

When the church bought gas for their annual barbecue, proceeds went from the sacred to the propane.
After spaCy splitting:
['When', 'the', 'church', 'bought', 'gas', 'for', 'their', 'annual', 'barbecue', ',', 'proceeds', 'went', 'from', 'the', 'sacred', 'to', 'the', 'propane', '.']
{0: 'When', 1: 'the', 2: 'church', 3: 'bought', 4: 'gas', 5: 'for', 6: 'their', 7: 'annual', 8: 'barbecue', 9: ',', 10: 'proceeds', 11: 'went', 12: 'from', 13: 'the', 14: 'sacred', 15: 'to', 16: 'the', 17: 'propane', 18: '.'}
Dependency Parsing Tree (by spaCy):


100%|██████████| 1/1 [00:00<00:00, 27.64it/s]



The probability of the sentence contains a pun is: 0.9964553117752075
The predicted pun word is: propane
The score of the predicted pun word is: 0.9759582281112671

Each score of each word:
0  When           	0.020838286727666855	Score list: tensor([0.0208])
1  the            	0.015309997834265232	Score list: tensor([0.0153])
2  church         	0.013559304177761078	Score list: tensor([0.0136])
3  bought         	0.013483976945281029	Score list: tensor([0.0135])
4  gas            	0.013726837933063507	Score list: tensor([0.0137])
5  for            	0.014190418645739555	Score list: tensor([0.0142])
6  their          	0.014400123618543148	Score list: tensor([0.0144])
7  annual         	0.014814143069088459	Score list: tensor([0.0148])
8  barbecue       	0.016122013330459595	Score list: tensor([0.0161])
9  ,              	0.01772751286625862	Score list: tensor([0.0177])
10 proceeds       	0.021320126950740814	Score list: tensor([0.0213])
11 went           	0.029796333983540535	Score list:

In [85]:
check_sentence("when the cannibal showed up late to the luncheon they gave him the cold shoulder") # Pun

when the cannibal showed up late to the luncheon they gave him the cold shoulder
After spaCy splitting:
['when', 'the', 'cannibal', 'showed', 'up', 'late', 'to', 'the', 'luncheon', 'they', 'gave', 'him', 'the', 'cold', 'shoulder']
{0: 'when', 1: 'the', 2: 'cannibal', 3: 'showed', 4: 'up', 5: 'late', 6: 'to', 7: 'the', 8: 'luncheon', 9: 'they', 10: 'gave', 11: 'him', 12: 'the', 13: 'cold', 14: 'shoulder'}
Dependency Parsing Tree (by spaCy):


100%|██████████| 1/1 [00:00<00:00, 59.36it/s]



The probability of the sentence contains a pun is: 0.9949048757553101
The predicted pun word is: cold
The score of the predicted pun word is: 0.9702832102775574

Each score of each word:
0  when           	0.020784731954336166	Score list: tensor([0.0208])
1  the            	0.015479343011975288	Score list: tensor([0.0155])
2  cannibal       	0.017805064097046852	Score list: tensor([0.0178, 0.0147, 0.0142])
3  showed         	0.013679572381079197	Score list: tensor([0.0137])
4  up             	0.013883010484278202	Score list: tensor([0.0139])
5  late           	0.013987056910991669	Score list: tensor([0.0140])
6  to             	0.014320753514766693	Score list: tensor([0.0143])
7  the            	0.014724026434123516	Score list: tensor([0.0147])
8  luncheon       	0.01625022478401661	Score list: tensor([0.0154, 0.0163])
9  they           	0.018966130912303925	Score list: tensor([0.0190])
10 gave           	0.0279456228017807	Score list: tensor([0.0279])
11 him            	0.05927911773

In [86]:
check_sentence("I used to be a banker but I lost interest.") # Pun

I used to be a banker but I lost interest.
After spaCy splitting:
['I', 'used', 'to', 'be', 'a', 'banker', 'but', 'I', 'lost', 'interest', '.']
{0: 'I', 1: 'used', 2: 'to', 3: 'be', 4: 'a', 5: 'banker', 6: 'but', 7: 'I', 8: 'lost', 9: 'interest', 10: '.'}
Dependency Parsing Tree (by spaCy):


100%|██████████| 1/1 [00:00<00:00, 38.11it/s]



The probability of the sentence contains a pun is: 0.9934548735618591
The predicted pun word is: lost
The score of the predicted pun word is: 0.9751361608505249

Each score of each word:
0  I              	0.0265132337808609	Score list: tensor([0.0265])
1  used           	0.020817534998059273	Score list: tensor([0.0208])
2  to             	0.02100871503353119	Score list: tensor([0.0210])
3  be             	0.024932967498898506	Score list: tensor([0.0249])
4  a              	0.03860398381948471	Score list: tensor([0.0386])
5  banker         	0.20281043648719788	Score list: tensor([0.2028])
6  but            	0.12722423672676086	Score list: tensor([0.1272])
7  I              	0.3514883816242218	Score list: tensor([0.3515])
8  lost           	0.9751361608505249	Score list: tensor([0.9751])
9  interest       	0.9647613763809204	Score list: tensor([0.9648])
10 .              	0.42225807905197144	Score list: tensor([0.4223])



In [87]:
check_sentence("I believe there s a lot of thinking about that time and America in general") # Not pun

I believe there s a lot of thinking about that time and America in general
After spaCy splitting:
['I', 'believe', 'there', 's', 'a', 'lot', 'of', 'thinking', 'about', 'that', 'time', 'and', 'America', 'in', 'general']
{0: 'I', 1: 'believe', 2: 'there', 3: 's', 4: 'a', 5: 'lot', 6: 'of', 7: 'thinking', 8: 'about', 9: 'that', 10: 'time', 11: 'and', 12: 'America', 13: 'in', 14: 'general'}
Dependency Parsing Tree (by spaCy):


100%|██████████| 1/1 [00:00<00:00, 61.43it/s]



The probability of the sentence contains a pun is: 0.01039062812924385
The predicted pun word is: I
The score of the predicted pun word is: 0.0004786340577993542

Each score of each word:
0  I              	0.0004786340577993542	Score list: tensor([0.0005])
1  believe        	0.00034323340514674783	Score list: tensor([0.0003])
2  there          	0.0002874478814192116	Score list: tensor([0.0003])
3  s              	0.00026404240634292364	Score list: tensor([0.0003])
4  a              	0.0002494723885320127	Score list: tensor([0.0002])
5  lot            	0.0002454558270983398	Score list: tensor([0.0002])
6  of             	0.00024277826014440507	Score list: tensor([0.0002])
7  thinking       	0.00024333490000572056	Score list: tensor([0.0002])
8  about          	0.00024578554439358413	Score list: tensor([0.0002])
9  that           	0.0002496751258149743	Score list: tensor([0.0002])
10 time           	0.00025437239673919976	Score list: tensor([0.0003])
11 and            	0.00026249221991

In [88]:
check_sentence("All but one of the workers were back on the job several hours later") # Not pun

All but one of the workers were back on the job several hours later
After spaCy splitting:
['All', 'but', 'one', 'of', 'the', 'workers', 'were', 'back', 'on', 'the', 'job', 'several', 'hours', 'later']
{0: 'All', 1: 'but', 2: 'one', 3: 'of', 4: 'the', 5: 'workers', 6: 'were', 7: 'back', 8: 'on', 9: 'the', 10: 'job', 11: 'several', 12: 'hours', 13: 'later'}
Dependency Parsing Tree (by spaCy):


100%|██████████| 1/1 [00:00<00:00, 48.75it/s]



The probability of the sentence contains a pun is: 0.011249003931879997
The predicted pun word is: All
The score of the predicted pun word is: 0.000453153537819162

Each score of each word:
0  All            	0.000453153537819162	Score list: tensor([0.0005])
1  but            	0.00032360796467401087	Score list: tensor([0.0003])
2  one            	0.00027077022241428494	Score list: tensor([0.0003])
3  of             	0.00024628263781778514	Score list: tensor([0.0002])
4  the            	0.000235644169151783	Score list: tensor([0.0002])
5  workers        	0.00023227986821439117	Score list: tensor([0.0002])
6  were           	0.00022912234999239445	Score list: tensor([0.0002])
7  back           	0.00023070459428709	Score list: tensor([0.0002])
8  on             	0.00023233875981532037	Score list: tensor([0.0002])
9  the            	0.00023659826547373086	Score list: tensor([0.0002])
10 job            	0.0002476967347320169	Score list: tensor([0.0002])
11 several        	0.000260876491665