In [7]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


## The followings are checkpoints for the models

Faster R-CNN: https://drive.google.com/file/d/1WAbR50Eows_lDsevBR8aUV9-RTRKFZ2O/view?usp=sharing

Prototypical Network with ResNet Backbone:https://drive.google.com/file/d/1--ZwmUJAvYlCForsYBtLG-Nbi-kelldr/view?usp=sharing

Matching Network with ResNet Backbone: https://drive.google.com/file/d/1-6Lxza8lPvnURCJrdG-8MGEEtSeptJeX/view?usp=sharing


In [None]:
!pip install easyfsl
!pip install gradio
!pip install git+https://github.com/facebookresearch/detectron2.git

In [2]:
import os 
import pandas as pd 
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from skimage import io
import torch.nn.functional as F
import torch.nn as nn
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
from torch.nn import Module
from easyfsl.utils import plot_images, sliding_average
import torch.optim 
import random
import numpy as np
import time
from tqdm import tqdm
from time import sleep


import glob
import sys
from math import tan, pi
import torchvision
import cv2



import json

import matplotlib.pyplot as plt

from detectron2.structures import BoxMode
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader
from detectron2.evaluation import COCOEvaluator, inference_on_dataset


from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode, Visualizer

import shutil
from PIL import Image
from tkinter import Tk
from tkinter.filedialog import askdirectory
import csv
from openpyxl import Workbook

#TODO HOOKS
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm



