# **Second Model**
(Based on current understanding of task)

In [6]:
%pip install pytorch_lightning -q

In [1]:
import torch, torchvision
from torch import nn
from torch.nn import functional as F
from torchvision import models
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from PIL import Image
import random

In [7]:
import pytorch_lightning as pylight
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


## Dataset

In [9]:
# Define color-to-class mapping
color_map = {
    (255, 255, 255): 0,   # white -> class 0
    (255, 0, 0): 1,       # red -> class 1
    (0, 0, 255): 2,       # blue -> class 2
    (0, 255, 0): 3,       # green -> class 3
    (255, 255, 0): 4,     # yellow -> class 4
    (0, 255, 255): 5,     # cyan -> class 5

    # pixels that are not in previous classes are considered 0
    (255, 0, 255): 0,
    (255, 255, 255): 0
}


def rgb_to_label(rgb_image):
    """Convert RGB image to class index tensor"""

    # Convert to numpy array
    np_image = np.array(rgb_image)
    # manage different colors
    np_image = np.where(np_image < 128, 0, 255)
    # Create empty label map
    label_map = np.zeros((np_image.shape[0], np_image.shape[1]), dtype=np.int64)

    # Map colors to class indices
    for color, class_idx in color_map.items():
        color_arr = np.array(color)
        mask = (np_image == color_arr).all(axis=-1)
        label_map[mask] = class_idx

    return torch.from_numpy(label_map).unsqueeze(0)


# path = '/content/drive/MyDrive/CVUSA_subset/polarmap/segmap/output0000008.png'
# img = Image.open(path)
# img_np = np.array(img)
# print(img_np.shape)     # (128, 512, 3)

# res = rgb_to_label(img)
# print(res)
# print(res.shape)

In [10]:
class CVUSADataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, triplet_list, img_size=224, train=False, add_dir = ""):
        self.dataset_dir = dataset_dir
        self.triplet_list = triplet_list
        self.img_size = img_size
        self.add_dir = add_dir

        # transformations with data augmentation
        if train:
            # For ground view images
            self.ground_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(10),
                transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # standard ImageNet normalization
            ])
        else:
            # For ground view images
            self.ground_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # standard ImageNet normalization
            ])


        # For aerial images
        self.aerial_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # standard ImageNet normalization
        ])

        # For aerial segmentation maps
        self.segmentation_transform = transforms.Compose([
            transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.Lambda(rgb_to_label),
            transforms.Lambda(lambda x: x.squeeze(0).long())  # (H, W) int64 tensor
        ])


    def __len__(self):
        return len(self.triplet_list)

    def __getitem__(self, idx):
        ground_rel, aerial_rel, seg_rel = self.triplet_list[idx]

        max_retries = 3
        for attempt in range(max_retries):
            try:
                # Load images
                ground_img = Image.open(self.dataset_dir + ground_rel).convert('RGB')
                aerial_img = Image.open(self.dataset_dir + aerial_rel).convert('RGB')
                synth_img = Image.open((self.dataset_dir + "generated_images/generated_images_model_1/"+self.add_dir+"/segmap_predictions/seg_pred_"+ground_rel.split("/")[-1].split(".")[0]+".png").strip()).convert('RGB')
                seg_synth_img = Image.open((self.dataset_dir + "generated_images/generated_images_model_1/"+self.add_dir+"/aerial_predictions/aerial_pred_"+ground_rel.split("/")[-1].split(".")[0]+".png").strip()).convert('RGB')

                # Break if successful
                break
            except Exception as e:
                print(f"Error loading images for index {idx}: {e}")
                if attempt == max_retries - 1:
                    raise e


        # Apply transforms
        ground_tensor = self.ground_transform(ground_img).to(dtype=torch.float32)
        aerial_tensor = self.aerial_transform(aerial_img).to(dtype=torch.float32)
        synth_tensor = self.aerial_transform(synth_img).to(dtype=torch.float32)
        seg_synth_tensor = self.segmentation_transform(seg_synth_img).to(dtype=torch.float32)

                              # candidate    # generated   # segmented generated
        return ground_tensor, aerial_tensor, synth_tensor, seg_synth_tensor  # ground, aerial, synthetic_aerial, segmented_aerial

