#preamble

In [19]:
# Choose python version (3.8.16 should be fine normally)
# https://www.datasciencelearner.com/change-python-version-in-google-colab-steps/
!python --version
#!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1
#!sudo update-alternatives --config python3
#!python3 --version
#!sudo apt install python3-pip

Python 3.10.12


In [20]:
%%capture
!pip install git+https://github.com/huggingface/transformers.git
#!pip install transformers==4.7.0
!pip install tokenizers==0.10.3

In [21]:
import torch
import transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np

In [22]:
! nvidia-smi

Fri Jun 30 17:57:54 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P0    30W /  70W |    645MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [23]:
# Check GPU is available and libraries version
print('Pytorch version ...............{}'.format(torch.__version__))
print('Transformers version ..........{}'.format(transformers.__version__))
print('GPU available .................{}'.format('\u2705' if torch.cuda.device_count() > 0 else '\u274c'))
print('Available devices .............{}'.format(torch.cuda.device_count()))
print('Active CUDA Device: ...........{}'.format(torch.cuda.current_device()))
print('Current cuda device: ..........{}'.format(torch.cuda.current_device()))

Pytorch version ...............2.0.1+cu118
Transformers version ..........4.31.0.dev0
GPU available .................✅
Available devices .............1
Active CUDA Device: ...........0
Current cuda device: ..........0


# Load the pre-trained model and the tokenizer

In [24]:
# Query GPU memory used before loading the model.
if torch.cuda.is_available():
  device = torch.device('cuda:0')
else:
  device = torch.device('cpu')
#memory_used_s = !nvidia-smi --query-gpu=memory.used --format=csv | grep ' MiB'
#memory_used_s = int(memory_used_s[0][:-4])

In [25]:
# Load pretrained model and tokenizer.
# The model will be downloaded from HuggingFace hub and cached.
# It may take ~5 minutes for the first excecution.

model = GPT2LMHeadModel.from_pretrained("asi/gpt-fr-cased-small").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("asi/gpt-fr-cased-small")
tokenizer.add_special_tokens({
  "eos_token": "</s>",
  "bos_token": "<s>",
  "unk_token": "<unk>",
  "pad_token": "<pad>",
  "mask_token": "<mask>"
})

0

In [27]:
# Query GPU memory used after loading the model.
#memory_used_e = !nvidia-smi --query-gpu=memory.used --format=csv | grep ' MiB'
#memory_used_e = int(memory_used_e[0][:-4])
#print("Model loaded in GPU memory and uses {:.2f} Go GPU RAM.".format(float(memory_used_e - memory_used_s)/1024))

In [28]:
# Check number of parameters.
print("Model has {:,} parameters.".format(model.num_parameters()))

Model has 124,242,432 parameters.


# Load fine-tuned model
this will change model from pre-trianed model to fine-tuned model

In [30]:
from google.colab import drive
drive.mount('/content/drive')

path = "/content/drive/MyDrive/gpt2_train/"

model = GPT2LMHeadModel.from_pretrained(path+'saved_model_spk_line_break_epoch30/').to(device)

Mounted at /content/drive


#Probability

## load sd dataset

### ->load Tree and Turn classes

