<a href="https://colab.research.google.com/github/Ian-Alexis/utils_robotframework/blob/main/Torch_Matching_Net_Omniglot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# One Shot learning on Omniglot Dataset

Our goal as part of this notebook is to solve the [one-shot learning problem](https://en.wikipedia.org/wiki/One-shot_learning) for the [omniglot dataset](https://github.com/brendenlake/omniglot). We aim to learn a model that given a set of images, its associated labels and a target image, we should be able to correctly assign the label of the target image from the provided labels.
##Example of a one shot classification task
![One shot Task](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAcoAAAHQCAYAAAAh0SohAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz%0AAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo%0AdHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3X10VdWd//HPJRgkBEjAIWqVhwaC%0AU9eMhVHkwWmVRKESS2nLLB9gfJgqq7SLsQGBYVntau3wMATraNXCaActTu3glCoU6AS0TnloqXWc%0AjlQgDMR2CShPSXgQCLm/P/o7x5PmZif33nPPPufc92st1to5yT33ezc355v9vXvvk0gmk0kBAICU%0AutkOAACAMCNRAgBgQKIEAMCARAkAgAGJEgAAAxIlAAAGkUqU27Zt05QpUzRhwgTdfffdOnjwoO2Q%0AYu/cuXNatGiRhg8fTn8HYNOmTZo8ebI+85nP6LbbbtPu3btthxRrGzdu1OTJkzVx4kT6O0Cvvfaa%0Ahg8frj/84Q+2Q+mSyCTKU6dOqaamRo888og2btyoG264QQ8//LDtsGJv5syZKioqsh1GXjh06JDm%0Az5+v2tparV+/XtXV1XrooYdshxVb7733nh5++GE9+eST2rBhgyZOnKgFCxbYDiv2Tp8+rdraWpWU%0AlNgOpcsikyi3b9+uyy+/XFdeeaUk6Qtf+IK2bNmiEydOWI4s3mbOnKlZs2bZDiMvdO/eXbW1tRo6%0AdKgk6a/+6q9UX19vOar4cvr7Yx/7mCRpzJgx2rdvn+Wo4u/xxx/XZz/7WfXq1ct2KF0WmUS5f/9+%0AXX755e7XvXr1UklJid59912LUcXfiBEjbIeQN/r3769PfepT7tevv/66rrrqKosRxduAAQM0btw4%0ASVJLS4t+/OMfq7Ky0nJU8bZr1y5t3bpVd911l+1Q0tLddgBddfr0afXo0aPNsR49eujUqVOWIgJy%0AZ9u2bVq5cqVWrlxpO5TYW7lypZ588kkNHDhQ3/3ud22HE1vJZFIPP/ywHnzwQV1wwQW2w0lLZEaU%0ARUVFOnPmTJtjH374YaSG70BX1NXVaf78+Xr66afdMixy584779T27dt155136tZbb9WHH35oO6RY%0AevHFFzV06FBdffXVtkNJW2QS5cc//vE2Zdbm5mY1NjZq0KBBFqMC/LV161Z9+9vf1rPPPqu/+Iu/%0AsB1OrO3du1dbt26VJCUSCVVXV+vkyZN8TpkjmzZt0qZNmzRu3DiNGzdOBw4c0Be/+EVt377ddmid%0AikyivPbaa/Xee+/p17/+tSTpX//1X3XDDTcwIxOxcfr0af3DP/yDHn/8cZWXl9sOJ/aOHj2quXPn%0A6tChQ5KkN954Q+fOnWszFwL+WbFihbZt26YtW7Zoy5YtuuSSS7R69WqNHj3admidisxnlBdeeKGW%0ALVumb37zmzp9+rQGDhyoRYsW2Q4r1g4fPqxp06a5X0+fPl0FBQVauXKlysrKLEYWT5s2bdLRo0c1%0AZ86cNsd/8IMf6KKLLrIUVXxdc801+vKXv6y7775bra2tKiws1KOPPqri4mLboSFkEtyPEgCAjkWm%0A9AoAgA0kSgAADEiUAAAYkCgBADAgUQIAYGBcHpJIJIKKI/L8mjxMn3edH31Of3cd/R0srinB66jP%0AGVECAGBAogQAwCAyO/Ok4t3ma+/evRYjAQDEFSNKAAAMjFvYhf1DYG/otmPlg/fgMbkkWPR3sLim%0ABI/JPAAAZIBECQCAQWCTeZwh7dGjR91j/fv3b/dzdXV1bnvEiBHGnwUAINcYUQIAYECiBADAILDS%0A6+bNmyVJ48ePd491Nqtr/vz5XT5/OjPE5s2bJ0lasmRJlx+TL2pqatx2c3OzJGnFihW2wgEA6xhR%0AAgBgYHUdpfepndFLnz59fD2v47777nPbuRghxWXNk/d1HD9+XJJUWlpqKxwj1vUFKwr9Haa11dmK%0AyzUlSlhHCQBABkiUAAAYhKb0OnToUEn+bG6e6iUF+VqyYbtMQukVHYlCf1N6bS/q/ZAt78duDzzw%0AgNseNmxYu5+l9AoAQAZIlAAAGET6fpRAnDjbN/71X/+1e6ywsDCjc+V7uU366H613Ku265x11LW1%0AtV1+zKZNm3IVTtoqKyu7/LPpvD8YUQIAYMCIEghYppM0Uo0SvX/Np9r1Kt9Glr/85S/ddn19vaS2%0AfeDdecrR0ehp9uzZbnvZsmV+hRhqu3fvlvTRTmodGTRokNtOZxSXa52NbquqqjI6LyNKAAAMSJQA%0AABiEZh2ln8/POsrMectNS5culWQ/po5EbV2f4/Tp0267qKgosOfN9rVGob+9zp8/L0nq1s08Hmhq%0AanLbra2tbrukpMRtDx48WJLU0NDgY4RmcbmmZCpM13FGlAAAGJAoAQAwsFp69XK2E3JmXUltZzCl%0AM1spTEP2dIWpTBL2mZNhLQWePHnSbXtLqxUVFZKkPXv2+P6cHemojzJ53WHt785k2ge2t3OM4zUl%0AHWG6jjOiBADAIDTrKIP8KzvO0vkrdPv27ZKkMWPG5CqcvOQdRdr+a977/N73hvP7lmpj6LjpqA+c%0AtvN7ILWtBgAORpQAABiQKAEAMAhN6TUXnHtc5pPOSn1lZWVu++DBg5I6L9c6mwdLbDBtEvaPD1KV%0AIJ2N2KXMt/eKqrNnz0qSRo8e3enPhvWerPkgDNdxRpQAABiQKAEAMIh16ZUyYXuHDh1y26nKtM56%0AMUnq27evpI/uwmBiY4uvsHFKRGG6P19HXnjhBUnS7bffbjmSYHm3aHTuGmJ7ZnI+8n6c0xnv9cfW%0A/xUjSgAADEKzM4/DG062O/MEGX8cd9FwXpP3w/Tm5ma37R2dpvp+nz59chhd+HaKCftORql4+zCd%0AnWoyFaa+cV6PUw2RwlURCeqacuzYMbed7aQlv2LuCDvzAAAQQiRKAAAMYj2ZB/57//333ba3DFJc%0AXCypbenVW8bYvHmzJKmysjLXIcbKkSNH3Hb//v1zct58tXjxYknS/v373WPV1dVue926dWmf089N%0A6IPive9mtpy1qYWFhe4x7wQqx7Jlyzo9lzPhpyuTCXONESUAAAYkSgAADEJdeqVMFw5dWcd04sSJ%0Adt+fN2+e2160aFGOogsfb39lsv1WqrWsfrjvvvvcdr9+/SS1/T/KN/Pnz5ckXXbZZe6xtWvXuu05%0Ac+a4becjheXLlwcUXXSkKjd7y61dKbOmEqZ18IwoAQAwIFECAGAQ2IYDAwYMkCT17t3b+HPespX3%0AhqrTpk0z/mwqbDiQnVSvKezxpSsXrydXi66zjdUb16lTpyRJvXr1yujxmQrT+yeVVatWue1st/fL%0AdiODoK4p3uf55S9/KanzO6oEObs3nU0x/HwuL0aUAAAYBDaidO57F8QEHWcCRZAfBsdxRJlqHVOY%0A4ovCCGfSpElu2ztRxMS7lq+zx3Q0anEmBHU0GSiT1x2F/o6ToK4p3g3Knd/11tZW91hTU5PbTrXm%0AMk7/p4woAQDIAIkSAACD0N09JKriWHpNtS1dmOLLt1Lgn//5n7vtnTt3dukxfqxnc+Rbf9tm45ri%0AvMc6en9deeWVxu9HHaVXAAAyQKIEAMCA0qtP4lh6dQS5jikdlAKDRX8HK87XlLCi9AoAQAZIlAAA%0AGJAoAQAwIFECAGBAogQAwIBECQCAAYkSAAADEiUAAAYkSgAADEiUAAAYkCgBADAgUQIAYNDddgAI%0APzZVBpDPGFECAGBAogQAwIBECQCAAYkSAAADEiUAAAYkSgAADEiUAAAYJJLJZNJ2EF3xhz/8QRMm%0ATNDll1/uHvvLv/xLLVmyxGJU8Xbo0CHNnz9fDQ0N6tWrlx566CFdc801tsOKrQ0bNug73/lOm2P7%0A9u3TG2+8oeLiYktRxdtLL72kZ555RslkUhdffLEeeughDRkyxHZYsbVmzRqtWLFCJ0+e1DXXXKNv%0Af/vbKiwstB1W55IR8fvf/z55ww032A4jr9x1113JZ599NplMJpPbtm1Lzpo1y3JE+WXdunXJr371%0Aq7bDiK36+vrkqFGjkgcPHkwmk8nkCy+8kLz11lstRxVfu3btSo4aNSr53nvvJVtbW5M1NTXJJ554%0AwnZYXULpFSkdOHBAb7/9tqZNmyZJGj16tB577DHLUeWPM2fO6LHHHtMDDzxgO5TY2rt3rwYPHqyy%0AsjJJf3yP79mzx3JU8bV9+3aNHj1al1xyiRKJhO6880797Gc/sx1Wl0QqUZ44cUIzZ87UxIkT9Xd/%0A93fau3ev7ZBi65133tFll12m2tpaTZgwQdOmTdPOnTtth5U3Vq9erZEjR2rgwIG2Q4mtq666Su++%0A+652796tZDKpn/3sZxo7dqztsGIrkUiotbXV/bqoqEjvvvuuxYi6LjKJslevXqqurtaCBQv005/+%0AVOPGjdPMmTPV0tJiO7RYampq0u7du3X11Vdr48aN+uxnP6uvfvWr9HcAWltb9eyzz+qee+6xHUqs%0AlZWVqaamRp/73Oc0atQorVq1SnPmzLEdVmyNGTNGW7Zs0e7du9XS0qJVq1bpzJkztsPqksgkytLS%0AUj300EO67LLL1K1bN9199906fPiw9u/fbzu0WOrdu7f69++vqqoqSdLUqVPV2NhIfwfgzTffVFFR%0AkYYNG2Y7lFjbuXOnnnrqKdXV1WnHjh2aPXu2vvzlLysZjfmNkTN06FB9/etfV01Njf7mb/5GQ4cO%0AVe/evW2H1SWRSZSNjY36/e9/3+ZYa2urunfnBii5cOmll+rkyZNuqSSRSKhbt27q1i0yb5nIeu21%0A1/TpT3/adhixt23bNo0YMUKXXnqpJOnmm29WfX29jh07Zjmy+JoyZYrWrl2r//iP/1BFRYUqKips%0Ah9Qlkbnq/fa3v9Wdd96po0ePSpJ+9KMf6ZJLLmmzXAT+GT58uAYMGKB///d/lyStX79effr04TOz%0AALzzzjsqLy+3HUbsDRkyRG+++aabGH/+85/rz/7sz1RaWmo5snhqaGjQ5MmT1dTUpHPnzunpp5/W%0A5z//edthdUlkhmPXXXedbr/9dt12221KJBIqKyvT448/roKCAtuhxVIikdA///M/a/78+Vq+fLn6%0A9++vxx57jBF8AA4ePKiLLrrIdhixN378eL399tu69dZbJUnFxcX6zne+w/1Xc2TQoEGqrKzU5MmT%0AlUgkNGnSJE2ZMsV2WF0SmQ0HAACwITKlVwAAbCBRAgBgQKIEAMCARAkAgAGJEgAAA+Ncf6ZJd51f%0Ak4fp867zo8/p766jv4PFNSV4HfU5I0oAAAxIlAAAGJAoAQAwIFECAGBAogQAwIBECQCAAYkSAAAD%0AEiUAAAYkSgAADEiUAAAYhOZ29TU1Nb6dq7a21vj9+vp6SdKwYcN8e04AQDwxogQAwCCRNOy8m8lm%0AuuXl5W7bGbnZ0tjYmPJ43759Jfm7WTAbGAePTbqDRX8Hi2tK8NgUHQCADJAoAQAw8L302r37R/OD%0AhgwZYvzZPXv2pH1+PzgvmdJrtEWtFOi83y+66CL3WGlpaWDPn62o9XfUxeWacuzYMbcd9vc7pVcA%0AADJAogQAwMD30msUUHqNh6iVAo8fPy7po1nXQT9/tqLW31EX5mvKpEmT3PbatWt9P79XkO8ZSq8A%0AAGQgNDvzIHxS7Za0bNkyC5HEQ0lJiSTpyJEj7jHvX7CMthBG9957r9tevnx5u+9v3ry53bHx48e7%0A7Xnz5rntJUuWdPl5nd+NMPyOMKIEAMCARAkAgAGTeXw+Z7aC7HNnu8FMtxq0/f6Iw+SSU6dOue2e%0APXu2+77t+Lzi0N9RYvOa0tFzO9uCOh8j/KlcTFjzxjJ79my3nYuPgZjMAwBABkiUAAAYUHr1+ZzZ%0AykWfdzZrzTsTzTtDLZWOXqeN90rcSoEFBQWSpJaWlpTftx1r3Po77GxcU06ePClJKioqco9dcMEF%0Abruj96bDidnPEqlTzpVyvwaZ0isAABlgROnzObPlZ0xz586VJC1evNg9NmfOHLddW1ub1fm9r9n5%0Aqy/ITY/zYYSTbyN45z0rSTfddJMkqaqqKuvnjSIb15Rsr43OSHLp0qXGn9u+fbvbHjNmTJfPn+s1%0AlYwoAQDIAIkSAAADSq8+nzNbuYhp06ZN7jE/y1iUXnPPu41gqlJ5GDaMTkdn8dbV1bntyspK42NS%0AxeNdl9qrV69MQgyNKJZeu3p+r0zikz66ruXqmubFiBIAAAMSJQAABtw9JGaOHTvW7lgQpYmrr75a%0A0kfb4knS3r17fXveqHLWRkqdr0FLB32bmnf9X2ely7CX3eOiuLi43bGo9T0jSgAADPJ6ROkdfQU5%0ACSWXvJsVe3fHyESqDY47kmpj9aj91ZgLR48eNX5/6NChbtsZcTY0NOQ0prhwJnM4k36ktu+5QYMG%0ASZJ++MMfusdGjx4dUHRwNDc3u+2mpiaLkWSOESUAAAYkSgAADPJyHWWqbZb8vGdaNvyMw9muznss%0A1Vq86upqt71u3Trfnt/WvePSkev3eLbrxsIk6P52nu+FF15wj91xxx1djiuq/eyIyzpKPydVsY4S%0AAIAQIlECAGCQl6VXh/ele++wMX/+/KzOlY1s+/zMmTNuu7Cw0HhOZ9avd6asn3cN8LOkm0rUSq/O%0ADNeoroEMur+d7fu8Hxf06dPHbXtnUzo6i9H7ccCJEyckpb5HaxjYuKY4s9e966FzfZcOSq8AAEQc%0AiRIAAANKrylk8rrDUnr16mpMnc1OTee1eRcUd2WjgmxErfQa9d8nW/3d2Njotr2l166Wsr0bi3g/%0AZsgmpiDYuKakKndTemVECQCAUV5vYRd3zl/cqbaX836/s7/IBw8e7La7d//oLeOdTPH+++9nGiZg%0A5K1MeP/id97XnY1IvNtT+jVKQ9c5W2FKua8y5QojSgAADEiUAAAYMJknhbhM5om7sE7mcSZESLmf%0AFBGksPX3nj17JLW9A0tn5/e+BqckGNY7B9m8pnifu6Kiwm07fZ7puU6dOiVJ6tWrV9rn+dNzOeu0%0A/VyjzWQeAAAyQKIEAMCA0msKlF6jIWylwLgLa397y4HeMmyq56T0mt1zx/33hdIrAAAZyOt1lN6/%0APjtaawgg3IYNG+a2Bw0a5Lb3798vibWTmeDa2BYjSgAADEiUAAAY5PVknoKCArfd0tLitpnMEw1h%0AnVwSV1Htb++9Fb3Cfl9QrinBYzIPAAAZIFECAGCQ17Nez58/77a992QEEB9hL7Ei/BhRAgBgkNeT%0AefzEB+/Bi+rkkqiiv4PFNSV4TOYBACADJEoAAAxIlAAAGJAoAQAwIFECAGBAogQAwIBECQCAAYkS%0AAAADEiUAAAYkSgAADIyJMplMWv939uxZLVy4UBUVFTpw4IB7/Pvf/74mTpyom266SQsWLNCZM2es%0AxumXsPb34cOHddddd6mqqsp6jH72ue3XYOrzJ554QhMmTNBNN92kv//7v1dTUxP9ncP+fuyxx9r0%0Ad2NjY+T7Owx93lF/O/8WLVqkG264wXqcpj4P/Yhy5syZKioqanPsv//7v/Xcc8/pxRdf1IYNG9Tc%0A3Kznn3/eUoTxkqq/jx8/rmnTpqmiosJSVPGWqs83bNigDRs2aPXq1Vq/fr0SiYT+5V/+xVKE8ZKq%0Av9euXautW7dqzZo1Wr9+vVpbW/X0009bijBeUvW345133lFdXV3AEaUvEoly1qxZbY5t2LBBN998%0As/r06aNEIqEvfOEL2rBhg6UI4yVVfycSCX33u9/V+PHjLUUVb6n6vLy8XAsXLlRxcbG6deumESNG%0AaM+ePZYijJdU/T106FB94xvf0IUXXqhu3bpp1KhR2rdvn6UI4yVVf0tSa2urvvGNb+j++++3EFV6%0AQn8/yhEjRrQ7tn///jYX7csvv1z/93//F2RYsZWqv/v27au+ffvqgw8+sBBR/KXq82HDhrX5+vXX%0AX9c111wTVEixlqq/r7jiCrfd3NysDRs2aPLkyUGGFVup+luSfvjDH6qiokJXXXVVwBGlL/QjylRO%0Anz6twsJC9+sLL7xQp0+fthgRkDtPPfWUjhw5ounTp9sOJfZmz56t6667TgMHDtTnPvc52+HE1gcf%0AfKCVK1dq9uzZtkPpkkgmyp49e+rs2bPu16dPn+6wBg5EWW1trf7zP/9TzzzzDO/xANTW1upXv/qV%0AioqK9MADD9gOJ7YWLlyor3zlK+rbt6/tULokkony4x//uBoaGtyvGxoaNHToUIsRAf57/PHH9Zvf%0A/EbPPfec+vXrZzucWNu2bZv7GXCPHj00depU/eIXv7AcVXy9+uqrWrx4scaNG6cvfvGLOnDggMaN%0AG9dmABQmkUyUn/nMZ7Ru3TodPnxYLS0teu655zRp0iTbYQG++d///V+tWbNGTz/9tIqLi22HE3tv%0AvPGGFi1a5F6oX331VQ0fPtxyVPH15ptvasuWLdqyZYtWr16tSy65RFu2bGnzkVqYhHoyz+HDhzVt%0A2jT36+nTp6ugoEArV67UPffcozvuuEPJZFJjx47VbbfdZjHSeOiov2fMmKHvfe97+vDDD3X48GFN%0AnDhRZWVlWrlypcVo46GjPr/66qvV3NysqVOnut/72Mc+pmeeecZGmLFhuqZ88MEHuuWWWyRJF198%0AsR555BFbYcaGqb/LysosRpaeRNLPla0AAMRMJEuvAAAEhUQJAIABiRIAAAMSJQAABiRKAAAMjMtD%0AEolEUHFEnl+Th+nzrvOjz4Ps71OnTkn6485Sjvr6erf9p/u7SmqzPjjVuj7vEoZc79wTtf6OOq4p%0AweuozxlRAgBgQKIEAMDAuOEAQ/auo0wSvCiUAlPF2NjY6Lb93BR68eLFbnv+/Pm+ndcRhf6OE64p%0AwaP0CgBABhhR+oS//oIXhRGON0bnDjd79+5N+bPOZJ7du3e7x7z361u2bJnb7t79j/Pwzp07l/Jc%0AuXhdUejvOOGaEjxGlAAAZIBECQCAQahvswVEkbdc6tVRyTUTLS0tvp0LgBkjSgAADEiUAAAYUHoF%0AQmLPnj1d/tlU290xuxHIDUaUAAAYRHpEOWDAALd96NAh387LX+bIhl/r36S2m55711HCP135/2pq%0AapLk705KiA5GlAAAGJAoAQAw8G0Lu5qaGklSbW1tRo93JjI423zZlEnple2mghfWLdV69+7ttp2S%0And/P5WxhN2TIEPdYOpOBMhG2/j527Jgkad26de6xadOmpX2eVBOjJOlLX/qS2547d66kttcnP9fF%0AphLUNaWurs5tV1ZW+vKcQfLzPcUWdgAAZIBECQCAQWhKrw5vODbuFZjp80a59Bpkn/spbKXAVLwx%0AlpWVSZLef//9nD5nroShvzuLIVf/n87zxrH0Wl1d7baHDx/uy3OmY+nSpW57zpw5xp9N1Sd+zgan%0A9AoAQAZ8G1GmOo2fIzM//1L8xCc+IUl6++23fXuuuIwovcI+ugzDCKcz27dvd9vXXntt2s+Zzmu0%0AVYFJRyYxlpeXu+36+vp25/Ie8/5sLiZ5xHFEaVuYKlqMKAEAyACJEgAAA9+3sMt26Ox9vHcY7LQz%0APb+3JOOUXFtbW91jBQUFGZ03yrxr/FJpbGyUxLZd2Rg9erTbdt7DxcXF7rETJ060+35Hjh8/7rad%0A96t3zab38U7Jd8yYMZmEHQnO77T3dzub80hty7ipdFZu9a7JzPW61jByrine92Uq3glEUcCIEgAA%0AAxIlAAAGvs969XPWkrM2U2q7PtPhLUWl2sbK2eJKkkpKSto9Poitj9KV61lfgwYNctv79++XJM2e%0APds9tmLFCrftlFFsz0TrSBRmvXp1NV5vWcr7vs7k/GF7j+d6HWUmz9WVc2ZyriBfq0mu3+PORzSS%0A1KdPn3bfr6iocNu7d+82nsv2tYZZrwAAZMD3EeWmTZuyj+r/845+/Nws3YmxqqrKt3NG5a+/dEYd%0AuR6hZCsMI5xMdBT3fffdJ6ntqD4d3pHoK6+8Iil+I8rO5GpEd/bsWUlSjx49rDx/NnLR596JSt5r%0AszOibG5uNj7eea9L0ve+9z23bfv6wogSAIAMkCgBADDIqvRq4z5m3tJuR6VT5yV5tw87efKk204V%0A6+LFi932/Pnz044rzGWSjiZFdfZc27Ztk9R2LaDt0ohXFEqBqXjj7sr7OZvnoPSaOe8kN2fj7s7O%0AGffSa6r7BofpmpAtSq8AAGSARAkAgIHvpddclRtycScTP9dZhrFM4vDG5l17WlpamvbjvWyXXKJQ%0ACkzlzJkzbruwsDAnsVB6pfTqVxxeuXhfhQmlVwAAMuD7puiZsrFmr6MRlRPLD37wA/eYs9tPlHhH%0A/I6ujiK9OtqoHpn5r//6L7cd1CQ4AJljRAkAgAGJEgAAA6ul11yXW71rI9OxefNmSdIdd9zhHoti%0A6dUp691yyy05Ob+3tOvnGsC48/aV93fAmVyWSXlcko4cOZJdYCHVWbnfz20zYZf3/3LkyJGSMv99%0A8BMjSgAADEiUAAAYZFV69c7Yc8qVnfFup+aVixmumd6FwXld3pKPd7f75cuXZxdYDnm36nOsXbvW%0At/N7/5+ZsZm9vXv3uu3y8vKsztWvXz+3HYdyZKprhfc64ZT+eR8Gz3unGj+vL+PHj3fbq1at8u28%0A2WJECQCAAYkSAACD0Gw4EHbeGbRhLr0WFRW57Tlz5vh+fm+Zy1uabmxslCT17dvX9+eMsyeffNJt%0Ae+/skq24zkL2lqdnzJjR7vv19fVu23l/ZntnIHzEKYc6NwaXcnODbClcKw0YUQIAYJDVpuiZrIPM%0A9N6I6XDi8t4zzTtpIt3z/KlUsdrcwNjWpuWLFi1y2/PmzQvkOb2isEl3Z4qLi912c3OzpPTet+fP%0An3fb3bp99HdvrjfXz5Sf7+/OOP3oHWWmE0tnz+tUUaTUlZQ4boruSKdvOuPtO9u/j2yKDgBABkiU%0AAAAYBD6Zx1tu9a5NhD8qKioCey7vxAin9Orn/fjywYkTJ9od85YKO+tDb7nVe//EOMn2fdTR3W9S%0Aldk6Wg/uXKt69+7tHvNey5y+X7ZsWVaxRoW3T1N9fJDOpL4NGzb4F1iOMKIEAMCARAkAgEFgs173%0A7Nkjqe2MvrCX5tIpI9qYobZ+/XpJ0sSJEzN6fC6cOXPGbRcWFrrtOM3CzDXv67ryyivd9s6dO922%0AszZw7ty57rFcv5a49ndH71kbX4IRAAAPzklEQVTvNoA21qWGedZrXDHrFQCADAQ2onR+Nkqbaod9%0AROnXcwaBEWXXdWVtZK7v5ZpKXPs7rBhRBo8RJQAAGSBRAgBg4Ns6yjgO78P+mrwTo0zS2b7Puyaq%0ArKysy4/LZItApFZQUOC2Oyu/dfU9ACBzjCgBADAgUQIAYJBV6TWdLbOcbbkeffTRbJ4SHrkod3q3%0AVEu1vRrsSVWGpeQN5B4jSgAADLJaR4mPsOYpePm2rs/G2snOnj9dUepv27imBI91lAAAZIBECQCA%0AQeD3owSQvePHj9sOAcgbjCgBADAgUQIAYEDpFYiIOXPmuO0HH3zQYiRAfmFECQCAAesofcKap+Dl%0A87q+Y8eOSZJKS0sDe8587m8buKYEj3WUAABkgEQJAICBMVEmk0nr/86ePauFCxeqoqJCBw4cUDKZ%0A1EsvvaSRI0dqwoQJ7r/nn3/eapx+CWN/J5NJ7dixQ5MmTVJlZaWmT5+ugwcPWo81Dv3dUZ8vXry4%0Azfv705/+tKZMmeI+pqSkRCUlJfS3T/3d0tKiRx55RDfddJMmTpyo+fPn68SJE5F/f4ehz1P197lz%0A57Rw4UJNmDBB119/vVasWGE9TlOfh35EOXPmTBUVFbU7fuONN2rDhg3uv2nTplmILn5S9feJEyd0%0A//3365FHHlFdXZ2uu+46rVu3zlKE8ZOqz+fOndvm/X399ddrypQpliKMl1T9/dJLL2nnzp165ZVX%0AtG7dOp09e1bLly+3FGG8pOrvH/3oR3rrrbf0k5/8RC+//LJeeukl/frXv7YUYecikShnzZplO4y8%0Akaq/6+rqdOWVV+qTn/ykJOm+++7TPffcYyO8WOrsPb57927t2LFDt912W4BRxVeq/t69e7dGjhyp%0AwsJCdevWTaNGjdKePXssRRgvqfp769atqq6uVo8ePdS7d299/vOf18aNGy1F2LnQJ8oRI0akPP67%0A3/1O06dP14QJE7RgwQI1NzcHHFk8pervXbt2qbS0VF/5ylc0YcIEfe1rX9PRo0ctRBdPHb3HHU88%0A8YS+9KUvqXt3lj37IVV/jx49Wq+//roaGxt15swZvfrqqxo3bpyF6OInVX8nEgm1tra6XxcVFend%0Ad98NMqy0hD5RpjJ48GBVVlbqqaee0po1a3TixAn94z/+o+2wYqupqUm/+MUvNHfuXK1du1aFhYX0%0Ad0AaGhr01ltvqbq62nYosVZVVaUrrrhC48aN0+jRo9Xc3KypU6faDiu2xo4dq9WrV6upqUnHjh3T%0Ayy+/rDNnztgOq0ORTJQjR47UrFmzVFxcrJ49e2rGjBl67bXXbIcVW71799aYMWM0aNAgXXDBBfrb%0Av/1bbdmyxXZYeeGnP/2pbrzxRl1wwQW2Q4m15557TkePHtWOHTu0Y8cOlZeX88dgDk2dOlVjx47V%0A1KlTNWvWLI0dO1Z9+vSxHVaHIpkoDxw40Kb0d/78ecpSOXTppZe2KW0XFBSooKDAYkT547XXXtOn%0APvUp22HE3pYtW3TjjTeqZ8+e6t69uyZOnKgdO3bYDiu2unfvrnnz5mnjxo16/vnnVVBQoIqKCtth%0AdSiSifLf/u3f9OCDD+rcuXM6f/68nn/+eV1//fW2w4qtqqoq7dixQ7t27ZIkvfjiixozZozlqPLD%0Arl27VF5ebjuM2BsyZIhef/11tbS0SPrjHyjDhg2zHFV8vfzyy/ra176m1tZWHTp0SD/+8Y91yy23%0A2A6rQ8Yt7Gw7fPiwu+xj3759GjhwoAoKCrRy5Uo9+uij+s1vfqNEIqGRI0dqwYIF6t27t+WIo83U%0A32+99Zb+6Z/+SYlEQsOGDdO3vvUt9evXz3LE0Wfq8x49eujaa6/Vb3/7WxUWFlqONB5M/b1kyRL9%0Az//8j7p166bBgwfrm9/8psrKyixHHG2m/v7Wt76lnTt3qnv37rr//vt18803W462Y6FOlAAA2BbJ%0A0isAAEEhUQIAYECiBADAgEQJAIABiRIAAAPjKn3ujN11fk0eps+7zo8+p7+7jv4OFteU4HXU54wo%0AAQAwIFECAGDABqnodIu0/fv3u+3z58/nOBoACBdGlAAAGDCizDPHjh1z2yUlJVmdi0kC/k24SNfx%0A48clSaWlpVaeH8gnjCgBADAgUQIAYGC8ewilta6Lyponb5xz5syRJNXW1mb0eC8b75UwrOvzxjB7%0A9mxJ0rJly7I6Z1iFob/zSVSuKXHCOkoAADJAogQAwIBZr3lszZo1aT/GW8bxlinuvfdeSdKKFSuy%0ADwwAQoQRJQAABoGPKGtqaty2dxJJZx84p/qQlQ+ps3PFFVdIkvbu3ZvR40+fPu22ly9fLokRJYD4%0AYUQJAIABiRIAAINYT+bxlmsp07a3du1aSZn3zde//nW3vXTpUl9iAoCwYUQJAIABiRIAAIPIlF77%0A9OkjSWpqaur0Z+vq6todmzRpkttet26df4FFTKp1kJmWqNevX++2Kb0CiCtGlAAAGERmRNnc3Gz8%0AvrMhtSRVVla2+/7w4cPddj6PKL2c0WOmmy/v3LnTz3Ai6ezZs27bWRcc103RgXzFiBIAAAMSJQAA%0ABpEpvabSUcnw+PHjkqSSkpIgw4ksb/kwHd5yd746deqU2y4sLLQYCZBfFi9eLEmaO3eueyxX6+UZ%0AUQIAYECiBADAINKl16FDh7rtVHfAyHQ2Z77p0aNHRo/zrp10yt0A8suxY8faHSstLc358zol18GD%0AB+f8uRhRAgBgkEgahl253kg8kx1h0nmM92e9E09ysc7Nr9Fr2Ddv7+h12ojbjz7PNm7vX9PO5LGw%0A/x9mKgz9nUkMUf3/iMo1JVWc9fX1bnvYsGG+n9/Lz9fX0XMxogQAwIBECQCAQWgm89x7772SpOXL%0Al1uOBH/Ku1bQK6olLURXOu85JvMFI9X/ibfvM70hRXFxsfH7mzZt6vK5ssWIEgAAAxIlAAAGoSm9%0Avvfee+2OdTak70hBQYEvMeW78vJySVLPnj3dY5RbYVOqWcad8aME69wHd+TIkSm/n2oddz574YUX%0A3PbatWvddjrXj87uGFVVVZV+YBliRAkAgEHgI8qTJ0+mPP7OO+9I6vgvDucvOi/vbjB9+/b1ITp4%0AeddCAWHQp0+fdse814xBgwa57f3790tqu4NXOrzvf+d5O/udoOLyR3fccYfbvv322932+fPnJWVe%0A9bPVv4woAQAwIFECAGAQeOm1qKgo5fHOPgzv3bt3u2Pecqt3i7qf/OQnktqWSXKxbV0c1dTUtDtG%0AOQlh0a3bR3/bz5kzp933u3dvf0lLZ6KNd81fOqqrqzN6XD7wXj+ciVWZbF9qEyNKAAAMSJQAABgE%0AfveQTO8+kepxHT3GWf/nLb3a2EE/E7bLEN7XsWrVKknStGnTbIVjFIa7WXD3kPRkujY67FKVgb1q%0Aa2vTPmdcrileAwYMkCQdOnTIPfaJT3zCbf/ud79z287r7+y+w37i7iEAAGSARAkAgEHgpVdvqerI%0AkSNuu7NFwd7HOUpLS1P+LKXX9NTV1bntyspKq7Gkg9JrsMLQ39lKdR3xQ2fb6WXyuqN8TelMR5vF%0ApJohG2T8lF4BAMhA4CPKIDCiTE9HsTc2Nkrq+ubTQQvDCIcRZXri2je5EOVrSjq8r/P06dNu27kZ%0AAyNKAABCjkQJAIBBaO5HieClKjN4y9XOBKuobTcVJO+EsjisBwSCtnnzZrc9fvx4i5F0jBElAAAG%0AJEoAAAxiOevV4V0fWFVVldPnisoMtXS2AkznZ20I2yxMG+u+ghS2/o67qFxT/GT7msOsVwAAMhDr%0AEWWQovLXnzdOZ7JORxsN2/7rrjNhG+E48Vx88cXuMe/mz1EXtv6Ou6hcU/zkva/w0qVLJTGiBAAg%0A9EiUAAAYUHr1SVTKJJ3F6WxbJ7XdrNgRpvdE2EqBTjzedWHeTeajLmz9HXdRuabECaVXAAAyQKIE%0AAMCALezyTKoyTE1Njduura1126dOnZIk9erVK/eBxcCmTZskhbfceubMGbddWFgoibIc0BWMKAEA%0AMGAyj0/44D14YZ1c4o1rxowZbnv58uXtfta7bszhrB/LpUxed1j7O664pgSPyTwAAGSARAkAgAGl%0AV59QJgleWEuBx44dc9slJSVpP/748eNd/tlPfvKTbruhoSHt50pHWPs7rrimBI/SKwAAGSBRAgBg%0AQOnVJ5RJgkcpMFj0d7C4pgSP0isAABkgUQIAYECiBADAgEQJAICBcTIPgGCdO3dOtbW1+v73v6+f%0A//znuvjii22HBOQ9RpRAiMycOVNFRUW2wwDgQaIEQmTmzJmaNWuW7TAAeJAogRAZMWKE7RAA/AkS%0AJQAABiRKAAAMSJQAABiQKAEAMGAdJRAShw8f1rRp0yRJ+/bt08CBA1VQUKCVK1eqrKzMcnRA/iJR%0AAgBgQOkVAAADEiUAAAYkSgAADEiUAAAYkCgBADDobjsAwKZEImE7hMhggjzyFSNKAAAMSJQAABiQ%0AKAEAMCBRAgBgQKIEAMCARAkAgAGJEgAAAxIlAAAGJEoAAAxIlAAAGJAoAQAwIFECAGDApuiAReXl%0A5W67paVFktTQ0GArHAApMKIEAMCARAkAgEEiyU3mkMds34/S++t3/PhxSVJpaamtcIy4VCBfMaIE%0AAMCARAkAgAGJEgAAAxIlAAAGJEoAAAxIlAAAGJAoAQAwIFECAGBAogQAwIBECQCAAXcPAXzibPG2%0AadMm91hVVZWtcAD4hBElAAAGjCgBn+zdu1eSVFlZ6R6rq6tz24wugWhiRAkAgAGJEgAAA0qvgE+G%0ADh0qSaqurnaPvfLKK27bmexj+x6YANLDiBIAAAMSJQAABomkUw8C8lCQZdDOftWampokSX379g0i%0AnLRxqUC+YkQJAIABI0rkNRsTa7Zt2+a2R48e3e77gwcPdtsNDQ1BhNQlXCqQrxhRAgBgQKIEAMCA%0A0ivymu01jTU1NW67trbW+LPz5s1z20uWLMlZTB3hUoF8xYgSAAADEiUAAAaUXpHXbJdevZxfRWcr%0APEn61a9+5bb79evntmfMmCFJWr58eUDRUXpF/mJECQCAAYkSAAADSq/Ia2EvvTo3g5akU6dOue2e%0APXu2e3yuXwuXCuQrRpQAABhwP0ogIoqKitx2eXm5JKm+vt495h3xhWmkDEQdI0oAAAxIlAAAGFB6%0ABSLImeRDiRXIPUaUAAAYkCgBADAgUQIAYECiBADAgEQJAIABiRIAAAMSJQAABiRKAAAMSJQAABiQ%0AKAEAMCBRAgBgQKIEAMCARAkAgAGJEgAAAxIlAAAGJEoAAAxIlAAAGJAoAQAwSCSTyaTtIABbEomE%0A7RBcqX4Vwx4fkA8YUQIAYNDddgAA/ihMo0cAH2FECQCAAYkSAAADSq/Ia0xQAdAZRpQAABiQKAEA%0AMCBRAgBgQKIEAMCARAkAgAGJEgAAg/8HrjWlqzde0WoAAAAASUVORK5CYII=)

Here our goal would be to match the 2nd Image of the first row to the single image shown in last row (as indicated by the labels)

#Installing Requirements and Imports

We will be using **PyTorch** for defining and training our model and **comet.ml** for logging our stats

In [1]:
!pip3 install torch torchvision
!pip3 install comet_ml

Collecting comet_ml
  Downloading comet_ml-3.47.5-py3-none-any.whl.metadata (3.9 kB)
Collecting everett<3.2.0,>=1.0.1 (from everett[ini]<3.2.0,>=1.0.1->comet_ml)
  Downloading everett-3.1.0-py2.py3-none-any.whl.metadata (17 kB)
Collecting python-box<7.0.0 (from comet_ml)
  Downloading python_box-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.8 kB)
Collecting semantic-version>=2.8.0 (from comet_ml)
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting simplejson (from comet_ml)
  Downloading simplejson-3.19.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.2 kB)
