<a href="https://colab.research.google.com/github/MLenthusiastic/Solution/blob/master/Davis_Contrastive_Working.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
from torch.utils.data import Dataset
import json
import numpy as np
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
from google.colab.patches import cv2_imshow
from IPython.display import Image, display
import torch
import math
import torch.nn.functional as F
from argparse import ArgumentParser
import copy
import torch.nn.functional as F

parser = ArgumentParser(description='Solution Network')
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--constractive_loss_margin', type=float, default=0.8)
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--img_size', type=int, default=200)

args, unknown = parser.parse_known_args()

DEVICE = args.device
if not torch.cuda.is_available():
    DEVICE = 'cpu'

class SimaseDavis(Dataset):

  def __init__(self, json_data, memmap):
      self.memmap = memmap
      self.json_data = json_data
      self.shape = self.json_data["shape"]
      self.objects = self.json_data["objects"]
      self.classes = self.json_data["classes"]

  def __getitem__(self, index):
    img_object_1 = self.objects[str(index)]
    img_object_1_width = img_object_1['width']
    img_object_1_height = img_object_1['height']
    image_1 =  self.memmap[index, :, :img_object_1_width, :img_object_1_height].astype(np.float32)
    class_1 = img_object_1['class']

    target = np.random.randint(0, 2)

    if target == 0:  #similar classes

        random_class_2 = np.random.choice(self.classes[class_1])
        random_class_2_idx = random_class_2["object_idx"]
        if index == random_class_2_idx:
          temp_classes = copy.deepcopy(self.classes[class_1])
          temp_classes.remove(random_class_2)
          random_class_2 = np.random.choice(temp_classes)
          random_class_2_idx = random_class_2["object_idx"]
        class_2 = class_1
        img_object_2 = self.objects[str(random_class_2_idx)]
        img_object_2_width = img_object_2['width']
        img_object_2_height = img_object_2['height']
        image_2 =  self.memmap[random_class_2_idx, :, :img_object_2_width, :img_object_2_height].astype(np.float32)

    else:
      all_class_labels = copy.deepcopy(list(self.classes.keys()))
      all_class_labels.remove(class_1)
      class_2 = np.random.choice(all_class_labels)
      img_object_2 = np.random.choice(self.classes[class_2])
      img_object_2_idx = img_object_2["object_idx"]
      img_object_2_width = img_object_2["width"]
      img_object_2_height = img_object_2["height"]
      image_2 =  self.memmap[img_object_2_idx, :, :img_object_2_width, :img_object_2_height].astype(np.float32)
    
    image_1_tensor_ext = torch.from_numpy(image_1).unsqueeze(dim=0)
    image_2_tensor_ext = torch.from_numpy(image_2).unsqueeze(dim=0)

    image_1_tensor = F.interpolate(image_1_tensor_ext, size=(args.img_size,args.img_size))
    image_2_tensor = F.interpolate(image_2_tensor_ext, size=(args.img_size,args.img_size))

    return image_1_tensor, image_2_tensor, target, class_1, class_2

  def __len__(self):
      return self.shape[0]

class Encoder(nn.Module):

    def __init__(self):
        super(Encoder, self).__init__()

        #conv and fc works as encoder
        self.conv = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2, stride=2),
                                  nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2, stride=2),
                                  nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2, stride=2)
                                  )
        
        # output 128, 21, 21 
        self.fc = nn.Sequential(nn.Linear(128 * 21 * 21, 1024),
                                nn.ReLU(),
                                nn.Linear(1024, 1024),
                                nn.ReLU(),
                                nn.Linear(1024, 256)
                                )
        
    def forward(self, in1, in2):
        x = torch.cat((in1, in2), dim=0)
        x = self.conv(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        z_out1, Z_out2 = torch.split(x, x.size(0) // 2, dim=0)
        return z_out1, Z_out2


class ContrastiveLoss(nn.Module):

    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output, target):
        eq_distance = F.pairwise_distance(output[0], output[1])
        loss = 0.5 * (1 - target) * torch.pow(eq_distance, 2) + \
               0.5 * target * torch.pow(torch.clamp(self.margin - eq_distance, min=0.00), 2)
        return loss.mean()
        
folder_path = './drive/My Drive/'
with open(folder_path+'sdavis_json.txt') as json_file:
    davis_json = json.load(json_file)

shape = davis_json["shape"]
memmap_path = folder_path+'s_davis.mmap'
davis_memmap = np.memmap(memmap_path, dtype='uint8', mode='r', shape=tuple(shape))

davis_dataset = SimaseDavis(davis_json,davis_memmap)
davis_dataloader = torch.utils.data.DataLoader(davis_dataset, batch_size =args.batch_size, shuffle=True)

'''

#showing pairs
itf = next(iter(davis_dataloader))
img1, img2, target, class_1, class_2 = itf
img1 = img1.squeeze(dim=1)
img2 = img2.squeeze(dim=1)

for k in range(4):
  i = img1[k].numpy().transpose(2,1,0)
  cv2_imshow(i)
  j = img2[k].numpy().transpose(2,1,0)
  cv2_imshow(j)

'''
encoder = Encoder()
encoder = encoder.to(DEVICE)

criterion = ContrastiveLoss(margin=args.constractive_loss_margin)
optimizer = torch.optim.Adam(params=encoder.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

print('training started')

encoder.train()
torch.set_grad_enabled(True)

for epoch in range(1,args.num_epochs+1):

  batch_losses = []
  
  for batch in davis_dataloader:
      img_objects_1, img_objects_2, target , class_1, class_2 = batch
      
      img_objects_1_t = img_objects_1.squeeze(dim=1)
      img_objects_2_t = img_objects_2.squeeze(dim=1)
      img_obj_1 = img_objects_1_t.to(DEVICE)
      img_obj_2 = img_objects_2_t.to(DEVICE)
   
      z_out1, z_out2 = encoder(img_obj_1, img_obj_2)

      z_out = [z_out1, z_out2]

      target = target.to(DEVICE)

      loss = criterion(z_out, target)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
      batch_losses.append(loss.item())
  
  print('Epoch : ', epoch, 'Loss : ', np.mean(batch_losses))

training started
Epoch :  1 Loss :  48.87554228305817
Epoch :  2 Loss :  0.08663841262459755
Epoch :  3 Loss :  0.372602079808712
Epoch :  4 Loss :  0.08054494934853892
Epoch :  5 Loss :  0.1343840226787029
Epoch :  6 Loss :  0.16552110090851785
Epoch :  7 Loss :  0.13352459820684773
Epoch :  8 Loss :  0.1900014728307724
Epoch :  9 Loss :  0.10929850118932109
Epoch :  10 Loss :  0.14953649642643313