In [31]:
class DepTree:
    '''all indexes of edges are modified so that they can start from 0
    e.g. 2 est 0 root -> 1 est -1 root,
         3 -ce 2 suj  -> 2 -ce  1  suj
         0: [(0, 'root', 2)] -> -1: [(-1, 'root', 1)'''
    ROOT_TOKEN = "<root>"

    def __init__(self, edges, wordlist=None, with_root=False, mwe_range=None):

        self.gov2dep = {}
        self.has_gov = set()  # set of nodes with a governor

        for (gov, label, dep) in edges:
            self.add_arc(gov, label, dep)

        if with_root:
            self.add_root()

        if wordlist is None:
            wordlist = []
        self.words =  wordlist + [DepTree.ROOT_TOKEN]
        self.mwe_ranges = [] if mwe_range is None else mwe_range

        self.dep2gov = {dep: (label, gov)
                   for node in self.gov2dep
                   for (gov, label, dep) in self.gov2dep[node]}

    def is_leaf(self,node):
        """
        Indicates if the node is leaf or not
        """
        return node not in self.gov2dep

    def get_all_edges(self):
        """
        Returns the list of edges found in this graph
        """
        return [edge for gov in self.gov2dep for edge in self.gov2dep[gov]]

    def get_all_labels(self):
        """
        Returns the list of dependency labels found on the arcs
        """
        all_labels = []
        for gov in self.gov2dep:
            all_labels.extend([label for (gov, label, dep) in self.gov2dep[gov]])
        return all_labels

    def get_dep_arcs(self,gov):
        """
        Returns all arcs governed by gov. An arc is a triple (gov,label,dep)
        """
        return self.gov2dep.get(gov,[])

    def get_arc(self, gov, dep):
        """
        Returns the arc between gov and dep if it exists or None otherwise
        Args:
            gov (int): node idx
            dep (int): node idx
        Returns:
            A triple (gov,label,dep) or None.
        """
        if gov in self.gov2dep:
            for (_gov, deplabel, _dep) in self.gov2dep[gov]:
                if _dep == dep:
                    return (_gov, deplabel, _dep)
        return None

    def add_root(self):
        '''
        <root> is at the -1 position of wordlist
        '''
        if self.gov2dep and -1 not in self.gov2dep: # the root has not been added to gov2dep
            root = list(set(self.gov2dep) - self.has_gov) # index of root
            if len(root) == 1:
                self.add_arc(-1, "root", root[0])
            else:
                assert False  # no single root... problem.
        elif not self.gov2dep:  # single word sentence
            self.add_arc(-1, "root", 0)

    def add_arc(self, gov, label, dep):
        """
        Adds an arc to the dep graph
        """
        if gov in self.gov2dep:
            self.gov2dep[gov].append((gov, label, dep))
        else:
            self.gov2dep[gov] = [(gov, label, dep)]

        self.has_gov.add(dep)

    def is_cyclic_add(self, gov, dep):
        """
        Checks if the addition of an arc from gov to dep would create
        a cycle in the dep tree
        """
        return gov in self.span(dep)

    def _gap_degree(self, node):
        """
        Returns the gap degree of a node
        Args :
            node (int): a dep tree node
        """
        # not adapted for id starting from 0
        nspan = list(self.span(node))
        nspan.sort()
        gd = 0
        for idx in range(len(nspan)):
            if idx > 0:
                if nspan[idx] - nspan[idx - 1] > 1:
                    gd += 1
        return gd

    def gap_degree(self):
        """
        Returns the gap degree of a tree (suboptimal)
        """
        return max(self._gap_degree(node) for node in self.gov2dep)

    def is_projective(self):
        """
        Returns true if this tree is projective
        """
        # not adapted for id starting from 0
        return self.gap_degree() == 0

    def span(self, gov):
        """
        Returns the list of nodes in the yield of this node
        the set of j such that (i -*> j).
        """
        agenda = [gov]
        closure = set([gov])
        while agenda:
            node = agenda.pop()
            succ = (
                [dep for (gov, label, dep) in self.gov2dep[node]]
                if node in self.gov2dep
                else []
            )
            agenda.extend([node for node in succ if node not in closure])
            closure.update(succ)
        return closure

    def is_dag_add(self, gov, dep):
        """
        Checks if the addition of an arc from gov to dep would create
        a Dag
        """
        return dep in self.has_gov


    def __str__(self):
        """
        Conll string for the dep tree
        """
        lines = []
        revdeps = [
            (dep, (label, gov))
            for node in self.gov2dep
            for (gov, label, dep) in self.gov2dep[node]
        ]
        revdeps = dict(revdeps)
        for node in range(0, len(self.words)-1): # node start from 0
            L = ["_"] * 10
            L[0] = str(node)
            L[1] = self.words[node] #self.words: [<root>,w1,w2...]
            label, head = revdeps[node] if node in revdeps else ("root", -1)
            L[6] = str(head)
            L[7] = label
            mwe_list = [
                (left, right, word)
                for (left, right, word) in self.mwe_ranges
                if left == L[0]
            ]
            for mwe in mwe_list:
                MWE = ["_"] * 10
                MWE[0] = "-".join(mwe[:2])
                MWE[1] = mwe[2]
                lines.append("\t".join(MWE))
            lines.append("\t".join(L))
        return "\n".join(lines)

    def __len__(self):
        return len(self.words)

