# **Transformer model**
From query groumd view images generate saellite images (natural and segmented)

**TODO**:
- rimuovere segmentation maps delle ground view images del dataset (non servono)
- modello pre-trained su satellite images per generare segmentation maps --> aggiungere al dataset le segmentation maps delle aerial images (ground truth per la generazione di segmentation maps)

In [1]:
%pip install -q torch torchvision transformers scikit-learn pytorch_lightning

Note: you may need to restart the kernel to use updated packages.


In [1]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

print("Running in Colab:", is_colab())

Running in Colab: False


In [2]:
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split

import torch.nn as nn
from transformers import SegformerForSemanticSegmentation

import pytorch_lightning as pylight
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

import os
import glob
from torch import amp
import time
import torch.optim as optim
from transformers import get_cosine_schedule_with_warmup

import torchvision
import matplotlib.pyplot as plt


if is_colab():
    from google.colab import drive
    drive.mount('/content/drive')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## dataset

In [None]:
class CVUSADataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, triplet_list, img_size=224, transform=None):
        self.dataset_dir = dataset_dir
        self.triplet_list = triplet_list
        self.img_size = img_size

        # Default transforms if none provided
        if transform is None:
            # For ground view images
            self.ground_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # standard ImageNet normalization
            ])
            # For aerial images
            self.aerial_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # standard ImageNet normalization
            ])
            # For aerial segmentation maps
            self.segmentation_transform = transforms.Compose([
                transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST),
                transforms.PILToTensor(),
                transforms.Lambda(lambda x: x.squeeze(0).long())  # (H, W) int64 tensor
            ])
        else:
            self.ground_transform = transform
            self.aerial_transform = transform
            self.segmentation_transform = transform


    def __len__(self):
        return len(self.triplet_list)

    def __getitem__(self, idx):
        ground_rel, aerial_rel, seg_rel = self.triplet_list[idx]

        max_retries = 3
        for attempt in range(max_retries):
            try:
                # Load images
                ground_img = Image.open(self.dataset_dir + ground_rel).convert('RGB')
                aerial_img = Image.open(self.dataset_dir + aerial_rel).convert('RGB')
                seg_map = Image.open(self.dataset_dir + seg_rel)

                # Break if successful
                break
            except Exception as e:
                print(f"Error loading images for index {idx}: {e}")
                if attempt == max_retries - 1:
                    raise e


        # Apply transforms
        ground_tensor = self.ground_transform(ground_img)
        aerial_tensor = self.aerial_transform(aerial_img)
        seg_tensor = self.segmentation_transform(seg_map)

        return ground_tensor, aerial_tensor, seg_tensor

