<a href="https://colab.research.google.com/github/AbhirKarande/OCRandProductRecognition/blob/main/FewShotTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/AbhirKarande/OCRandProductRecognition.git

Cloning into 'OCRandProductRecognition'...
remote: Enumerating objects: 43, done.[K
remote: Counting objects: 100% (43/43), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 43 (delta 8), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (43/43), 5.77 MiB | 14.26 MiB/s, done.


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from PIL import Image

In [6]:
class PrototypicalNetwork(nn.Module):
    def __init__(self, num_classes, feature_dim):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = nn.Linear(feature_dim, num_classes)
        
    def forward(self, support, query):
        prototypes = torch.mean(support, dim=1)
        logits = self.encoder(query)
        return logits, prototypes

In [3]:
!pip install keras

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [20]:
train_dir = '/content/OCRandProductRecognition/WholeFoodsTrainingImages/train'
batch_size = 2
image_size = (224, 224)
num_classes = 4
label_names = ["DavesKillerBread", "GothamGreensBasil", "LabellePatrimoineHeritageEggs", "WholeFoodsMarketBrownButterChocolateChunk"]
num_epochs = 10
num_classes = len(label_names)
feature_dim = 512
num_support = 2
num_query = 2
learning_rate = 0.001

In [21]:
# Define the few-shot learning model
class PrototypicalNetwork(nn.Module):
    def __init__(self, num_classes, feature_dim):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = nn.Linear(feature_dim, num_classes)
        
    def forward(self, support, query):
        prototypes = torch.mean(support, dim=1)
        logits = self.encoder(query)
        return logits, prototypes

In [22]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [24]:
# Load the pre-trained CNN model
cnn = resnet18(pretrained=True)
cnn = nn.Sequential(*list(cnn.children())[:-1])  # Remove the last classification layer
cnn.to(device)

# Create the few-shot learning model
model = PrototypicalNetwork(num_classes, feature_dim)
model.to(device)

# Define the optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [27]:

import os
for epoch in range(num_epochs):
    # Iterate over each class
    for class_idx, label_name in enumerate(label_names):
        class_dir = os.path.join(train_dir, label_name)
        image_files = os.listdir(class_dir)
        support_images = []
        query_images = []
        
        # Load support and query images for the current class
        for i, image_file in enumerate(image_files):
            image_path = os.path.join(class_dir, image_file)
            image = Image.open(image_path).convert("RGB")
            image = transform(image).unsqueeze(0)
            if i < num_support:
                support_images.append(image)
            else:
                query_images.append(image)
        
        # Prepare support and query tensors
        support_images = torch.cat(support_images, dim=0).to(device)
        query_images = torch.cat(query_images, dim=0).to(device)
        
        # Extract features using the pre-trained CNN
        support_features = cnn(Variable(support_images)).squeeze()
        query_features = cnn(Variable(query_images)).squeeze()
        
        # Forward pass
        logits, prototypes = model(support_features, query_features)
        
        # Compute loss
        num_query_samples = query_features.size(0)  # Get the actual number of query samples
        labels = torch.tensor([class_idx] * num_query_samples).to(device)  # Adjust the labels tensor size
        loss = criterion(logits, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print loss
        print(f"Epoch [{epoch+1}/{num_epochs}], Class {label_name}, Loss: {loss.item():.4f}")


Epoch [1/10], Class DavesKillerBread, Loss: 1.3567
Epoch [1/10], Class GothamGreensBasil, Loss: 1.5450
Epoch [1/10], Class LabellePatrimoineHeritageEggs, Loss: 2.1043
Epoch [1/10], Class WholeFoodsMarketBrownButterChocolateChunk, Loss: 3.0192
Epoch [2/10], Class DavesKillerBread, Loss: 0.8479
Epoch [2/10], Class GothamGreensBasil, Loss: 1.2347
Epoch [2/10], Class LabellePatrimoineHeritageEggs, Loss: 1.6904
Epoch [2/10], Class WholeFoodsMarketBrownButterChocolateChunk, Loss: 2.3781
Epoch [3/10], Class DavesKillerBread, Loss: 1.2057
Epoch [3/10], Class GothamGreensBasil, Loss: 1.2237
Epoch [3/10], Class LabellePatrimoineHeritageEggs, Loss: 1.3709
Epoch [3/10], Class WholeFoodsMarketBrownButterChocolateChunk, Loss: 1.7780
Epoch [4/10], Class DavesKillerBread, Loss: 1.5638
Epoch [4/10], Class GothamGreensBasil, Loss: 1.3716
Epoch [4/10], Class LabellePatrimoineHeritageEggs, Loss: 1.2718
Epoch [4/10], Class WholeFoodsMarketBrownButterChocolateChunk, Loss: 1.3902
Epoch [5/10], Class DavesKil

In [28]:
import torch

# Save the entire model
torch.save(model.state_dict(), 'few_shot_model.pth')


In [None]:
# Once the model is trained, you can use it for inference on new, unseen images
# Extract features from the new image using the pre-trained CNN
new_image = Image.open("path_to_new_image.jpg").convert("RGB")
new_image = transform(new_image).unsqueeze(0).to(device)
new_features = cnn(Variable(new_image)).squeeze()

In [None]:
# Calculate distances to class prototypes
with torch.no_grad():
    new_features = new_features.unsqueeze(0)
    _, prototypes = model(None, new_features)
    distances = torch.cdist(new_features, prototypes)

In [None]:
# Get the predicted class label
predicted_class = torch.argmin(distances).item()
predicted_label = label_names[predicted_class]

print("Predicted Label:", predicted_label)