In [None]:
# The process of implementing this has been moved into a Python script in order to run it as a Slurm job.

In [1]:
import os
import torch
from datasets import load_dataset, load_from_disk
from pprint import PrettyPrinter
from transformers import AutoModelForCausalLM
import matplotlib.pyplot as plt
import pandas as pd
import random
from time import sleep

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# globals
DEFAULT_MODEL = "LumiOpen/Poro-34B"
DATASET_PATH = "../data/europarl_en-fi"
pprint = PrettyPrinter(compact=True).pprint

In [3]:
if not os.path.exists(DATASET_PATH):
    print("Predownloaded dataset not found. Starting download.")
    dataset = load_dataset("Helsinki-NLP/europarl", "en-fi" ,split="train")
    pprint(dataset)
    dataset.save_to_disk(dataset_path=DATASET_PATH)

In [4]:
dataset = load_from_disk(dataset_path=DATASET_PATH)
pprint(f'{dataset["translation"][0]["en"]}   {dataset["translation"][0]["fi"]}')

'Resumption of the session   Istuntokauden uudelleenavaaminen'


In [5]:
translation_dicts = dataset["translation"]
print(translation_dicts[:10])

[{'en': 'Resumption of the session', 'fi': 'Istuntokauden uudelleenavaaminen'}, {'en': 'I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.', 'fi': 'Julistan perjantaina joulukuun 17. päivänä keskeytetyn Euroopan parlamentin istunnon avatuksi ja esitän vielä kerran vilpittömän toiveeni siitä, että teillä olisi ollut oikein mukava joululoma.'}, {'en': "Although, as you will have seen, the dreaded 'millennium bug' failed to materialise, still the people in a number of countries suffered a series of natural disasters that truly were dreadful.", 'fi': 'Kuten olette varmaan saattaneet huomata, vuodenvaihteeseen 2000 povattuja suuria tietokoneongelmia ei ilmennytkään. Sen sijaan todella kauheat luonnonkatastrofit koettelivat kansalaisia joissakin unionimme maissa.'}, {'en': 'You have requested a debate on this subject in the course of the

In [6]:
sorted_list = sorted(translation_dicts, key=lambda d: len(d['en']))
filtered_list = [d for d in sorted_list if len(d['en']) >= 10]
df = pd.DataFrame.from_dict(filtered_list)
df = df.drop_duplicates()

In [7]:
len(df)

1919415

In [8]:
# Takes a sorted dataframe as the input
# Bands control the amount of reference points
# Per is the amount of objects per band
# Thold is the percentage amount of variation within a band
def picker(dframe, bands, per, thold):
    df_list = dframe.values.tolist()
    df_len = len(df_list)
    variables = []

    band_len = int(df_len / bands)
    low_b = int(df_len / bands / 2) # 
    high_b = int(df_len - low_b)

    print(f"{low_b}, {high_b}, {band_len}")

    for band_no in range(0, bands, 1):
        phase_dict = {}
        entry_list = []
        rand_idxs = []
        
        if band_no == 0:
            band_loc = low_b
        elif band_no == bands-1:
            band_loc = high_b
        else:
            band_loc = int(low_b + band_no * band_len)

        phase_dict["band_loc"] = band_loc
        phase_dict["band_no"] = band_no
        phase_dict["median_len"] = len(df_list[band_loc][0])

        high_t = int(band_loc * (thold+1))
        low_t = int(band_loc * (1-thold))

        for i in range(per):
            flag = True # used to assure the following loop runs at least once, essentially a do-while
            rand_idx = 0
            
            while flag:
                rand_idx = int(random.uniform(low_t, high_t))

                # TEST - change False to True to test
                if i == 1 and False:
                    rand_idxs.append(rand_idx)  
                    
                if rand_idx in rand_idxs:
                    print("Detected random index clash! Will sleep a bit before retry.") # don't overload
                    sleep(0.5)
                else:
                    rand_idxs.append(rand_idx)
                    flag = False

            entry_list.append(df_list[rand_idx])

        phase_dict["entries"] = entry_list
        variables.append(phase_dict)

    pprint(variables)
        
picker(dframe=df, bands=10, per=5, thold=0.05)

95970, 1823445, 191941
[{'band_loc': 95970,
  'band_no': 0,
  'entries': [['This vote will also help in that regard.',
               'Tämä äänestys on avuksi myös siinä suhteessa.'],
              ['The foregoing is the reason for our vote.',
               'Edellä mainitut asiat olivat syynä äänestyskäyttäytymiseemme.'],
              ['The whole process is highly questionable.',
               'Koko asia on erittäin arveluttava.'],
              ['We must remain very vigilant about this.',
               'Meidän tulee valvoa tarkasti, että niin ei tapahdu.'],
              ['Children and young people expect as much.',
               'Sitä lapset ja nuoret odottavat.']],
  'median_len': 41},
 {'band_loc': 287911,
  'band_no': 1,
  'entries': [['I will not bore you with the details, but we finally arrived '
               'here.',
               'En halua pitkästyttää teitä yksityiskohdilla, mutta lopulta '
               'pääsimme Strasbourgiin.'],
              ['I leave you with on