In [11]:
def read_triplets_csv(csv_path):
    """Reads CSV file into list of (aerial, ground, seg) triplets"""
    triplets = []
    with open(csv_path, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            triplets.append((
                parts[0].strip(),  # streetview
                parts[1].strip(),  # polar aerial
                parts[2].strip()   # polar seg
            ))
    return triplets


dataset_dir = "/content/drive/MyDrive/CVUSA_subset/"

In [12]:
train_triplets = read_triplets_csv("/content/drive/MyDrive/CVUSA_subset/train.csv")
train_triplets, val_triplets = train_test_split(train_triplets, test_size=0.15, random_state=19)  # training/validation set
test_triplets = read_triplets_csv("/content/drive/MyDrive/CVUSA_subset/val.csv")        # test set

In [13]:
BATCH_SIZE = 8

train_dataset = CVUSADataset(dataset_dir, train_triplets, train=True, add_dir="train_set")
val_dataset = CVUSADataset(dataset_dir, val_triplets, add_dir="val_set")
test_dataset = CVUSADataset(dataset_dir, test_triplets, add_dir="test_set")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,drop_last=True)

## Auxiliary module





In [14]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

## Upper Branch

### simplified version

In [None]:
class GroundBranchSim(nn.Module):
  def __init__(self, use_seg=False):
    super().__init__()

    self.use_seg = use_seg

    self.vgg1 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
    self.vgg2 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

    # Freezing initial layers for finetuning
    for param1, param2 in zip(self.vgg1.features.parameters(), self.vgg2.features.parameters()):#, self.vgg1.features.parameters()):
      param1.requires_grad = False
      param2.requires_grad = False

    # If needed can modify size of final output...
    #self.vgg1.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    #self.vgg2.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)

    # Modify size of input
    self.vgg1.features[0] = nn.Conv2d(
        1, # This may need to be 3...
        64,
        kernel_size=3,
        padding=1
    )
    # Initiate weights
    nn.init.kaiming_normal_(self.vgg1.features[0].weight)

    # Feed Forward Network turns output of VGG into embedding # TODO: decide final size...
    self.FNN = nn.Sequential(
      nn.Linear(1000, 1024),
      nn.LayerNorm(1024),
      nn.ReLU(),

      nn.Linear(1024, 512)
    )

  def forward(self, ground_view):
    x_ground = self.vgg2(ground_view)

    x = x_ground

    x = self.FNN(x)

    return x

### amplified version

In [15]:
class GroundBranch(nn.Module):
  def __init__(self, use_seg=False):
    super().__init__()

    self.use_seg = use_seg

    self.vgg1 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

    # Freezing first five conv layers for finetuning
    for n, param1 in enumerate(self.vgg1.features.parameters()):
      if n >=6:
        break
      param1.requires_grad = False

    # If needed can modify size of final output...
    #self.vgg1.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    #self.vgg2.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    # Remove fully connected layer
    self.vgg1.classifier = nn.Sequential(Identity())

    self.conv = nn.Conv2d(512,512,1)
    self.sigma = nn.Sigmoid()
    self.GAP = nn.AdaptiveAvgPool2d((1,1))

    # Feed Forward Network turns output of VGG into embedding # TODO: decide final size...
    self.FNN = nn.Sequential(
      nn.Linear(512, 1024),
      nn.LayerNorm(1024),
      nn.ReLU(),

      nn.Linear(1024, 512)
    )

  def forward(self, ground_view):
    x_ground = self.vgg1(ground_view).view(-1,512,7,7)

    return x_ground
    """
    x = self.GAP(x_ground)
    x = x.view(x.shape[0], -1)

    x = self.FNN(x)

    return x
    """

