In [1]:
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from torch.utils.data import Dataset
from torch import nn
from torch.optim import Adam
from datasets import load_dataset
import random
from copy import deepcopy
import wandb
import optuna


In [2]:
# Initialize W&B
wandb.init(
    project="few-shot-yelp",
    name="maml-yelp-few-shot-optuna",
)

wandb: Currently logged in as: kostic-stojan23 (kostic-stojan23-university-of-belgrade). Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [3]:
# Load fine-tuned model and tokenizer
model_path = "NLP_VER1"
base_model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [4]:
# Load Yelp Polarity Dataset
dataset = load_dataset("yelp_polarity")
train_data = dataset["train"]
test_data = dataset["test"]


In [5]:
# Define Few-Shot Dataset
class FewShotDataset(Dataset):
    def __init__(self, data, num_support, num_query):
        self.data = data
        self.num_support = num_support
        self.num_query = num_query

    def get_task(self):
        indices = list(range(len(self.data)))
        random.shuffle(indices)

        support_indices = indices[:self.num_support]
        query_indices = indices[self.num_support:self.num_support + self.num_query]

        support_set = [(self.data[i]['text'], self.data[i]['label']) for i in support_indices]
        query_set = [(self.data[i]['text'], self.data[i]['label']) for i in query_indices]

        return support_set, query_set