In [21]:
class SignatureDataset(Dataset):
    def __init__(self, label_file, root_dir, transform = None):
        '''Arguments:
        label_file: path to csv file which contains 2 columns:
        one with name of the person, other with the label
        root_dir: path file of images
        transform: transformations that will be applied (default: None)'''
        self.df = pd.read_csv(label_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        '''lenght of the file'''
        return len(self.df)
    def __getitem__(self, index):
        '''getting image based on index (protocol in MapDatasets)'''
        images = []
        labels = []

        for ind in index: #since our index coming from 
            img_path = os.path.join(self.root_dir, str(self.df.iloc[ind, 0]))
            image_i = io.imread(img_path)
            label_i = torch.tensor(int(self.df.iloc[ind, 1]))

            if self.transform:
                '''transformations to be done to image'''
                image_i = self.transform(image_i)
            images.append(image_i)
            labels.append(label_i)

        #added squeeze to test it
        label = torch.stack(labels)
        images = torch.stack(images)
        return (images, label)

In [22]:
class NShotTaskSampler(Sampler):
    def __init__(self,
                 dataset: torch.utils.data.Dataset,
                 episodes_per_epoch: int = None,
                 n: int = None,
                 k: int = None,
                 q: int = None,
                 num_tasks: int = 1
                #  fixed_tasks: List[Iterable[int]] = None
    ):
        """PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks.

        Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets
        of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k
        samples are from the support set while the remaining q * k samples are from the query set.

        The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples.

        # Arguments
            dataset: Instance of torch.utils.data.Dataset from which to draw samples
            episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch
            n_shot: int. Number of samples for each class in the n-shot classification tasks.
            k_way: int. Number of classes in the n-shot classification tasks.
            q_queries: int. Number query samples for each class in the n-shot classification tasks.
            num_tasks: Number of n-shot tasks to group into a single batch
            fixed_tasks: If this argument is specified this Sampler will always generate tasks from
                the specified classes
        """
        super(NShotTaskSampler, self).__init__(dataset)
        self.episodes_per_epoch = episodes_per_epoch
        self.dataset = dataset
        if num_tasks < 1:
            raise ValueError('num_tasks must be > 1.')

        self.num_tasks = num_tasks
        # TODO: Raise errors if initialise badly
        self.k = k
        self.n = n
        self.q = q
        # self.fixed_tasks = fixed_tasks

        self.i_task = 0

    def __len__(self):
        return self.episodes_per_epoch

    def __iter__(self):
        for _ in range(self.episodes_per_epoch):
            batch = []

            for task in range(self.num_tasks):
                # if self.fixed_tasks is None:
                #     # Get random classes
                #     episode_classes = np.random.choice(self.dataset.df['class_id'].unique(), size=self.k, replace=False)
                # else:
                    # Loop through classes in fixed_tasks
                # print(self.dataset.df['class_id'].unique())
                episode_classes = np.random.choice(self.dataset.df['class_id'].unique(), size=self.k, replace=False)
                # print(episode_classes, "EPISODE CLASSES HERE ")
                df = self.dataset.df[self.dataset.df['class_id'].isin(episode_classes)]

                support_k = {k: None for k in episode_classes}
                for k in episode_classes:
                    # Select support examples
                    support = df[df['class_id'] == k].sample(self.n)
                    # print(support, "HERE SUPPORT SET")
                    support_k[k] = support

                    for i, s in support.iterrows():
                        # print(s['id'], "INSIDE OF SAMPLER")
                        batch.append(s['id'])

                # for k in episode_classes:
                #     query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q)
                #     for i, q in query.iterrows():
                #         # print(q['id'], "INSIDE OF SAMPLER")
                #         batch.append(q['id'])
                query = df[~df['id'].isin(batch)].sample(self.q)
                for i, q in query.iterrows():
                    batch.append(q['id'])

            yield np.stack(batch)

## Please set device

In [9]:
device = torch.device('cpu')

In [4]:
class ConvNet(nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(ConvNet, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.main = nn.Sequential(
            nn.Conv2d(self.input_channels, self.hidden_channels, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(self.hidden_channels),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(self.hidden_channels),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(self.hidden_channels),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(self.hidden_channels),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(self.hidden_channels),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            # nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten()
        )

    def forward(self, x):
        return self.main(x)

      # return output
class resnet(nn.Module):
    def __init__(self):
        super(resnet, self).__init__()
        self.resnet_18 = resnet18(pretrained=False)
        self.resnet_18.fc = nn.Flatten()
    def forward(self, x):
        #these steps are required since our images have 1 channel, and resnet-18 takes 3 channel images as input
        if x.shape[1] == 3:
          return self.resnet_18(x)
        elif x.shape[1] == 1:
          x = torch.cat((x, x, x), dim = 1)
          return self.resnet_18(x)
        else:
          raise ValueError('shape[1] is not 1 or 3 or it is not even channel dimension')

In [5]:
class Proto(nn.Module):
    def __init__(self, hidden_channels, input_channels
                 #, n_shot, k_way, q
                 ):
        super(Proto, self).__init__()
        # self.n_shot = n_shot
        # self.k_way = k_way
        # self.q = q
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        # self.backbone = ConvNet(self.input_channels, self.hidden_channels)
        self.backbone = resnet()
    def forward(self, x, n_shot, k_way, q):
        #here y_support is already made label:
        x = x.squeeze(0).to(device)
        embedded_x = self.backbone(x)
        # print(embedded_x.shape)
        support_set = embedded_x[:k_way*n_shot] #shape of n*k, embedding dim 
        q_set = embedded_x[k_way*n_shot:] #shape of q*k, embedding dim
        mean_support = torch.cat([torch.mean(support_set[i*n_shot:(i+1)*n_shot], dim = 0).unsqueeze(0) for i in range(k_way)]) 

        l2_distance = torch.cdist(q_set, mean_support) #euclidian distance

        return -l2_distance

## Please change n_shot, k_way, q variables as you wish

In [26]:
n_shot = 1
k_way = 5
q = 10

train_set = SignatureDataset('/content/gdrive/My Drive/train_labels.csv', '/content/gdrive/My Drive/full_org', T.Compose([T.ToTensor(), T.Resize((500, 500))]))
val_set = SignatureDataset('/content/gdrive/My Drive/val.csv','/content/gdrive/My Drive/full_org_val_1', T.Compose([T.ToTensor(), T.Resize((500, 500))]))
test_set = SignatureDataset('/content/gdrive/My Drive/val_labels.csv','/content/gdrive/My Drive/full_org_test', T.Compose([T.ToTensor(), T.Resize((500, 500))]))


train_sampler = NShotTaskSampler(dataset = train_set, episodes_per_epoch = 100, n = n_shot, k = k_way, q=q)
val_sampler = NShotTaskSampler(dataset = val_set, episodes_per_epoch = 100, n = n_shot, k = k_way, q=q)
test_sampler = NShotTaskSampler(dataset = test_set, episodes_per_epoch = 100, n = n_shot, k = k_way, q=q)


train_dataloader = DataLoader(train_set, sampler = train_sampler) #validation dataloader
val_dataloader = DataLoader(val_set, sampler=val_sampler)
test_dataloader = DataLoader(test_set, sampler = test_sampler)


In [19]:
def not_sorted_label(y:torch.tensor):
  if len(y.shape) == 2:
    y = y.squeeze(0).to(device)
  y_shape = y.shape[0]
  # y_clone = y.detach().clone().to(device)
  dc = {}
  label = 0
  for i in range(y_shape):
    # print(y[i], "First time")
    if y[i].item() in dc.keys():
      y[i] = dc[y[i].item()]
    else:
      # print({y[i]:label})
      dc.update({y[i].item(): label})
      y[i] = label
      label += 1
    # print(dc)
    # print(y[i], "Second time")
  return y

In [24]:


def evaluate_one_episode(x, y, n_shot, q, model):
    # print(y.shape, "here shape of y")
    k_way = (y.shape[0] - q)//n_shot
    query_labels = y[n_shot*k_way:]
    return (torch.max(
            model(x, n_shot, k_way, q).detach().data,1,)[1]
        == query_labels.to(device)
    ).sum().item(), len(query_labels)

def evaluate(data_loader, model):
    total_predictions = 0
    correct_predictions = 0

    model.eval()
    
    for episode_index, (x, y) in tqdm(enumerate(data_loader), total=len(data_loader)):
        y = y.squeeze(0).to(device)
        
        y = not_sorted_label(y)
        correct, total = evaluate_one_episode(
            x, y, n_shot, q, model
        )

        total_predictions += total
        correct_predictions += correct

    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
    )
# evaluate(val_dataloader, model)

In [None]:
# model = Proto(32, 1
#               # , n_shot, k_way, q
#               ).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def fit(x, y, n_shot, q
) -> float:
    
    optimizer.zero_grad()
    k_way = (y.shape[0] - q)//n_shot
    y = y[n_shot*k_way:]
    scores = model(x, n_shot, k_way, q)
    loss = criterion(scores, y.to(device))
    loss.backward()
    optimizer.step()

    return loss.item()


In [None]:
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()
import gc
# del variables
gc.collect()


log_update_frequency = 10
epochs = 25

all_loss = []
for epoch in range(epochs):
  model.train()
  with tqdm(enumerate(train_dataloader), total=len(train_dataloader)) as tqdm_train:
      for episode_index, (x, y) in tqdm_train:
          y = y.squeeze(0)
          y = not_sorted_label(y)
          # y = y[n_shot*k_way:]
          # print(y)

          loss_value = fit(x, y, n_shot, q)
          all_loss.append(loss_value)

          if episode_index % log_update_frequency == 0:
              tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))

