In [1]:
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from torch.utils.data import Dataset
from torch import nn
from torch.optim import Adam
from datasets import load_dataset
import random
from copy import deepcopy
import optuna
from torch.utils.data import DataLoader
import wandb

In [2]:
# Initialize W&B
wandb.init(
    project="few-shot-multiset",
    name="maml-multiset-optuna",
)

wandb: Currently logged in as: kostic-stojan23 (kostic-stojan23-university-of-belgrade). Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [3]:
def load_and_normalize(dataset_name):
    dataset = load_dataset(dataset_name)

    if dataset_name == "amazon_polarity":
        def process(example):
            return {"text": example["content"], "label": example["label"]}
    elif dataset_name == "yelp_polarity":
        def process(example):
            return {"text": example["text"], "label": example["label"]}
    elif dataset_name == "sentiment140":
        def process(example):
            label = 0 if example["sentiment"] == 0 else 1  # Convert (0,4) to (0,1)
            return {"text": example["text"], "label": label}

    dataset["train"] = dataset["train"].map(process)
    dataset["test"] = dataset["test"].map(process)

    return dataset["train"], dataset["test"]

In [4]:
dataset_names = ["amazon_polarity", "yelp_polarity", "sentiment140"]
train_data, test_data = {}, {}

for name in dataset_names:
    train_data[name], test_data[name] = load_and_normalize(name)

In [5]:
class FewShotDataset(Dataset):
    def __init__(self, data, num_support, num_query):
        self.data = data
        self.num_support = num_support
        self.num_query = num_query

    def get_task(self):
        indices = list(range(len(self.data)))
        random.shuffle(indices)

        support_indices = indices[:self.num_support]
        query_indices = indices[self.num_support:self.num_support + self.num_query]

        support_set = [(self.data[i]['text'], self.data[i]['label']) for i in support_indices]
        query_set = [(self.data[i]['text'], self.data[i]['label']) for i in query_indices]

        return support_set, query_set

In [6]:
model_path = "NLP_VER1"
base_model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model.to(device)

