In [1]:
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from torch.utils.data import Dataset
from torch import nn
from torch.optim import Adam
from datasets import load_dataset
import random
from copy import deepcopy
import wandb


In [2]:
# Initialize W&B
wandb.init(
    project="few-shot-yelp",
    name="maml-yelp-few-shot",
    config={
        "num_support": 5,
        "num_query": 5,
        "meta_epochs": 3,
        "batch_size": 2,
        "inner_steps": 1,
        "meta_lr": 1e-4,
        "inner_lr": 1e-5
    }
)

wandb: Currently logged in as: kostic-stojan23 (kostic-stojan23-university-of-belgrade). Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [3]:
# Load fine-tuned model and tokenizer
model_path = "NLP_VER1"
base_model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [4]:
# Load Yelp Polarity Dataset
dataset = load_dataset("yelp_polarity")
train_data = dataset["train"]
test_data = dataset["test"]

In [5]:
# Few-Shot Dataset for Yelp
class FewShotDataset(Dataset):
    def __init__(self, data, num_support, num_query):
        self.data = data
        self.num_support = num_support
        self.num_query = num_query

    def get_task(self):
        indices = list(range(len(self.data)))
        random.shuffle(indices)

        support_indices = indices[:self.num_support]
        query_indices = indices[self.num_support:self.num_support + self.num_query]

        support_set = [(self.data[i]['text'], self.data[i]['label']) for i in support_indices]
        query_set = [(self.data[i]['text'], self.data[i]['label']) for i in query_indices]

        return support_set, query_set

In [6]:
# Inner loop: Task adaptation
def inner_loop(model, support_set, num_steps=1, lr=1e-5):
    task_model = deepcopy(model)
    optimizer = Adam(task_model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(num_steps):
        for text, label in support_set:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
            labels = torch.tensor([label])

            outputs = task_model(**inputs, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return task_model

In [7]:
# Outer loop: Meta-training
def outer_loop(meta_model, tasks, meta_optimizer, num_inner_steps=1):
    meta_optimizer.zero_grad()
    total_loss = 0

    for support_set, query_set in tasks:
        task_model = inner_loop(meta_model, support_set, num_steps=num_inner_steps)

        for text, label in query_set:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
            labels = torch.tensor([label])

            outputs = task_model(**inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss

    total_loss.backward()
    meta_optimizer.step()

    return total_loss.item()

In [8]:
# Create Few-Shot Dataset
num_support = wandb.config.num_support
num_query = wandb.config.num_query
few_shot_dataset = FewShotDataset(train_data, num_support, num_query)

# Define meta-optimizer
meta_optimizer = Adam(base_model.parameters(), lr=wandb.config.meta_lr)

# Meta-training loop
meta_epochs = wandb.config.meta_epochs
batch_size = wandb.config.batch_size
inner_steps = wandb.config.inner_steps

for epoch in range(meta_epochs):
    tasks = [few_shot_dataset.get_task() for _ in range(batch_size)]
    loss = outer_loop(base_model, tasks, meta_optimizer, num_inner_steps=inner_steps)

    wandb.log({"epoch": epoch + 1, "meta_loss": loss})
    print(f"Epoch {epoch + 1}, Meta Loss: {loss:.4f}")

Epoch 1, Meta Loss: 5.7888
Epoch 2, Meta Loss: 1.0674
Epoch 3, Meta Loss: 1.5546


In [9]:
# Evaluate on a new Few-Shot task
new_task_data = FewShotDataset(test_data, num_support=5, num_query=5)
new_support_set, new_query_set = new_task_data.get_task()

adapted_model = inner_loop(base_model, new_support_set, num_steps=inner_steps)
total_loss = 0

for text, label in new_query_set:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    labels = torch.tensor([label])

    outputs = adapted_model(**inputs, labels=labels)
    loss = outputs.loss
    total_loss += loss

wandb.log({"new_task_loss": total_loss.item()})
print(f"New Task Evaluation Loss: {total_loss.item():.4f}")

# Finish W&B logging
wandb.finish()

New Task Evaluation Loss: 0.2210


0,1
epoch,▁▅█
meta_loss,█▁▂
new_task_loss,▁

0,1
epoch,3.0
meta_loss,1.55464
new_task_loss,0.22104