In [None]:
evaluate(val_dataloader)

100%|██████████| 100/100 [00:23<00:00,  4.25it/s]

Model tested on 100 tasks. Accuracy: 98.20%





In [None]:
torch.save(model.state_dict(), '/content/gdrive/My Drive/Proto_ResNet.pth')

In [10]:
model = Proto(32, 1
              # , n_shot, k_way, q
              ).to(device)
model.load_state_dict(torch.load('/content/gdrive/My Drive/Proto_ResNet.pth', map_location=device))
# evaluate(val_dataloader, model2)

<All keys matched successfully>

In [27]:
evaluate(val_dataloader, model)

100%|██████████| 100/100 [05:26<00:00,  3.27s/it]

Model tested on 100 tasks. Accuracy: 96.60%





In [None]:
# TODO NOT FREEZE
def run():
    torch.multiprocessing.freeze_support()
    print('loop')

# TODO GETTING DATA
def get_data_dicts(directory, classes):
    dataset_dicts = []
    for filename in [file for file in os.listdir(directory) if file.endswith('.json')]:
        json_file = os.path.join(directory, filename)
        with open(json_file) as f:
            img_anns = json.load(f)

        record = {}

        filename = os.path.join(directory, img_anns["imagePath"])

        record["file_name"] = filename
        record["height"] = 800
        record["width"] = 800

        annos = img_anns["shapes"]
        objs = []
        for anno in annos:
            px = [a[0] for a in anno['points']]  # x coord
            py = [a[1] for a in anno['points']]  # y-coord
            poly = [(x, y) for x, y in zip(px, py)]  # poly for segmentation
            poly = [p for x in poly for p in x]

            obj = {
                "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [poly],
                "category_id": classes.index(anno['label']),
                "iscrowd": 0
            }
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)
    return dataset_dicts