## Lower Branch

### simplified version

In [None]:
class AerialBranchSim(nn.Module):
  def __init__(self):
    super().__init__()

    self.vgg1 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
    self.vgg2 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

    # Freezing initial layers for finetuning
    for param1, param2 in zip(self.vgg1.features.parameters(), self.vgg2.features.parameters()):
      param1.requires_grad = False
      param2.requires_grad = False

    # Modify size of input
    self.vgg1.features[0] = nn.Conv2d(
        1,
        64,
        kernel_size=3,
        padding=1
    )
    # Initiate weights
    nn.init.kaiming_normal_(self.vgg1.features[0].weight)


    # If needed can modify size of final output...
    #self.vgg1.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    #self.vgg2.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)

    # Feed Forward Network turns output of VGG into embedding # TODO: decide final size...
    self.FNN = nn.Sequential(
        nn.Linear(3000, 2048),
        nn.LayerNorm(2048),
        nn.ReLU(),

        nn.Linear(2048, 1024),
        nn.LayerNorm(1024),
        nn.ReLU(),

        nn.Linear(1024, 512)
    )

  def forward(self, synthetic_aerial, segmented_aerial, candidate_aerial):

    x_segmented = self.vgg1(segmented_aerial)
    x_synthetic = self.vgg2(synthetic_aerial)
    x_candidate = self.vgg2(candidate_aerial)

    x = torch.cat((x_synthetic, x_segmented, x_candidate), dim=-1)
    x = self.FNN(x)

    return x

### amplified version

In [16]:
class AerialBranch(nn.Module):
  def __init__(self):
    super().__init__()

    self.vgg1 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
    self.vgg2 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

    # Freezing initial layers for finetuning
    for n, (param1, param2) in enumerate(zip(self.vgg1.features.parameters(), self.vgg2.features.parameters())):
      if n >=6:
        break
      param1.requires_grad = False
      param2.requires_grad = False

    # Modify size of input
    self.vgg1.features[0] = nn.Conv2d(
        1,
        64,
        kernel_size=3,
        padding=1
    )
    # Initiate weights
    nn.init.kaiming_normal_(self.vgg1.features[0].weight)

    # If needed can modify size of final output...
    #self.vgg1.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    #self.vgg2.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    # Remove fully connected layer
    self.vgg1.classifier = nn.Sequential(Identity())
    self.vgg2.classifier = nn.Sequential(Identity())

    self.conv = nn.Conv2d(512,512,1)
    self.sigma = nn.Sigmoid()
    self.GAP = nn.AdaptiveAvgPool2d((1,1))

    # Feed Forward Network turns output of VGG into embedding # TODO: decide final size...
    self.FNN = nn.Sequential(
        nn.Linear(1536, 1024),
        nn.LayerNorm(1024),
        nn.ReLU(),

        nn.Linear(1024, 512)
    )

  def forward(self, synthetic_aerial, segmented_aerial, candidate_aerial):

    segmented_aerial = segmented_aerial.reshape(8,1,224,224) if segmented_aerial.shape[0] == 8 else segmented_aerial
    x_segmented = self.vgg1(segmented_aerial).view(-1,512,7,7)
    x_synthetic = self.vgg2(synthetic_aerial).view(-1,512,7,7)
    x_candidate = self.vgg2(candidate_aerial).view(-1,512,7,7)

    mask = self.conv(x_segmented)

    x_candidate *= mask
    x_synthetic *= mask

    x_candidate = self.GAP(x_candidate)
    x_candidate = x_candidate.view(x_candidate.shape[0], -1)
    x_synthetic = self.GAP(x_synthetic)
    x_synthetic = x_synthetic.view(x_synthetic.shape[0], -1)
    x_segmented = self.GAP(x_segmented)
    x_segmented = x_segmented.view(x_segmented.shape[0], -1)

    x = torch.cat((x_synthetic, x_segmented, x_candidate), dim=-1)
    x = self.FNN(x)

    return x

