In [1]:
import inseq
import pandas as pd
import warnings
import byt5_model
import torch

warnings.filterwarnings("ignore")

ModuleNotFoundError: No module named 'inseq'

In [None]:
!pwd
!pip3 install inseq

In [None]:
import platform 
platform.python_version()

In [None]:
class InseqAttributer:
    def __init__(self, model="./models/test/", attribution_method="input_x_gradient") -> None:
        """Create an object that loads a given model through Inseq with a given attribution method"""
        self.model = inseq.load_model(model, attribution_method)
    
    def attribute(self, inp:str, out:str=None):
        """Use the Inseq model to generate feature attributions using a given input or input & output"""
        if inp and out:
            inp_out = (inp, out)
        else:
            inp_out = (inp,)

        inseq_out = self.model.attribute(*inp_out,
                                         attribute_target=True,
                                         step_scores=["probability"]
                                        )
        inseq_out.show()


def predict_on_data(data, model_config_path, spaces=True, device='auto'):
    if device == 'auto':
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    elif device == 'gpu' or device == 'cuda':
        device = torch.device("cuda")
    elif device == 'cpu':
        device = torch.device("cpu")
    else:
        print('Set device to auto, cpu, or gpu/cuda.')
        return


    if spaces:
        data['inputs'] = data['lemma'] + ' ' + data['features']
    else:
        data['inputs'] = data['lemma'] + data['features']


    data_gen_comparison = byt5_model.comparer(data, byt5_model.generate(model_config_path, data, device))
    data = pd.concat([data, data_gen_comparison], axis=1).rename(columns={'labels': 'plural'})

    correct = []
    for idx, row in data.iterrows():
        if row['Expected'] == row['Predicted']:
            correct.append('correct')
        else:
            correct.append('incorrect')

    data = data.assign(correct=correct)

    return (data[data.correct == 'incorrect'], data[data.correct == 'correct'])

def construct_character_spans(elements):
    spans = []
    for idx, char in enumerate(elements):
        # The capital A with accents denote the start of a two-bytes wide character, for turkish
#         These are the characters: ÇĞİÖŞÜçğıöşü
        if char.token in 'ÃÄÅ':
            spans.append((idx, idx+2))
    
    return spans


def inseq_on_predictions(samples, inseq_model, expected=False):
    
    count = 0
    for idx, sample in samples.iterrows():
        count += 1
        if count > 1:
            print('\n')
        print('='*80)

        inp = sample['lemma'] + ' ' + sample['features']
        if expected:
            print('Testing Expected (gold) outcome (what would the model look at to get to the expected outcome)')
            output_type = 'gold'
            out = sample['Expected']
        else:
            print('Testing Predicted (pred) outcome (what did the model take from the input during generation)')
            output_type = 'pred'
            out = sample['Predicted']

        print('Sample: {}\nIndex: {}\nThis prediction is {}\nInput -> output ({}): \'{}\' -> \'{}\''.format(
             count, idx, sample['correct'], output_type, inp, out
        ))
        if output_type == 'pred':
            print(f"The expected or gold output would be '{sample['Expected']}'")
            contrast = inseq_model.model.encode(sample['Expected'], as_targets=True)


        print('='*80)
        
        # Figure out where the plural suffix is in the word. We only need to attribute from there on out.
        plural_suffix_start_index = max([-1] + [out.find(plural_suffix) for plural_suffix in ['ler', 'lar'] if plural_suffix in out])

        # If aggregation on the target is disabled, the offset should be not the length of the string, but the byte-length.
        # For the attribution start and end, the offset should be compute as byte-length, instead of utf-8 string length
#         plural_suffix_start_index = len(out[:plural_suffix_start_index].encode('utf-8'))
#         print(f"{out=}, {plural_suffix_start_index=}")

        # These suffixes are always 3 characters long ('ler' or 'lar')
        plural_suffix_end_index = plural_suffix_start_index + 3

        if plural_suffix_start_index > 0:  # The suffix is not always present in the output
            out = inseq_model.model.attribute(
                inp,
                out,
                show_progress=False,
#                 attr_pos_start=plural_suffix_start_index,  # These are unfortunately disabled, as they seem to interfere with the aggregation
#                 attr_pos_end=plural_suffix_end_index,    # This causes conflicts with the aggregation.
                attribute_target=True,
                step_scores=["probability"],
            )
        else:
            out = inseq_model.model.attribute(
                inp,
                out,
                show_progress=False,
                attribute_target=True,
                step_scores=["probability"],
            )
        
        # Construct source and target spans to aggregate characters that are two bytes wide
        source_spans = construct_character_spans(out.sequence_attributions[0].source) or None
        target_spans = construct_character_spans(out.sequence_attributions[0].target) or None
        print(source_spans, target_spans)
        
        if source_spans or target_spans:
#         if source_spans:
            out=inseq.data.aggregator.ContiguousSpanAggregator.aggregate(out.sequence_attributions[0], 
                                                                    source_spans=source_spans, 
                                                                    target_spans=target_spans)
        out.show()
            

## Turkish

Focus on relation between last vowel of the stem and the first vowel of the suffix. We expect the last vowel to be salient in predicting the first vowel of the suffix. We focus on plural nouns.

In [None]:
model_tur_finetuned = "./drive-symlink/NLP_project_morphological_inflection/finetuned_tur_3"
inseq_tur_finetuned = InseqAttributer(model_tur_finetuned, "input_x_gradient")

In [None]:
# sample_runner(data_sampler('./data/tur.gold', model_tur_finetuned, device='cpu'), inseq_tur_finetuned, expected=True)

In [None]:
header_names = ['lemma', 'labels', 'features']

turkish_train = pd.read_csv('./data/tur_large.train', sep='\t', names=header_names)

In [None]:
turkish_train.head()

In [None]:
filtered_turkish = turkish_train[
    turkish_train['features'].str.startswith('N;')  # Ensure it is a noun
  & turkish_train['features'].str.contains('PL;')]  # And plural
print(filtered_turkish.size)
filtered_turkish.head()

In [None]:
incorrect_turkish, correct_turkish = predict_on_data(filtered_turkish.head(40), model_tur_finetuned)
print(f"Number of incorrect inflections: {len(incorrect_turkish)}, and correct: {len(correct_turkish)}")

In [None]:
incorrect_attributions = inseq_on_predictions(incorrect_turkish, inseq_tur_finetuned, expected=False)

In [None]:
correct_attributions = inseq_on_predictions(correct_turkish, inseq_tur_finetuned, expected=False)