if __name__ == '__main__':
    run()


# # TODO CLASS AND SHIT
classes = ['signature', 'date']

data_path = 'C:/Users/kamra/Documents/My Documents/2022 Fall/Deep Learning/Project/ds_reformatted1/'

for d in ["train", "test"]:
    DatasetCatalog.register(
        "category_" + d,
        lambda d=d: get_data_dicts(data_path+d, classes)
    )
    MetadataCatalog.get("category_" + d).set(thing_classes=classes)

microcontroller_metadata = MetadataCatalog.get("category_train")


# # TODO TRAIN OPTIONS
cfg = get_cfg()
cfg.MODEL.DEVICE = "cpu"
# #cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
# #cfg.MODEL.RESNETS.DEPTH = 34
# cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"))
# cfg.DATASETS.TRAIN = ("category_train",)
# cfg.DATASETS.TEST = ()
# cfg.TEST.EVAL_PERIOD = 15
# cfg.DATALOADER.NUM_WORKERS = 0
# cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml")
# cfg.SOLVER.IMS_PER_BATCH = 8
# cfg.SOLVER.BASE_LR = 0.00025
# cfg.SOLVER.MAX_ITER = 250
# cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2


# TODO TEST OPTIONS
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "/content/gdrive/My Drive/model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6

predictor = DefaultPredictor(cfg)

# TODO TEST ONLY

import json
import matplotlib.pyplot as plt

def out(imageName):
    im = np.array(imageName)
    im2 = imageName
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1],
                    metadata=microcontroller_metadata)
    #
    pred_classes = outputs['instances'].pred_classes.cpu().tolist()
    class_names = MetadataCatalog.get("category_train").thing_classes
    pred_class_names = list(map(lambda x: class_names[x], pred_classes))
    #Save detected image
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    #v.save(imageName + "_test.png")
    # Cropping
    i = 0
    for pred in range(len(pred_class_names)):
        if pred_class_names[pred] == "signature":
            break
        else:
            i += 1
    boxes = list(outputs["instances"].pred_boxes)
    box = boxes[i]
    box = box.detach().cpu().numpy()
    x_top_left = box[0]
    y_top_left = box[1]
    x_bottom_right = box[2]
    y_bottom_right = box[3]
    crop_img = im2.crop((int(x_top_left), int(y_top_left), int(x_bottom_right), int(y_bottom_right)))
    # print(int(x_top_left),int(x_bottom_right), int(y_bottom_right), int(y_top_left), "Results for this")
    # print(im2.shape)
    # crop_img = im2[int(x_top_left):int(x_bottom_right), int(y_top_left):int(y_bottom_right)]
    #crop_img.save(imageName + "_cropped.png")
    return v, crop_img

In [None]:
import gradio as gr

def detection(image):
  v, cropped_image = out(image)
  v = Image.fromarray(np.uint8(v.get_image())).convert('RGB')
  return v, cropped_image


demo_detection = gr.Interface(
    fn=detection,
    inputs=[gr.Image(type="pil")],
    outputs =[gr.Image(), gr.Image()]
)
demo_detection.launch(debug= True)

In [None]:
import gradio as gr
import requests

signatures = [] #signatures to be stores 
names = [] #carrier of the signatures
def label_to_name(names):
  return {name:i for i, name in enumerate(names)}

transform = T.Compose([T.ToTensor(), T.Grayscale(), T.Resize((200, 200))])
def registering_new_signature(name,image):
    image = transform(image)
    signatures.append(image)
    names.append(name)
    return str(image.shape[0]) + "," + str(image.shape[1]) + "," + str(image.shape[2])

def classify_signature(image):

    v, cropped_image = out(image)
    v = Image.fromarray(np.uint8(v.get_image())).convert('RGB')
    image = transform(cropped_image)
    x = torch.cat((torch.stack(signatures), image.unsqueeze(0)))

    probs = torch.softmax(model(x, 1, len(signatures), 1), dim = 1).squeeze(0)


    return {str(names[i]):float(probs[i]) for i in range(len(signatures))}, v
    
demo_register = gr.Interface(
    fn=registering_new_signature,
    inputs=['text', gr.Image()],
    outputs = "text"
)

demo_classifier = gr.Interface(
    fn = classify_signature,
    inputs = gr.Image(type = 'pil'),
    outputs=[gr.Label(), gr.Image()]
)

