In [27]:
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from torch.utils.data import Dataset
from torch import nn
from torch.optim import Adam
from datasets import load_dataset
import random
from copy import deepcopy
from torch.utils.data import DataLoader
import wandb

In [28]:
# Initialize W&B
wandb.init(
    project="maml-fewshot-multiset",
    name="maml-multiset-manual",
)

In [29]:
def load_and_normalize(dataset_name):
    dataset = load_dataset(dataset_name)

    if dataset_name == "amazon_polarity":
        def process(example):
            return {"text": example["content"], "label": example["label"]}
    elif dataset_name == "yelp_polarity":
        def process(example):
            return {"text": example["text"], "label": example["label"]}
    elif dataset_name == "sentiment140":
        def process(example):
            label = 0 if example["sentiment"] == 0 else 1  # Convert (0,4) to (0,1)
            return {"text": example["text"], "label": label}

    dataset["train"] = dataset["train"].map(process)
    dataset["test"] = dataset["test"].map(process)

    return dataset["train"], dataset["test"]

In [30]:
dataset_names = ["amazon_polarity", "yelp_polarity", "sentiment140"]
train_data, test_data = {}, {}

for name in dataset_names:
    train_data[name], test_data[name] = load_and_normalize(name)

In [31]:
class FewShotDataset(Dataset):
    def __init__(self, data, num_support, num_query):
        self.data = data
        self.num_support = num_support
        self.num_query = num_query

    def get_task(self):
        indices = list(range(len(self.data)))
        random.shuffle(indices)

        support_indices = indices[:self.num_support]
        query_indices = indices[self.num_support:self.num_support + self.num_query]

        support_set = [(self.data[i]['text'], self.data[i]['label']) for i in support_indices]
        query_set = [(self.data[i]['text'], self.data[i]['label']) for i in query_indices]

        return support_set, query_set

In [32]:
model_path = "NLP_VER1"
base_model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model.to(device)

