# **Second Model**
(Based on current understanding of task)

## First Half

In [None]:
class FirstHalf(nn.Module):
  def __init__(self):
    super().__init__()

    self.vgg1 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
    self.vgg2 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

    # Freezing initial layers for finetuning
    for param1, param2 in zip(self.vgg1.features.parameters(), self.vgg2.features.parameters()):
      param1.requires_grad = False
      param2.requires_grad = False

    # If needed can modify size of final output...
    #self.vgg1.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    #self.vgg2.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    #self.vgg3.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)

    # Feed Forward Network turns output of VGG into embedding # TODO: decide final size...
    self.FNN = nn.Sequential(
        nn.Linear(3000, 2048),
        nn.LayerNorm(2048),
        nn.ReLU(),

        nn.Linear(2048, 1024),
        nn.LayerNorm(1024),
        nn.ReLU(),

        nn.Linear(1024, 512)
    )

  def forward(self, ground_view, synthetic_aerial, segmented_aerial, candidate_aerial):
    #x_ground = self.vgg1(ground_view)
    x_segmented = self.vgg1(segmented_aerial)
    x_synthetic = self.vgg2(synthetic_aerial)
    x_candidate = self.vgg2(candidate_aerial)

    x = torch.cat((x_synthetic, x_segmented, x_candidate), dim=-1)
    x = self.FNN(x)

    return x

## Second Half

In [None]:
class SecondHalf(nn.Module):
  def __init__(self):
    super().__init__()

    self.vgg1 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
    self.vgg2 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

    # Freezing initial layers for finetuning
    for param1, param2 in zip(self.vgg1.features.parameters(), self.vgg2.features.parameters()):#, self.vgg1.features.parameters()):
      param1.requires_grad = False
      param2.requires_grad = False

    # If needed can modify size of final output...
    #self.vgg1.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    #self.vgg2.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    #self.vgg3.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)

    # Feed Forward Network turns output of VGG into embedding # TODO: decide final size...
    self.FNN = nn.Sequential(
        nn.Linear(2000, 1024),
        nn.LayerNorm(1024),
        nn.ReLU(),

        nn.Linear(1024, 512)
    )

  def forward(self, ground_view, segmented_ground):
    x_ground = self.vgg2(ground_view)
    x_segmented = self.vgg3(segmented_ground)

    x = torch.cat((x_ground, x_segmented), dim=-1)
    x = self.FNN(x)

    return x

## Complete Network

In [None]:
class CompNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.FH = FirstHalf()
    self.SH = SecondHalf()

  def forward(self, ground_view, segmented_ground, synthetic_aerial, segmented_aerial, candidate_aerial):
    return self.FH(ground_view, segmented_ground), self.SH(synthetic_aerial, segmented_aerial, candidate_aerial)

## Triplet Loss

In [8]:
class WeightedSoftMarginTripletLoss(nn.Module):
  def __init__(self, margin=0.2):
    super().__init__()
    self.margin = margin

  def forward(self, anchor, positive, negatives):

    first = torch.norm(anchor - positive, dim=-1, keepdim=True)
    second = -torch.norm(anchor - negatives, dim=-1, keepdim=True)
    arg = self.margin * (second+first)
    const = torch.zeros((arg.shape[0],1))
    arg = torch.cat((const,arg), dim=-1)

    return torch.logsumexp(arg, dim=-1).mean()


### Training

In [None]:
def train_epoch(model, dataloader, optimizer, criterion=WeightedSoftMarginTripletLoss(), scheduler=None):
    model.train()
    tot_loss = 0.0

    for ground, aerial, segmented_ground, synthetic_aerial, segmented_aerial in dataloader:
      ground = ground.to(device)
      aerial = aerial.to(device)
      segmented_ground = segmented_ground.to(device)
      synthetic_aerial = synthetic_aerial.to(device)
      segmented_aerial =  segmented_aerial.to(device)

      # Forward pass
      optimizer.zero_grad()       # resets gradients from previous batch
      labels, predictions = model( ground, segmented_ground, synthetic_aerial, segmented_aerial, aerial)
      #print(aerial_pred.shape)

      batch_size = len(labels)
      batch_loss = 0.0

      for i in range(batch_size):
        indeces = range(batch_size)
        indeces.pop(i)
        anchor = labels[i]
        positive = predictions[i]
        negatives = predictions(indeces)

        # Sanity check
        if negatives.shape[0] != batch_size-1:
          print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")

        batch_loss += criterion(anchor, positive, negatives)
        #Backward pass
      batch_loss.backward()         # computes gradients via backpropagation
      optimizer.step()        # updates weights using gradients

      tot_loss += batch_loss.item()/batch_size

    return tot_loss / len(dataloader)

def evaluate(model, dataloader, device, criterion=WeightedSoftMarginTripletLoss()):
  model.eval()
  total_loss = 0.0

  with torch.no_grad():
    tot_loss = 0.0

    for ground, aerial, segmented_ground, synthetic_aerial, segmented_aerial in dataloader:
      ground = ground.to(device)
      aerial = aerial.to(device)
      segmented_ground = segmented_ground.to(device)
      synthetic_aerial = synthetic_aerial.to(device)
      segmented_aerial =  segmented_aerial.to(device)

      # Forward pass
      optimizer.zero_grad()       # resets gradients from previous batch
      labels, predictions = model( ground, segmented_ground, synthetic_aerial, segmented_aerial, aerial)
      #print(aerial_pred.shape)

      batch_size = len(labels)
      batch_loss = 0.0

      for i in range(batch_size):
        indeces = range(batch_size)
        indeces.pop(i)
        anchor = labels[i]
        positive = predictions[i]
        negatives = predictions(indeces)

        # Sanity check
        if negatives.shape[0] != batch_size-1:
          print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")

        batch_loss += criterion(anchor, positive, negatives)

      tot_loss += batch_loss.item()/batch_size

    return tot_loss / len(dataloader)

In [None]:
# Main training
num_epochs = 1
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = evaluate(model, val_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"\tTrain Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    scheduler.step()    # adjusts learning rate after each epoch

    # Save checkpoint
    if (epoch+1) % 5 == 0:
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")