<a href="https://colab.research.google.com/github/Amaljayaranga/Solution/blob/master/Davis_Simple_Contrastive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from torch.utils.data import Dataset
import json
import numpy as np
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
from google.colab.patches import cv2_imshow
from IPython.display import Image, display
import torch
import math
import torch.nn.functional as F
from argparse import ArgumentParser
import copy

parser = ArgumentParser(description='Solution Network')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--constractive_loss_margin', type=float, default=0.8)
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--img_size', type=int, default=200)

args, unknown = parser.parse_known_args()

DEVICE = args.device
if not torch.cuda.is_available():
    DEVICE = 'cpu'


class SimaseDavis(Dataset):

  def __init__(self, json_data, memmap, need_classes):
      self.memmap = memmap
      self.json_data = json_data
      self.need_classes = need_classes
      self.having_classes = self.json_data["classes"]
      self.shape = self.json_data["shape"]

  def image_to_tensor(self,image, mean=0, std=1.):
        image = image.astype(np.float32)
        image = (image - mean) / std
        image = image.transpose((2, 0, 1))
        tensor = torch.from_numpy(image)
        return tensor

  def crop_objects(self, img_list):
    objects = []
    path = './drive/My Drive/Thesis_2020/dataset/'
    for obj in img_list:
        idx = (obj['index'])
        x = obj['x']
        y = obj['y']
        width = obj['width']
        height = obj['height']
        img_path = path+str(idx)+'.jpg'
        #display(Image(img_path))
        image = cv2.imread(img_path)
        crop_img = image[y:y + height, x:x + width]
        #plt.imshow(crop_img, interpolation='nearest')
        #plt.show()
        img_resized = cv2.resize(crop_img, (args.img_size,args.img_size), interpolation = cv2.INTER_AREA)
        objects.append(self.image_to_tensor(img_resized))
    return objects
  
  def __getitem__(self, index):
    target = np.random.randint(0, 2)
    crop_objects = []

    if target == 0 :
      class_label = np.random.choice(self.need_classes)
      images_per_class = self.having_classes[class_label]
      selection = np.random.choice(images_per_class, 2)
      crop_objects = self.crop_objects(selection)
     
    else:
       class_label = np.random.choice(self.need_classes)
       images_per_class = self.having_classes[class_label]
       img1  = np.random.choice(images_per_class)
       temp_list = copy.deepcopy(self.need_classes)
       temp_list.remove(class_label)
       other_class = np.random.choice(temp_list)
       images_other_class = self.having_classes[other_class]
       img2  = np.random.choice(images_other_class)
       crop_objects = self.crop_objects([img1,img2])
    
    return crop_objects[0], crop_objects[1], target

  def __len__(self):
      return self.shape[0]

class Encoder(nn.Module):

    def __init__(self):
        super(Encoder, self).__init__()

        #conv and fc works as encoder
        self.conv = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2, stride=2),
                                  nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2, stride=2),
                                  nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2, stride=2)
                                  )
        
        # output 128, 21, 21 
        self.fc = nn.Sequential(nn.Linear(128 * 21 * 21, 1024),
                                nn.ReLU(),
                                nn.Linear(1024, 1024),
                                nn.ReLU(),
                                nn.Linear(1024, 256)
                                )
        
    def forward(self, in1, in2):
        x = torch.cat((in1, in2), dim=0)
        x = self.conv(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        z_out1, Z_out2 = torch.split(x, x.size(0) // 2, dim=0)
        return z_out1, Z_out2


class ContrastiveLoss(nn.Module):

    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output, target):
        eq_distance = F.pairwise_distance(output[0], output[1])
        loss = 0.5 * (1 - target) * torch.pow(eq_distance, 2) + \
               0.5 * target * torch.pow(torch.clamp(self.margin - eq_distance, min=0.00), 2)
        return loss.mean()
        


folder_path = './drive/My Drive/Thesis_2020/'
with open(folder_path+'davis-json.txt') as json_file:
    davis_json = json.load(json_file)

memmap_path = './drive/My Drive/Thesis_2020/memmap/'
shape = davis_json["shape"]
complete_memmap = memmap_path+'davis.mmap'
davis_memmap = np.memmap(complete_memmap, dtype='uint8', mode='r', shape=tuple(shape))


need_classes = ['person', 'tennis racket','sports ball','horse','surfboard','skateboard',
                'sheep','motorcycle','backpack','kite','sports ball','bird','dog','bicycle',
                'handbag']

davis_dataset = SimaseDavis(davis_json,davis_memmap,need_classes)
davis_dataloader = torch.utils.data.DataLoader(davis_dataset, batch_size = args.batch_size, shuffle=True)

encoder = Encoder()
encoder = encoder.to(DEVICE)

criterion = ContrastiveLoss(margin=args.constractive_loss_margin)
optimizer = torch.optim.Adam(params=encoder.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

print('training started')

encoder.train()
torch.set_grad_enabled(True)

for epoch in range(args.num_epochs):

  batch_losses = []

  for batch in davis_dataloader:
      img_objects_1, img_objects_2, target  = batch
      img_objects_1 = img_objects_1.to(DEVICE)
      img_objects_2 = img_objects_2.to(DEVICE)

      z_out1, z_out2 = encoder(img_objects_1, img_objects_2)

      z_out = [z_out1, z_out2]

      target = target.to(DEVICE)

      loss = criterion(z_out, target)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      batch_losses.append(loss.item())
  
  print('Epoch : ', epoch+1, 'loss : ', np.mean(batch_losses))
  
  





  

training started
Epoch :  1 loss :  6.376217713113874
Epoch :  2 loss :  0.08317986456677318
Epoch :  3 loss :  0.09014493541326374
Epoch :  4 loss :  0.0715072265593335
Epoch :  5 loss :  0.07269156526308507
Epoch :  6 loss :  0.06413766578771174
Epoch :  7 loss :  0.05644102836959064
Epoch :  8 loss :  0.050493661779910326
Epoch :  9 loss :  0.0497823950718157
Epoch :  10 loss :  0.04216685553546995
