In [8]:
import os
import torch 
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import Dataset,DataLoader 
from collections import defaultdict
import numpy as np
import random
import json 
from tqdm import tqdm
import transformers
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer,BertModel
print("Transformers version:", transformers.__version__)

Transformers version: 4.53.2


# Data Preparation

In [9]:


# Data Preparation
class FewshotDataset(Dataset):
    def __init__(self,data_path,n_way=5, k_shot=5,q_queries=5,max_len=50,tokenizer=None):
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_queries = q_queries
        self.max_len = max_len

        with open(data_path) as f:
            self.raw_data = json.load(f)
        self.classes = list(self.raw_data.keys())
        self.tokenizer = tokenizer if tokenizer else BertTokenizer.from_pretrained('bert-base-uncased')
        self.processed_data = {}
        for cls in self.classes : 
            encoded = self.tokenizer(
                self.raw_data[cls],
                padding='max_length',
                max_length=max_len,
                truncation=True,
                return_tensors='pt'
            )
            self.processed_data[cls] = {
                'input_ids' : encoded['input_ids'],
                'attention_mask' : encoded['attention_mask'],
            }
    
    def __len__(self):
        return len(self.classes)

    def __getitem__(self, idx):
        selected_classes = random.sample(self.classes, self.n_way)
        support_input_ids = []
        support_attention_mask = []
        support_labels = []
        query_input_ids = []
        query_attention_mask = []
        query_labels = []

        for cls_idx, cls in enumerate(selected_classes):
            all_input_ids = self.processed_data[cls]['input_ids']
            all_attention_mask = self.processed_data[cls]['attention_mask']

            # Check if we have enough samples
            num_samples = len(all_input_ids)
            total_needed = self.k_shot + self.q_queries
            
            if num_samples < total_needed:
                # Sample with replacement if not enough samples
                selected = torch.randint(0, num_samples, (total_needed,))
            else:
                # Sample without replacement
                selected = torch.randperm(num_samples)[:total_needed]

            # Support examples 
            support_input_ids.append(all_input_ids[selected[:self.k_shot]])
            support_attention_mask.append(all_attention_mask[selected[:self.k_shot]])
            support_labels.extend([cls_idx] * self.k_shot)

            # Query examples
            query_input_ids.append(all_input_ids[selected[self.k_shot:]])
            query_attention_mask.append(all_attention_mask[selected[self.k_shot:]])
            query_labels.extend([cls_idx] * self.q_queries)
        
        # Stack the tensors
        support_input_ids = torch.cat(support_input_ids)
        support_attention_mask = torch.cat(support_attention_mask)
        support_labels = torch.tensor(support_labels)

        query_input_ids = torch.cat(query_input_ids)
        query_attention_mask = torch.cat(query_attention_mask)
        query_labels = torch.tensor(query_labels)

        return {
            'support_input_ids': support_input_ids,
            'support_attention_mask': support_attention_mask,
            'support_labels': support_labels,
            'query_input_ids': query_input_ids,
            'query_attention_mask': query_attention_mask,
            'query_labels': query_labels
        }

# model architecture ( prototypical network with BERT )