demo = gr.TabbedInterface([demo_register, demo_classifier], ["Registering", "Detection & Classification"])
print(names)
demo.launch(debug= True)

In [None]:
evaluate(val_dataloader)

100%|██████████| 100/100 [01:04<00:00,  1.55it/s]

Model tested on 100 tasks. Accuracy: 86.00%





In [None]:
class AttLSTM(nn.Module):
    def __init__(self, K, input_size):
        '''attention LSTM with skip connections
        Arguments
        ---------
        K: number of procesing steps
        input_size: size of input
        hidden_size: size of hidden features'''
        super(AttLSTM, self).__init__()
        self.processing = K #number of times to run lstm cells ( number of lstm cells basically)
        self.input_size = input_size
        # self.hidden_size = hidden_size
        self.lstm = nn.LSTMCell(self.input_size, self.input_size)
        self.softmax = nn.Softmax(dim = 1)
    def forward(self, f_x, g_S):
        h = f_x
        #g_S: shape (n*k, feature_size)
        #f_X: shape: (n*q, feature_size)
        c = torch.zeros(f_x.shape[0], f_x.shape[1]).to(device).float() #putting our tensor to device
        for _ in range(self.processing):
            product = torch.matmul(h, g_S.T)#product shape (nk x nq) -- > (nq x nk)
            a = self.softmax(product) #attention
            r = torch.matmul(a, g_S) #summation over sequential data
            concat = h + r #concatination
            h, c  = self.lstm(f_x, (concat, c)) #output of LSTM
            h = h + f_x #skip connection
        return h

In [None]:
class BidirectionalLSTM(nn.Module):
    def __init__(self, size: int, layers: int):
        """Bidirectional LSTM used to generate fully conditional embeddings (FCE) of the support set as described
        in the Matching Networks paper.

        # Arguments
            size: Size of input and hidden layers. These are constrained to be the same in order to implement the skip
                connection described in Appendix A.2
            layers: Number of LSTM layers
        """
        super(BidirectionalLSTM, self).__init__()
        self.num_layers = layers
        self.batch_size = 1
        # Force input size and hidden size to be the same in order to implement
        # the skip connection as described in Appendix A.1 and A.2 of Matching Networks
        self.lstm = nn.LSTM(input_size=size,
                            num_layers=layers,
                            hidden_size=size,
                            bidirectional=True)
    def forward(self, inputs):
        # Give None as initial state and Pytorch LSTM creates initial hidden states
        output, (hn, cn) = self.lstm(inputs, None)

        forward_output = output[:, :, :self.lstm.hidden_size]
        backward_output = output[:, :, self.lstm.hidden_size:]

        # g(x_i, S) = h_forward_i + h_backward_i + g'(x_i) as written in Appendix A.2
        # AKA A skip connection between inputs and outputs is used
        output = forward_output + backward_output + inputs
        return output, hn, cn