In [25]:
def read_triplets_csv(csv_path):
    """Reads CSV file into list of (aerial, ground, seg) triplets"""
    triplets = []
    with open(csv_path, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            triplets.append((
                parts[0].strip(),  # aerial path
                parts[1].strip(),  # ground path
                parts[2].strip()   # seg path (ground view segmented map)
            ))
    return triplets


if is_colab():
    dataset_dir = "/content/drive/MyDrive/CVUSA_subset/"
else:
    dataset_dir = "./CVUSA_subset"


if is_colab():
    train_triplets = read_triplets_csv("/content/drive/MyDrive/CVUSA_subset/train.csv")
    train_triplets, val_triplets = train_test_split(train_triplets, test_size=0.15, random_state=19)  # training/validation set
    test_triplets = read_triplets_csv("/content/drive/MyDrive/CVUSA_subset/val.csv")        # test set
else:
    train_triplets = read_triplets_csv("./CVUSA_subset/train.csv")
    train_triplets, val_triplets = train_test_split(train_triplets, test_size=0.15, random_state=19)  # training/validation set
    test_triplets = read_triplets_csv("./CVUSA_subset/val.csv")        # test set

train_dataset = CVUSADataset(dataset_dir, train_triplets)
val_dataset = CVUSADataset(dataset_dir, val_triplets)
test_dataset = CVUSADataset(dataset_dir, test_triplets)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

## model  

```
                                                             ---> Aerial Decoder  
                                                           /  
Ground Image --> Patch Embedding --> SegFormer Encoder ---  
                                                           \  
                                                             ---> Segmentation Decoder  
```

In [None]:
class DualTaskSegFormer(nn.Module):
    def __init__(self, pretrained_model="nvidia/mit-b1", num_classes=6):
        super().__init__()
        # Load pretrained SegFormer
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(
            pretrained_model,
            num_labels=num_classes,
            return_dict=False
        )
        # Remove original classification head
        self.segformer.decode_head.classifier = nn.Identity()
        
        # Get decoder hidden size
        decoder_hidden_size = self.segformer.config.decoder_hidden_size
        
        
        # dual-task heads
        self.aerial_head = nn.Sequential(
            nn.Conv2d(decoder_hidden_size, 3, kernel_size=1),
            nn.Tanh()  # Output in [-1,1] range
        )
        
        self.seg_head = nn.Conv2d(decoder_hidden_size, num_classes, kernel_size=1)
        

    def forward(self, pixel_values):
        # Encoder processing
        encoder_outputs = self.segformer.segformer(pixel_values)
        
        # Decoder processing
        batch_size = pixel_values.shape[0]
        decoder_output = self.segformer.decode_head(encoder_outputs)
        
        # Dual-task outputs
        aerial_output = self.aerial_head(decoder_output)
        seg_logits = self.seg_head(decoder_output)
        
        return aerial_output, seg_logits

## lightning wrapper

In [None]:
class LightningWrapper(pylight.LightningModule):

  def __init__(self, device=device, model=DualTaskSegFormer()):
    super().__init__()
    self.dvc=device
    self.model=model
    self.criterion_aerial=nn.L1Loss()
    self.criterion_segmap=nn.CrossEntropyLoss(ignore_index=-1)



  def forward(self, x):
    return self.model(x)



  def training_step(self, batch, batch_idx):
    ground, aerial, seg = batch
    ground = ground.to(self.dvc)
    aerial = aerial.to(self.dvc)
    seg = seg.to(self.dvc)
    
    # Forward pass
    aerial_pred, seg_pred = self.model(ground)
    
    # Compute loss
    loss_seg = self.criterion_segmap(seg_pred, seg)
    loss_aerial = self.criterion_aerial(aerial_pred, aerial)
    loss = loss_seg + 0.7 * loss_aerial  # Weighted sum, favor segmentation accuracy
    self.log("train_loss", loss, prog_bar=True)

    return loss



  def validation_step(self, batch, batch_idx):
    ground, aerial, seg = batch
    ground = ground.to(self.dvc)
    aerial = aerial.to(self.dvc)
    seg = seg.to(self.dvc)
    
    # Forward pass
    aerial_pred, seg_pred = self.model(ground)
    
    # Compute loss
    loss_seg = self.criterion_segmap(seg_pred, seg)
    loss_aerial = self.criterion_aerial(aerial_pred, aerial)
    loss = loss_seg + 0.7 * loss_aerial  # Weighted sum, favor segmentation accuracy
    self.log("val_loss", loss, prog_bar=True)

    return {"val_loss":loss}
  


  def test_step(self, batch, batch_idx):
    ground, aerial, seg = batch
    ground = ground.to(self.dvc)
    aerial = aerial.to(self.dvc)
    seg = seg.to(self.dvc)
    
    # Forward pass
    aerial_pred, seg_pred = self.model(ground)
    
    # Compute loss
    loss_seg = self.criterion_segmap(seg_pred, seg)
    loss_aerial = self.criterion_aerial(aerial_pred, aerial)
    loss = loss_seg + 0.7 * loss_aerial  # Weighted sum, favor segmentation accuracy
    self.log("test_loss", loss, prog_bar=True)

    return {"test_loss":loss}



  def configure_optimizers(self):
    optimizer = optim.AdamW([
        {'params': self.model.segformer.parameters(), 'lr': 5e-5},
        {'params': self.model.seg_head.parameters(), 'lr': 1e-3},
        {'params': self.model.aerial_head.parameters(), 'lr': 1e-3}
    ], weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    return {"optimizer":optimizer, "lr_scheduler":scheduler}
  


## Training + Testing (with lightning)

### training

In [None]:
model = LightningWrapper(device, model=DualTaskSegFormer()).to(device)

if is_colab():
    LOG_DIR = "/content/drive/MyDrive/transformer_logs/"
else:
    LOG_DIR = "./transformer_logs/"

ckpt_path = LOG_DIR + "checkpoints/"

EPOCHS = 20
start_from_epoch = None    # None to start from scratch, or specify epoch number to resume

if start_from_epoch:
    start_from_path = ckpt_path + f"best-checkpoint-{start_from_epoch:02d}-*.ckpt"

checkpoint_callback = ModelCheckpoint(
    dirpath=ckpt_path,
    filename="best-checkpoint-{epoch:02d}-{val_loss:-2f}",
    save_top_k=5,
    verbose=True,
    monitor="val_loss",
    mode="min",
    save_last=True,
    every_n_epochs=2
)

if not start_from_epoch:
    trainer0 = Trainer(
        enable_checkpointing=False,
        enable_progress_bar=True,
        max_epochs=2
    )
    # Freeze encoder for first 2 epochs
    for param in model.segformer.segformer.parameters():
        param.requires_grad = False
    
    trainer0.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=None)

    # Unfreeze encoder
    for param in model.segformer.segformer.parameters():
        param.requires_grad = True
    
    start_from_path = None


trainer = Trainer(
    enable_checkpointing=True,
    default_root_dir=LOG_DIR,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True,
    max_epochs=EPOCHS
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=start_from_path)

### testing

In [None]:
model.model.eval()
with torch.no_grad():
  for i, (ground, aerial, seg) in enumerate(test_loader):
    ground = ground.to(device)
    aerial = aerial.to(device)
    seg = seg.to(device)

    aerial_pred, seg_pred = model(ground)

    aerial_pred = torchvision.transforms.functional.to_pil_image(aerial_pred, mode=None)
    seg_pred = torchvision.transforms.functional.to_pil_image(seg_pred, mode=None)
    
    ground = torchvision.transforms.functional.to_pil_image(ground, mode=None)
    aerial = torchvision.transforms.functional.to_pil_image(aerial, mode=None)
    seg = torchvision.transforms.functional.to_pil_image(seg, mode=None)

    if i % 443 == 0:
      plt.figure(figsize=(15, 10))

      plt.subplot(2, 3, 1)
      plt.title('Ground view')
      plt.imshow(ground)
      plt.axis('off')

      plt.subplot(2, 3, 2)
      plt.title('Aerial view')
      plt.imshow(aerial)
      plt.axis('off')

      plt.subplot(2, 3, 3)
      plt.title('Aerial segmap')
      plt.imshow(seg)
      plt.axis('off')

      plt.subplot(2, 3, 4)
      plt.title('Aerial view prediction')
      plt.imshow(aerial_pred)
      plt.axis('off')

      plt.subplot(2, 3, 5)
      plt.title('Aerial segmap prediction')
      plt.imshow(seg_pred)
      plt.axis('off')

      plt.show()

## Training + Testing (no lightning)

### training

In [27]:
EPOCHS = 20
if is_colab():
    CHECKPOINT_DIR = "/content/drive/MyDrive/transformer_checkpoints/"
else:
    CHECKPOINT_DIR = "./transformer_checkpoints/"

SAVE_INTERVAL = 1   # Save checkpoint every SAVE_INTERNAL epochs

RESUME = True  # Set True to resume from latest checkpoint

# Create checkpoint directory if needed
os.makedirs(CHECKPOINT_DIR, exist_ok=True)


# Initialize model
model = DualTaskSegFormer().to(device)
optimizer = optim.AdamW([
    {'params': model.segformer.parameters(), 'lr': 5e-5},
    {'params': model.seg_head.parameters(), 'lr': 1e-3},
    {'params': model.aerial_head.parameters(), 'lr': 1e-3}
], weight_decay=1e-4)

# Learning rate scheduler
total_steps = len(train_loader) * EPOCHS
scheduler = get_cosine_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=100,
    num_training_steps=total_steps
)