In [32]:
class Turn:
    '''read one turn'''
    def __init__(self,doc_id,sent_id,speaker_id,interval,overlapping,raw_text,conll):
        self.doc_id = doc_id
        self.sent_id = int(sent_id)
        self.speaker_id = speaker_id
        self.interval = interval
        self.overlapping = overlapping
        self.raw_text = raw_text
        self.conll = conll

        self.mwe_ranges = []
        # parse conll only if it is a real speaker turn with at least a real word
        if self.conll and not re.match(r'(com)|(silence)', self.speaker_id):
            self.data = self.read_conll()
            self.tree = DepTree(self.data['edges'],self.data['words'],with_root=True,mwe_range=self.mwe_ranges)

    @staticmethod
    def read_turn(istream,doc_id):
        '''Read a turn from input stream'''
        conll = []
        non_words = []
        line = istream.readline()
        if line.strip() == "": # the last line
            return None
        while istream and line.isspace(): # ignore the first line of a turn composed by space
            line = istream.readline()
        while istream and not line.strip() == "": # stop when we meet the beginning of the next turn
            if re.match("#", line):
                name,value = line.lower().split('=',1)
                name = name.strip()
                if name == "# newdoc id":
                    doc_id = value.strip()
                elif name == "# sent_id":
                    sent_id = value.strip()
                elif name == "# interval":
                    interval = value.strip()
                elif name == "# speaker id":
                    speaker_id = value.strip()
                elif name == "# overlapping":
                    overlapping = value.strip()
                elif name == "# text":
                    raw_text = value.strip()
            else:
                if re.match(r"[0-9]+\t", line): # a real word
                    conll.append(line.strip().split("\t"))
                elif re.match(r"[0-9]+\.[0-9]+\t", line):
                    non_words.append(line.strip().split("\t"))
            line = istream.readline()

        return Turn(doc_id,sent_id,speaker_id,interval,overlapping,raw_text,conll)

    def read_conll(self):
        ''' conll contains real words (ignore not empty word)'''
        words = []
        lemmas = []
        postags = []
        postags_feats = []
        edges = []
        misc = []

        for dataline in self.conll:
            if len(dataline) < 10: # pad dataline
                dataline.extend(["_"] * (10-len(dataline)))
            if "-" in dataline[0]: # for mwt, -1 for all word indexes
                self.mwe_ranges.append([(i-1, j-1) for (i,j) in dataline[0].split("-")] + [dataline[1]])
                continue
            else:
                words.append(dataline[1])
                lemmas.append(dataline[2])
                postags.append(dataline[3])
                postags_feats.append(dataline[5])
                misc.append(dataline[9])

                # -1 for all word idx in tree so that all idx start from 1
                if dataline[6] != "0":  # do not add root immediately
                    edges.append( (int(dataline[6])-1, dataline[7], int(dataline[0])-1)  )

        return dict(zip(['words','lemmas','postags','postags_feats','edges','misc'],
                        [words,lemmas,postags,postags_feats,edges,misc]))

    def update_misc(self,i,new_info):
        '''update the information colomn of the word i
        misc: dict or _
        new_info: dict'''
        if self.data['misc'][i] == '_':
            self.data['misc'][i] = new_info
        else:
            self.data['misc'][i].update(new_info)

    def __str__(self):
        '''print conll of a turn'''
        #lines = ['\t'.join(line) for line in self.conll]
        turn_printed = self.doc_id + '\t' + str(self.sent_id) + '\t' + self.speaker_id + '\n' + self.raw_text
        return turn_printed

### ->load speaker_turns_content

In [45]:
import re

# load the data and get speaker turn with true words
corpus_file = '/content/drive/MyDrive/gpt2_train/MPF_corpus/MPF_version4b_full.conll'
istream = open(corpus_file)
speaker_turns_content = [] # real speaker turn with content (words)
doc_id = ""
turn = Turn.read_turn( istream,doc_id ) # the first turn

while turn:
    # ignore commentaires or silences turn and turns without real words
    if len(turn.conll) != 0 and not re.match(r'(com)|(silence)', turn.speaker_id):
        speaker_turns_content.append(turn)
    doc_id = turn.doc_id
    turn = Turn.read_turn( istream,doc_id )

istream.close()
print(len(speaker_turns_content)) #73971

73971


## ->update SD

In [46]:
# read csv and update misc
import csv
from collections import defaultdict
import re

def read_mycsv(filename, fieldnames):
    '''only update lines which are not NA'''
    updated_lines = defaultdict(list)
    with open(filename+'.csv',newline='',encoding='latin-1') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            if row['sd'] not in ['na','abandon']:
                updated_lines[int(row['sent_id'])].append((int(row['verb_i']),
                                                           {field: row[field] for field in fieldnames}))
    return updated_lines