## Complete Network

In [17]:
class CompNet(nn.Module):
  def __init__(self, ground_branch=GroundBranch(), aerial_branch=AerialBranch()):
    super().__init__()
    self.GB = ground_branch
    self.AB = aerial_branch

  def forward(self, ground_view, synthetic_aerial, segmented_aerial, candidate_aerial):
    return self.GB(ground_view), self.AB(synthetic_aerial, segmented_aerial, candidate_aerial)

### Triplet Loss

In [18]:
class WeightedSoftMarginTripletLoss(nn.Module):
  def __init__(self, margin=0.2):
    super().__init__()
    self.margin = margin

  def forward(self, anchor, positive, negative):

    first = torch.norm(anchor - positive, dim=-1, keepdim=True)
    second = -torch.norm(anchor - negative, dim=-1, keepdim=True)
    arg = self.margin * (second+first)
    const = torch.zeros((arg.shape[0],1))
    arg = torch.cat((const,arg), dim=-1)

    return torch.logsumexp(arg, dim=-1).mean()


### Lightning Wrapper

In [19]:
class LightningWrapper(pylight.LightningModule):
  def __init__(self, device, model=CompNet()):
    super().__init__()
    self.dvc=device

    self.model=model
    self.criterion=WeightedSoftMarginTripletLoss()

    self.mask = mask = torch.ones((8,8), requires_grad=False)-torch.eye(8, requires_grad=False)

  def forward(self, ground_view, synthetic_aerial, segmented_aerial, candidate_aerial):
    return self.model(ground_view, synthetic_aerial, segmented_aerial, candidate_aerial)

  def training_step(self, batch, batch_idx):
    ground, aerial, synthetic_aerial, segmented_aerial = batch

    ground = ground.to(self.dvc)
    aerial = aerial.to(self.dvc)
    synthetic_aerial = synthetic_aerial.to(self.dvc)
    segmented_aerial =  segmented_aerial.to(self.dvc)

    # Forward pass
    labels, predictions = model( ground, synthetic_aerial, segmented_aerial, aerial)

    with torch.no_grad():
      embs = labels.view(labels.shape[0],-1)
      embs = F.normalize(embs,dim = -1)
      embs = embs @ embs.T * self.mask

    labels = self.model.GB.GAP(labels)
    labels = labels.view(labels.shape[0],-1)
    labels = self.model.GB.FNN(labels)

    batch_size = len(predictions)

    neg_list = []

    # Find closest (hard) negative in batch
    closest_idx = torch.argmax(embs, dim=1)
    negatives = self.model.AB(synthetic_aerial,segmented_aerial,aerial[closest_idx,:,:])
    anchors = labels
    positives = predictions

    # Sanity check
    if negatives.shape[0] != batch_size:
      print("Not Sane")

    loss = self.criterion(anchors, positives, negatives)
    self.log("train_loss", loss, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    ground, aerial, synthetic_aerial, segmented_aerial = batch

    ground = ground.to(self.dvc)
    aerial = aerial.to(self.dvc)
    synthetic_aerial = synthetic_aerial.to(self.dvc)
    segmented_aerial =  segmented_aerial.to(self.dvc)

    # Forward pass
    labels, predictions = model( ground, synthetic_aerial, segmented_aerial, aerial)

    with torch.no_grad():
      embs = labels.view(labels.shape[0],-1)
      embs = F.normalize(embs,dim = -1)
      embs = embs @ embs.T * self.mask

    labels = self.model.GB.GAP(labels)
    labels = labels.view(labels.shape[0],-1)
    labels = self.model.GB.FNN(labels)

    batch_size = len(predictions)

    neg_list = []

    # Find closest (hard) negative in batch
    closest_idx = torch.argmax(embs, dim=1)
    negatives = self.model.AB(synthetic_aerial,segmented_aerial,aerial[closest_idx,:,:])
    anchors = labels
    positives = predictions

    # Sanity check
    if negatives.shape[0] != batch_size:
      print("Not Sane")

    loss = self.criterion(anchors, positives, negatives)
    self.log("val_loss", loss, prog_bar=True)
    return {"val_loss":loss}

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    return {"optimizer":optimizer, "lr_scheduler":scheduler}

### Training

### No Lightning

In [None]:
def train_epoch(model, dataloader, optimizer, criterion=WeightedSoftMarginTripletLoss(), scheduler=None):
    model.train()
    tot_loss = 0.0

    for ground, aerial, segmented_ground, synthetic_aerial, segmented_aerial in dataloader:
      ground = ground.to(device)
      aerial = aerial.to(device)
      segmented_ground = segmented_ground.to(device)
      synthetic_aerial = synthetic_aerial.to(device)
      segmented_aerial =  segmented_aerial.to(device)

      # Forward pass
      optimizer.zero_grad()       # resets gradients from previous batch
      labels, predictions = model( ground, segmented_ground, synthetic_aerial, segmented_aerial, aerial)
      #print(aerial_pred.shape)

      batch_size = len(labels)
      batch_loss = 0.0

      for i in range(batch_size):
        indeces = range(batch_size)
        indeces.pop(i)
        anchor = labels[i]
        positive = predictions[i]
        negatives = predictions(indeces)

        # Sanity check
        if negatives.shape[0] != batch_size-1:
          print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")

        batch_loss += criterion(anchor, positive, negatives)
        #Backward pass
      batch_loss.backward()         # computes gradients via backpropagation
      optimizer.step()        # updates weights using gradients
      #scheduler.step()    # adjusts learning rate after each epoch


      tot_loss += batch_loss.item()/batch_size

    return tot_loss / len(dataloader)

def evaluate(model, dataloader, device, criterion=WeightedSoftMarginTripletLoss()):
  model.eval()
  total_loss = 0.0

  with torch.no_grad():
    tot_loss = 0.0

    for ground, aerial, segmented_ground, synthetic_aerial, segmented_aerial in dataloader:
      ground = ground.to(device)
      aerial = aerial.to(device)
      segmented_ground = segmented_ground.to(device)
      synthetic_aerial = synthetic_aerial.to(device)
      segmented_aerial =  segmented_aerial.to(device)

      # Forward pass
      labels, predictions = model( ground, segmented_ground, synthetic_aerial, segmented_aerial, aerial)
      #print(aerial_pred.shape)

      batch_size = len(labels)
      batch_loss = 0.0

      for i in range(batch_size):
        indeces = range(batch_size)
        indeces.pop(i)
        anchor = labels[i]
        positive = predictions[i]
        negatives = predictions(indeces)

        # Sanity check
        if negatives.shape[0] != batch_size-1:
          print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")

        batch_loss += criterion(anchor, positive, negatives)

      tot_loss += batch_loss.item()/batch_size

    return tot_loss / len(dataloader)

In [None]:
model = CompNet()

In [None]:
# Main training
num_epochs = 1
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = evaluate(model, val_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"\tTrain Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save checkpoint
    if (epoch+1) % 5 == 0:
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")

### Yes Lightning

In [None]:
model = LightningWrapper(device).to(device)

log_path = "/content/drive/MyDrive/SavedModels/CV/first_attempt"
ckpt_path = log_path+"checkpoints"

EPOCHS = 100

checkpoint_callback = ModelCheckpoint(
    dirpath=ckpt_path,
    filename="best-checkpoint-{epoch:02d}-{val_loss:-2f}",
    save_top_k=2,
    verbose=True,
    monitor="val_loss",
    mode="min",
    save_last=True,
    every_n_epochs=5
)

trainer = Trainer(
    enable_checkpointing=True,
    default_root_dir=log_path,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True,
    max_time = {"hours":8}
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=None)

## TEST