In [10]:
class ProtoNet(nn.Module):
    def __init__(self, bert_model = 'bert-base-uncased', hidden_size = 768):
        super(ProtoNet, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.hidden_size = hidden_size

        #freeze BERT parameters 
        # for param in self.bert.parameters():
        #     param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask = attention_mask)
        return outputs.last_hidden_state[:, 0, :]

    def compute_prototypes(self, support_embeddings, support_labels):
        prototypes = []
        # Convert labels to 1D tensor
        support_labels = support_labels.view(-1)
        # Ensure support_embeddings is 2D (batch_size * n_way * k_shot, hidden_size)
        support_embeddings = support_embeddings.view(-1, support_embeddings.size(-1))
        
        for cls in torch.unique(support_labels):
            mask = (support_labels == cls).nonzero(as_tuple=True)[0]
            class_embeddings = support_embeddings[mask]
            prototype = class_embeddings.mean(dim=0)
            prototypes.append(prototype)
        return torch.stack(prototypes)

    def compute_distances(self, prototypes, query_embeddings):
        n_way = prototypes.shape[0]
        n_queries = query_embeddings.shape[0]
        # Expand prototypes and query_embeddings to compute distances
        prototypes = prototypes.unsqueeze(0).expand(n_queries, -1, -1)
        query_embeddings = query_embeddings.unsqueeze(1).expand(-1, n_way, -1)
        distances = torch.sum((query_embeddings - prototypes) ** 2, dim=-1) 
        return -distances

# Training Process

In [11]:

# Training Process
def train_meta_learner(data_path, epochs = 10, batch_size=1, n_ways=5, k_shot=5, q_queries=5, use_gpu=True):
    dataset = FewshotDataset(data_path=data_path, n_way=n_ways, k_shot=k_shot, q_queries=q_queries)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize model, optimizer
    device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
    print(f"Using device: {device}")
    
    model = ProtoNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    # Add gradient clipping
    max_grad_norm = 1.0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total_correct = 0 
        total_samples = 0 

        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}')):
            support_input_ids = batch['support_input_ids'].to(device)
            support_attention_mask = batch['support_attention_mask'].to(device)
            support_labels = batch['support_labels'].to(device)
            query_input_ids = batch['query_input_ids'].to(device)
            query_attention_mask = batch['query_attention_mask'].to(device)
            query_labels = batch['query_labels'].to(device)

            # Clear GPU cache if using CUDA
            if device.type == 'cuda':
                torch.cuda.empty_cache()

            batch_loss = 0
            batch_correct = 0

            # Process each episode in the batch separately
            for i in range(batch_size):
                # Extract single episode data
                episode_support_input_ids = support_input_ids[i]
                episode_support_attention_mask = support_attention_mask[i]
                episode_support_labels = support_labels[i]
                episode_query_input_ids = query_input_ids[i]
                episode_query_attention_mask = query_attention_mask[i]
                episode_query_labels = query_labels[i]

                # Get support embeddings for this episode
                support_embeddings = model(
                    input_ids=episode_support_input_ids,
                    attention_mask=episode_support_attention_mask
                )

                # Get query embeddings for this episode
                query_embeddings = model(
                    input_ids=episode_query_input_ids,
                    attention_mask=episode_query_attention_mask
                )

                # Compute prototypes for this episode
                prototypes = model.compute_prototypes(
                    support_embeddings,
                    episode_support_labels
                )

                # Compute distances for query  
                logits = model.compute_distances(
                    prototypes,
                    query_embeddings
                )
                
                loss = criterion(logits, episode_query_labels)
                batch_loss += loss 

                # Calculate accuracy 
                pred = torch.argmax(logits, dim=1)
                batch_correct += (pred == episode_query_labels).sum().item() 

            # Average the loss and backpropagate
            loss = batch_loss / batch_size
            optimizer.zero_grad()
            loss.backward()
            
            # Apply gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()

            # Accumulate stats 
            total_loss += loss.item()
            total_correct += batch_correct
            total_samples += batch_size * (n_ways * q_queries)

            # Clear some memory
            if device.type == 'cuda':
                torch.cuda.empty_cache()

        avg_loss = total_loss / len(dataloader)
        avg_accuracy = total_correct / total_samples
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}')
        print(f'Total Samples: {total_samples}, Total Correct: {total_correct}')
        print(f'Total Loss: {total_loss:.4f}')
        print() 

    return model

# real time inference system