In [6]:
# Inner loop for task-specific adaptation
def inner_loop(model, support_set, num_steps=1, lr=1e-5):
    task_model = deepcopy(model)
    optimizer = Adam(task_model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(num_steps):
        for text, label in support_set:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
            labels = torch.tensor([label])

            outputs = task_model(**inputs, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return task_model

In [7]:
# Outer loop for meta-learning
def outer_loop(meta_model, tasks, meta_optimizer, num_inner_steps=1):
    meta_optimizer.zero_grad()
    total_loss = 0

    for support_set, query_set in tasks:
        task_model = inner_loop(meta_model, support_set, num_steps=num_inner_steps)

        for text, label in query_set:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
            labels = torch.tensor([label])

            outputs = task_model(**inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss

    total_loss.backward()
    meta_optimizer.step()

    return total_loss.item()

In [8]:
# Optuna objective function with W&B logging
def objective(trial):
    # Suggest hyperparameters
    meta_lr = trial.suggest_loguniform("meta_lr", 1e-5, 1e-3)
    inner_lr = trial.suggest_loguniform("inner_lr", 1e-5, 1e-3)
    num_support = trial.suggest_int("num_support", 1, 10)
    num_query = trial.suggest_int("num_query", 1, 10)
    inner_steps = trial.suggest_int("inner_steps", 1, 5)
    batch_size = trial.suggest_int("batch_size", 2, 8)
    meta_epochs = trial.suggest_int("meta_epochs", 3, 5)

    # Create Few-Shot Dataset
    few_shot_dataset = FewShotDataset(train_data, num_support, num_query)

    # Define meta-optimizer
    meta_optimizer = Adam(base_model.parameters(), lr=meta_lr)

    # Meta-training loop
    for epoch in range(meta_epochs):
        tasks = [few_shot_dataset.get_task() for _ in range(batch_size)]
        loss = outer_loop(base_model, tasks, meta_optimizer, num_inner_steps=inner_steps)

        # Log metrics to W&B
        wandb.log({
            "trial": trial.number,
            "epoch": epoch + 1,
            "meta_loss": loss,
            "meta_lr": meta_lr,
            "inner_lr": inner_lr,
            "num_support": num_support,
            "num_query": num_query,
            "inner_steps": inner_steps,
            "batch_size": batch_size,
        })

    return loss

In [9]:
# Optuna study setup and optimization
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=50)

# Log the best trial to W&B
best_trial = study.best_trial
wandb.log({
    "best_trial": best_trial.number,
    **best_trial.params,
    "best_loss": best_trial.value,
})

# Meta-training with best parameters
best_params = best_trial.params
few_shot_dataset = FewShotDataset(train_data, best_params["num_support"], best_params["num_query"])
meta_optimizer = Adam(base_model.parameters(), lr=best_params["meta_lr"])

for epoch in range(best_params["meta_epochs"]):
    tasks = [few_shot_dataset.get_task() for _ in range(best_params["batch_size"])]
    loss = outer_loop(base_model, tasks, meta_optimizer, num_inner_steps=best_params["inner_steps"])
    wandb.log({"epoch": epoch + 1, "meta_loss": loss})
    print(f"Epoch {epoch + 1}, Meta Loss: {loss:.4f}")

[I 2025-01-24 22:58:43,455] A new study created in memory with name: no-name-ac778edf-dd23-4515-922e-182905a61003
  meta_lr = trial.suggest_loguniform("meta_lr", 1e-5, 1e-3)
  inner_lr = trial.suggest_loguniform("inner_lr", 1e-5, 1e-3)
[I 2025-01-24 23:09:00,472] Trial 0 finished with value: 23.132558822631836 and parameters: {'meta_lr': 0.0007043030472696078, 'inner_lr': 6.557933362066147e-05, 'num_support': 6, 'num_query': 10, 'inner_steps': 4, 'batch_size': 8, 'meta_epochs': 4}. Best is trial 0 with value: 23.132558822631836.
[I 2025-01-24 23:12:42,890] Trial 1 finished with value: 1.9332340955734253 and parameters: {'meta_lr': 0.00012328944677160724, 'inner_lr': 0.00011600755636253694, 'num_support': 10, 'num_query': 9, 'inner_steps': 3, 'batch_size': 2, 'meta_epochs': 5}. Best is trial 1 with value: 1.9332340955734253.
[I 2025-01-24 23:18:24,629] Trial 2 finished with value: 9.519462585449219 and parameters: {'meta_lr': 0.0005056738982375394, 'inner_lr': 4.63185013957147e-05, 'num

Epoch 1, Meta Loss: 0.0162
Epoch 2, Meta Loss: 0.0075
Epoch 3, Meta Loss: 1.4160
Epoch 4, Meta Loss: 0.0067


In [10]:
# Evaluate on a new task
new_task_data = FewShotDataset(test_data, best_params["num_support"], best_params["num_query"])
new_support_set, new_query_set = new_task_data.get_task()

adapted_model = inner_loop(base_model, new_support_set, num_steps=best_params["inner_steps"])
total_loss = 0

for text, label in new_query_set:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    labels = torch.tensor([label])

    outputs = adapted_model(**inputs, labels=labels)
    loss = outputs.loss
    total_loss += loss

wandb.log({"new_task_loss": total_loss.item()})
print(f"New Task Evaluation Loss: {total_loss.item():.4f}")

New Task Evaluation Loss: 0.0010


In [11]:
# Classification metrics
from sklearn.metrics import accuracy_score, classification_report

true_labels = []
predicted_labels = []
for text, label in new_query_set:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    outputs = adapted_model(**inputs)
    predictions = torch.argmax(outputs.logits, dim=1)

    true_labels.append(label)
    predicted_labels.append(predictions.item())

print("Accuracy:", accuracy_score(true_labels, predicted_labels))
print("Classification Report:\n", classification_report(true_labels, predicted_labels))

wandb.finish()

Accuracy: 1.0
Classification Report:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00         1

    accuracy                           1.00         1
   macro avg       1.00      1.00      1.00         1
weighted avg       1.00      1.00      1.00         1



0,1
batch_size,▁▂▃▇▅▇▇▁▁▁▁▂▃▃▅██▅▅▅▂▂▂▂▂▂▃▃▃▁▂▂▂▁▁▁▁▁▁▁
best_loss,▁
best_trial,▁
epoch,▁▅▁█▃▅▅▃▅▁▁▁█▅▁▅▃▆▅▁▆▃▆▅▆▁▆▆█▃▆▁▃▅▁▁▆▅▁▁
inner_lr,▂▁▁▁▆▅▅▁▁▁▁▃▃▃█▂▁▁▂▂▃▃▃▂▄▄▄▄▇▅▆▆▂▂▂▄▄▅▇▅
inner_steps,▆▆▃▃▃▁▁▁▅▅▃▁▁▃▆▃▃▃▅▅▅▆▆▆▆▆▆▆██▆▆▆▆▆█████
meta_epochs,▁
meta_loss,█▆▁▁▂▁▄▂▁▁▂▁▁▁▁▅▂▁▁▂▁▁▁▁▅▁▁▁▂▂▂▂▁▁▁▁▂▁▃▁
meta_lr,▆▂▂▅▅▁▃▃▃▃▁▃▃▂▁▁▄▂▂▄██▆▆▄▄▄▄▅▅▅▅▃▃▅▆▇▇▇▅
new_task_loss,▁

0,1
batch_size,2.0
best_loss,0.00313
best_trial,43.0
epoch,4.0
inner_lr,0.00044
inner_steps,4.0
meta_epochs,4.0
meta_loss,0.00671
meta_lr,0.00072
new_task_loss,0.001


In [15]:
# Define file paths for saving
model_save_path = "maml_adapted_model.pth"
tokenizer_save_path = "maml_tokenizer/"

# Save the adapted model
torch.save(adapted_model.state_dict(), model_save_path)
print(f"Model saved at {model_save_path}")

# Save the tokenizer
tokenizer.save_pretrained(tokenizer_save_path)
print(f"Tokenizer saved at {tokenizer_save_path}")

Model saved at maml_adapted_model.pth
Tokenizer saved at maml_tokenizer/
