In [1]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification

tokenizer =  AutoTokenizer.from_pretrained('../../models/mbert/tokenizer')
full_size_model = AutoModelForSequenceClassification.from_pretrained('../../models/mbert/normalCase/fullSize/2022-06-16_16-57-36/model', num_labels=3)
full_size_lower_model = AutoModelForSequenceClassification.from_pretrained('../../models/mbert/lowerCase/fullSize/2022-07-04_14-12-24/model', num_labels=3)

for param in full_size_model.parameters():
    param.requires_grad_(False)


In [3]:
import re
import pandas as pd

In [19]:
import torch.nn.functional as F
@torch.no_grad()
def sentenceCategoryMbertVersion(text: str, model) -> int:
    tokenized_text = tokenizer(text, padding="longest", truncation=True, return_tensors='pt')
    prediction = model(input_ids=tokenized_text["input_ids"], attention_mask=tokenized_text["attention_mask"], token_type_ids=tokenized_text["token_type_ids"])
    return [f'{r*100:.2f}%' for r in F.softmax(prediction.logits, dim=-1).detach().cpu().numpy()[0]]

phrase = "Ko ngā ērā This is a trial Ko ngā ērā This is a trial Ko ngā ērā This is a trial"
print(sentenceCategoryMbertVersion(phrase, full_size_model))
print(sentenceCategoryMbertVersion(phrase.lower(), full_size_lower_model))

['100.00%', '0.00%', '0.00%']
['100.00%', '0.00%', '0.00%']


In [4]:
@torch.no_grad()
def sentenceCategoryMbertVersion(text: str, model) -> int:
    tokenized_text = tokenizer(text, padding="longest", truncation=True, return_tensors='pt')
    prediction = model(input_ids=tokenized_text["input_ids"], attention_mask=tokenized_text["attention_mask"], token_type_ids=tokenized_text["token_type_ids"])
    return prediction.logits.detach().cpu().numpy().argmax()

def detectCodeSwitchingPointMbertVersion(x: str, w: int, model) -> list():
    wordsList = x.split()
    end = len(wordsList)
    if w >= end and end > 2:
        w = end - 1
    elif end == 1:
        w = 1
    elif end == 2:
        w = 2
    else:
        pass

    if end < 1:
        return []

    elif end == 1:
        if re.search(u'[āēīōūĀĒĪŌŪ]', x):
            return [1]
        elif re.search(u'[bBcCdDfFgGjJlLqQsSvVxXyYzZ]', x):
            return [2]
        else:
            return [sentenceCategoryMbertVersion(x, model)]

    elif end == 2:
        if not re.search(u'[āēīōūĀĒĪŌŪ]', x):
            if sentenceCategoryMbertVersion(x, model) == 1 and not re.search(u'[bBcCdDfFgGjJlLqQsSvVxXyYzZ]', x):
                return [1, 1]
            elif sentenceCategoryMbertVersion(x, model) == 2:
                return [2, 2]
            else:
                if sentenceCategoryMbertVersion(wordsList[0], model) == 1 and not re.search(u'[bBcCdDfFgGjJlLqQsSvVxXyYzZ]', wordsList[0]):
                    return [1, 2]
                else:
                    return [2, 1]
        else:
            if re.search(u'[āēīōūĀĒĪŌŪ]', wordsList[0]) and re.search(u'[āēīōūĀĒĪŌŪ]', wordsList[1]):
                return [1, 1]
            if re.search(u'[āēīōūĀĒĪŌŪ]', wordsList[0]) and not re.search(u'[āēīōūĀĒĪŌŪ]', wordsList[1]):
                return [1, 2]
            else:
                return [2, 1]
    
    else:
        result = []
        ptr = 0
        while ptr < end:
            thisWindow = wordsList[ptr:ptr+w]
            if ptr + w > end:
                w = end - ptr
            else:
                pass
            if sentenceCategoryMbertVersion(" ".join(thisWindow), model) == 1 and not re.search(u'[bBcCdDfFgGjJlLqQsSvVxXyYzZ]', " ".join(thisWindow)):
                result.extend([1 for _ in range(w)])
            elif sentenceCategoryMbertVersion(" ".join(thisWindow), model) == 2 and not re.search(u'[āēīōūĀĒĪŌŪ]', " ".join(thisWindow)):
                result += [2 for _ in range(w)]
            else:
                if w >= 4 and w % 2 == 0:
                    result += detectCodeSwitchingPointMbertVersion(" ".join(thisWindow), w-2, model)
                else:
                    result += detectCodeSwitchingPointMbertVersion(" ".join(thisWindow), w-1, model)
            ptr += w
        return result

In [5]:
phrase = "Ko ngā ērā This is a trial Ko ngā ērā This is a trial Ko ngā ērā This is a trial"
print(detectCodeSwitchingPointMbertVersion(phrase, 250, full_size_model))
print(detectCodeSwitchingPointMbertVersion(phrase.lower(), 250, full_size_lower_model))

[1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2]
[1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2]


In [6]:
df = pd.read_csv("../../small_data.csv")