def update_turns(turns,updated_lines):
    '''turns: old speaker_turns_content
        updated_lines: a dictionary with sent_id as key
        and list of updated idx and dictionaries as value
        only update turns whose id was in the csv'''
    n = 0
    updated_turns = []
    for turn in turns:
        if turn.sent_id in updated_lines:
            for line in updated_lines[turn.sent_id]:
                turn.update_misc(*line)
                updated_turns.append(turn)
                n+=1
    print(n)
    return updated_turns

fieldnames = ['sent_id', 'verb_i', 'speaker_name', 'sd', 'position',
              's1_i', 's2_i', 's1', 's2',
              's1_i_np', 's1_np',
              's1_lemma', 's1_freq', 's1_type', 's1_animacy', 's1_len',
              'v_lemma', 'v_freq', 'v_type',
              'finite_v_i', 'finite_v_lemma',
              'dist_v2s',
              'polarity', 'polarity_item', 'polarity_position',
              'clause',
              #'cg', 'v', 'cd'
             ]


# s1 constituent
sd_file = '/content/drive/MyDrive/gpt2_train/MPF_SD/sd_left_np_n4671'

updated_lines = read_mycsv(sd_file, fieldnames)
update_turns(speaker_turns_content, updated_lines) #4671
print()

4671



## ->construct context precedent (cps) cp
the head of the NP subject will be the last token

In [47]:
import json
import re

def clean(w):
    w = re.sub(r'@+(a|x|s)?$', '', w)
    w = re.sub(r'^(.+)[.,!?]$', '\\1', w)
    return w

def construct_cp_np(turns, window, min_len, add_speaker, add_line_break):
    '''construct context precedent the target sentence, end by s1'''
    cps = []
    for k,turn in enumerate(turns):
        for i,annot in enumerate(turn.data['misc']):
            if 'sd' in annot:
                conditions = [annot['sd'] in ['no_sd', 'sd'],
                            annot['position'] in ['left', 'na'],
                            annot['s1_type'] != 'expression',
                            annot['s1_lemma'] != 'anonyme',
                            annot['s1_i_np'] != '',
                            '-' not in  annot['s1_i']]
                if all(conditions): # same condition as in audio extraction and text version
                #if re.match(r'^[0-9]+$', annot['s1_i'].strip()) and not re.match(r'x@', annot['s1'].lower()) and annot['s1'].lower() != 'x' and annot['s1_lemma']!='anonyme': #uni-token NP subject and the name is not anonymized as X
                    cp = []
                    for j in range(k-window, k):
                        # same document
                        if turns[j].doc_id == turn.doc_id and k >= window:
                            if add_speaker is False:
                                cp += turns[j].data['words']
                            elif add_speaker is True and add_line_break is True:
                                # a blank will be added after "speaker:" latter in this function
                                cp += [f"{turns[j].speaker_id.capitalize()}:"] + turns[j].data['words'] + ["\n"]
                        else:
                            break #cp remains an empty list i.e., no previous context will be added
                    if cp != []: # only append the target sentence when a certain window size contexts have been added
                        if add_speaker is False:
                            cp += turn.data['words'][:int(annot['s1_i'])+1] #s1 will be the last token
                        elif add_speaker is True:
                            cp += [f"{turn.speaker_id.capitalize()}:"] + turn.data['words'][:int(annot['s1_i'])+1] #s1 will be the last token

                    # add more precedent sents to the left if the length is short
                    extra = 1
                    while len(cp)<=min_len:
                        if turns[k-window-extra].doc_id == turn.doc_id and k >= window+extra:
                            if add_speaker is False:
                                cp = turns[k-window-extra].data['words'] + cp
                            elif add_speaker is True and add_line_break is True:
                                cp = [f"{turns[k-window-extra].speaker_id.capitalize()}:"] + turns[k-window-extra].data['words'] + ["\n"] + cp
                            extra += 1
                        else:
                            break

                    if len(cp)>min_len:
                        # stock cp only when it attains minimal length
                        #print(turn.sent_id, cp)
                        cp = ' '.join([clean(w) for w in cp])
                        # suppress blank before \n (the blank was added by join function)
                        cp = cp.replace(" \n ", "\n")
                        #print(turn.sent_id, cp)
                        #turn.update_misc((i, {'cp': cp}))
                        #cps.append([turn.sent_id, i, annot['sd'], annot['position'], annot['s1'], cp])
                        cps.append([turn.sent_id, i, annot['sd'], annot['position'],
                                    annot['s1'], annot['s1_i_np'], annot['s1_np'], cp])


    return cps

