In [14]:
import random
from typing import List
import numpy as np
import math


class SmoothedMeanWeightUpdater:
    def __init__(
            self,
            dataset_names: List[str],
            weights: List[float],
            smoothing_factor: float = 0.9,
    ):
        self.dataset_names = dataset_names
        self.dataset_map = {name: i for i, name in enumerate(dataset_names)}
        self.num_datasets = len(dataset_names)
        self.weights = weights
        self._estimated_reward = {name: 0.0 for name in dataset_names}
        total_weights = np.sum(weights)
        self._probabilities = {name: weight/total_weights for name, weight in zip(dataset_names, weights)}
        self.eps = 1/self.num_datasets
        self.prev_eps = None
        self.smoothing_factor = smoothing_factor
        self.vars_to_log = ["_probabilities", "_estimated_reward"]

    def update(self, dataset_name: str, reward: float, iteration: int) -> List[float]:
        """
        Updates the weights based on the provided reward.
        """

        # update cumulative estimated reward
        self._estimated_reward[dataset_name] = self.smoothing_factor*self._estimated_reward[dataset_name] + (1-self.smoothing_factor)*math.exp(reward)

        # calculate epsilons
        self.prev_eps = self.eps
        self.eps = min(1/self.num_datasets, math.sqrt(math.log(self.num_datasets)/(self.num_datasets*iteration)))

        # calculate scaling factor
        total_estimated_rewards = sum([math.exp(r*self.prev_eps) for r in self._estimated_reward.values()])
        scaling_factor = (1-self.num_datasets*self.eps)/total_estimated_rewards

        # update weights
        for name in self.dataset_names:
            self.weights[self.dataset_map[name]] = math.exp(self._estimated_reward[name]*self.prev_eps)*scaling_factor + self.eps

        # update probabilities
        total_weights = sum(self.weights)
        for name in self.dataset_names:
            self._probabilities[name] = self.weights[self.dataset_map[name]]/total_weights

        return list(self._probabilities.values())

    def group_update(self, dataset_names: List[str], rewards: List, iteration: int):
        # calculate epsilons
        self.prev_eps = self.eps
        self.eps = min(1/self.num_datasets, math.sqrt(math.log(self.num_datasets)/(self.num_datasets*iteration)))

        # update cumulative estimated reward
        for name, reward in zip(dataset_names, rewards):
            # smoothed mean
            # self._estimated_reward[name] = self.smoothing_factor*self._estimated_reward[name] + (1-self.smoothing_factor)*reward
            # smoothed exponentiated mean
            self._estimated_reward[name] = self.smoothing_factor*self._estimated_reward[name] + (1-self.smoothing_factor)*math.exp(reward)
        # print(f"Rank: {torch.distributed.get_rank()} -- estimated_reward {self._estimated_reward}")

        # calculate normalized scaling factor
        total_estimated_rewards = sum((r*self.prev_eps) for r in self._estimated_reward.values())
        scaling_factor = (1-self.num_datasets*self.eps)/total_estimated_rewards

        # update weights
        for name in self.dataset_names:
            # self.weights[self.dataset_map[name]] = math.exp(self._estimated_reward[name]*self.prev_eps)*scaling_factor + self.eps
            self.weights[self.dataset_map[name]] = self._estimated_reward[name]*self.prev_eps*scaling_factor + self.eps

        # update probabilities
        total_weights = sum(self.weights)
        for name in self.dataset_names:
            self._probabilities[name] = self.weights[self.dataset_map[name]]/total_weights

        return list(self._probabilities.values())

In [48]:
weights = SmoothedMeanWeightUpdater(["a", "b"], [0.5, 0.5])
weights.update("a", 4, 1)
weights.update("b", 4, 2)
weights.update("a", 2, 3)
weights.update("a", 0, 3)

[0.4925946534736098, 0.5074053465263902]

In [2]:
import pandas as pd
from datasets import Dataset, DatasetDict
from numpy import dtype
from pandas import DataFrame
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model_path="base_models/granite-3.2-2b-instruct"
device= "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map=device,
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_path
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
from datasets import Dataset, DatasetDict