Collecting wurlitzer>=1.0.2 (from comet_ml)
  Downloading wurlitzer-3.1.1-py3-none-any.whl.metadata (2.5 kB)
Collecting dulwich!=0.20.33,>=0.20.6 (from comet_ml)
  Downloading dulwich-0.22.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Collecting configobj (from everett[ini

In [2]:
from comet_ml import Experiment
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Model Architecture

## Convolutional Block
A Standard convolutional block which follows: <br>
Image --> Convolutional layer --> ReLU Activation --> Batch Normalization --> Pooling Layer --> Dropout.<br>
The dropout layer has been added for improved regularization.

In [3]:
class ConvLayerWithBatchNorm(nn.Module):

  def __init__(self, in_channels, out_channels=64, kernel_size=3, padding=1, dropout_probality=0.2):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
    self.ReLU = nn.ReLU()
    self.batch_norm_layer = nn.BatchNorm2d(out_channels)
    self.maxpool = nn.MaxPool2d(2, 2)
    self.dropout = nn.Dropout(dropout_probality) # Dropout to add regularization and improve model generalization

  def forward(self, X):
    x = self.conv(X)
    x = self.ReLU(x)
    x = self.batch_norm_layer(x)
    x = self.maxpool(x)
    x = self.dropout(x)
    return x

##Embedding Layer
This is the embedding model, which is a stack of 4 of the above `ConvLayerWithBatchNorm` layers. <br>
An addtional fully connected layer is added at the end along with dropout

In [4]:
class ConvEmbedding(nn.Module):

  def __init__(self, in_channels=1, embedding_size=256, dropout_probality=0.2):
    super().__init__()
    self.conv1 = ConvLayerWithBatchNorm(in_channels, 64, dropout_probality=dropout_probality)
    self.conv2 = ConvLayerWithBatchNorm(64, 64, dropout_probality=dropout_probality)
    self.conv3 = ConvLayerWithBatchNorm(64, 64, dropout_probality=dropout_probality)
    self.conv4 = ConvLayerWithBatchNorm(64, 64, dropout_probality=dropout_probality)
    self.dense = nn.Linear(64, embedding_size)
    self.dropout = nn.Dropout(dropout_probality) # Dropout to add regularization and improve model generalization
    self.embedding_size = embedding_size

  def forward(self, X):
    # Input shape is (batch_size, 1, 28, 28)
    x = self.conv1(X) # x's Shape (batch_size, 64, 14, 14)
    x = self.conv2(x) # x's Shape (batch_size, 64, 7, 7)
    x = self.conv3(x) # x's Shape (batch_size, 64, 3, 3)
    x = self.conv4(x) # x's Shape (batch_size, 64, 1, 1)
    x = x.view(x.size()[0], -1) # x's Shape (batch_size, 64)
    x = self.dense(x) # x's Shape (batch_size, embedding_size)
    x = self.dropout(x) # x's Shape (batch_size, embedding_size) with a few activations flipped to 0
    return x

##  The Fully Conditional Embedding - Target Image
Now we have the fully conditional embedding layer for our target image. This allows us to obtain an embedding for our target image in the context of the support images.

One significant change in our implementation is that instead of concatenating the hidden state and our attended output, we choose to add them instead. This is due to to the available LSTMCell has hidden weights of shape (hidden_size, hidden_size) but in order to take the concatenated input as the hidden state, we must construct our custom LSTMCell which has hidden weights of size (2xhidden_size, hidden_size).

This can be found mentioned by one of the authors Oriol Vinayls [on machine learning subreddit here](https://www.reddit.com/r/MachineLearning/comments/6efl5g/d_order_matters_attention_mechanisms/)

In [5]:
class FullyConditionalEmbeddingTargetImage(nn.Module):

  def __init__(self, embedding_size, processing_steps=10):
    super().__init__()
    self.lstm_cell = torch.nn.LSTMCell(embedding_size, embedding_size)
    self.processing_steps = processing_steps
    self.embedding_size = embedding_size
    self.attn_softmax = nn.Softmax(dim=1)
  def forward(self, target_image_encoded, support_images_encoded):
    batch_size, num_images, _ = support_images_encoded.shape

#     hidden_state_prev = torch.zeros(batch_size, self.embedding_size).to(device)
    cell_state_prev = torch.zeros(batch_size, self.embedding_size).to(device)
    hidden_state_prev = torch.sum(support_images_encoded, dim=1) / num_images
    for i in range(self.processing_steps):
      hidden_out, cell_out = self.lstm_cell(target_image_encoded, (hidden_state_prev, cell_state_prev))
      hidden_out = hidden_out + target_image_encoded
      attn = self.attn_softmax(torch.bmm(support_images_encoded, hidden_out.unsqueeze(2)))
      attended_values = torch.sum(attn * support_images_encoded, dim=1)
      hidden_state_prev = hidden_out + attended_values
      cell_state_prev = cell_out

    return hidden_out

## The Fully Conditional Embedding - Support Images

Here we have the the fully conditional embeddings for our support images, this is similar to our previous layer for the target image except for the difference that it is just a simple bi-directional LSTM where we just pass our support images sequence.

Our new embeddings are going to be the LSTM's **Forward Activation + Backward Activation + Support Image Embedding**

In [6]:
class FullyConditionalEmbeddingSupportImages(nn.Module):

  def __init__(self, embedding_size):
    super().__init__()
    self.embedding_size = embedding_size
    self.bidirectionalLSTM = nn.LSTM(input_size=embedding_size, hidden_size=embedding_size, bidirectional=True, batch_first=True)

  def initialize_hidden(self, batch_size):
    #Initialize the states needed for our bi-directional LSTM
    hidden_state = torch.zeros(2, batch_size, self.embedding_size).to(device)
    cell_state = torch.zeros(2, batch_size, self.embedding_size).to(device)
    return (hidden_state, cell_state)

  def forward(self, support_embeddings):
    batch_size, num_images, _ = support_embeddings.shape
    # Initialize states
    lstm_states = self.initialize_hidden(batch_size)
    # Get the LSTM Outputs
    support_embeddings_contextual, internal_states = self.bidirectionalLSTM(support_embeddings, lstm_states)
    # Get the forward and backward outputs
    support_embeddings_contextual = support_embeddings_contextual.view(batch_size, num_images, 2, self.embedding_size)
    # Add the forward and backward outputs
    support_embeddings_contextual = torch.sum(support_embeddings_contextual, dim=2)
    # Add the skip connection to our output
    support_embeddings_contextual = support_embeddings_contextual + support_embeddings
    return support_embeddings_contextual

##Cosine Distance Module
A cosine distance module that allows us to compute the cosine distance of our target image with each of the support images

In [7]:
class CosineDistance(nn.Module):

  def __init__(self):
    super().__init__()

  def forward(self, target_image, support_images):
    support_images_normed = F.normalize(support_images, p=2, dim=2)
    # the 'p=2' param represents squared norm
    target_image_normed = F.normalize(target_image, p=2, dim=1)
    target_image_normed = target_image_normed.unsqueeze(dim=1).permute(0, 2, 1)
    # This will cause the dimensions to be [5, 64, 1]
    similarities = torch.bmm(support_images_normed, target_image.unsqueeze(1).permute(0, 2, 1))
    # torch.bmm = batch matrix multiply
    # [5, 20, 64] @ [5, 64, 1]
    # the output shape is [5, 20, 1]
    similarities = similarities.squeeze(dim=2)
    # remove last dimension
    return similarities

##One-Hot Converter Module
A helper class that allows us to convert labels into one-hot vectors

In [8]:
# Taken from @activatedgeeks's answer from https://stackoverflow.com/questions/44461772/creating-one-hot-vector-from-indices-given-as-a-tensor
class ConvertOneHot(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, labels, num_classes):
    batch_size, num_images, _ = labels.size()
    one_hot_labels = torch.Tensor(batch_size, num_images, num_classes).to(labels.device).float().zero_()
    return one_hot_labels.scatter(2, labels, 1)

##Matching Network
The orchestrating module that builds the entire matching network.<br> We implementing the matching networks described in this [paper](https://arxiv.org/pdf/1606.04080.pdf)<br>

For an overview of the paper read up on this blogpost from [Adrian Colyer on it](https://blog.acolyer.org/2017/01/03/matching-networks-for-one-shot-learning/)

**Input to the Network**: *Support Images*, *Support Labels* and *Target Image*<br>
**Output of the Network**: *One-Hot encoded Target label*<br>

1. The model operates by first getting the embeddings for both support images and our target image via the `ConvEmbedding` model.
* Then we calculate the consine distance between the target image embedding and each of the support image embedding.
* We squash the distances via a softmax function. This acts as our attention weights
* We multiply the attention weights with our one-hot encoded labels and sum them in order to get our predicted one-hot label.
* We then measure the cross entropy loss of our predicted label and the expected target label and backpropagate.

Drew Heavy inspiration from the following Matching Networks implementations:


1.   BoyuanJiang's implementation - https://github.com/BoyuanJiang/matching-networks-pytorch
2.   activatedgeek's implementation - https://github.com/activatedgeek/Matching-Networks
3.   Mark's implementation for the Full Context Embeddings - https://github.com/markdtw/matching-networks


In [9]:
class MatchingNet(nn.Module):

  def __init__(self, image_shape, embedding_size=256, dropout_probality=0.2, use_fce=True):
    super().__init__()
    self.attn = nn.Softmax(dim=1)
    self.embedding = ConvEmbedding(embedding_size=embedding_size, dropout_probality=dropout_probality)
    self.distance = CosineDistance()
    self.use_fce = use_fce
    self.onehotconverter = ConvertOneHot()
    if self.use_fce:
      self.full_conditional_embedding_support = FullyConditionalEmbeddingSupportImages(embedding_size=embedding_size)
      self.full_conditional_embedding_target = FullyConditionalEmbeddingTargetImage(embedding_size=embedding_size)
    self.image_shape = image_shape

  def forward(self, support_images, support_labels, target_image):

    batch_size, num_images, _ = support_labels.size()

    # Get the image encodings from convolutional embedding
    target_image_encoded = self.embedding(target_image)
    support_images_encoded = self.embedding(support_images.view(-1, *self.image_shape)).view(-1, num_images, self.embedding.embedding_size)

    if self.use_fce:
      # Get the support images embedding with context
      support_images_encoded = self.full_conditional_embedding_support(support_images_encoded)

      # Get the target image embedding with context
      target_image_encoded = self.full_conditional_embedding_target(target_image_encoded, support_images_encoded)

    # Get the cosine distances between target image and the support images
    distances = self.distance(target_image_encoded, support_images_encoded)

    # Get the attention value based on the distances
    attention = self.attn(distances)

    # Convert the labels into one hot vectors
    support_set_one_hot_labels = self.onehotconverter(support_labels, num_images)

    # Get the prediction logits by attention * one-hot-labels (automatically summed due to the unsqueeze operation)
    prediction_logits = torch.bmm(attention.unsqueeze(1), support_set_one_hot_labels).squeeze()

    # Get the final labels for predictions
    _, prediction_labels = torch.max(prediction_logits, 1)
    return prediction_logits, prediction_labels

#Testing the model

We just run the model via random data over a few epochs to ensure that it is able to learn something and ensure that the gradients propagate correctly i.e catch any silly mistakes like dimension mismatch etc

In [10]:
img_shape = (1, 28, 28)
matching_net_trial = MatchingNet(img_shape, dropout_probality=0.1, use_fce=False)
print("Model Summary")
print(matching_net_trial)
epochs = 10

support_images = torch.rand(32, 20, *img_shape)
target_image = torch.rand(32, *img_shape)
support_labels = torch.LongTensor(32, 20, 1) % 20
target_labels = torch.LongTensor(32) % 20

matching_net_trial.to(device)
support_images = support_images.to(device)
support_labels = support_labels.to(device)
target_image = target_image.to(device)
target_labels = target_labels.to(device)
optimizer = torch.optim.Adam(matching_net_trial.parameters(), lr=0.001)
for epoch in range(epochs):
  logits, predictions = matching_net_trial(support_images, support_labels,target_image)
  loss = F.cross_entropy(logits, target_labels)
  print(loss.item())
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

Model Summary
MatchingNet(
  (attn): Softmax(dim=1)
  (embedding): ConvEmbedding(
    (conv1): ConvLayerWithBatchNorm(
      (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (ReLU): ReLU()
      (batch_norm_layer): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv2): ConvLayerWithBatchNorm(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (ReLU): ReLU()
      (batch_norm_layer): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv3): ConvLayerWithBatchNorm(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (ReLU): ReLU()
   

#Gathering the Data

We fetch the data from github repository and unzip the train and evaluation zip files

In [11]:
!git clone --depth 1 https://github.com/brendenlake/omniglot.git
!unzip -qq omniglot/python/images_background.zip
!unzip -qq omniglot/python/images_evaluation.zip

Cloning into 'omniglot'...
remote: Enumerating objects: 38, done.[K
remote: Counting objects: 100% (38/38), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 38 (delta 5), reused 30 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (38/38), 85.74 MiB | 41.91 MiB/s, done.
Resolving deltas: 100% (5/5), done.


Helper functions to read and rotate images

In [12]:
from skimage import io, transform
import numpy as np

def read_image(path, size, angle=0):
  img = io.imread(path, as_grey=True)
  img = transform.resize(img, size, mode='constant')
#   img = 1 - np.expand_dims(img, 0)
  img = np.expand_dims(img, 0)
  return img

def rotate(img, angle):
  return np.expand_dims(transform.rotate(np.squeeze(img), angle), 0)

In [13]:
PATH = 'images_background'
EVALUATION_PATH = 'images_evaluation'

We recursively go over each Language and Alphabet and extract the images of each Letter present in them into a dictionary.<br>
We also extract the images here into numpy arrays for both the "Train" and "Evaluation" sets.

In [15]:
import os

def load_alphabet(path):
  lang_dict = []
  for alphabet in os.listdir(path):
    alphabet_path = os.path.join(path, alphabet)
    for letter in os.listdir(alphabet_path):
      current_letter_dict = {"alphabet": alphabet, "letter":letter, "images":[], "image_mat": []}
      letter_path = os.path.join(alphabet_path, letter)
      for letter_image in os.listdir(letter_path):
        letter_image_path = os.path.join(letter_path, letter_image)
        current_letter_dict["images"].append(letter_image_path)
        img = read_image(letter_image_path, (28, 28))
        current_letter_dict["image_mat"].append(img)
      lang_dict.append(current_letter_dict)
  return lang_dict

train_dict = load_alphabet(PATH)
test_dict = load_alphabet(EVALUATION_PATH)
print(f"found letters in Training set - {len(train_dict)}")
print(f"found letters in Testing set - {len(test_dict)}")

TypeError: PillowPlugin.read() got an unexpected keyword argument 'as_grey'

We also define an augmentation function that addtionally generates 4 rotations of each alphabet thus increasing the number of classes we possess by a factor of 4. There is still some issue with the model when it utilizes augmented data so for now we will not be using it.

In [None]:
def augment_dataset(dataset):
  augmented_dataset = []
  for letter in dataset:
    for angle in range(4):
      current_letter_dict = {"alphabet": letter["alphabet"],
                             "letter":letter["letter"],
                             "images":letter["images"],
                             "image_mat": []}
      for image in letter["image_mat"]:
        rotated_image = rotate(image, angle*90)
        current_letter_dict["image_mat"].append(rotated_image)
      augmented_dataset.append(current_letter_dict)
  return augmented_dataset

##Splitting data into Train/Dev/Test Sets
Using sklearn's `train_test_split` function we create 3 sets of data.


1.   **Training Set** - Used to train our model
2.   **Dev Set** - Used to evaluate our model during the training process in order to perform hyperparameter tuning
3.   **Testing Set** - Used to finally evaluate our model

Before generating the above three sets, we first combine the provided Train and Test sets since they seem to be coming from different distributions and the algorithm was performing poorly when using the provided Train/Test Split.

We set aside 25% of total data for testing set.

From the remaining 75% of data we again split the data in 9:1 ratio. 67.5% of total data for training set and 7.5% of total data for dev set



In [None]:
from sklearn.model_selection import train_test_split
random_state = 50
total_data = train_dict + test_dict
training_dict, testing_dict = train_test_split(total_data, test_size=0.25, random_state=random_state)
training_dict, dev_dict = train_test_split(training_dict, test_size=0.1, random_state=random_state)

print(f"Total Dataset size - {len(total_data)}")
print(f"Training Dataset size - {len(training_dict)}")
print(f"Dev Dataset size - {len(dev_dict)}")
print(f"Testing Dataset size - {len(testing_dict)}")

In [None]:
# augmented_train_dict = augment_dataset(training_dict)
# augmented_dev_dict = augment_dataset(dev_dict)
augmented_train_dict = training_dict
augmented_dev_dict = dev_dict

##Creating the tasks for our Model

We define a function that takes one of the datasets and creates a ***N-Way*** one shot task where we choose ***N*** different letters (i.e Unique Characters from our dataset) and take one random image from each of them. Additionally one extra Image is chosen from one of the ***N***  chosen letters (different from the random image chosen). This now constitues our one shot task.<br>
If our model is able to successfully return the label of the image from the support set which best matches the target image, then our model is successful.

In [None]:
img_size = (28, 28)

def make_oneshot_task(dataset, nway=20):

  # Choose nway random letters from dataset
  letter_choices = np.random.choice(len(dataset), nway, replace=False)

  # Placeholders for our support dataset
  X = np.empty((nway, 1, *img_size))
  y = np.empty((nway), dtype=int)

  # Choose random letter from support set to be the target letter
  random_target = np.random.choice(nway, size=1)
  batch_letters = [dataset[k] for k in letter_choices]
  required_letter = None
  required_class = -1
  for i, letter in enumerate(batch_letters):
      # Choose random image from each letter
      letter_index = np.random.randint(len(letter["images"]))
      X[i,] = letter["image_mat"][letter_index]
      y[i] = i
      if(i == random_target[0]):
        # Fetching a different image for our target image from chosen letter
        required_index = (letter_index + np.random.randint(1, len(letter["images"]))) % len(letter["images"])
        required_letter = letter["image_mat"][required_index]
        required_class = i
  return X, y, required_letter, required_class

## Visualizing one-shot tasks

Helper Visualization function to ensure our datasets and task generation functions are valid.

In [None]:
import matplotlib.pyplot as plt

def visualize_oneshot_task(X, y, required_letter, required_class):
  fig=plt.figure(figsize=(8, 8))
  columns = 5
  rows = X.shape[0]//columns + 1
  for i in range(1, X.shape[0] + 1):
    img = X[i-1]
    fig.add_subplot(rows, columns, i).set_title(y[i-1])
    plt.axis('off')
    plt.xticks([])
    plt.yticks([])
    plt.imshow(img.reshape(*img_size))
  fig.add_subplot(rows, columns, X.shape[0] + 3).set_title(required_class)
  plt.imshow(required_letter.reshape(*img_size))
  plt.axis('off')
  plt.xticks([])
  plt.yticks([])
  plt.show()

In [None]:
# Visualize our training dataset
visualize_oneshot_task(*make_oneshot_task(augmented_train_dict))

In [None]:
# Visualize our validation/dev dataset
visualize_oneshot_task(*make_oneshot_task(augmented_dev_dict))

In [None]:
# Visualize our test dataset
visualize_oneshot_task(*make_oneshot_task(testing_dict))

Create a PyTorch `Dataset` class that allows us to leverage multiprocessing while loading our data between batches.

In [None]:
from torch.utils.data.dataset import Dataset
from torchvision import transforms
import torch

class OneShotDataset(Dataset):

  def __init__(self, data_dict, iterations=1000):
    self.data_dict = data_dict
    self.length = iterations

  def __getitem__(self, index):
    support_images, support_labels, target_image, target_label = make_oneshot_task(self.data_dict)
    return (torch.from_numpy(support_images).float(),
            torch.from_numpy(support_labels).long().unsqueeze(-1),
            torch.from_numpy(target_image).float(),
            target_label)
  def __len__(self):
    return self.length

Helper function that will allows us to evaluate our model

In [None]:
def evaluate_model(model, dataset_dict, iterations=1000, batch_size=20):
  model.eval() # Set the model in eval mode so as ensure Dropout and BatchNorm layers operate in Evaluation mode

  # Set up our dataloaders
  dataset = OneShotDataset(dataset_dict, iterations=iterations)
  dataset_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                  batch_size=batch_size,
                                                  num_workers=4,
                                                  shuffle=False)

  # Run in no_grad mode to speed up inference
  with torch.no_grad():
    # Placeholders to accumulate the metrics across batches
    total_accuracy = 0
    total_loss = 0
    for iteration, (support_images, support_labels, images, labels) in enumerate(dataset_loader):
      # Move Data to GPU
      support_batch_torch = support_images.to(device)
      support_labels_torch = support_labels.to(device)
      target_images_torch = images.to(device)
      target_labels_torch =labels.to(device)

      # Perform inference via our model
      logits, predictions = model(support_batch_torch, support_labels_torch, target_images_torch)
      # Calculate loss and accuracy
      loss = F.cross_entropy(logits, target_labels_torch)
      accuracy = torch.mean((predictions == target_labels_torch).float())

      # Accumulate the metrics
      total_loss = total_loss + loss.item()
      total_accuracy = total_accuracy + accuracy.item()

    accuracy = total_accuracy/(iteration+1)
    loss = total_loss/(iteration+1)
    return accuracy, loss

##Baseline KNN Accuracy

In [None]:
from sklearn.neighbors import KNeighborsClassifier

def evaluate_one_shot_performance(dataset, nway=20, num_tests=300):
  correct = 0
  for i in range(num_tests):
    support_set, support_labels, required_letter, required_label = make_oneshot_task(dataset, nway=nway)
    images = np.insert(support_set, 0, required_letter, axis=0)
    # Baseline model with just l2 Distance between images
    embeddings = images.reshape(nway+1, -1)
    neigh = KNeighborsClassifier(n_neighbors=1)
    neigh.fit(embeddings[1:], support_labels)
    prediction = neigh.predict(embeddings[:1])
    if(prediction[0] == required_label):
      correct = correct + 1
  print(f"Accuracy - {correct * 100 / num_tests}%")

In [None]:
evaluate_one_shot_performance(testing_dict, nway=20, num_tests=500)

#Training our model

##Hyperparameters

In [None]:
epochs = 60 #@param {type: "slider", min: 1, max: 100}
batch_size = 20 #@param {type: "slider", min: 20, max: 200}
iterations = 10000 #@param {type: "slider", min: 500, max: 10000}
lr = 0.0001 #@param ["0.1", "0.01", "0.001", "0.0001"] {type:"raw", allow-input: true}
embedding_size = 64 #@param {type: "slider", min: 64, max: 512}
dropout_probality=0.1 #@param {type: "slider", min: 0, max: 1, step: 0.1}
use_fce = True #@param ["True", "False"] {type:"raw"}

##Setting up logging for our training

We will now setup an experiment via comet.ml library. This will allow us to keep track of multiple runs of our code and the hyperparameters used in each iteration.
It also generates the loss and accuracy graphs over time.
We are mostly interested in logging the

1.   Train loss
2.   Train accuracy
3.   Dev loss
4.   Dev Accuracy
5.   Model Performance (Test Accuracy)
5.   Model loss (Test loss)


In [None]:
pytorch_experiment = Experiment(api_key="SCmpLSERj7defIsfWhQcs7D4E",
                        project_name="PyTorch Matching Net", workspace="ramesharvind")

In [None]:
pytorch_experiment.log_multiple_params({"Epochs":epochs, "Batch Size": batch_size, "Iterations Per Epoch": iterations, "Learning Rate": lr, "Embedding Size": embedding_size, "Dropout": dropout_probality, "Using FCE": use_fce})

##Training Loop

In [None]:
# Define the model
matching_net = MatchingNet(img_shape, embedding_size=embedding_size, dropout_probality=dropout_probality, use_fce=use_fce)
print("Model Summary")
print(matching_net)

In [None]:
# Move the model to GPU
matching_net.to(device)
# Create the optimizer
optimizer = torch.optim.Adam(matching_net.parameters(), lr=lr)
# placeholders to store our best performing metrics on Dev Set
best_dev_loss = 0
best_dev_accuracy = 0

print("Beginning Training..")

for epoch in range(epochs):
  matching_net.train() # Set our model in train mode

  # Placeholder variables to help track epoch loss and accuracy
  total_loss = 0
  total_accuracy = 0

  # Define our dataset and dataloader
  train_dataset = OneShotDataset(augmented_train_dict, iterations)
  train_dataset_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                  batch_size=batch_size,
                                                  num_workers=4,
                                                  shuffle=False)
  # Main training loop
  for iteration, (support_images, support_labels, images, labels) in enumerate(train_dataset_loader):
    # Move our data to GPU
    support_batch_torch = support_images.to(device)
    support_labels_torch = support_labels.to(device)
    target_images_torch = images.to(device)
    target_labels_torch =labels.to(device)
    # Get the predictions and logits from our model
    logits, predictions = matching_net(support_batch_torch, support_labels_torch, target_images_torch)

    # Calculate loss and accuracy in current iteration
    loss = F.cross_entropy(logits, target_labels_torch)
    accuracy = torch.mean((predictions == target_labels_torch).float())

    # Accumulate the values
    total_loss = total_loss + loss.item()
    total_accuracy = total_accuracy + accuracy.item()

    # Perform Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  total_accuracy = total_accuracy/(iteration+1)
  total_loss = total_loss/(iteration+1)

  # Log the training metrics to comet.ml
  with pytorch_experiment.train():
    pytorch_experiment.log_current_epoch(epoch)
    pytorch_experiment.log_metric("loss",total_loss)
    pytorch_experiment.log_metric("accuracy",total_accuracy)
  print(f"In epoch - {epoch} Train Set - Accuracy {total_accuracy} Loss - {total_loss} - for {(iteration+1)} iterations")


  # Run the model on Dev Set to evaluate performance on unseen data (every 3 epochs)
  if epoch%3 == 0:
    with pytorch_experiment.validate():
      pytorch_experiment.log_current_epoch(epoch)
      dev_accuracy, dev_loss = evaluate_model(matching_net, augmented_dev_dict)
      print(f"--In epoch - {epoch} Dev Set Accuracy - {dev_accuracy} Loss - {dev_loss}")

      # Save the best performing model across all the epochs
      if not best_dev_loss or dev_loss < best_dev_loss:
        print(f"---Found Better Model to save with Accuracy - {dev_accuracy} and loss - {dev_loss}")
        best_dev_loss = dev_loss
        best_dev_accuracy = dev_accuracy
        torch.save(matching_net.state_dict(), "matching_net-%0.2f-accuracy.pt"%(best_dev_accuracy))

      # Log the Dev metrics to comet.ml
      pytorch_experiment.log_metric("accuracy",dev_accuracy)
      pytorch_experiment.log_metric("dev_loss",dev_loss)

##Evaluating our model with the Test Data

In [None]:
with pytorch_experiment.test():
  matching_net_test = MatchingNet(img_shape, embedding_size=embedding_size, dropout_probality=dropout_probality, use_fce=use_fce)
  matching_net_test.load_state_dict(torch.load("matching_net-%0.2f-accuracy.pt"%(best_dev_accuracy)))
  matching_net_test.to(device)
  test_accuracy, test_loss = evaluate_model(matching_net_test, testing_dict, iterations=5000)
  pytorch_experiment.log_metric("accuracy",test_accuracy)
  pytorch_experiment.log_metric("test_loss",test_loss)
  print(f"Test Set Accuracy - {test_accuracy} Loss - {test_loss}")
pytorch_experiment.end()

# Pre-trained Models

## Downloading our weights

In [None]:
!git clone https://github.com/RameshArvind/Pytorch-Matching-Networks.git

In [None]:
model_weights_with_fce = "Pytorch-Matching-Networks/matching_net-FCE-0.73-accuracy.pt"
model_weights_without_fce = "Pytorch-Matching-Networks/matching_net-WITHOUT-FCE-0.74-accuracy.pt"

## Model Without FCE

In [None]:
embedding_size = 64
dropout_probality=0.1
use_fce = False

matching_net_test = MatchingNet(img_shape, embedding_size=embedding_size, dropout_probality=dropout_probality, use_fce=use_fce)
matching_net_test.load_state_dict(torch.load(model_weights_without_fce))
matching_net_test.to(device)
test_accuracy, test_loss = evaluate_model(matching_net_test, testing_dict, iterations=5000)
pytorch_experiment.log_metric("accuracy",test_accuracy)
pytorch_experiment.log_metric("test_loss",test_loss)
print(f"Without FCE - Test Set Accuracy - {test_accuracy} Loss - {test_loss}")

## Model With FCE

In [None]:
embedding_size = 64
dropout_probality=0.1
use_fce = True

matching_net_test = MatchingNet(img_shape, embedding_size=embedding_size, dropout_probality=dropout_probality, use_fce=use_fce)
matching_net_test.load_state_dict(torch.load(model_weights_with_fce))
matching_net_test.to(device)
test_accuracy, test_loss = evaluate_model(matching_net_test, testing_dict, iterations=5000)
pytorch_experiment.log_metric("accuracy",test_accuracy)
pytorch_experiment.log_metric("test_loss",test_loss)
print(f"With FCE - Test Set Accuracy - {test_accuracy} Loss - {test_loss}")

#Conclusion
We are able to train a model that can solve the one shot 20-way classification problem to a reasonable degree with about **~73% with Fully Conditional Embedding** (and **~77% accuracy without it**) on the test set.

Some of the approaches we can try to reach the reported 93.8% results are -


1.   Use image augmentation
2.   Automated Hyperparameter optimization
3.   Utilize correct implmentation of the Fully conditional embedding for the target image.




#TODO Items
###<s>Make dataloader - To speed up training</s> - DONE - Loading all the numpy images at start speeds up things by a TON
###<s>Figure out what optimizer was used - SGD/ADAM? </s>- ADAM WORKS
### <s>Use image augmentation </s> - Done
###<s>There is still something wrong with the network - Probably something with the way augmentation is done?</s> There isn't a problem with augmentation, rather the training and testing set provided by default have different "distributions". Need to investigate this further. Fixed by mixing the provided test and train set and creating new Train/Dev/Test sets.
###<s>Setup a way to log training of models</s> - Done via integraion with comet.ml
###Setup Hyperparameter Optimization
###Use a high level pytorch API to setup training loops and Checkpointing - PyTorch Ignite
###Using a dense layer at the end during embedding was causing the network to converge faser and better - Figure out why
###<s>Work on the LSTM based embedding</s> - Done Albeit incorrectly
### Implement Custom LSTM Cell which has hidden weights which can accept concatenated input