def save2json(data, outputname):
    json_str = json.dumps(data, indent=4)
    with open(outputname + '.json', 'w') as json_file:
        json_file.write(json_str)

# test functions
cps = construct_cp_np(speaker_turns_content, window=2, min_len=1, add_speaker=True, add_line_break=True)
print(cps[:3])
#save to json
#save2json(cps, path+'MPF_sd_left_cps_sent2_len20')


#load json
#filename = 'MPF_sd_left_cps_sent5_len40.json'
#cps = json.load(open(path+filename, encoding='utf-8'))
#print(len(cps)) #4532


[[15, 7, 'sd', 'left', 'projet', '3-4', 'le projet', 'Julie: Co- comment vous raconteriez ce ce projet ?\nKimia: Tu v-\nElikia: Ben déjà euh le projet'], [224, 10, 'sd', 'left', 'vie', '6-7', 'la vie', 'Elikia: Pas du tout .\nJulie: Pourquoi ils sont partis tes parents ?\nElikia: Ben parce que ils pensaient que la vie'], [231, 4, 'sd', 'left', 'pÃ¨re', '0-1', 'Mon pÃ¨re', "Julie: Ils le pensent plus aujourd'hui ?\nElikia: Non plus forcément .\nElikia: Mon père"]]


## proba estimation

###->mapping (sub)words with the original token

In [48]:
#all subwords begin with Ġ (corresponding to a whitespace), except the first token
def mapping(input_ids, converted_tokens):
    '''input: a tensor of ids and corresponding tokens
    the first token of the sentence does not begin with a space, the remaining tokens start with a space
    we use the space as frointier marker of words
    output = [([idx in the input list, ], [idx in the voc]), ]'''

    i = 0
    start = 0
    output = []
    current_token_id = []

    # always stock previous subwords when encountering a token begins with Ġ (beginning of a new token)
    for i,token in enumerate(converted_tokens):
        # the first line is necessary to assure the first token is never directly included to output (to avoid multi inclusion)
        if i == 0: #first token does not begin with a Ġ, stock the id(s) of the first token (but not append it until the seoncd round)
            pass
        #output always stocks previous token
        elif re.match('Ġ', token): #not a suffix (important: make sure there is no empty token)
            #if start == i-1: #previous token is made of one subword (don't understand why I made these two conditions)
            #    output.append((i-1, current_token_id))
            #else: #previous token is made of more than one subword
            output.append(( [k for k in range(start, i) ], current_token_id)) # output previous subwords (forming a token)
            current_token_id = []
            start = i

        current_token_id.append(input_ids[0][i].item()) # .item(): take the int from a tensor

    #last token (info stocked in current_token_id)
    #if start == i: #the last token is made of one subword (don't understand why I made these two conditions)
    #    output.append((i, current_token_id))
    #else: #the last token is made of more than one subword
    output.append(( [k for k in range(start, i+1) ], current_token_id))
    print(output)
    return output

# test
input_sent = " Mon daron il est là . Il est pas content contre le GPT"
input_ids = tokenizer.encode(input_sent, return_tensors='pt')

print(input_ids[0])
converted_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print(converted_tokens)
print(len(converted_tokens))

mapping(input_ids, converted_tokens)

tensor([ 1388,   207, 11717,   361,   314,  1116,  2937,   389,   314,   322,
         5192,   677,   239, 18252,    56])
['ĠMon', 'Ġd', 'aron', 'Ġil', 'Ġest', 'ĠlÃł', 'Ġ.', 'ĠIl', 'Ġest', 'Ġpas', 'Ġcontent', 'Ġcontre', 'Ġle', 'ĠGP', 'T']
15
[([0], [1388]), ([1, 2], [207, 11717]), ([3], [361]), ([4], [314]), ([5], [1116]), ([6], [2937]), ([7], [389]), ([8], [314]), ([9], [322]), ([10], [5192]), ([11], [677]), ([12], [239]), ([13, 14], [18252, 56])]


[([0], [1388]),
 ([1, 2], [207, 11717]),
 ([3], [361]),
 ([4], [314]),
 ([5], [1116]),
 ([6], [2937]),
 ([7], [389]),
 ([8], [314]),
 ([9], [322]),
 ([10], [5192]),
 ([11], [677]),
 ([12], [239]),
 ([13, 14], [18252, 56])]

test: next token proba

In [None]:
input_sentence = "Mon daron il est là . Il est pas content contre le GPT"