In [12]:
def update_support_set(self, new_texts, new_labels):
    # Tokenize
    encoded = self.tokenizer(
        new_texts, padding='max_length',
        max_length=self.max_len, truncation=True,
        return_tensors='pt'
    )

    # Ensure all new labels are added to current_classes
    if self.current_classes is None:
        self.current_classes = []
    for label in new_labels:
        if label not in self.current_classes:
            self.current_classes.append(label)

    # Convert labels to indices
    label_indices = [self.current_classes.index(label) for label in new_labels]
    label_tensor = torch.tensor(label_indices)

    # Update support set
    if self.support_input_ids is None:
        self.support_input_ids = encoded['input_ids'].to(self.device)
        self.support_attention_mask = encoded['attention_mask'].to(self.device)
        self.support_labels = label_tensor.to(self.device)
    else:
        self.support_input_ids = torch.cat([
            self.support_input_ids, encoded['input_ids'].to(self.device)
        ])
        self.support_attention_mask = torch.cat([
            self.support_attention_mask, encoded['attention_mask'].to(self.device)
        ])
        self.support_labels = torch.cat([
            self.support_labels, label_tensor.to(self.device)
        ])

# example usage 

In [13]:
example_data = {
    "greeting": ["hello", "hi there", "good morning", "hey", "howdy", "greetings", "hello there", "hi", "good day", "morning", "afternoon", "evening"],
    "goodbye": ["bye", "see you later", "goodbye", "farewell", "take care", "see ya", "later", "bye bye", "good night", "until next time", "catch you later", "peace out"],
    "question": ["what's up?", "how are you?", "what's new?", "how's it going?", "how are things?", "what's happening?", "how's life?", "what's going on?", "how have you been?", "what's the news?", "how are you doing?", "what's the story?"],
    "purchase": ["I want to buy", "I'd like to purchase", "can I get", "I need to order", "I want to order", "can I buy", "I'd like to get", "I need to purchase", "I want to get", "can I order", "I'd like to buy", "I need to get"],
    "complaint": ["this is broken", "I'm not happy", "this doesn't work", "poor quality", "this is defective", "I'm disappointed", "this is faulty", "not satisfied", "this is damaged", "terrible service", "bad product", "doesn't function"]
}

In [14]:


with open('intent_data.json', 'w') as f:
    json.dump(example_data, f)

In [15]:


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print("Training meta-learner...")


Training meta-learner...


In [16]:

# Try GPU first, fall back to CPU if memory issues persist
try:
    model = train_meta_learner('intent_data.json', epochs=5, batch_size=1, use_gpu=True)
except RuntimeError as e:
    if "out of memory" in str(e):
        print("GPU out of memory, falling back to CPU...")
        model = train_meta_learner('intent_data.json', epochs=5, batch_size=1, use_gpu=False)
    else:
        raise e

print("Training complete!")

Using device: cuda


Epoch 1/5: 100%|██████████| 5/5 [00:17<00:00,  3.50s/it]
Epoch 1/5: 100%|██████████| 5/5 [00:17<00:00,  3.50s/it]


Epoch 1/5, Loss: 3.1480, Accuracy: 0.8160
Total Samples: 125, Total Correct: 102
Total Loss: 15.7402



Epoch 2/5: 100%|██████████| 5/5 [00:17<00:00,  3.49s/it]
Epoch 2/5: 100%|██████████| 5/5 [00:17<00:00,  3.49s/it]


Epoch 2/5, Loss: 0.2951, Accuracy: 0.9600
Total Samples: 125, Total Correct: 120
Total Loss: 1.4753



Epoch 3/5: 100%|██████████| 5/5 [00:15<00:00,  3.09s/it]
Epoch 3/5: 100%|██████████| 5/5 [00:15<00:00,  3.09s/it]


Epoch 3/5, Loss: 0.3175, Accuracy: 0.9760
Total Samples: 125, Total Correct: 122
Total Loss: 1.5876



Epoch 4/5: 100%|██████████| 5/5 [00:15<00:00,  3.05s/it]
Epoch 4/5: 100%|██████████| 5/5 [00:15<00:00,  3.05s/it]


Epoch 4/5, Loss: 0.0258, Accuracy: 0.9920
Total Samples: 125, Total Correct: 124
Total Loss: 0.1289



Epoch 5/5: 100%|██████████| 5/5 [00:17<00:00,  3.57s/it]

Epoch 5/5, Loss: 0.0013, Accuracy: 1.0000
Total Samples: 125, Total Correct: 125
Total Loss: 0.0063

Training complete!





In [19]:


# Initialize real-time classifier
classifier = RealTimeFewShotClassifier(model, tokenizer)  # FIXED: Class name typo

In [20]:
print("=" * 60)
print("REAL-TIME FEW-SHOT CLASSIFIER DEMO")
print("=" * 60)

# 1. Initialize with some basic support examples
print("\n1. Setting up initial support set...")
initial_texts = [
    "hello there", "good morning", "hi",  # greeting
    "goodbye", "see you later", "bye",    # goodbye  
    "how are you?", "what's up?", "how's it going?"  # question
]
initial_labels = ["greeting", "greeting", "greeting", 
                 "goodbye", "goodbye", "goodbye",
                 "question", "question", "question"]

classifier.update_support_set(initial_texts, initial_labels)
print(f"Current classes: {classifier.current_classes}")
print(f"Support set size: {len(classifier.support_labels)} examples")

REAL-TIME FEW-SHOT CLASSIFIER DEMO

1. Setting up initial support set...
Current classes: ['goodbye', 'greeting', 'question']
Support set size: 9 examples


In [22]:
# 2. Test initial predictions
print("\n2. Testing initial predictions...")
test_queries = [
    "hey there friend",
    "see ya later", 
    "what's happening today?",
    "good evening"
]
predictions = classifier.predict(test_queries)
predictions_with_probs, probs = classifier.predict(test_queries, return_probs=True)

for i, (query, pred) in enumerate(zip(test_queries, predictions)):
    confidence = probs[i].max()
    print(f"Query: '{query}' -> Predicted: {pred} (Confidence: {confidence:.3f})")


2. Testing initial predictions...
Query: 'hey there friend' -> Predicted: greeting (Confidence: 1.000)
Query: 'see ya later' -> Predicted: goodbye (Confidence: 1.000)
Query: 'what's happening today?' -> Predicted: question (Confidence: 1.000)
Query: 'good evening' -> Predicted: greeting (Confidence: 1.000)
Query: 'hey there friend' -> Predicted: greeting (Confidence: 1.000)
Query: 'see ya later' -> Predicted: goodbye (Confidence: 1.000)
Query: 'what's happening today?' -> Predicted: question (Confidence: 1.000)
Query: 'good evening' -> Predicted: greeting (Confidence: 1.000)


In [23]:
# 3. Add a completely new class dynamically
print("\n3. Adding new class 'complaint' with examples...")
complaint_examples = [
    "this product is broken",
    "I'm not satisfied with the service", 
    "this doesn't work properly",
    "poor quality item",
    "I want to return this"
]

classifier.add_new_class("complaint", complaint_examples)
print(f"Updated classes: {classifier.current_classes}")
print(f"New support set size: {len(classifier.support_labels)} examples")


3. Adding new class 'complaint' with examples...
Updated classes: ['goodbye', 'greeting', 'question', 'complaint']
New support set size: 14 examples


In [24]:
 #4. Test with the new class
print("\n4. Testing with new class included...")
new_test_queries = [
    "hello everyone",           # should be greeting
    "this is terrible quality", # should be complaint  
    "farewell my friend",       # should be goodbye
    "how have you been?",       # should be question
    "this product is defective" # should be complaint
]

predictions = classifier.predict(new_test_queries)
predictions_with_probs, probs = classifier.predict(new_test_queries, return_probs=True)

for i, (query, pred) in enumerate(zip(new_test_queries, predictions)):
    confidence = probs[i].max()
    prob_dist = {classifier.current_classes[j]: probs[i][j] for j in range(len(classifier.current_classes))}
    print(f"Query: '{query}'")
    print(f"  -> Predicted: {pred} (Confidence: {confidence:.3f})")
    print(f"  -> Full distribution: {prob_dist}")
    print()


4. Testing with new class included...
Query: 'hello everyone'
  -> Predicted: greeting (Confidence: 1.000)
  -> Full distribution: {'goodbye': np.float32(1.14958495e-29), 'greeting': np.float32(1.0), 'question': np.float32(0.0), 'complaint': np.float32(2.587118e-11)}