# Loss functions
seg_loss = nn.CrossEntropyLoss(ignore_index=-1)
aerial_loss = nn.L1Loss()


# ******************************************************************************************* Checkpoint loading function
def load_latest_checkpoint():
    """Loads the latest checkpoint based on epoch number"""
    checkpoint_files = glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint_epoch_*.pt"))
    
    if not checkpoint_files:
        print("No checkpoints found. Starting from scratch.")
        return 0, None
    
    # Extract epoch numbers from filenames
    epoch_numbers = []
    for f in checkpoint_files:
        try:
            epoch_num = int(f.split("_")[-1].split(".")[0])
            epoch_numbers.append(epoch_num)
        except:
            continue
    
    if not epoch_numbers:
        print("No valid checkpoints found. Starting from scratch.")
        return 0, None
    
    latest_epoch = max(epoch_numbers)
    latest_file = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{latest_epoch}.pt")
    
    print(f"Loading checkpoint: {latest_file}")
    checkpoint = torch.load(latest_file)
    
    return latest_epoch, checkpoint
# *******************************************************************************************


# Load checkpoint if resuming
if RESUME:
    latest_epoch, checkpoint = load_latest_checkpoint()
    
    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        START_EPOCH = latest_epoch
        print(f"Resuming training from epoch {START_EPOCH + 1}")
    
    else:
        START_EPOCH = 0
        print("Starting training from scratch")