input_ids = tokenizer.encode(input_sentence, return_tensors='pt').to(device)
print(input_ids)
print(input_ids.size())
print(tokenizer.decode(input_ids[0]))
print(tokenizer.convert_ids_to_tokens(input_ids[0]))
print(tokenizer.convert_ids_to_tokens(input_ids[0])[0][0])

outputs = model(input_ids) # an object of CausalLMOutputWithCrossAttentions
print(type(outputs))
print('logits size', outputs.logits.size()) # (batch_size, sequence_length, config.vocab_size)

next_token_logits = outputs.logits[:, 13, :] #the logit of the word generated given the word in the postion 13 (start from 0)
print('next token', )
print(next_token_logits.size())
print(next_token_logits)

next_token_probas = torch.nn.functional.softmax(next_token_logits, dim=-1) # A dimension along which Softmax will be computed (so every slice along dim will sum to 1)
print('next_token_probas', next_token_probas.size())

next_token_log_probas = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
print('next_token_log_probas', next_token_log_probas.size())

top10 = torch.topk(next_token_log_probas, k=10, dim=-1)
top10_log_proba = top10.values
top10_indices = top10.indices
print('top10 log proba:', top10_log_proba)
print('top10 indices:', top10_indices)
print(tokenizer.convert_ids_to_tokens(top10_indices[0]))

tensor([[ 8566,   207, 11717,   361,   314,  1116,  2937,   389,   314,   322,
          5192,   677,   239, 18252,    56]])
torch.Size([1, 15])
Mon daron il est là. Il est pas content contre le GPT
['Mon', 'Ġd', 'aron', 'Ġil', 'Ġest', 'ĠlÃł', 'Ġ.', 'ĠIl', 'Ġest', 'Ġpas', 'Ġcontent', 'Ġcontre', 'Ġle', 'ĠGP', 'T']
M
<class 'transformers.modeling_outputs.CausalLMOutputWithCrossAttentions'>
logits size torch.Size([1, 15, 50000])
next token
torch.Size([1, 50000])
tensor([[-10.1105, -19.7650, -19.3792,  ...,  -9.5175,  -8.3068,  -2.7066]],
       grad_fn=<SliceBackward0>)
next_token_probas torch.Size([1, 50000])
next_token_log_probas torch.Size([1, 50000])
top10 log proba: tensor([[-2.0544, -2.0715, -2.3568, -3.0034, -3.0649, -3.0669, -3.2693, -3.3223,
         -3.7563, -3.8741]], grad_fn=<TopkBackward0>)
top10 indices: tensor([[ 18,  57,  48, 490,  22,  21, 214,  16, 748,  23]])
['.', 'U', 'L', 'Ġ?', '2', '1', 'Ġde', ',', 'Ġ!', '3']


###->get last token proba, ranking, etc

In [49]:
def next_token_prediction(input_ids, position, i_gold, model_outputs, k, printk):
    '''predict the word following the word in the position 'position'
    position: index of the target s1
    i_gold: index of the gold s1 in the vocabulary '''

    #next token logit
    next_token_logits = model_outputs.logits[:, position, :] #the logit of the word generated given the word in the postion "position" (start from 0)
    next_token_probas = torch.nn.functional.softmax(next_token_logits, dim=-1)
    next_token_log_probas = torch.nn.functional.log_softmax(next_token_logits, dim=-1)

    topk = torch.topk(next_token_log_probas, k=k, dim=-1)
    topk_log_proba = topk.values
    topk_indices = topk.indices #topk_indices size = [batch_size, k]

    # winner
    i_w = topk_indices[:, 0] #index in the matrix = index in the voc, i_w is a tensor of size [1]
    word_w = tokenizer.convert_ids_to_tokens(i_w.squeeze(dim=0).item())
    score_w = next_token_logits[:, i_w].squeeze(dim=0).item() #tensor
    proba_w = next_token_probas[:, i_w].squeeze(dim=0).item() #tensor
    log_proba_w = next_token_log_probas[:, i_w].squeeze(dim=0).item() #tensor
    ratio_w = score_w/score_w
    ranking_w = topk_indices.squeeze(dim=0).tolist().index(i_w.item()) + 1 #index start from 0, should be 1 for the winner

    # gold
    word_g = tokenizer.convert_ids_to_tokens(i_gold) #i_gold is an int
    score_g = next_token_logits[:, i_gold].squeeze(dim=0).item() #tensor
    proba_g = next_token_probas[:, i_gold].squeeze(dim=0).item() #tensor
    log_proba_g = next_token_log_probas[:, i_gold].squeeze(dim=0).item() #tensor
    ratio_g = score_g/score_w

    if i_gold in topk_indices.squeeze(dim=0).tolist():
        ranking_g = topk_indices.squeeze(dim=0).tolist().index(i_gold) + 1 #index start from 0, should be 1 for the winner
    else:
        ranking_g = ">" + str(k)

    # topk
    print('---generation---')
    print('input_ids size:', input_ids[0].size())
    print('the index of the word at the previous position', input_ids[0][position].unsqueeze(dim=0))
    print('at the postion', position)
    print('previous word:', tokenizer.convert_ids_to_tokens(input_ids[0][position].unsqueeze(dim=0)))
    print('top%i log proba:' %k, topk_log_proba)
    print('top%i indices:' %k, topk_indices)
    print('top%i candidates:' %k, tokenizer.convert_ids_to_tokens(topk_indices[0])) #index in the matrix, equivalent to index in the voc
    print()

    return [word_w, score_w, proba_w, log_proba_w, ratio_w, ranking_w], [word_g, score_g, proba_g, log_proba_g, ratio_g, ranking_g], tokenizer.convert_ids_to_tokens(topk_indices[0, :printk])