In [None]:
class Matching(nn.Module):
    def __init__(self, hidden_channels, K, layers, input_channels, fce = False, ):
        super(Matching, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.K = K
        self.layers = layers
        self.attLSTM = AttLSTM(self.K, input_size = 512) #hidden size and input size should be same
        self.biLSTM = BidirectionalLSTM(512, self.layers)
        self.encoder = resnet() #hidden size and input size should be same
        self.fce = fce
    def forward(self, x, n_shot, k_way, q):
        if len(x.shape) ==5: #to make sure that the size of x is 4 dimensional (B, C, H, W)
            x = x.squeeze(0).to(device)
        embedding_x = self.encoder(x) #embedding of x: (n*k+k*q, feature_dim)
        support = embedding_x[:n_shot*k_way]
        # print(embedding_x.shape, "Shape of support ")
        query = embedding_x[n_shot*k_way:]
        if self.fce:
            #in order to feed support set to biLSTM we need to make it 5 dimensional: (L, B, C, H, W)
            support = support.unsqueeze(1).to(device)
            
            g_S, _, _= self.biLSTM(support)
            g_S = g_S.squeeze(1)
            f_x = self.attLSTM(query, g_S)
            distance = torch.cdist(f_x, g_S)
            return -distance
        else:
            return -torch.cdist(query, support) #dimension in this case is 


In [None]:
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()
import gc
# del variables
gc.collect()




criterion = nn.NLLLoss()
model = Matching(hidden_channels = 16, K  = 2, layers = 1, input_channels = 1, fce = True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def fit_Matching(x, y, n_shot, k_way, q
) -> float:
    if len(y.shape) == 2:
      y = y.squeeze(1)
    y_support = y[:n_shot*k_way]
    y = y[n_shot*k_way:]
    ohe_y_support = F.one_hot(y_support).to(device).float()
    optimizer.zero_grad()
    scores = model(x, n_shot, k_way, q)
    attention = torch.softmax(scores, dim = 1)
    # print(attention.shape, "shape of attention")
    # print(ohe_y_support.shape, "shape of ohe_y") #dimension of attention is QxS (Q = k*q), (S = k*n)
    prediction = torch.matmul(attention, ohe_y_support)
    # print(prediction.shape, "prediction shape")
    loss = criterion(prediction.log(), y.cuda())
    loss.backward()
    optimizer.step()

    return loss.item()



In [None]:
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()
import gc
# del variables
gc.collect()

log_update_frequency = 10
epochs = 25

all_loss = []
for epoch in range(epochs):
  model.train()
  with tqdm(enumerate(train_dataloader), total=len(train_dataloader)) as tqdm_train:
      for episode_index, (x, y) in tqdm_train:
          y = y.squeeze(0)
          # print(y.shape, "shape of y to check whether squeeze is true or not")
          # print(y.shape, x.shape, "X and Y shape")
          # y = 
          # for i in range(k_way):
          #     y[i*n_shot:(i+1)*n_shot] = i
          #     y[n_shot*k_way+i*q:n_shot*k_way+(i+1)*q]=i
          # # print(y)
          y = not_sorted_label(y)
          loss_value = fit_Matching(x, y, n_shot, k_way, q)
          all_loss.append(loss_value)

          if episode_index % log_update_frequency == 0:
              tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))
  

100%|██████████| 100/100 [00:37<00:00,  2.69it/s, loss=0.61]
100%|██████████| 100/100 [00:37<00:00,  2.67it/s, loss=0.278]
100%|██████████| 100/100 [00:37<00:00,  2.67it/s, loss=0.207]
100%|██████████| 100/100 [00:37<00:00,  2.67it/s, loss=0.309]
100%|██████████| 100/100 [00:37<00:00,  2.66it/s, loss=0.168]
100%|██████████| 100/100 [00:37<00:00,  2.67it/s, loss=0.267]
100%|██████████| 100/100 [00:37<00:00,  2.69it/s, loss=0.119]
100%|██████████| 100/100 [00:37<00:00,  2.67it/s, loss=0.272]
100%|██████████| 100/100 [00:37<00:00,  2.69it/s, loss=0.251]
100%|██████████| 100/100 [00:37<00:00,  2.69it/s, loss=0.202]
100%|██████████| 100/100 [00:37<00:00,  2.67it/s, loss=0.175]
100%|██████████| 100/100 [00:37<00:00,  2.69it/s, loss=0.0438]
100%|██████████| 100/100 [00:37<00:00,  2.67it/s, loss=0.142]
100%|██████████| 100/100 [00:37<00:00,  2.68it/s, loss=0.0464]
100%|██████████| 100/100 [00:37<00:00,  2.69it/s, loss=0.0383]
100%|██████████| 100/100 [00:37<00:00,  2.66it/s, loss=0.0632]
100%|

In [None]:
evaluate(val_dataloader, model)

100%|██████████| 100/100 [00:24<00:00,  4.11it/s]

Model tested on 100 tasks. Accuracy: 94.90%





In [None]:
torch.save(model.state_dict(), '/content/gdrive/My Drive/Matching_ResNet.pth')

In [None]:
model2 = Matching(hidden_channels = 16, K  = 2, layers = 1, input_channels = 1, fce = True).to(device)
model2.load_state_dict(torch.load('/content/gdrive/My Drive/Matching_ResNet.pth'))
model2.fce = False
evaluate(val_dataloader, model2)

100%|██████████| 100/100 [00:24<00:00,  4.11it/s]

Model tested on 100 tasks. Accuracy: 95.20%





In [None]:
model2.fce = False
evaluate(val_dataloader, model2)

100%|██████████| 100/100 [00:24<00:00,  4.09it/s]

Model tested on 100 tasks. Accuracy: 95.90%



