<a href="https://colab.research.google.com/github/AbhirKarande/OCRandProductRecognition/blob/main/PrototypicalNetworksForProductClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/AbhirKarande/OCRandProductRecognition.git

Cloning into 'OCRandProductRecognition'...
remote: Enumerating objects: 87, done.[K
remote: Counting objects: 100% (38/38), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 87 (delta 10), reused 0 (delta 0), pack-reused 49[K
Unpacking objects: 100% (87/87), 6.79 MiB | 12.18 MiB/s, done.


In [6]:
!pip install learn2learn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting learn2learn
  Using cached learn2learn-0.1.7.tar.gz (841 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gsutil (from learn2learn)
  Using cached gsutil-5.24.tar.gz (3.0 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting qpth>=0.0.15 (from learn2learn)
  Using cached qpth-0.0.15.tar.gz (11 kB)
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>[0m See above for output.

[1;35mnote[0m: This is an issue with the

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import argparse
import numpy as np


In [12]:
!pip install --quiet pytorch-lightning>=1.4


In [13]:
import pytorch_lightning as pl

In [16]:
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint


In [20]:
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "/data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial16"

In [21]:
from torchvision.datasets import CIFAR100, SVHN

CIFAR_train_set = CIFAR100(root=DATASET_PATH, train=True, download=True, transform=transforms.ToTensor())
CIFAR_test_set = CIFAR100(root=DATASET_PATH, train=False, download=True, transform=transforms.ToTensor())


Files already downloaded and verified
Files already downloaded and verified


In [9]:
def get_convnet(output_size):
    convnet = torchvision.models.DenseNet(growth_rate=32,
                                          block_config=(6, 6, 6, 6),
                                          bn_size=2,
                                          num_init_features=64,
                                          num_classes=output_size  # Output dimensionality
                                         )
    return convnet

In [15]:
class ProtoNet(pl.LightningModule):
  def __init__(self, proto_dim, lr):
    super().__init__()
    self.save_hyperparameters()
    self.model = get_convnet(output_size=self.hparams.proto_dim)
  def configure_optimizers(self):
    optimizer = optim.AdamW(self.parameters(), lr = self.hparams.lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[140,180], gamma = 0.1)
    return [optimizer], [scheduler]

  @staticmethod
  def calculate_prototypes(features, targets):
    classes, _ = torch.unique(targets).sort()
    prototypes = []
    for c in classes:
      p = features[torch.where(targets == c)[0]].mean(dim=0)
      prototypes.append(p)
    prototypes = torch.stack(prototypes, dim = 0)
    return prototypes, classes

  def classify_feats(self, prototypes, classes, feats, targets):
    dist = torch.pow(prototypes[None,:] - feats[:, None], 2).sum(dim=2)
    preds = F.log_softmax(-dist, dim=1)
    labels = (classes.argmax(dim=1) == labels).float().mean()
    acc = (preds.argmax(dim=1) == labels).float().mean()
    return preds, labels, acc

  def calculate_loss(self, batch, mode):
      # Determine training loss for a given support and query set
      imgs, targets = batch
      features = self.model(imgs)  # Encode all images of support and query set
      support_feats, query_feats, support_targets, query_targets = split_batch(features, targets)
      prototypes, classes = ProtoNet.calculate_prototypes(support_feats, support_targets)
      preds, labels, acc = self.classify_feats(prototypes, classes, query_feats, query_targets)
      loss = F.cross_entropy(preds, labels)

      self.log(f"{mode}_loss", loss)
      self.log(f"{mode}_acc", acc)
      return loss

  def training_step(self, batch, batch_idx):
      return self.calculate_loss(batch, mode="train")

  def validation_step(self, batch, batch_idx):
      _ = self.calculate_loss(batch, mode="val")

In [None]:
def train_model(model_class, train_loader, val_loader, **kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, model_class.__name__),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=200,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
                                    LearningRateMonitor("epoch")],
                         enable_progress_bar=False)
    trainer.logger._default_hp_metric = None

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(
        CHECKPOINT_PATH, model_class.__name__ + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        # Automatically loads the model with the saved hyperparameters
        model = model_class.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # To be reproducable
        model = model_class(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = model_class.load_from_checkpoint(
            trainer.checkpoint_callback.best_model_path)  # Load best checkpoint after training

    return model

In [None]:
protonet_model = train_model(ProtoNet,
                             proto_dim=64,
                             lr=2e-4,
                             train_loader=train_data_loader,
                             val_loader=val_data_loader)

In [None]:
@torch.no_grad()
def test_proto_net(model, dataset, data_feats=None, k_shot=4):
    """
    Inputs
        model - Pretrained ProtoNet model
        dataset - The dataset on which the test should be performed.
                  Should be instance of ImageDataset
        data_feats - The encoded features of all images in the dataset.
                     If None, they will be newly calculated, and returned
                     for later usage.
        k_shot - Number of examples per class in the support set.
    """
    model = model.to(device)
    model.eval()
    num_classes = dataset.targets.unique().shape[0]
    exmps_per_class = dataset.targets.shape[0]//num_classes  # We assume uniform example distribution here

    # The encoder network remains unchanged across k-shot settings. Hence, we only need
    # to extract the features for all images once.
    if data_feats is None:
        # Dataset preparation
        dataloader = data.DataLoader(dataset, batch_size=128, num_workers=4, shuffle=False, drop_last=False)

        img_features = []
        img_targets = []
        for imgs, targets in tqdm(dataloader, "Extracting image features", leave=False):
            imgs = imgs.to(device)
            feats = model.model(imgs)
            img_features.append(feats.detach().cpu())
            img_targets.append(targets)
        img_features = torch.cat(img_features, dim=0)
        img_targets = torch.cat(img_targets, dim=0)
        # Sort by classes, so that we obtain tensors of shape [num_classes, exmps_per_class, ...]
        # Makes it easier to process later
        img_targets, sort_idx = img_targets.sort()
        img_targets = img_targets.reshape(num_classes, exmps_per_class).transpose(0, 1)
        img_features = img_features[sort_idx].reshape(num_classes, exmps_per_class, -1).transpose(0, 1)
    else:
        img_features, img_targets = data_feats

    # We iterate through the full dataset in two manners. First, to select the k-shot batch.
    # Second, the evaluate the model on all other examples
    accuracies = []
    for k_idx in tqdm(range(0, img_features.shape[0], k_shot), "Evaluating prototype classification", leave=False):
        # Select support set and calculate prototypes
        k_img_feats, k_targets = img_features[k_idx:k_idx+k_shot].flatten(0,1), img_targets[k_idx:k_idx+k_shot].flatten(0,1)
        prototypes, proto_classes = model.calculate_prototypes(k_img_feats, k_targets)
        # Evaluate accuracy on the rest of the dataset
        batch_acc = 0
        for e_idx in range(0, img_features.shape[0], k_shot):
            if k_idx == e_idx:  # Do not evaluate on the support set examples
                continue
            e_img_feats, e_targets = img_features[e_idx:e_idx+k_shot].flatten(0,1), img_targets[e_idx:e_idx+k_shot].flatten(0,1)
            _, _, acc = model.classify_feats(prototypes, proto_classes, e_img_feats, e_targets)
            batch_acc += acc.item()
        batch_acc /= img_features.shape[0]//k_shot-1
        accuracies.append(batch_acc)

    return (mean(accuracies), stdev(accuracies)), (img_features, img_targets)

In [None]:
protonet_accuracies = dict()
data_feats = None
for k in [2, 4, 8, 16, 32]:
    protonet_accuracies[k], data_feats = test_proto_net(protonet_model, test_set, data_feats=data_feats, k_shot=k)
    print(f"Accuracy for k={k}: {100.0*protonet_accuracies[k][0]:4.2f}% (+-{100*protonet_accuracies[k][1]:4.2f}%)")