def inner_loop(model, support_set, num_steps=1, lr=1e-5):
    task_model = deepcopy(model).to(device)  # Move task-specific model to GPU
    optimizer = torch.optim.Adam(task_model.parameters(), lr=lr)
    loss_fn = torch.nn.CrossEntropyLoss()

    for _ in range(num_steps):
        for text, label in support_set:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
            labels = torch.tensor([label]).unsqueeze(0).to(device)  # Move labels to GPU

            outputs = task_model(**inputs, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return task_model

In [8]:
def outer_loop(meta_model, tasks, meta_optimizer, num_inner_steps=1):
    meta_optimizer.zero_grad()
    total_loss = 0

    for support_set, query_set in tasks:
        task_model = inner_loop(meta_model, support_set, num_steps=num_inner_steps)

        for text, label in query_set:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
            labels = torch.tensor([label]).unsqueeze(0).to(device)

            outputs = task_model(**inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss

    total_loss.backward()
    meta_optimizer.step()

    return total_loss.item()

In [9]:
def objective(trial):
    meta_lr = trial.suggest_loguniform("meta_lr", 1e-5, 1e-3)
    inner_lr = trial.suggest_loguniform("inner_lr", 1e-5, 1e-3)
    num_support = trial.suggest_int("num_support", 1, 10)
    num_query = trial.suggest_int("num_query", 1, 10)
    inner_steps = trial.suggest_int("inner_steps", 1, 5)
    batch_size = trial.suggest_int("batch_size", 2, 8)
    meta_epochs = trial.suggest_int("meta_epochs", 5, 10)

    dataset_name = random.choice(dataset_names)
    few_shot_dataset = FewShotDataset(train_data[dataset_name], num_support, num_query)

    meta_optimizer = Adam(base_model.parameters(), lr=meta_lr)

    for epoch in range(meta_epochs):
        tasks = [few_shot_dataset.get_task() for _ in range(batch_size)]
        loss = outer_loop(base_model, tasks, meta_optimizer, num_inner_steps=inner_steps)

        wandb.log({
            "trial": trial.number,
            "epoch": epoch + 1,
            "meta_loss": loss,
            "meta_lr": meta_lr,
            "inner_lr": inner_lr,
            "num_support": num_support,
            "num_query": num_query,
            "inner_steps": inner_steps,
            "batch_size": batch_size,
        })

        print(f"[Trial {trial.number}, Epoch {epoch + 1}] Loss: {loss:.4f}")

    return loss

In [11]:
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=34)

best_trial = study.best_trial
print(f"Best Trial: {best_trial.number}, Loss: {best_trial.value}")
print("Best Hyperparameters:", best_trial.params)

[I 2025-01-30 21:52:40,410] A new study created in memory with name: no-name-25124f3e-a74c-4f17-82ff-59787b330d7e
  meta_lr = trial.suggest_loguniform("meta_lr", 1e-5, 1e-3)
  inner_lr = trial.suggest_loguniform("inner_lr", 1e-5, 1e-3)


[Trial 0, Epoch 1] Loss: 23.6549
[Trial 0, Epoch 2] Loss: 4.3525
[Trial 0, Epoch 3] Loss: 15.8666
[Trial 0, Epoch 4] Loss: 11.7129
[Trial 0, Epoch 5] Loss: 6.9439
[Trial 0, Epoch 6] Loss: 9.8187
[Trial 0, Epoch 7] Loss: 5.4228
[Trial 0, Epoch 8] Loss: 12.7705
[Trial 0, Epoch 9] Loss: 4.7034


[I 2025-01-30 22:00:25,321] Trial 0 finished with value: 4.703431606292725 and parameters: {'meta_lr': 0.00012765534849510702, 'inner_lr': 1.9468831996027415e-05, 'num_support': 9, 'num_query': 4, 'inner_steps': 5, 'batch_size': 6, 'meta_epochs': 9}. Best is trial 0 with value: 4.703431606292725.


[Trial 1, Epoch 1] Loss: 34.0209
[Trial 1, Epoch 2] Loss: 13.9256
[Trial 1, Epoch 3] Loss: 13.2510
[Trial 1, Epoch 4] Loss: 8.4297
[Trial 1, Epoch 5] Loss: 12.9661
[Trial 1, Epoch 6] Loss: 18.4708


[I 2025-01-30 22:02:28,265] Trial 1 finished with value: 18.4708251953125 and parameters: {'meta_lr': 1.9320689557340608e-05, 'inner_lr': 0.00019001272780410218, 'num_support': 6, 'num_query': 8, 'inner_steps': 4, 'batch_size': 5, 'meta_epochs': 6}. Best is trial 0 with value: 4.703431606292725.


[Trial 2, Epoch 1] Loss: 6.7252
[Trial 2, Epoch 2] Loss: 2.0755
[Trial 2, Epoch 3] Loss: 0.5456
[Trial 2, Epoch 4] Loss: 3.7954
[Trial 2, Epoch 5] Loss: 4.2484
[Trial 2, Epoch 6] Loss: 4.1412


[I 2025-01-30 22:03:15,191] Trial 2 finished with value: 4.1411824226379395 and parameters: {'meta_lr': 3.9060169472092286e-05, 'inner_lr': 3.182033474887267e-05, 'num_support': 6, 'num_query': 2, 'inner_steps': 2, 'batch_size': 2, 'meta_epochs': 6}. Best is trial 2 with value: 4.1411824226379395.


[Trial 3, Epoch 1] Loss: 5.6219
[Trial 3, Epoch 2] Loss: 10.0801
[Trial 3, Epoch 3] Loss: 10.7330
[Trial 3, Epoch 4] Loss: 18.7148
[Trial 3, Epoch 5] Loss: 11.6823
[Trial 3, Epoch 6] Loss: 5.0801
[Trial 3, Epoch 7] Loss: 4.5051
[Trial 3, Epoch 8] Loss: 13.6569


[I 2025-01-30 22:07:08,144] Trial 3 finished with value: 13.656900405883789 and parameters: {'meta_lr': 2.0907685277119343e-05, 'inner_lr': 1.7668644805978523e-05, 'num_support': 3, 'num_query': 8, 'inner_steps': 3, 'batch_size': 4, 'meta_epochs': 8}. Best is trial 2 with value: 4.1411824226379395.


[Trial 4, Epoch 1] Loss: 3.6463
[Trial 4, Epoch 2] Loss: 6.3260
[Trial 4, Epoch 3] Loss: 0.1661
[Trial 4, Epoch 4] Loss: 4.9609
[Trial 4, Epoch 5] Loss: 1.0969
[Trial 4, Epoch 6] Loss: 0.5304
[Trial 4, Epoch 7] Loss: 4.7963


[I 2025-01-30 22:08:53,425] Trial 4 finished with value: 4.796279430389404 and parameters: {'meta_lr': 0.000915923776610583, 'inner_lr': 0.0002075165155870697, 'num_support': 5, 'num_query': 6, 'inner_steps': 1, 'batch_size': 2, 'meta_epochs': 7}. Best is trial 2 with value: 4.1411824226379395.


[Trial 5, Epoch 1] Loss: 5.1177
[Trial 5, Epoch 2] Loss: 2.0561
[Trial 5, Epoch 3] Loss: 0.2264
[Trial 5, Epoch 4] Loss: 2.5274
[Trial 5, Epoch 5] Loss: 12.1637
[Trial 5, Epoch 6] Loss: 1.2927
[Trial 5, Epoch 7] Loss: 6.7408
[Trial 5, Epoch 8] Loss: 5.7796
[Trial 5, Epoch 9] Loss: 1.4448
[Trial 5, Epoch 10] Loss: 1.1115


[I 2025-01-30 22:10:08,499] Trial 5 finished with value: 1.1114599704742432 and parameters: {'meta_lr': 0.00013603828869305698, 'inner_lr': 0.0003754408822980123, 'num_support': 5, 'num_query': 3, 'inner_steps': 2, 'batch_size': 4, 'meta_epochs': 10}. Best is trial 5 with value: 1.1114599704742432.


[Trial 6, Epoch 1] Loss: 7.6003
[Trial 6, Epoch 2] Loss: 7.7233
[Trial 6, Epoch 3] Loss: 16.8586
[Trial 6, Epoch 4] Loss: 13.8449
[Trial 6, Epoch 5] Loss: 7.7543
[Trial 6, Epoch 6] Loss: 5.9565


[I 2025-01-30 22:10:53,947] Trial 6 finished with value: 5.956506729125977 and parameters: {'meta_lr': 0.0007795386475656917, 'inner_lr': 0.000988254249791263, 'num_support': 4, 'num_query': 7, 'inner_steps': 2, 'batch_size': 2, 'meta_epochs': 6}. Best is trial 5 with value: 1.1114599704742432.


[Trial 7, Epoch 1] Loss: 4.5838
[Trial 7, Epoch 2] Loss: 10.8669
[Trial 7, Epoch 3] Loss: 2.4891
[Trial 7, Epoch 4] Loss: 5.9419
[Trial 7, Epoch 5] Loss: 7.5549


[I 2025-01-30 22:11:55,904] Trial 7 finished with value: 7.554874897003174 and parameters: {'meta_lr': 0.00010657913547510096, 'inner_lr': 0.0005362744747230688, 'num_support': 5, 'num_query': 4, 'inner_steps': 5, 'batch_size': 4, 'meta_epochs': 5}. Best is trial 5 with value: 1.1114599704742432.


[Trial 8, Epoch 1] Loss: 20.3686
[Trial 8, Epoch 2] Loss: 9.7355
[Trial 8, Epoch 3] Loss: 10.6000
[Trial 8, Epoch 4] Loss: 20.0862
[Trial 8, Epoch 5] Loss: 8.9484
[Trial 8, Epoch 6] Loss: 10.8548


[I 2025-01-30 22:16:09,264] Trial 8 finished with value: 10.854832649230957 and parameters: {'meta_lr': 0.00013923684312892843, 'inner_lr': 0.0006217252382166167, 'num_support': 3, 'num_query': 9, 'inner_steps': 5, 'batch_size': 5, 'meta_epochs': 6}. Best is trial 5 with value: 1.1114599704742432.


[Trial 9, Epoch 1] Loss: 3.7133
[Trial 9, Epoch 2] Loss: 4.7007
[Trial 9, Epoch 3] Loss: 5.1362
[Trial 9, Epoch 4] Loss: 9.0248
[Trial 9, Epoch 5] Loss: 6.2047
[Trial 9, Epoch 6] Loss: 13.5139
[Trial 9, Epoch 7] Loss: 6.5189


[I 2025-01-30 22:24:23,360] Trial 9 finished with value: 6.51888370513916 and parameters: {'meta_lr': 7.067011693902542e-05, 'inner_lr': 0.00011113736713401701, 'num_support': 5, 'num_query': 4, 'inner_steps': 5, 'batch_size': 7, 'meta_epochs': 7}. Best is trial 5 with value: 1.1114599704742432.


[Trial 10, Epoch 1] Loss: 3.3815
[Trial 10, Epoch 2] Loss: 1.5572
[Trial 10, Epoch 3] Loss: 0.6036
[Trial 10, Epoch 4] Loss: 0.1500
[Trial 10, Epoch 5] Loss: 5.2355
[Trial 10, Epoch 6] Loss: 8.9201
[Trial 10, Epoch 7] Loss: 0.0986
[Trial 10, Epoch 8] Loss: 2.5974
[Trial 10, Epoch 9] Loss: 4.5319
[Trial 10, Epoch 10] Loss: 0.6450


[I 2025-01-30 22:33:48,238] Trial 10 finished with value: 0.6449998021125793 and parameters: {'meta_lr': 0.00032927120859901817, 'inner_lr': 4.956300757463975e-05, 'num_support': 1, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 10}. Best is trial 10 with value: 0.6449998021125793.


[Trial 11, Epoch 1] Loss: 5.5753
[Trial 11, Epoch 2] Loss: 0.2168
[Trial 11, Epoch 3] Loss: 0.0567
[Trial 11, Epoch 4] Loss: 3.0650
[Trial 11, Epoch 5] Loss: 0.0781
[Trial 11, Epoch 6] Loss: 0.6767
[Trial 11, Epoch 7] Loss: 0.0468
[Trial 11, Epoch 8] Loss: 4.5873
[Trial 11, Epoch 9] Loss: 1.0222
[Trial 11, Epoch 10] Loss: 4.2074


[I 2025-01-30 22:41:44,212] Trial 11 finished with value: 4.207398891448975 and parameters: {'meta_lr': 0.00034125878374823796, 'inner_lr': 4.745993889530881e-05, 'num_support': 1, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 10}. Best is trial 10 with value: 0.6449998021125793.


[Trial 12, Epoch 1] Loss: 9.3527
[Trial 12, Epoch 2] Loss: 9.0494
[Trial 12, Epoch 3] Loss: 0.2407
[Trial 12, Epoch 4] Loss: 11.6954
[Trial 12, Epoch 5] Loss: 4.4752
[Trial 12, Epoch 6] Loss: 5.2841
[Trial 12, Epoch 7] Loss: 0.8155
[Trial 12, Epoch 8] Loss: 1.1046
[Trial 12, Epoch 9] Loss: 12.2626
[Trial 12, Epoch 10] Loss: 0.2172


[I 2025-01-30 22:43:58,821] Trial 12 finished with value: 0.217221200466156 and parameters: {'meta_lr': 0.00033269093101486407, 'inner_lr': 6.15019569557295e-05, 'num_support': 1, 'num_query': 2, 'inner_steps': 2, 'batch_size': 8, 'meta_epochs': 10}. Best is trial 12 with value: 0.217221200466156.


[Trial 13, Epoch 1] Loss: 0.4027
[Trial 13, Epoch 2] Loss: 0.2820
[Trial 13, Epoch 3] Loss: 2.9965
[Trial 13, Epoch 4] Loss: 1.2060
[Trial 13, Epoch 5] Loss: 0.0770
[Trial 13, Epoch 6] Loss: 2.4142
[Trial 13, Epoch 7] Loss: 5.5846
[Trial 13, Epoch 8] Loss: 0.7309
[Trial 13, Epoch 9] Loss: 0.1633


[I 2025-01-30 22:50:49,011] Trial 13 finished with value: 0.16325432062149048 and parameters: {'meta_lr': 0.0003238870723018769, 'inner_lr': 5.3011413355797864e-05, 'num_support': 1, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 9}. Best is trial 13 with value: 0.16325432062149048.


[Trial 14, Epoch 1] Loss: 9.7654
[Trial 14, Epoch 2] Loss: 5.6957
[Trial 14, Epoch 3] Loss: 16.7275
[Trial 14, Epoch 4] Loss: 4.7089
[Trial 14, Epoch 5] Loss: 7.8915
[Trial 14, Epoch 6] Loss: 11.6817
[Trial 14, Epoch 7] Loss: 11.3146
[Trial 14, Epoch 8] Loss: 10.9585
[Trial 14, Epoch 9] Loss: 15.3361


[I 2025-01-30 22:54:01,714] Trial 14 finished with value: 15.3361234664917 and parameters: {'meta_lr': 0.00033794521816553905, 'inner_lr': 8.021191636520669e-05, 'num_support': 2, 'num_query': 2, 'inner_steps': 3, 'batch_size': 7, 'meta_epochs': 9}. Best is trial 13 with value: 0.16325432062149048.


[Trial 15, Epoch 1] Loss: 3.0469
[Trial 15, Epoch 2] Loss: 3.2333
[Trial 15, Epoch 3] Loss: 5.2038
[Trial 15, Epoch 4] Loss: 0.0931
[Trial 15, Epoch 5] Loss: 6.0406
[Trial 15, Epoch 6] Loss: 0.2199
[Trial 15, Epoch 7] Loss: 0.4189
[Trial 15, Epoch 8] Loss: 3.0387
[Trial 15, Epoch 9] Loss: 4.4924


[I 2025-01-30 22:56:58,193] Trial 15 finished with value: 4.492351055145264 and parameters: {'meta_lr': 0.0005426389346370192, 'inner_lr': 9.844128211442174e-05, 'num_support': 8, 'num_query': 1, 'inner_steps': 2, 'batch_size': 7, 'meta_epochs': 9}. Best is trial 13 with value: 0.16325432062149048.


[Trial 16, Epoch 1] Loss: 8.2102
[Trial 16, Epoch 2] Loss: 10.4556
[Trial 16, Epoch 3] Loss: 9.3189
[Trial 16, Epoch 4] Loss: 9.5959
[Trial 16, Epoch 5] Loss: 5.8101
[Trial 16, Epoch 6] Loss: 0.6800
[Trial 16, Epoch 7] Loss: 6.8624
[Trial 16, Epoch 8] Loss: 7.1749


[I 2025-01-30 23:05:00,203] Trial 16 finished with value: 7.174880027770996 and parameters: {'meta_lr': 0.000226673085895108, 'inner_lr': 1.1154669024398464e-05, 'num_support': 2, 'num_query': 3, 'inner_steps': 3, 'batch_size': 8, 'meta_epochs': 8}. Best is trial 13 with value: 0.16325432062149048.


[Trial 17, Epoch 1] Loss: 26.6707
[Trial 17, Epoch 2] Loss: 42.1638
[Trial 17, Epoch 3] Loss: 34.8392
[Trial 17, Epoch 4] Loss: 24.0678
[Trial 17, Epoch 5] Loss: 31.1372
[Trial 17, Epoch 6] Loss: 14.0724
[Trial 17, Epoch 7] Loss: 29.0780
[Trial 17, Epoch 8] Loss: 21.6773
[Trial 17, Epoch 9] Loss: 31.6127


[I 2025-01-30 23:07:43,236] Trial 17 finished with value: 31.612667083740234 and parameters: {'meta_lr': 0.0005118623898126613, 'inner_lr': 5.440703132223809e-05, 'num_support': 1, 'num_query': 5, 'inner_steps': 1, 'batch_size': 6, 'meta_epochs': 9}. Best is trial 13 with value: 0.16325432062149048.


[Trial 18, Epoch 1] Loss: 3.4544
[Trial 18, Epoch 2] Loss: 5.9648
[Trial 18, Epoch 3] Loss: 1.9321
[Trial 18, Epoch 4] Loss: 0.2536
[Trial 18, Epoch 5] Loss: 0.2053
[Trial 18, Epoch 6] Loss: 2.1856
[Trial 18, Epoch 7] Loss: 0.2282
[Trial 18, Epoch 8] Loss: 0.1818
[Trial 18, Epoch 9] Loss: 5.4852
[Trial 18, Epoch 10] Loss: 5.5686


[I 2025-01-30 23:09:23,914] Trial 18 finished with value: 5.5686163902282715 and parameters: {'meta_lr': 1.0427803471929359e-05, 'inner_lr': 0.00016442350367569273, 'num_support': 3, 'num_query': 2, 'inner_steps': 2, 'batch_size': 6, 'meta_epochs': 10}. Best is trial 13 with value: 0.16325432062149048.


[Trial 19, Epoch 1] Loss: 72.1607
[Trial 19, Epoch 2] Loss: 72.0202
[Trial 19, Epoch 3] Loss: 70.6245
[Trial 19, Epoch 4] Loss: 67.9339
[Trial 19, Epoch 5] Loss: 49.7945
[Trial 19, Epoch 6] Loss: 89.9792
[Trial 19, Epoch 7] Loss: 82.3509
[Trial 19, Epoch 8] Loss: 49.3305


[I 2025-01-30 23:21:26,933] Trial 19 finished with value: 49.33052444458008 and parameters: {'meta_lr': 0.00022669005627067745, 'inner_lr': 2.864398423566454e-05, 'num_support': 7, 'num_query': 10, 'inner_steps': 4, 'batch_size': 8, 'meta_epochs': 8}. Best is trial 13 with value: 0.16325432062149048.


[Trial 20, Epoch 1] Loss: 8.7488
[Trial 20, Epoch 2] Loss: 11.9792
[Trial 20, Epoch 3] Loss: 10.2376
[Trial 20, Epoch 4] Loss: 2.8145
[Trial 20, Epoch 5] Loss: 5.2915
[Trial 20, Epoch 6] Loss: 9.4012
[Trial 20, Epoch 7] Loss: 2.1092
[Trial 20, Epoch 8] Loss: 8.3260
[Trial 20, Epoch 9] Loss: 10.5357


[I 2025-01-30 23:23:36,494] Trial 20 finished with value: 10.535721778869629 and parameters: {'meta_lr': 0.00021916980239517213, 'inner_lr': 7.734374914693465e-05, 'num_support': 2, 'num_query': 3, 'inner_steps': 1, 'batch_size': 7, 'meta_epochs': 9}. Best is trial 13 with value: 0.16325432062149048.


[Trial 21, Epoch 1] Loss: 0.6227
[Trial 21, Epoch 2] Loss: 0.1416
[Trial 21, Epoch 3] Loss: 3.5031
[Trial 21, Epoch 4] Loss: 4.7330
[Trial 21, Epoch 5] Loss: 0.0866
[Trial 21, Epoch 6] Loss: 4.6211
[Trial 21, Epoch 7] Loss: 8.2899
[Trial 21, Epoch 8] Loss: 5.0394
[Trial 21, Epoch 9] Loss: 0.7418
[Trial 21, Epoch 10] Loss: 0.0828


[I 2025-01-30 23:26:40,288] Trial 21 finished with value: 0.0828155130147934 and parameters: {'meta_lr': 0.0004379174327349713, 'inner_lr': 4.4169684443381925e-05, 'num_support': 1, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 10}. Best is trial 21 with value: 0.0828155130147934.


[Trial 22, Epoch 1] Loss: 0.3718
[Trial 22, Epoch 2] Loss: 2.7505
[Trial 22, Epoch 3] Loss: 0.2723
[Trial 22, Epoch 4] Loss: 0.1587
[Trial 22, Epoch 5] Loss: 3.5955
[Trial 22, Epoch 6] Loss: 0.2516
[Trial 22, Epoch 7] Loss: 0.2252
[Trial 22, Epoch 8] Loss: 0.2736
[Trial 22, Epoch 9] Loss: 0.1225
[Trial 22, Epoch 10] Loss: 0.4475


[I 2025-01-30 23:30:38,112] Trial 22 finished with value: 0.44751641154289246 and parameters: {'meta_lr': 0.0005244498440478509, 'inner_lr': 3.280215359934574e-05, 'num_support': 1, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 10}. Best is trial 21 with value: 0.0828155130147934.


[Trial 23, Epoch 1] Loss: 6.9694
[Trial 23, Epoch 2] Loss: 6.4900
[Trial 23, Epoch 3] Loss: 8.8578
[Trial 23, Epoch 4] Loss: 9.2669
[Trial 23, Epoch 5] Loss: 0.6901
[Trial 23, Epoch 6] Loss: 1.0208
[Trial 23, Epoch 7] Loss: 4.2944
[Trial 23, Epoch 8] Loss: 0.5860
[Trial 23, Epoch 9] Loss: 2.2045
[Trial 23, Epoch 10] Loss: 3.4968


[I 2025-01-30 23:36:45,177] Trial 23 finished with value: 3.496807813644409 and parameters: {'meta_lr': 6.882335057400459e-05, 'inner_lr': 5.418672312165638e-05, 'num_support': 2, 'num_query': 2, 'inner_steps': 2, 'batch_size': 7, 'meta_epochs': 10}. Best is trial 21 with value: 0.0828155130147934.


[Trial 24, Epoch 1] Loss: 0.0572
[Trial 24, Epoch 2] Loss: 0.6370
[Trial 24, Epoch 3] Loss: 9.6591
[Trial 24, Epoch 4] Loss: 4.6616
[Trial 24, Epoch 5] Loss: 9.5747
[Trial 24, Epoch 6] Loss: 4.7506
[Trial 24, Epoch 7] Loss: 5.1325
[Trial 24, Epoch 8] Loss: 0.0946
[Trial 24, Epoch 9] Loss: 2.7783


[I 2025-01-30 23:49:36,657] Trial 24 finished with value: 2.7782809734344482 and parameters: {'meta_lr': 0.0006880609098856353, 'inner_lr': 2.4567011872998434e-05, 'num_support': 3, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 9}. Best is trial 21 with value: 0.0828155130147934.


[Trial 25, Epoch 1] Loss: 21.5747
[Trial 25, Epoch 2] Loss: 11.5926
[Trial 25, Epoch 3] Loss: 14.5048
[Trial 25, Epoch 4] Loss: 26.3094
[Trial 25, Epoch 5] Loss: 8.8375
[Trial 25, Epoch 6] Loss: 14.2636
[Trial 25, Epoch 7] Loss: 17.9615
[Trial 25, Epoch 8] Loss: 17.6540
[Trial 25, Epoch 9] Loss: 22.3820
[Trial 25, Epoch 10] Loss: 14.9552


[I 2025-01-30 23:57:08,637] Trial 25 finished with value: 14.955235481262207 and parameters: {'meta_lr': 0.00040357533758133993, 'inner_lr': 0.00013608277326632252, 'num_support': 10, 'num_query': 3, 'inner_steps': 2, 'batch_size': 6, 'meta_epochs': 10}. Best is trial 21 with value: 0.0828155130147934.


[Trial 26, Epoch 1] Loss: 12.8727
[Trial 26, Epoch 2] Loss: 11.7881
[Trial 26, Epoch 3] Loss: 7.7704
[Trial 26, Epoch 4] Loss: 7.2614
[Trial 26, Epoch 5] Loss: 11.3088
[Trial 26, Epoch 6] Loss: 14.0818
[Trial 26, Epoch 7] Loss: 13.3330
[Trial 26, Epoch 8] Loss: 11.9786
[Trial 26, Epoch 9] Loss: 9.2271


[I 2025-01-31 00:08:42,688] Trial 26 finished with value: 9.227128982543945 and parameters: {'meta_lr': 0.00019132031555773533, 'inner_lr': 7.057665012524313e-05, 'num_support': 4, 'num_query': 5, 'inner_steps': 1, 'batch_size': 7, 'meta_epochs': 9}. Best is trial 21 with value: 0.0828155130147934.


[Trial 27, Epoch 1] Loss: 7.0552
[Trial 27, Epoch 2] Loss: 1.7751
[Trial 27, Epoch 3] Loss: 6.5154
[Trial 27, Epoch 4] Loss: 2.6272
[Trial 27, Epoch 5] Loss: 9.3366
[Trial 27, Epoch 6] Loss: 0.3050
[Trial 27, Epoch 7] Loss: 9.8438
[Trial 27, Epoch 8] Loss: 5.0220


[I 2025-01-31 00:14:13,994] Trial 27 finished with value: 5.0219926834106445 and parameters: {'meta_lr': 0.00028526053252671583, 'inner_lr': 3.779658024658e-05, 'num_support': 1, 'num_query': 2, 'inner_steps': 2, 'batch_size': 8, 'meta_epochs': 8}. Best is trial 21 with value: 0.0828155130147934.


[Trial 28, Epoch 1] Loss: 0.0268
[Trial 28, Epoch 2] Loss: 0.0634
[Trial 28, Epoch 3] Loss: 0.8177
[Trial 28, Epoch 4] Loss: 0.0292
[Trial 28, Epoch 5] Loss: 0.0214
[Trial 28, Epoch 6] Loss: 3.3039
[Trial 28, Epoch 7] Loss: 0.0155
[Trial 28, Epoch 8] Loss: 0.0286
[Trial 28, Epoch 9] Loss: 0.0179
[Trial 28, Epoch 10] Loss: 0.9515


[I 2025-01-31 00:15:10,906] Trial 28 finished with value: 0.9514747262001038 and parameters: {'meta_lr': 0.0009507391356933477, 'inner_lr': 1.0484071063564204e-05, 'num_support': 2, 'num_query': 1, 'inner_steps': 1, 'batch_size': 3, 'meta_epochs': 10}. Best is trial 21 with value: 0.0828155130147934.


[Trial 29, Epoch 1] Loss: 10.6670
[Trial 29, Epoch 2] Loss: 14.9919
[Trial 29, Epoch 3] Loss: 4.1333
[Trial 29, Epoch 4] Loss: 5.8490
[Trial 29, Epoch 5] Loss: 12.1909
[Trial 29, Epoch 6] Loss: 8.3282
[Trial 29, Epoch 7] Loss: 13.5090
[Trial 29, Epoch 8] Loss: 12.4762
[Trial 29, Epoch 9] Loss: 15.5459


[I 2025-01-31 00:25:32,070] Trial 29 finished with value: 15.54587459564209 and parameters: {'meta_lr': 0.00017839248714827044, 'inner_lr': 1.6116423869207887e-05, 'num_support': 4, 'num_query': 4, 'inner_steps': 2, 'batch_size': 7, 'meta_epochs': 9}. Best is trial 21 with value: 0.0828155130147934.


[Trial 30, Epoch 1] Loss: 2.8531
[Trial 30, Epoch 2] Loss: 3.0713
[Trial 30, Epoch 3] Loss: 0.1044
[Trial 30, Epoch 4] Loss: 2.8426
[Trial 30, Epoch 5] Loss: 1.4634
[Trial 30, Epoch 6] Loss: 0.1281
[Trial 30, Epoch 7] Loss: 14.8360
[Trial 30, Epoch 8] Loss: 1.8040
[Trial 30, Epoch 9] Loss: 3.8561


[I 2025-01-31 00:26:33,813] Trial 30 finished with value: 3.8561410903930664 and parameters: {'meta_lr': 0.0004846190927487317, 'inner_lr': 0.0002630738800558842, 'num_support': 1, 'num_query': 2, 'inner_steps': 3, 'batch_size': 5, 'meta_epochs': 9}. Best is trial 21 with value: 0.0828155130147934.


[Trial 31, Epoch 1] Loss: 0.1481
[Trial 31, Epoch 2] Loss: 3.7985
[Trial 31, Epoch 3] Loss: 0.1614
[Trial 31, Epoch 4] Loss: 0.4472
[Trial 31, Epoch 5] Loss: 2.1438
[Trial 31, Epoch 6] Loss: 4.3411
[Trial 31, Epoch 7] Loss: 0.0665
[Trial 31, Epoch 8] Loss: 2.3211
[Trial 31, Epoch 9] Loss: 2.7652
[Trial 31, Epoch 10] Loss: 0.1597


[I 2025-01-31 00:36:44,657] Trial 31 finished with value: 0.15974721312522888 and parameters: {'meta_lr': 0.0006764764640143719, 'inner_lr': 3.7697683888844164e-05, 'num_support': 1, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 10}. Best is trial 21 with value: 0.0828155130147934.


[Trial 32, Epoch 1] Loss: 1.1204
[Trial 32, Epoch 2] Loss: 0.9243
[Trial 32, Epoch 3] Loss: 8.6580
[Trial 32, Epoch 4] Loss: 5.2228
[Trial 32, Epoch 5] Loss: 2.1078
[Trial 32, Epoch 6] Loss: 8.3319
[Trial 32, Epoch 7] Loss: 6.8379
[Trial 32, Epoch 8] Loss: 5.0200
[Trial 32, Epoch 9] Loss: 6.1424
[Trial 32, Epoch 10] Loss: 6.6841


[I 2025-01-31 00:43:40,338] Trial 32 finished with value: 6.684129238128662 and parameters: {'meta_lr': 0.0006922877388649675, 'inner_lr': 4.230670857459767e-05, 'num_support': 2, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 10}. Best is trial 21 with value: 0.0828155130147934.


[Trial 33, Epoch 1] Loss: 15.6074
[Trial 33, Epoch 2] Loss: 21.4301
[Trial 33, Epoch 3] Loss: 12.6821
[Trial 33, Epoch 4] Loss: 15.6322
[Trial 33, Epoch 5] Loss: 6.1121
[Trial 33, Epoch 6] Loss: 13.2195
[Trial 33, Epoch 7] Loss: 15.3118
[Trial 33, Epoch 8] Loss: 13.5080
[Trial 33, Epoch 9] Loss: 7.3450
[Trial 33, Epoch 10] Loss: 17.0412


[I 2025-01-31 00:50:56,025] Trial 33 finished with value: 17.041229248046875 and parameters: {'meta_lr': 0.0004067884856220484, 'inner_lr': 2.4374013218275132e-05, 'num_support': 1, 'num_query': 2, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 10}. Best is trial 21 with value: 0.0828155130147934.


Best Trial: 21, Loss: 0.0828155130147934
Best Hyperparameters: {'meta_lr': 0.0004379174327349713, 'inner_lr': 4.4169684443381925e-05, 'num_support': 1, 'num_query': 1, 'inner_steps': 1, 'batch_size': 8, 'meta_epochs': 10}


In [12]:
best_params = best_trial.params
dataset_name = random.choice(dataset_names)
few_shot_dataset = FewShotDataset(train_data[dataset_name], best_params["num_support"], best_params["num_query"])
meta_optimizer = Adam(base_model.parameters(), lr=best_params["meta_lr"])

for epoch in range(best_params["meta_epochs"]):
    tasks = [few_shot_dataset.get_task() for _ in range(best_params["batch_size"])]
    loss = outer_loop(base_model, tasks, meta_optimizer, num_inner_steps=best_params["inner_steps"])
    wandb.log({"epoch": epoch + 1, "meta_loss": loss})
    print(f"Epoch {epoch + 1}, Meta Loss: {loss:.4f}")

Epoch 1, Meta Loss: 4.1011
Epoch 2, Meta Loss: 3.2963
Epoch 3, Meta Loss: 4.5106
Epoch 4, Meta Loss: 3.4007
Epoch 5, Meta Loss: 7.4569
Epoch 6, Meta Loss: 9.8577
Epoch 7, Meta Loss: 5.2001
Epoch 8, Meta Loss: 11.9975
Epoch 9, Meta Loss: 10.9002
Epoch 10, Meta Loss: 6.7005


In [13]:
new_dataset_name = random.choice(dataset_names)
new_task_data = FewShotDataset(test_data[new_dataset_name], best_params["num_support"], best_params["num_query"])
new_support_set, new_query_set = new_task_data.get_task()

adapted_model = inner_loop(base_model, new_support_set, num_steps=best_params["inner_steps"])
adapted_model = adapted_model.to(device)
total_loss = 0

for text, label in new_query_set:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
    labels = torch.tensor([label]).unsqueeze(0).to(device)


    outputs = adapted_model(**inputs, labels=labels)
    loss = outputs.loss
    total_loss += loss

wandb.log({"new_task_loss": total_loss.item()})
print(f"New Task ({new_dataset_name}) Evaluation Loss: {total_loss.item():.4f}")

New Task (amazon_polarity) Evaluation Loss: 0.0311


In [None]:
#This code works but it takes hours to complete because of milions of instances
from sklearn.metrics import accuracy_score, classification_report
def evaluate_model_on_all_datasets(model, test_data):
    model.eval()
    all_preds = []
    all_labels = []

    for dataset_name in ["sentiment140", "amazon_polarity", "yelp_polarity"]:
        encodings = tokenizer(test_data[dataset_name]["text"], padding=True, truncation=True, return_tensors="pt", max_length=512)

        test_dataset = torch.utils.data.TensorDataset(encodings.input_ids, encodings.attention_mask, torch.tensor(test_data[dataset_name]["label"]))
        test_loader = DataLoader(test_dataset, batch_size=16)

        with torch.no_grad():
            for batch in test_loader:
                input_ids, attention_mask, labels = batch
                input_ids = input_ids.to(model.device)
                attention_mask = attention_mask.to(model.device)
                labels = labels.to(model.device)

                outputs = model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                preds = torch.argmax(logits, dim=1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        accuracy = accuracy_score(all_labels, all_preds)
        wandb.log({f"{dataset_name}_accuracy": accuracy})

        print(f"Evaluation on {dataset_name}: Accuracy = {accuracy:.4f}")
        print(f"Classification Report for {dataset_name}:")
        print(classification_report(all_labels, all_preds))

evaluate_model_on_all_datasets(adapted_model, test_data)

In [38]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# 20 unseen sentences for testing
test_sentences = [
    "The customer service was beyond my expectations!",
    "I waited an hour for my order, and it was still wrong.",
    "This new phone update is a complete disaster.",
    "I'm absolutely thrilled with my new laptop!",
    "The food was bland and overpriced, not coming back.",
    "Best vacation spot ever, can't wait to return!",
    "The product broke after just two uses, very disappointed.",
    "Excellent book, well-written and engaging.",
    "Movie was predictable and boring, nothing special.",
    "Customer support was extremely helpful and quick to respond.",
    "I regret purchasing this item, waste of money.",
    "One of the best restaurants in town, highly recommended!",
    "The new policy changes are frustrating and unnecessary.",
    "I love how comfortable and stylish these shoes are!",
    "The concert was an unforgettable experience.",
    "The software crashes frequently, making it unusable.",
    "Brilliant storytelling, kept me hooked from start to finish.",
    "Shipping took forever, and the package arrived damaged.",
    "Great workout program, helped me get in shape quickly.",
    "The app's interface is confusing and hard to navigate."
]

inputs = tokenizer(test_sentences, padding=True, truncation=True, return_tensors="pt", max_length=512)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

inputs = {key: value.to(device) for key, value in inputs.items()}

with torch.no_grad():
    outputs = adapted_model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=1)

if isinstance(predictions, torch.Tensor):
    predictions = predictions.cpu().numpy()
predictions = predictions.tolist()

import torch.nn.functional as F

probs = F.softmax(logits, dim=1).cpu().numpy()
for i, (sentence, prob) in enumerate(zip(test_sentences, probs)):
    print(f"{i+1}. {sentence}\n   ➝ Positive: {prob[1]:.2f}, Negative: {prob[0]:.2f}\n")


1. The customer service was beyond my expectations!
   ➝ Positive: 0.92, Negative: 0.08

2. I waited an hour for my order, and it was still wrong.
   ➝ Positive: 0.01, Negative: 0.99

3. This new phone update is a complete disaster.
   ➝ Positive: 0.00, Negative: 1.00

4. I'm absolutely thrilled with my new laptop!
   ➝ Positive: 0.98, Negative: 0.02

5. The food was bland and overpriced, not coming back.
   ➝ Positive: 0.00, Negative: 1.00

6. Best vacation spot ever, can't wait to return!
   ➝ Positive: 0.99, Negative: 0.01

7. The product broke after just two uses, very disappointed.
   ➝ Positive: 0.00, Negative: 1.00

8. Excellent book, well-written and engaging.
   ➝ Positive: 1.00, Negative: 0.00

9. Movie was predictable and boring, nothing special.
   ➝ Positive: 0.00, Negative: 1.00

10. Customer support was extremely helpful and quick to respond.
   ➝ Positive: 0.98, Negative: 0.02

11. I regret purchasing this item, waste of money.
   ➝ Positive: 0.00, Negative: 1.00

12. O

In [17]:
wandb.finish()

0,1
batch_size,▆▅▁▃▁▃▁▅▇██████▇▇▇▆▆▇▇▇████▇██▆▇██▂█████
epoch,▃▃▃▁▃▄▄▃▃▁▁▂▄▇▅▃▄▁▆▂▂▃▄▅▃▇█▁▃▇▄▄▂▁▂▄▁█▁▇
inner_lr,▁▂▂▁▂▄█▅▅▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▃▁▁▁▁▁▁▁
inner_steps,█▆▁▁▁▃██▁▁▃▁▁▁▃▁▃▃▆▁▁▁▁▁▃▁▁▃▃▃▅▅▅▅▁▁▁▁▁▁
meta_loss,▂▂▁▂▂▁▁▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▇▆█▁▂▂▁▂▂▁▂▁▁▁▂▂▁▁
meta_lr,▂▂▁▂▂▂▂▂▂▂▃▃▃▃▃▅▃▃▁▁▃▃▄▄▄▅▆▄▄▂███▂▅▅▅▆▆▄
new_task_loss,▁
num_query,▃▃▆▂▂▆▅▅▆▆▇▇▁▂▂▁▃▃▄▄▂█▃▁▁▂▂▁▁▁▄▄▂▃▃▂▁▁▂▂
num_support,█▅▅▃▅▅▄▅▅▃▁▁▁▁▁▂▂▇▂▂▂▁▃▂▂▁▂▂▂▂▃▄▁▁▁▄▁▁▂▁
sentiment140_accuracy,▁

0,1
batch_size,8.0
epoch,10.0
inner_lr,2e-05
inner_steps,1.0
meta_loss,6.70054
meta_lr,0.00041
new_task_loss,0.03112
num_query,2.0
num_support,1.0
sentiment140_accuracy,0.76707
