# **Transformer model**
From query groumd view images generate saellite images (natural and segmented)

**TODO**:
- rimuovere segmentation maps delle ground view images del dataset (non servono)
- modello pre-trained su satellite images per generare segmentation maps --> aggiungere al dataset le segmentation maps delle aerial images (ground truth per la generazione di segmentation maps)

In [1]:
%pip install torch torchvision transformers scikit-learn -q

Note: you may need to restart the kernel to use updated packages.


In [10]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

In [None]:
import torch
import torch.nn as nn
from transformers import ViTModel, ViTConfig
from torchvision import transforms
from torch.utils.data import DataLoader
#import segmentation_models_pytorch as smp
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

if is_colab():
    from google.colab import drive
    drive.mount('/content/drive')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## dataset

In [None]:
class CVUSADataset(torch.utils.data.Dataset):
    def __init__(self, ground_dir, aerial_dir, triplet_list, img_size=224, transform=None):
        """
        Args:
            ground_dir: Directory with all the ground view images
            aerial_dir: Directory with all the ground aerial images
            split: 'train', 'val' or 'test'
            img_size: Size for images, 224x224
            transform (callable, optional): Optional transform to be applied
        """
        
        self.ground_dir = ground_dir
        self.aerial_dir = aerial_dir
        self.triplet_list = triplet_list
        self.img_size = img_size
        
        # Default transforms if none provided
        if transform is None:
            # For ground view images
            self.ground_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # standard ImageNet normalization
            ])
            # For aerial images (we might want different processing)
            self.aerial_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # standard ImageNet normalization
            ])
            # Segmentation transform (nearest neighbor resize)
            self.segmentation_transform = transforms.Compose([
                transforms.Resize((img_size, img_size), 
                interpolation=transforms.InterpolationMode.NEAREST),
                transforms.PILToTensor(),
                transforms.Lambda(lambda x: x.squeeze(0).long())  # (H, W) int64 tensor
            ])
        else:
            self.ground_transform = transform
            self.aerial_transform = transform
            self.segmentation_transform = transform
            
        
    def __len__(self):
        return len(self.triplet_list)
    
    def __getitem__(self, idx):
        aerial_rel, ground_rel, seg_rel = self.triplet_list[idx]

        # Load images
        ground_img = Image.open(self.ground_dir + ground_rel)
        aerial_img = Image.open(self.aerial_dir + aerial_rel)
        seg_map = Image.open(self.ground_dir + seg_rel)

        # Apply transforms
        ground_tensor = self.ground_transform(ground_img)
        aerial_tensor = self.aerial_transform(aerial_img)
        seg_tensor = self.segmentation_transform(seg_map)  # Shape [H, W]

        return ground_tensor, aerial_tensor, seg_tensor