else:
    START_EPOCH = 0
    print("Starting training from scratch (RESUME=False)")

# Mixed precision scaler
scaler = amp.GradScaler(device)

config.json: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


pytorch_model.bin:   0%|          | 0.00/14.4M [00:00<?, ?B/s]

ValueError: Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply when loading files with safetensors.
See the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434

model.safetensors:   0%|          | 0.00/14.3M [00:00<?, ?B/s]

In [None]:
# Training loop
for epoch in range(START_EPOCH, EPOCHS):
    start_time = time.time()
    
    model.train()
    epoch_loss = 0.0
    
    # Freeze encoder for first 2 epochs if starting from scratch
    if epoch < 2 and START_EPOCH == 0:
        for param in model.segformer.segformer.parameters():
            param.requires_grad = False
    else:
        for param in model.segformer.segformer.parameters():
            param.requires_grad = True
    

    for i, (ground, aerial_true, seg_true) in enumerate(train_loader):
        ground = ground.to(device)
        aerial_true = aerial_true.to(device)
        seg_true = seg_true.to(device)
        
        optimizer.zero_grad()
        
        with amp.autocast(device):
            aerial_pred, seg_pred = model(ground)
            
            # Compute losses
            loss_seg = seg_loss(seg_pred, seg_true)
            loss_aerial = aerial_loss(aerial_pred, aerial_true)
            total_loss = loss_seg + 0.7 * loss_aerial  # Weighted sum, favor segmentation accuracy
        
        # Backpropagation
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        epoch_loss += total_loss.item()
        
        # Log every 50 batches
        if i % 50 == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Epoch [{epoch+1}/{EPOCHS}] | Batch [{i}/{len(train_loader)}] | "
                  f"Loss: {total_loss.item():.4f} | LR: {current_lr:.2e}")
    
    # Calculate epoch metrics
    avg_loss = epoch_loss / len(train_loader)
    epoch_time = time.time() - start_time
    

    # --------------------------------------------------------------------------- Validation
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        for ground, aerial_true, seg_true in val_loader:
            ground = ground.to(device)
            aerial_true = aerial_true.to(device)
            seg_true = seg_true.to(device)
            
            aerial_pred, seg_pred = model(ground)
            loss_seg = seg_loss(seg_pred, seg_true)
            loss_aerial = aerial_loss(aerial_pred, aerial_true)
            total_val_loss = loss_seg + 0.7 * loss_aerial
            val_loss += total_val_loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
    
    print(f"Epoch [{epoch+1}/{EPOCHS}] completed in {epoch_time:.2f}s | "
          f"Train Loss: {avg_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    

    # --------------------------------------------------------------------------- Save checkpoint
    if (epoch + 1) % SAVE_INTERVAL == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_loss,
            'val_loss': avg_val_loss
        }, checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}")

print("Training completed!")

### testing

In [None]:
model.eval()
with torch.no_grad():
  for ground, (aerial, _) in test_loader:
    ground = ground.to(device)
    aerial = aerial.to(device)
    aerial_pred = model(ground)

    aerial_pred = torchvision.transforms.functional.to_pil_image(aerial_pred, mode=None)
    aerial = torchvision.transforms.functional.to_pil_image(aerial, mode=None)

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title('Satellite Image RGB')
    plt.imshow(aerial)     # Original full-size image
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(aerial_pred, cmap='viridis')
    plt.title('Satellite prediction')
    plt.axis('off')

    plt.show()
    break

In [None]:
g = torchvision.transforms.functional.to_pil_image(ground, mode=None)
plt.imshow(g)