Query: 'this is terrible quality'
  -> Predicted: complaint (Confidence: 1.000)
  -> Full distribution: {'goodbye': np.float32(0.0), 'greeting': np.float32(3.605713e-24), 'question': np.float32(0.0), 'complaint': np.float32(1.0)}

Query: 'farewell my friend'
  -> Predicted: greeting (Confidence: 1.000)
  -> Full distribution: {'goodbye': np.float32(0.9999968), 'greeting': np.float32(3.2066764e-06), 'question': np.float32(0.0), 'complaint': np.float32(1.8438376e-12)}

Query: 'how have you been?'
  -> Predicted: question (Confidence: 1.000)
  -> Full distribution: {'goodbye': np.float32(0.0), 'greeting': np.float32(0.0), 'question': np.float32(1.0), 'complaint': np.float32(0.0)}

Query: 'this product is defective'
  -> Pre

In [25]:
# 5. Incrementally add more examples to existing classes
print("5. Adding more examples to existing classes...")
more_greetings = ["good afternoon", "howdy partner", "greetings"]
more_questions = ["what's new with you?", "how are things going?"]

classifier.update_support_set(more_greetings, ["greeting"] * len(more_greetings))
classifier.update_support_set(more_questions, ["question"] * len(more_questions))

print(f"Final support set size: {len(classifier.support_labels)} examples")

5. Adding more examples to existing classes...
Final support set size: 19 examples


In [26]:
# 6. Add another new class - purchase intent
print("\n6. Adding 'purchase' class...")
purchase_examples = [
    "I want to buy this item",
    "can I purchase this product", 
    "I'd like to order something",
    "add this to my cart",
    "I need to get this"
]

classifier.add_new_class("purchase", purchase_examples)


6. Adding 'purchase' class...


In [27]:
# 7. Final comprehensive test
print("\n7. Final comprehensive test with all classes...")
final_test_queries = [
    "good day to you",              # greeting
    "until we meet again",          # goodbye  
    "what's your opinion?",         # question
    "this is absolutely horrible",  # complaint
    "I want to buy three of these", # purchase
    "hey how's everything?",        # question/greeting - test ambiguous case
    "this service is amazing",      # might be tricky - not clearly any class
]

predictions = classifier.predict(final_test_queries)
predictions_with_probs, probs = classifier.predict(final_test_queries, return_probs=True)

print(f"Final classes: {classifier.current_classes}")
print(f"Total support examples: {len(classifier.support_labels)}")
print("\nPrediction Results:")
print("-" * 50)

for i, (query, pred) in enumerate(zip(final_test_queries, predictions)):
    confidence = probs[i].max()
    # Get top 2 predictions
    top2_indices = probs[i].argsort()[-2:][::-1]
    top2_classes = [classifier.current_classes[idx] for idx in top2_indices]
    top2_probs = [probs[i][idx] for idx in top2_indices]
    
    print(f"Query: '{query}'")
    print(f"  Top prediction: {pred} ({confidence:.3f})")
    print(f"  Second choice: {top2_classes[1]} ({top2_probs[1]:.3f})")
    
    # Show if prediction is confident or uncertain
    if confidence > 0.7:
        print(f"  Status: HIGH CONFIDENCE ✓")
    elif confidence > 0.5:
        print(f"  Status: MODERATE CONFIDENCE ~") 
    else:
        print(f"  Status: LOW CONFIDENCE ⚠")
    print()


7. Final comprehensive test with all classes...
Final classes: ['goodbye', 'greeting', 'question', 'complaint', 'purchase']
Total support examples: 24

Prediction Results:
--------------------------------------------------
Query: 'good day to you'
  Top prediction: greeting (1.000)
  Second choice: purchase (0.000)
  Status: HIGH CONFIDENCE ✓

Query: 'until we meet again'
  Top prediction: goodbye (1.000)
  Second choice: greeting (0.000)
  Status: HIGH CONFIDENCE ✓

Query: 'what's your opinion?'
  Top prediction: question (1.000)
  Second choice: purchase (0.000)
  Status: HIGH CONFIDENCE ✓