In [None]:
def read_triplets_csv(csv_path):
    """Reads CSV file into list of (aerial, ground, seg) triplets"""
    triplets = []
    with open(csv_path, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            triplets.append((
                parts[0].strip(),  # aerial path
                parts[1].strip(),  # ground path
                parts[2].strip()   # seg path (ground view segmented map)
            ))
    return triplets


if is_colab():
    ground_dir = "/content/drive/MyDrive/CV_dataset/CVPR_subset/streetview/"
    aerial_dir = "/content/drive/MyDrive/CV_dataset/CVPR_subset/bingmap/"
else:
    ground_dir = "./CV_dataset/CVPR_subset/streetview/"
    aerial_dir = "./CV_dataset/CVPR_subset/bingmap/"


train_triplets = read_triplets_csv("./CV_dataset/CVPR_subset/splits/splits/train-19zl.csv")
train_triplets, val_triplets = train_test_split(train_triplets, test_size=0.15, random_state=19)  # training/validation set
test_triplets = read_triplets_csv("./CV_dataset/CVPR_subset/splits/splits/val-19zl.csv")        # test set

train_dataset = CVUSADataset(ground_dir, aerial_dir, train_triplets)
val_dataset = CVUSADataset(ground_dir, aerial_dir, val_triplets)
test_dataset = CVUSADataset(ground_dir, aerial_dir, test_triplets)

## model  

```
                                                       ---> Aerial Decoder  
                                                     /  
Ground Image --> Patch Embedding --> ViT Encoder ---  
                                                     \  
                                                       ---> Segmentation Decoder  
```

In [None]:
class GroundToAerialTransformer(nn.Module):
    def __init__(self, num_seg_classes=7, pretrained=True):
        """
        Args:
            num_seg_classes: Number of segmentation classes
            pretrained: Use pretrained ViT weights
        """
        super().__init__()
        
        # ViT Encoder (shared backbone)
        model_name = 'google/vit-base-patch16-224-in21k'        # ViT base model, 16x16 patches, 224x224 input size
        self.vit_config = ViTConfig.from_pretrained(model_name)
        if pretrained:
            self.vit = ViTModel.from_pretrained(model_name)
        else:
            self.vit = ViTModel(self.vit_config)
        
        # Aerial Image Decoder
        self.aerial_decoder = nn.Sequential(
            # First upsample to 14x14 (from 197x768)
            nn.ConvTranspose2d(self.vit_config.hidden_size, 512, kernel_size=4, stride=2),      # convolution
            nn.BatchNorm2d(512),        # batch normalization
            nn.ReLU(),
            
            # Upsample to 28x28
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            # Upsample to 56x56
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            # Final upsample to 224x224
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # Output layer
            nn.Conv2d(64, 3, kernel_size=3, padding=1),     # 3 output channels (RGB)
            nn.Tanh()  # Output in [-1, 1] range
        )
        
        # Segmentation Head
        self.segmentation_head = nn.Sequential(
            # First upsample
            nn.ConvTranspose2d(self.vit_config.hidden_size, 256, kernel_size=4, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            # Second upsample
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            # Third upsample
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # Final upsample
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            # Output layer
            nn.Conv2d(32, num_seg_classes, kernel_size=3, padding=1),       # num_seg_classes output channels (number of segmentation classes)
            nn.Softmax(dim=1)  # Multi-class probabilities
        )
        
        # Learnable positional embedding for aerial reconstruction
        self.aerial_pos_embed = nn.Parameter(torch.zeros(1, 196, self.vit_config.hidden_size))      # 196 = 14x14 (number of patches)
        nn.init.trunc_normal_(self.aerial_pos_embed, std=0.02)
        

    def forward(self, x):
        # Encode ground image with ViT (process image into patch of tokens)
        vit_outputs = self.vit(x)       # Output shape: [batch, 197, hidden_size]

        last_hidden_state = vit_outputs.last_hidden_state  # (batch, seq_len, hidden_size)
        
        # remove CLS token for image generation (ViT outputs [CLS] token + 196 patch tokens)
        aerial_tokens = last_hidden_state[:, 1:]

        # add learned positional embedding for aerial structure
        aerial_tokens = aerial_tokens + self.aerial_pos_embed
        
        # Reshape to spatial dimensions (14x14)
        batch_size = aerial_tokens.size(0)
        aerial_tokens = aerial_tokens.view(batch_size, 14, 14, -1)      # convert 1D sequence into 2D spatial grid. shape becomes: (batch_size, 14, 14, hidden_size)
        aerial_tokens = aerial_tokens.permute(0, 3, 1, 2)  # permute shape: (batch_size, hidden_size, 14, 14)
        
        # Decode aerial image
        aerial_output = self.aerial_decoder(aerial_tokens)
        
        # Decode segmentation map
        seg_output = self.segmentation_head(aerial_tokens)
        
        return aerial_output, seg_output

## training

In [None]:
# Initialize model
model = GroundToAerialTransformer(num_seg_classes=5).cuda()

# Loss functions
aerial_loss_fn = nn.L1Loss()  # For aerial images
seg_loss_fn = nn.CrossEntropyLoss()  # For segmentation

# Combined loss with weighting
def total_loss(aerial_pred, aerial_true, seg_pred, seg_true):
    # Image reconstruction loss
    img_loss = aerial_loss_fn(aerial_pred, aerial_true)
    # Segmentation loss
    seg_loss = seg_loss_fn(seg_pred, seg_true)
    # Weighted combination
    return 0.7 * img_loss + 0.3 * seg_loss

# Optimizer
optimizer = torch.optim.AdamW([
    {'params': model.vit.parameters(), 'lr': 5e-5},  # Lower LR for pretrained (fine-tuning)
    {'params': model.aerial_decoder.parameters(), 'lr': 1e-4},
    {'params': model.segmentation_head.parameters(), 'lr': 1e-4},
    {'params': model.aerial_pos_embed, 'lr': 1e-4}
], weight_decay=0.01)

# Learning rate scheduler (adjust learning rate during training)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

In [None]:
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    
    for ground, (aerial, seg) in dataloader:
        ground = ground.to(device)
        aerial = aerial.to(device)
        seg = seg.to(device)  # Assuming seg is preprocessed
        
        # Forward pass
        optimizer.zero_grad()       # resets gradients from previous batch
        aerial_pred, seg_pred = model(ground)
        
        # Compute loss
        loss = total_loss(aerial_pred, aerial, seg_pred, seg)
        
        # Backward pass
        loss.backward()         # computes gradients via backpropagation
        optimizer.step()        # updates weights using gradients
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)


train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)


# Main training
num_epochs = 50
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = evaluate(model, val_loader, device)
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"\tTrain Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    scheduler.step()    # adjusts learning rate after each epoch
    
    # Save checkpoint
    if (epoch+1) % 5 == 0:
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")