In [7]:
from bs4 import BeautifulSoup
import pickle
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
import torch
from torch import nn
from transformers import RobertaModel, RobertaTokenizer
import spacy

In [8]:
class RobertaForSequenceClassification(nn.Module):
    def __init__(self, num_labels=3):
        super(RobertaForSequenceClassification, self).__init__()
        self.roberta_single = RobertaModel.from_pretrained("roberta-large")
        self.single_hidden2tag = RobertaClassificationHead(1024, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.roberta_single(input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        logits = self.single_hidden2tag(pooled_output)
        return logits


class RobertaClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(RobertaClassificationHead, self).__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(0.1)
        self.out_proj = nn.Linear(hidden_size, num_labels)

    def forward(self, x):
        x = self.dropout(x)
        x = torch.tanh(self.dense(x))
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


def remove_html_tags_bs4(text):
    soup = BeautifulSoup(text, "html.parser")
    return soup.get_text()


def genearate_summary(text):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model_name = "facebook/bart-large-cnn"
    tokenizer = BartTokenizer.from_pretrained(model_name)
    model = BartForConditionalGeneration.from_pretrained(model_name).to(device)

    def summarize_text(text, max_length=130, min_length=30, length_penalty=2.0):
        inputs = tokenizer(
            text, max_length=1024, truncation=True, return_tensors="pt"
        ).to(device)
        summary_ids = model.generate(
            inputs["input_ids"],
            max_length=max_length,
            min_length=min_length,
            length_penalty=length_penalty,
            num_beams=4,
            early_stopping=True,
        )
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return summary

    return summarize_text(remove_html_tags_bs4(text))


def check_entailment(tokenizer, model, premise, hypothesis):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = tokenizer(
        premise,
        hypothesis,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512,
    ).to(device)

    with torch.no_grad():
        logits = model(
            input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
        )

    return logits[0]


def split_sentences(text):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp(text)
    sentences = [sent.text for sent in doc.sents]
    return sentences

In [6]:
nli_model_file = "/content/drive/MyDrive/ICSE2023/DocNLI.pretrained.RoBERTA.model.pt"
answer_body = "This should be replaced with SO post with html tags."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
model = RobertaForSequenceClassification(num_labels=2)
model.to(device)

state_dict = torch.load(
    nli_model_file,
    map_location="cpu",
)
state_dict.pop("roberta_single.embeddings.position_ids", None)
model.load_state_dict(state_dict, strict=False)
model.eval()

summary = genearate_summary(answer_body)
for i in split_sentences(remove_html_tags_bs4(answer_body)):
    print("For sentence:", i)
    print(
        f"The possibility of it being important is: {check_entailment(tokenizer, model, summary, i)}"
    )

Using device: cpu


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


{'entailment': 0.0003622408548835665, 'neutral': 0.002163344295695424, 'contradiction': 0.9974743723869324}


In [5]:
genearate_summary(answer_body)

Using device: cpu


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


KeyboardInterrupt: 