Query: 'this is absolutely horrible'
  Top prediction: complaint (1.000)
  Second choice: purchase (0.000)
  Status: HIGH CONFIDENCE ✓

Query: 'I want to buy three of these'
  Top prediction: purchase (1.000)
  Second choice: greeting (0.000)
  Status: HIGH CONFIDENCE ✓

Query: 'hey how's everything?'
  Top prediction: question (1.000)
  Second choice: greeting (0.000)
  Status: HIGH CONFIDENCE ✓

In [28]:
# 8. Demonstrate error handling
print("8. Testing error handling...")
try:
    # Try to add duplicate class
    classifier.add_new_class("greeting", ["hello again"])
except ValueError as e:
    print(f"Caught expected error: {e}")


8. Testing error handling...
Caught expected error: Class greeting already exists in the support set.


In [29]:
# 9. Show current state summary
print("\n9. Current classifier state summary:")
print("=" * 40)
print(f"Total classes: {len(classifier.current_classes)}")
print(f"Classes: {classifier.current_classes}")
print(f"Total support examples: {len(classifier.support_labels)}")

# Count examples per class
from collections import Counter
label_counts = Counter(classifier.support_labels.cpu().numpy())
print("\nExamples per class:")
for class_idx, count in label_counts.items():
    class_name = classifier.current_classes[class_idx]
    print(f"  {class_name}: {count} examples")

print("\n" + "=" * 60)
print("DEMO COMPLETE - Classifier ready for real-time use!")
print("=" * 60)



9. Current classifier state summary:
Total classes: 5
Classes: ['goodbye', 'greeting', 'question', 'complaint', 'purchase']
Total support examples: 24

Examples per class:
  greeting: 6 examples
  goodbye: 3 examples
  question: 5 examples
  complaint: 5 examples
  purchase: 5 examples

DEMO COMPLETE - Classifier ready for real-time use!


In [30]:
# 10. Interactive prediction function (example)
def interactive_predict(text_input):
    """Helper function for real-time single predictions"""
    prediction = classifier.predict([text_input])
    pred_with_probs, probs = classifier.predict([text_input], return_probs=True)
    
    result = {
        'input': text_input,
        'prediction': prediction[0],
        'confidence': probs[0].max(),
        'all_probabilities': {classifier.current_classes[i]: probs[0][i] for i in range(len(classifier.current_classes))}
    }
    return result

# Example of interactive usage
print("\n10. Interactive prediction examples:")
interactive_examples = [
    "I need help with something",
    "thanks for your help, goodbye",
    "this is not working at all"
]

for example in interactive_examples:
    result = interactive_predict(example)
    print(f"Input: '{result['input']}'")
    print(f"Prediction: {result['prediction']} (confidence: {result['confidence']:.3f})")
    print(f"All probabilities: {result['all_probabilities']}")
    print()


10. Interactive prediction examples:
Input: 'I need help with something'
Prediction: purchase (confidence: 0.989)
All probabilities: {'goodbye': np.float32(0.0), 'greeting': np.float32(2.230397e-09), 'question': np.float32(0.0), 'complaint': np.float32(0.010685488), 'purchase': np.float32(0.98931444)}

Input: 'I need help with something'
Prediction: purchase (confidence: 0.989)
All probabilities: {'goodbye': np.float32(0.0), 'greeting': np.float32(2.230397e-09), 'question': np.float32(0.0), 'complaint': np.float32(0.010685488), 'purchase': np.float32(0.98931444)}

Input: 'thanks for your help, goodbye'
Prediction: goodbye (confidence: 1.000)
All probabilities: {'goodbye': np.float32(1.0), 'greeting': np.float32(6.7234066e-13), 'question': np.float32(0.0), 'complaint': np.float32(3.8552064e-26), 'purchase': np.float32(4.8823356e-17)}

Input: 'thanks for your help, goodbye'
Prediction: goodbye (confidence: 1.000)
All probabilities: {'goodbye': np.float32(1.0), 'greeting': np.float32(6.7