def get_last_token_logits(token_ids, voc_ids, outputs):
    '''get last token logits
    token_ids: s1 index in the sentence
    voc_ids: s1 index in the vocabulary'''
    last_token_logits = []

    if type(token_ids) == int: #whole word
        last_token_logits.append(outputs.logits[:, token_ids-1, :]) #minus 1 to have the word preceding s1 in the sent as input
    else: #multiple subwords
        for t in range(len(token_ids)):
            #print("tokens id")
            #print(token_ids)
            last_token_logits.append(outputs.logits[:, token_ids[t]-1, :])

    return last_token_logits

def logits2proba(last_token_logits, token_ids, voc_ids):
    '''turn logits of the target position to proba'''
    last_token_probas = []
    last_token_log_probas = []
    for i,logit in enumerate(last_token_logits):
        # softmax and log_softmax
        last_token_position_probas = torch.nn.functional.softmax(logit, dim=-1) # A dimension along which Softmax will be computed (so every slice along dim will sum to 1)
        last_token_position_log_probas = torch.nn.functional.log_softmax(logit, dim=-1)
        #print('last_token_position_probas', last_token_position_probas.size())
        #print('last_token_position_log_probas', last_token_position_log_probas.size())

        last_token_probas.append(last_token_position_probas[:, voc_ids[i]].squeeze(dim=0).item())
        last_token_log_probas.append(last_token_position_log_probas[:, voc_ids[i]].squeeze(dim=0).item())
    print('---target words---')
    print('last token probas:', last_token_probas)
    print('last token log probas:', last_token_log_probas)

    for i in range(len(voc_ids)):
        print('last token', voc_ids[i], tokenizer.convert_ids_to_tokens(voc_ids[i]))

    return last_token_probas, last_token_log_probas




### ->main function to get sd left subject probas (get_scores_sents)