df = df.replace({'Labels_Final': {'P': 2, 'M': 1, 'B':0}})
df['Labels_Final'] = df['Labels_Final'].astype(int)
df.head()

Unnamed: 0,id,number,text,label,Labels_Final
0,H20031118,36,Will the Tertiary Education Commission be meas...,"P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,M,P,P,P,P,P",0
1,H20031118,54,What progress is being made on Treaty of Waita...,"P,P,P,P,P,P,P,P,M,P",0
2,H20031118,59,We will also be shortly signing another deed o...,"P,P,P,P,P,P,P,P,P,P,P,M,M",0
3,H20031118,71,"The Office of Treaty Settlements, with Te Puni...","P,P,P,P,P,P,M,M,M,P,P,P,P,P,P,P,P,P,P,P,P,P,P,...",0
4,H20031118,74,When will the Minister undertake a comprehensi...,"P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,P,...",0


In [8]:
# Test full size model
for window in range(2, 10):
    sentence_label_error = 0
    word_label_error = 0

    for ind, row in df.iterrows():
        x = row['text']
        l = row['Labels_Final']
        lw = row['label']
        ly = lw.split(",")
        ly = [item.replace("P", "2") for item in ly]
        ly = [item.replace("M", "1") for item in ly]

        for i, j in zip(detectCodeSwitchingPointMbertVersion(x, window, full_size_model), ly):
            if i != int(j):
                word_label_error += 1
                # break
            
    total_words = df['text'].apply(lambda x: len(str(x).split(' '))).sum()

    print(" ")
    print("------------------------------------------")
    print("Model: full size")
    print("Window size: ", window)
    print("Total sentence label error", sentence_label_error)
    print("Total number of words",  total_words)
    print("Total word label error in bilingual sentences", word_label_error)

# Test full size lower model
for window in range(2, 10):
    sentence_label_error = 0
    word_label_error = 0

    for ind, row in df.iterrows():
        x = row['text'].lower()
        l = row['Labels_Final']
        lw = row['label']
        ly = lw.split(",")
        ly = [item.replace("P", "2") for item in ly]
        ly = [item.replace("M", "1") for item in ly]

        for i, j in zip(detectCodeSwitchingPointMbertVersion(x, window, full_size_lower_model), ly):
            if i != int(j):
                word_label_error += 1
                # break
            
    total_words = df['text'].apply(lambda x: len(str(x).split(' '))).sum()

    print(" ")
    print("------------------------------------------")
    print("Model: full size lower")
    print("Window size: ", window)
    print("Total sentence label error", sentence_label_error)
    print("Total number of words",  total_words)
    print("Total word label error in bilingual sentences", word_label_error)

# # Test size 2 model
# sentence_label_error = 0
# word_label_error = 0

# for ind, row in df.iterrows():
#     x = row['text']
#     l = row['Labels_Final']
#     lw = row['label']
#     ly = lw.split(",")
#     ly = [item.replace("P", "2") for item in ly]
#     ly = [item.replace("M", "1") for item in ly]

#     for i, j in zip(detectCodeSwitchingPointMbertVersion(x, 2, size_2_model), ly):
#         if i != int(j):
#             word_label_error += 1
#             # break
        
# total_words = df['text'].apply(lambda x: len(str(x).split(' '))).sum()

# print(" ")
# print("------------------------------------------")
# print("Model: size 2")
# print("Window size: ", 2)
# print("Total sentence label error", sentence_label_error)
# print("Total number of words",  total_words)
# print("Total word label error in bilingual sentences", word_label_error)

# # Test size 3 model
# sentence_label_error = 0
# word_label_error = 0

# for ind, row in df.iterrows():
#     x = row['text']
#     l = row['Labels_Final']
#     lw = row['label']
#     ly = lw.split(",")
#     ly = [item.replace("P", "2") for item in ly]
#     ly = [item.replace("M", "1") for item in ly]

#     for i, j in zip(detectCodeSwitchingPointMbertVersion(x, 2, size_3_model), ly):
#         if i != int(j):
#             word_label_error += 1
#             # break
        
# total_words = df['text'].apply(lambda x: len(str(x).split(' '))).sum()

# print(" ")
# print("------------------------------------------")
# print("Model: size 3")
# print("Window size: ", 3)
# print("Total sentence label error", sentence_label_error)
# print("Total number of words",  total_words)
# print("Total word label error in bilingual sentences", word_label_error)

 
------------------------------------------
Model: full size
Window size:  2
Total sentence label error 0
Total number of words 2872
Total word label error in bilingual sentences 13
 
------------------------------------------
Model: full size
Window size:  3
Total sentence label error 0
Total number of words 2872
Total word label error in bilingual sentences 20
 
------------------------------------------
Model: full size
Window size:  4
Total sentence label error 0
Total number of words 2872
Total word label error in bilingual sentences 12
 
------------------------------------------
Model: full size
Window size:  5
Total sentence label error 0
Total number of words 2872
Total word label error in bilingual sentences 17
 
------------------------------------------
Model: full size
Window size:  6
Total sentence label error 0
Total number of words 2872
Total word label error in bilingual sentences 13
 
------------------------------------------
Model: full size
Window size:  7
Total s