def inner_loop(model, support_set, num_steps=1, lr=1e-5):
    task_model = deepcopy(model).to(device)
    optimizer = torch.optim.Adam(task_model.parameters(), lr=lr)
    loss_fn = torch.nn.CrossEntropyLoss()

    for _ in range(num_steps):
        for text, label in support_set:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
            labels = torch.tensor([label]).unsqueeze(0).to(device)  # Move labels to GPU

            outputs = task_model(**inputs, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return task_model

In [34]:
def outer_loop(meta_model, tasks, meta_optimizer, num_inner_steps=1, lr = 1e-5):
    meta_optimizer.zero_grad()
    total_loss = 0

    for support_set, query_set in tasks:
        task_model = inner_loop(meta_model, support_set, num_steps=num_inner_steps, lr=lr)

        for text, label in query_set:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
            labels = torch.tensor([label]).unsqueeze(0).to(device)

            outputs = task_model(**inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss

    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(meta_model.parameters(), max_norm=0.5)
    meta_optimizer.step()

    return total_loss.item()

In [35]:
meta_lr = 1e-5
inner_lr = 1e-6
num_support = 5
num_query = 5
inner_steps = 5
batch_size = 4
meta_epochs = 20

In [36]:
few_shot_datasets = {
    name: FewShotDataset(train_data[name], num_support, num_query)
    for name in dataset_names
}

meta_optimizer = Adam(base_model.parameters(), lr=meta_lr)

for epoch in range(meta_epochs):
    tasks = []
    for name in dataset_names:
        tasks.extend([few_shot_datasets[name].get_task() for _ in range(batch_size)])

    loss = outer_loop(base_model, tasks, meta_optimizer, num_inner_steps=inner_steps, lr=inner_lr)

    wandb.log({
        "epoch": epoch + 1,
        "meta_loss": loss,
        "meta_lr": meta_lr,
        "inner_lr": inner_lr,
        "num_support": num_support,
        "num_query": num_query,
        "inner_steps": inner_steps,
        "batch_size": batch_size,
    })

    print(f"Epoch {epoch + 1}, Meta Loss: {loss:.4f}")

Epoch 1, Meta Loss: 39.6423
Epoch 2, Meta Loss: 37.4839
Epoch 3, Meta Loss: 42.3168
Epoch 4, Meta Loss: 21.7572
Epoch 5, Meta Loss: 30.6382
Epoch 6, Meta Loss: 32.8623
Epoch 7, Meta Loss: 31.2697
Epoch 8, Meta Loss: 30.9950
Epoch 9, Meta Loss: 17.8127
Epoch 10, Meta Loss: 23.9224
Epoch 11, Meta Loss: 21.5387
Epoch 12, Meta Loss: 23.4956
Epoch 13, Meta Loss: 19.8614
Epoch 14, Meta Loss: 32.4785
Epoch 15, Meta Loss: 22.6566
Epoch 16, Meta Loss: 28.3319
Epoch 17, Meta Loss: 34.7042
Epoch 18, Meta Loss: 19.4609
Epoch 19, Meta Loss: 27.8355
Epoch 20, Meta Loss: 20.2789


In [37]:
for dataset_name in dataset_names:
    new_task_data = FewShotDataset(test_data[dataset_name], num_support, num_query)
    new_support_set, new_query_set = new_task_data.get_task()

    adapted_model = inner_loop(base_model, new_support_set, num_steps=inner_steps)
    adapted_model = adapted_model.to(device)
    total_loss = 0

    for text, label in new_query_set:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
        labels = torch.tensor([label]).unsqueeze(0).to(device)

        outputs = adapted_model(**inputs, labels=labels)
        loss = outputs.loss
        total_loss += loss

    wandb.log({f"{dataset_name}_task_loss": total_loss.item()})
    print(f"Final Evaluation Loss on {dataset_name}: {total_loss.item():.4f}")

Final Evaluation Loss on amazon_polarity: 0.1861
Final Evaluation Loss on yelp_polarity: 0.2266
Final Evaluation Loss on sentiment140: 4.6426


In [38]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# 20 unseen sentences for testing
test_sentences = [
    "The customer service was beyond my expectations!",
    "I waited an hour for my order, and it was still wrong.",
    "This new phone update is a complete disaster.",
    "I'm absolutely thrilled with my new laptop!",
    "The food was bland and overpriced, not coming back.",
    "Best vacation spot ever, can't wait to return!",
    "The product broke after just two uses, very disappointed.",
    "Excellent book, well-written and engaging.",
    "Movie was predictable and boring, nothing special.",
    "Customer support was extremely helpful and quick to respond.",
    "I regret purchasing this item, waste of money.",
    "One of the best restaurants in town, highly recommended!",
    "The new policy changes are frustrating and unnecessary.",
    "I love how comfortable and stylish these shoes are!",
    "The concert was an unforgettable experience.",
    "The software crashes frequently, making it unusable.",
    "Brilliant storytelling, kept me hooked from start to finish.",
    "Shipping took forever, and the package arrived damaged.",
    "Great workout program, helped me get in shape quickly.",
    "The app's interface is confusing and hard to navigate."
]

inputs = tokenizer(test_sentences, padding=True, truncation=True, return_tensors="pt", max_length=512)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

inputs = {key: value.to(device) for key, value in inputs.items()}

with torch.no_grad():
    outputs = adapted_model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=1)

if isinstance(predictions, torch.Tensor):
    predictions = predictions.cpu().numpy()
predictions = predictions.tolist()

import torch.nn.functional as F

probs = F.softmax(logits, dim=1).cpu().numpy()
for i, (sentence, prob) in enumerate(zip(test_sentences, probs)):
    print(f"{i+1}. {sentence}\n   ➝ Positive: {prob[1]:.2f}, Negative: {prob[0]:.2f}\n")


1. The customer service was beyond my expectations!
   ➝ Positive: 0.98, Negative: 0.02

2. I waited an hour for my order, and it was still wrong.
   ➝ Positive: 0.02, Negative: 0.98

3. This new phone update is a complete disaster.
   ➝ Positive: 0.00, Negative: 1.00

4. I'm absolutely thrilled with my new laptop!
   ➝ Positive: 0.99, Negative: 0.01

5. The food was bland and overpriced, not coming back.
   ➝ Positive: 0.00, Negative: 1.00

6. Best vacation spot ever, can't wait to return!
   ➝ Positive: 0.99, Negative: 0.01

7. The product broke after just two uses, very disappointed.
   ➝ Positive: 0.00, Negative: 1.00

8. Excellent book, well-written and engaging.
   ➝ Positive: 1.00, Negative: 0.00

9. Movie was predictable and boring, nothing special.
   ➝ Positive: 0.00, Negative: 1.00

10. Customer support was extremely helpful and quick to respond.
   ➝ Positive: 0.99, Negative: 0.01

11. I regret purchasing this item, waste of money.
   ➝ Positive: 0.00, Negative: 1.00

12. O

In [39]:
from sklearn.metrics import accuracy_score, classification_report
def evaluate_model_on_all_datasets(model, test_data):
    model.eval()
    all_preds = []
    all_labels = []

    for dataset_name in ["sentiment140", "amazon_polarity", "yelp_polarity"]:
        encodings = tokenizer(test_data[dataset_name]["text"], padding=True, truncation=True, return_tensors="pt", max_length=512)

        test_dataset = torch.utils.data.TensorDataset(encodings.input_ids, encodings.attention_mask, torch.tensor(test_data[dataset_name]["label"]))
        test_loader = DataLoader(test_dataset, batch_size=16)

        with torch.no_grad():
            for batch in test_loader:
                input_ids, attention_mask, labels = batch
                input_ids = input_ids.to(model.device)
                attention_mask = attention_mask.to(model.device)
                labels = labels.to(model.device)

                outputs = model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                preds = torch.argmax(logits, dim=1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        accuracy = accuracy_score(all_labels, all_preds)
        wandb.log({f"{dataset_name}_accuracy": accuracy})

        print(f"Evaluation on {dataset_name}: Accuracy = {accuracy:.4f}")
        print(f"Classification Report for {dataset_name}:")
        print(classification_report(all_labels, all_preds))

evaluate_model_on_all_datasets(adapted_model, test_data)

Evaluation on sentiment140: Accuracy = 0.7952
Classification Report for sentiment140:
              precision    recall  f1-score   support

           0       0.80      0.57      0.66       177
           1       0.80      0.92      0.85       321

    accuracy                           0.80       498
   macro avg       0.80      0.74      0.76       498
weighted avg       0.80      0.80      0.79       498

Evaluation on amazon_polarity: Accuracy = 0.8728
Classification Report for amazon_polarity:
              precision    recall  f1-score   support

           0       0.95      0.79      0.86    200177
           1       0.82      0.96      0.88    200321

    accuracy                           0.87    400498
   macro avg       0.88      0.87      0.87    400498
weighted avg       0.88      0.87      0.87    400498

Evaluation on yelp_polarity: Accuracy = 0.8758
Classification Report for yelp_polarity:
              precision    recall  f1-score   support

           0       0.95  

In [40]:
wandb.finish()

0,1
amazon_polarity_accuracy,▁
amazon_polarity_task_loss,█▁
batch_size,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▃▃▃▃▃▃▄▄▅▁▁▁▂▂▂▂▃▄▄▄▄▄▄▅▅▆▆▆▇▇██▁▂▂▂▂
inner_lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
inner_steps,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█████
meta_loss,█▂▁▂▁▁▁▂▁▂▂▂▂▂▁▂▂▂▁▂▂▂▂▂▁▂▁▁▂▂▂▂▂▂▁▁▂▂▂▁
meta_lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
num_query,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
num_support,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
amazon_polarity_accuracy,0.87276
amazon_polarity_task_loss,0.18609
batch_size,4.0
epoch,20.0
inner_lr,0.0
inner_steps,5.0
meta_loss,20.27886
meta_lr,1e-05
num_query,5.0
num_support,5.0


In [42]:
output_dir = "./NLP_MAML_model"
adapted_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

('./NLP_MAML_model\\tokenizer_config.json',
 './NLP_MAML_model\\special_tokens_map.json',
 './NLP_MAML_model\\vocab.txt',
 './NLP_MAML_model\\added_tokens.json')