with open("data/coco.ml.txt") as f:
    ml = f.readlines()

with open("data/coco.en.txt") as f:
    eng = f.readlines()

def get_dataset(ml, eng):
    ml = [sen.strip() for sen in ml]
    eng = [sen.strip() for sen in eng]
    return [{"ml": ml, "eng": eng, "content": f'Translate to english:<|end_of_text|>{ml}<|end_of_text|>{eng}<|end_of_text|>'} for ml, eng in zip(ml, eng)]

dataset = get_dataset(ml, eng)
n = 100
train_dataset = Dataset.from_list(dataset[:n // 10 * 8])
valid_dataset = Dataset.from_list(dataset[n // 10 * 8:n])
dataset = DatasetDict({"train": train_dataset, "validation": valid_dataset})
dataset

DatasetDict({
    train: Dataset({
        features: ['ml', 'eng', 'content'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['ml', 'eng', 'content'],
        num_rows: 20
    })
})

In [4]:
from transformers import Trainer, TrainingArguments
args = TrainingArguments("ml_to_en")

def tokenize_function(examples):
    tokenized_inputs = tokenizer(examples["content"], truncation=True, padding=True, return_tensors="pt")
    tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone()
    return tokenized_inputs

tokenized_datasets = dataset.map(tokenize_function, remove_columns=dataset['train'].column_names, batched=True)
tokenized_datasets

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 20
    })
})

In [20]:
from tqdm.auto import tqdm
from transformers import get_scheduler, DataCollatorWithPadding
from torch.utils.data import DataLoader
from transformers import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=1, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=1, collate_fn=data_collator
)

num_epochs = 3
num_training_steps = 100
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

progress_bar = tqdm(range(num_training_steps))

data_loaders = {
    "train": train_dataloader,
    "eval": eval_dataloader,
}

data_loader_iters = {k: iter(v) for k, v in data_loaders.items()}
weights = SmoothedMeanWeightUpdater(["train", "eval"], [0.5, 0.5])

model.train()
for i in range(num_training_steps):
    batch_name = random.choices(["train", "eval"], weights=weights.weights)[0]
    print(batch_name)
    try:
        batch = next(data_loader_iters[batch_name]) 
    except StopIteration:
        data_loader_iters[batch_name] = iter(data_loaders[batch_name])
        batch = next(data_loader_iters[batch_name])
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    res = weights.update(batch_name, loss.item(), i + 1)
    print(loss, res)
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    progress_bar.update(1)

  0%|          | 0/240 [00:00<?, ?it/s]

eval
tensor(0.7216, device='mps:0', grad_fn=<NllLossBackward0>) [0.5, 0.5]
train
tensor(6.5066, device='mps:0', grad_fn=<NllLossBackward0>) [0.5837226944211507, 0.4162773055788494]
eval
tensor(2.7237, device='mps:0', grad_fn=<NllLossBackward0>) [0.6601110032765516, 0.33988899672344836]
eval
tensor(1.1982, device='mps:0', grad_fn=<NllLossBackward0>) [0.7056474942696611, 0.294352505730339]
train
tensor(1.3106, device='mps:0', grad_fn=<NllLossBackward0>) [0.7367231006648045, 0.2632768993351955]
eval
tensor(0.7517, device='mps:0', grad_fn=<NllLossBackward0>) [0.7596620857223804, 0.24033791427761972]
eval
tensor(1.3359, device='mps:0', grad_fn=<NllLossBackward0>) [0.7774899908727757, 0.22251000912722438]
train
tensor(1.2670, device='mps:0', grad_fn=<NllLossBackward0>) [0.7918567751490906, 0.20814322485090947]
train
tensor(1.1806, device='mps:0', grad_fn=<NllLossBackward0>) [0.8037351917354268, 0.19626480826457324]
train
tensor(1.4122, device='mps:0', grad_fn=<NllLossBackward0>) [0.813702116