In [54]:
def get_scores_sents(cps, get_proba='constituent', k=10, printk=10):
    '''
    cps: head: [(turn.sent_id, i, annot['sd'], annot['position'], annot['s1'], cp)]
        constituent: [(turn.sent_id, i, annot['sd'], annot['position'], annot['s1'], annot['s1_i_np'], annot['s1_np'], cp)]

    get_proba='head': estimate the probability of the head of subject
    get_proba='constituent': estimate the probability of the sujbect constituent'''

    sd = []
    for n, sent in enumerate(cps):
        print('processing the number ---%s--- sd' %str(n))
        print(sent[-1].split(' ')[-6:])
        # get the index of subject or subject constituent in the vocabulary
        # sent[-1] is cp
        input_ids = tokenizer.encode(sent[-1], return_tensors='pt').to(device) # sent[-1] is cp
        #print(input_ids[0])
        converted_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
        #print(converted_tokens)

        ## get model outputs (to get logits, i.e., scores)
        with torch.no_grad():
            outputs = model(input_ids) # an object of CausalLMOutputWithCrossAttentions

        input_mapped_ids = mapping(input_ids, converted_tokens)
        #print(input_mapped_ids)

        # estimate the probability of the head of subject
        if get_proba == 'head':
            token_ids, voc_ids = input_mapped_ids[-1] #the last token (s1) #token_ids can be int or list, voc_ids is a list
        # compute proba of the constituent
        elif get_proba == 'constituent':
            # sent[-3] is s1_np_i only when get_proba is 'constituent'
            s1_np_i = sent[-3]
            if '-' in s1_np_i:
                b, e = s1_np_i.split('-')
                token_ids, voc_ids = list(zip(*input_mapped_ids[-(int(e)-int(b)+1):]))
                print(input_mapped_ids[-(int(e)-int(b)+1):])
                print(token_ids)
                token_ids = [x for j in token_ids for x in j]
                voc_ids = [x for j in voc_ids for x in j]

            else:
                 token_ids, voc_ids = input_mapped_ids[-1]

        # get last token logits
        last_token_logits = get_last_token_logits(token_ids, voc_ids, outputs)
        # compute proba of the gold word based on logits
        last_token_probas, last_token_log_probas = logits2proba(last_token_logits, token_ids, voc_ids)

        #test whether the token at the right position is predicted
        winners = []
        golds = []
        topk_predictions = []
        if type(token_ids) == int:
            winner, gold, topk_prediction = next_token_prediction(input_ids, token_ids-1, voc_ids[0], outputs, k=k, printk=printk)
            winners.append(winner)
            golds.append(gold)
            topk_predictions.append(topk_prediction)
        else:
            assert len(token_ids) == len(voc_ids)
            for i, id in enumerate(token_ids):
                winner, gold, topk_prediction = next_token_prediction(input_ids, id-1, voc_ids[i], outputs, k=k, printk=printk)
                winners.append(winner)
                golds.append(gold)
                topk_predictions.append(topk_prediction)

        sd.append([sent, last_token_probas, last_token_log_probas, topk_predictions, winners, golds])
    return sd


#sd_probas = get_scores_sents(cps, k=30, printk=15)
#print(sd_probas)

### ->write out

In [55]:
def save2json(data, outputname):
    import json
    json_str = json.dumps(data, indent=4)
    with open(outputname + '.json', 'w', encoding='utf-8') as json_file:
        json_file.write(json_str)

#save to json
#save2json(sd_probas, path+'MPF_sd_winner_gold_ratios_win2_len20')
#len(sd_probas) #4636

In [53]:
print(path)

/content/drive/MyDrive/gpt2_train/


### ->> main function: write out a bunch of probas with different window size

In [None]:
for win in range(6,11):
    cps = construct_cp_np(speaker_turns_content, window=win, min_len=1, add_speaker=True, add_line_break=True)
    #save2json(cps, path+'MPF_sd_left_constituent_cps_sent_spk_win'+str(win))

    #filename = 'MPF_sd_left_cps_sent' + str(win) + '.json'
    #cps = json.load(open(path+filename, encoding='utf-8'))

    ## estimate the probability of the head of subject
    #sd_probas = get_scores_sents(cps, get_proba='head', k=30, printk=10)

    ## estimate the probability of the subject constituent (NP)
    sd_probas = get_scores_sents(cps, get_proba='constituent', k=30, printk=10)
    save2json(sd_probas, path+'MPF_sd_constituent_winner_gold_ratios_spk_new_win'+str(win))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
         -5.4850, -5.5238, -6.2674, -6.3006, -6.4072, -6.4315, -6.4403, -6.7511,
         -6.8473, -6.9739, -7.0105, -7.1660, -7.1998, -7.2117, -7.2581, -7.3626,
         -7.3715, -7.3802, -7.3898, -7.4740, -7.5000, -7.5002]],
       device='cuda:0')
top30 indices: tensor([[  421,   281,   533,   217,   376,   650,   260,   546,   757,   239,
          1871,   960,  1238,   210,   234,  4471,   628,   722,  1729,   503,
           171,  1679, 16576,   509,   294,   426,   640,   929,   506,   207]],
       device='cuda:0')
top30 candidates: ['Ġje', 'Ġj', 'Ġvous', 'Ġc', 'Ġce', 'Ġtu', 'Ġles', 'Ġnous', 'ĠÃ§a', 'Ġle', 'Ġmaintenant', 'Ġmoi', 'Ġbon', 'Ġl', 'Ġla', 'Ġoui', 'Ġces', 'Ġdit', 'Ġtoi', 'Ġcette', 'Ċ', 'Ġparce', 'Ġeuh', 'Ġtout', 'Ġpour', 'Ġsont', 'Ġma', 'Ġmes', 'Ġmon', 'Ġd']

---generation---
input_ids size: torch.Size([547])
the index of the word at the previous position tensor([239], device='cuda:0')
at the postion 545