# SET-UP

In [None]:
import os

# IF THE DIRECTORY ALREADY EXISTS (ALREADY IMPORTED)
if not os.path.exists('/kaggle/working/taming-transformers'):
    print("Cloning taming-transformers repository...")
    !git clone https://github.com/CompVis/taming-transformers.git
    print("Repository cloned successfully!")
else:
    print("Repository already exists, skipping clone.")

# MOVE TO THE DIRECTORY (KAGGLE STARTS FROM THE ROOT DIRECTORY OF WORKING)
%cd taming-transformers



# 1. Downgrade PyTorch alle versioni compatibili
!pip install torch # ==1.7.0 
!pip install torchvision # ==0.8.1

# 2. Install PyTorch Lightning versione specifica
!pip install pytorch-lightning # ==1.0.8

# 3. Install altre dependencies con versioni esatte
!pip install albumentations # ==0.4.3
!pip install opencv-python #==4.1.2.30
!pip install omegaconf # ==2.0.0
!pip install einops # ==0.3.0
!pip install transformers==4.3.1

!pip install imageio # ==2.9.0
!pip install streamlit # >=0.73.1

#INFERED
!pip install tensorboard # ==2.2.0

# 4. Install il package
!pip install -e .

print ("ready final")

# VQGAN CREATION

In [None]:
# ------------------------------------------------------ 
#                   IMPORTS & SET-UP
# ------------------------------------------------------ 

import taming                     # Import Taming Transformers (for Image Generation)
import torch                      # PyTorch
import torchvision                # Pytorch for CV
import os

#FIX ISSUES -------------------------------------------------------
utils_file = '/kaggle/working/taming-transformers/taming/data/utils.py'

# Read the file
with open(utils_file, 'r') as f:
    content = f.read()

# Replace the problematic import
content = content.replace(
    'from torch._six import string_classes',
    'string_classes = str'
)

# Write back the fixed file
with open(utils_file, 'w') as f:
    f.write(content)

print("Fixed torch._six import issue!")

# ----------------------------------------------------------------------

import urllib.request

# Create models directory
os.makedirs('/kaggle/working/models', exist_ok=True)

# Download VQGAN ImageNet f16 model
model_urls = {
    'vqgan_imagenet_f16_16384.yaml': 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1',
    'vqgan_imagenet_f16_16384.ckpt': 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1'
}

for filename, url in model_urls.items():
    filepath = f'/kaggle/working/models/{filename}'
    if not os.path.exists(filepath):
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
        print(f"Downloaded {filename}")
    else:
        print(f"{filename} already exists")

# ----------------------------------------------------------------------

import yaml
from taming.models.vqgan import VQModel

def load_vqgan_model(config_path, checkpoint_path):
    # Load config with regular yaml
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Initialize model
    model = VQModel(**config['model']['params'])
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    
    model.eval()
    return model

# Load the model
config_path = '/kaggle/working/models/vqgan_imagenet_f16_16384.yaml'
checkpoint_path = '/kaggle/working/models/vqgan_imagenet_f16_16384.ckpt'

vqgan_model = load_vqgan_model(config_path, checkpoint_path)
print("Model loaded!")

# CONFIGURATION

In [None]:
from omegaconf import OmegaConf

# Complete configuration
config = OmegaConf.create({
    "first_stage_config": {
        "target": "taming.models.vqgan.VQModel",
        "params": {
            "ckpt_path": "/kaggle/working/models/vqgan_imagenet_f16_16384.ckpt",
            "embed_dim": 256,
            "n_embed": 16384,
            "ddconfig": {
                "double_z": False,
                "z_channels": 256,
                "resolution": 256,
                "in_channels": 3,
                "out_ch": 3,
                "ch": 128,
                "ch_mult": [1, 1, 2, 2, 4],
                "num_res_blocks": 2,
                "attn_resolutions": [16],
                "dropout": 0.0
            },
            "lossconfig": {
                "target": "taming.modules.losses.DummyLoss"
            }
        }
    },
    "cond_stage_config": "__is_first_stage__",  # Use same VQ-GAN for both images
    "transformer_config": {
        "target": "taming.modules.transformer.mingpt.GPT",
        "params": {
            "vocab_size": 16384,  # This matches your codebook size
            "block_size": 512,    # Sequence length (for 16x16 = 256 tokens x2)
            "n_layer": 12,
            "n_head": 8,
            "n_embd": 512,
            # --- ADDED DROPOUT FOR REGULARIZATION ---
            "embd_pdrop": 0.2,    # Dropout on embeddings
            "resid_pdrop": 0.2,   # Dropout on residual connections
            "attn_pdrop": 0.2     # Dropout on attention weights
        }
    }
})

print("Configuration created!")
print(f"VQ-GAN codebook size: {config.first_stage_config.params.n_embed}")
print(f"Transformer vocab size: {config.transformer_config.params.vocab_size}")

# Define device first
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# TOP K

In [None]:
import torch
import torch.nn.functional as F
from torch import Tensor

def top_k_top_p_filtering(
    logits: Tensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
) -> Tensor:
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

# Inject the real function into transformers module
import transformers
transformers.top_k_top_p_filtering = top_k_top_p_filtering

print("✅ Real top_k_top_p_filtering function added to transformers!")

# VQGAN - LOAD

In [None]:
from taming.models.cond_transformer import Net2NetTransformer


# Load checkpoint with explicit weights_only=False
vqgan_state = torch.load("/kaggle/working/models/vqgan_imagenet_f16_16384.ckpt", 
                        map_location=device, 
                        weights_only=False)

# Modify config to not load checkpoint automatically
config_no_ckpt = config.copy()
config_no_ckpt.first_stage_config.params.ckpt_path = None


# Create model without loading checkpoint
model = Net2NetTransformer(
    transformer_config=config.transformer_config,
    first_stage_config=config_no_ckpt.first_stage_config,
    cond_stage_config=config.cond_stage_config,
    first_stage_key="satellite",
    cond_stage_key="ground",
    unconditional=False
)

# Manually load the VQ-GAN weights
model.first_stage_model.load_state_dict(vqgan_state["state_dict"], strict=False)

print("Model created and VQ-GAN weights loaded successfully!")

# FREEZE

In [None]:
#Freeze the VQ-GAN parameters
for param in model.first_stage_model.parameters():
    param.requires_grad = False

# Since cond_stage_model is the same as first_stage_model, it's already frozen

# Verify the freeze worked
vqgan_params = sum(p.numel() for p in model.first_stage_model.parameters() if p.requires_grad)
transformer_params = sum(p.numel() for p in model.transformer.parameters() if p.requires_grad)

print("After freezing VQ-GAN:")
print(f"Trainable VQ-GAN parameters: {vqgan_params}")
print(f"Trainable Transformer parameters: {transformer_params}")
print(f"Total trainable parameters: {vqgan_params + transformer_params}")

# CSV Train-Test Split Fixing

In [None]:
import csv
import os
from pathlib import Path

def check_csv_structure(csv_path):
    """
    Check the structure of the CSV file and print sample rows (no pandas)
    """
    print(f"📁 Checking CSV structure: {csv_path}")
    
    # Check if file exists
    if not os.path.exists(csv_path):
        print(f"❌ File not found: {csv_path}")
        return None, None
    
    # Read first few lines to understand structure
    with open(csv_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader(file)
        rows = []
        for i, row in enumerate(reader):
            rows.append(row)
            if i >= 4:  # Read first 5 rows
                break
    
    if not rows:
        print(f"❌ Empty CSV file")
        return None, None
    
    first_row = rows[0]
    second_row = rows[1] if len(rows) > 1 else None
    
    print(f"🔍 Raw file inspection:")
    print(f"  First row:  {','.join(first_row)}")
    if second_row:
        print(f"  Second row: {','.join(second_row)}")
    
    # Check if first row looks like data (contains .png) rather than headers
    has_proper_headers = not any('.png' in cell for cell in first_row)
    
    # Count total rows
    with open(csv_path, 'r', newline='', encoding='utf-8') as file:
        total_rows = sum(1 for _ in csv.reader(file))
    
    if has_proper_headers:
        print(f"✅ File has proper headers")
        data_rows = rows[1:]  # Skip header
        headers = first_row
        actual_data_count = total_rows - 1
    else:
        print(f"⚠️  File has NO proper headers - first row is data!")
        print(f"📝 Using custom column names...")
        data_rows = rows  # All rows are data
        headers = ['satellite_path', 'ground_path', 'duplicate_ground_path']
        actual_data_count = total_rows
    
    print(f"\n📊 CSV Info:")
    print(f"  - Total rows: {total_rows}")
    print(f"  - Data rows: {actual_data_count}")
    print(f"  - Columns: {len(headers)} -> {headers}")
    
    print(f"\n📋 First 5 data rows:")
    for i, row in enumerate(data_rows[:5]):
        print(f"  Row {i+1}: {row}")
    
    print(f"\n🔍 Sample paths analysis:")
    if data_rows:
        sample_row = data_rows[0]
        for i, (header, value) in enumerate(zip(headers, sample_row)):
            print(f"  Column {i+1} ({header}): {value}")
    
    return rows, has_proper_headers

def fix_csv_paths(input_csv_path, output_csv_path=None):
    """
    Fix the CSV file paths according to the requirements (no pandas):
    - Column 2: Remove 'input' from filename (streetview/input0026840.png → streetview/0026840.png)
    - Column 3: Change to polarmap/normal/input{ID}.png
    """
    
    # Check structure first (this only reads first 5 rows for inspection)
    result = check_csv_structure(input_csv_path)
    if result[0] is None:
        return None
    
    sample_rows, has_proper_headers = result
    
    # Now read ALL rows from the file
    print(f"🔄 Reading complete file: {input_csv_path}")
    with open(input_csv_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader(file)
        all_rows = list(reader)
    
    print(f"📊 Complete file loaded: {len(all_rows)} total rows")
    
    print(f"\n🔧 Applying transformations...")
    
    # Determine data start and headers using the complete dataset
    if has_proper_headers:
        headers = all_rows[0]
        data_rows = all_rows[1:]
        print(f"  - Using existing headers: {headers}")
    else:
        headers = ['satellite_path', 'ground_path', 'polarmap_path']
        data_rows = all_rows  # All rows are data
        print(f"  - Using custom headers: {headers}")
    
    # Transform the data
    fixed_rows = []
    
    print(f"  - Processing {len(data_rows)} data rows...")
    print(f"  - Fixing column 2: removing 'input' from filenames AND changing .png to .jpg")
    print(f"  - Fixing column 3: changing to polarmap/normal/input{{ID}}.png")
    
    for row_idx, row in enumerate(data_rows):
        if len(row) < 3:
            print(f"⚠️  Row {row_idx + 1} has fewer than 3 columns: {row}")
            fixed_rows.append(row)  # Keep as-is
            continue
        
        # Extract original values
        satellite_path = row[0]  # Keep as-is
        ground_path = row[1]     # Remove 'input'
        third_path = row[2]      # Convert to polarmap
        
        # Transform column 2: Remove 'input' from filename AND change .png to .jpg
        fixed_ground_path = ground_path.replace('input', '').replace('.png', '.jpg')
        
        # Transform column 3: Change to polarmap/normal/input{ID}.png
        try:
            # Extract filename from original path
            original_filename = Path(third_path).name
            fixed_third_path = f"polarmap/normal/{original_filename}"
        except:
            fixed_third_path = third_path  # Keep original if parsing fails
        
        # Create fixed row
        fixed_row = [satellite_path, fixed_ground_path, fixed_third_path]
        
        # Add any additional columns if they exist
        if len(row) > 3:
            fixed_row.extend(row[3:])
        
        fixed_rows.append(fixed_row)
    
    # Show before/after comparison
    if data_rows:
        print(f"\n📊 Transformation Results:")
        print(f"Original sample row:")
        sample_orig = data_rows[0]
        for i, val in enumerate(sample_orig[:3]):
            print(f"  Col {i+1}: {val}")
        
        print(f"\nFixed sample row:")
        sample_fixed = fixed_rows[0]
        for i, val in enumerate(sample_fixed[:3]):
            print(f"  Col {i+1}: {val}")
    
    # Determine output path
    if output_csv_path is None:
        input_path = Path(input_csv_path)
        output_csv_path = Path("/kaggle/working") / f"{input_path.stem}_fixed{input_path.suffix}"
    
    # Write the fixed CSV
    with open(output_csv_path, 'w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        
        # Write headers
        writer.writerow(headers)
        
        # Write data
        writer.writerows(fixed_rows)
    
    print(f"\n✅ Fixed CSV saved to: {output_csv_path}")
    print(f"📁 Location: Kaggle working directory (writable)")
    print(f"📊 Output summary: {len(headers)} columns, {len(fixed_rows)} data rows")
    
    return fixed_rows, headers

def process_train_test_csvs(train_csv_path, test_csv_path):
    """
    Process both training and test CSV files (no pandas)
    """
    print("="*60)
    print("🚀 Processing Training and Test CSV Files")
    print("="*60)
    
    # Process training CSV
    print("\n📚 TRAINING CSV:")
    train_result = fix_csv_paths(train_csv_path)
    
    print("\n" + "="*60)
    
    # Process test CSV
    print("\n🧪 TEST CSV:")
    test_result = fix_csv_paths(test_csv_path)
    
    print("\n" + "="*60)
    print("✅ Both CSV files processed successfully!")
    
    return train_result, test_result

def load_csv_data(csv_path):
    """
    Simple utility to load CSV data back into memory (no pandas)
    Returns: (data_rows, headers)
    """
    with open(csv_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader(file)
        headers = next(reader)  # First row is headers
        data_rows = list(reader)  # Rest are data
    
    return data_rows, headers

# Example usage:
if __name__ == "__main__":
    # Your actual file paths
    train_csv = "/kaggle/input/cvusa-subset/train-19zl.csv"
    test_csv = "/kaggle/input/cvusa-subset/val-19zl.csv"  # Note: this is val, not test
    
    # Check structure of both files
    print("Checking CSV structures...")
    check_csv_structure(train_csv)
    print("\n" + "="*40)
    check_csv_structure(test_csv)
    
    print("\n" + "="*60)
    
    # Process both files - saves to /kaggle/working/
    train_result, test_result = process_train_test_csvs(train_csv, test_csv)
    
    # Print the final file locations
    print(f"\n📁 Fixed files are now available at:")
    print(f"   Training: /kaggle/working/train-19zl_fixed.csv")
    print(f"   Test:     /kaggle/working/val-19zl_fixed.csv")
    
    # Example of loading the data back
    print(f"\n🔄 Example: Loading fixed training data...")
    train_data, train_headers = load_csv_data("/kaggle/working/train-19zl_fixed.csv")
    print(f"   Headers: {train_headers}")
    print(f"   Data rows: {len(train_data)}")
    print(f"   Sample: {train_data[0] if train_data else 'No data'}")


# DATASET AND DATALOADER

In [None]:
import os
import csv
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np

class CVUSADataset(Dataset):
    def __init__(self, csv_path, data_root, size=256, polar=True, is_train=True):
        self.data_root = data_root
        self.size = size
        self.polar = polar
        
        # Load file pairs from CSV
        self.file_pairs = []
        
        print(f"📂 Loading dataset from: {csv_path}")
        
        with open(csv_path, 'r', newline='', encoding='utf-8') as file:
            reader = csv.reader(file)
            headers = next(reader)
            for row in reader:
                if len(row) < 3: continue
                
                bingmap_path, ground_path, polarmap_path = row[0], row[1], row[2]
                satellite_relative_path = polarmap_path if self.polar else bingmap_path
                
                satellite_full_path = os.path.join(data_root, satellite_relative_path)
                ground_full_path = os.path.join(data_root, ground_path)
                
                if os.path.exists(satellite_full_path) and os.path.exists(ground_full_path):
                    self.file_pairs.append((satellite_full_path, ground_full_path))

        print(f"✅ Found {len(self.file_pairs)} valid image pairs.")
        
        # --- MODIFIED: Create separate transform pipelines ---
        
        # Pipeline for TARGET images (Satellite). ALWAYS without augmentation.
        self.satellite_transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        # Pipeline for INPUT images (Ground). Augmentation is applied ONLY if is_train=True.
        if is_train:
            self.ground_transform = transforms.Compose([
                transforms.RandomApply([
                    transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)),
                    transforms.RandomPerspective(distortion_scale=0.3, p=0.5)
                ], p=0.5),
                transforms.RandomApply([
                    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
                    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
                ], p=0.5),
                # --- End of Augmentations ---
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
            print(f"   -> Mode: TRAINING (applying augmentations to ground images)")
        else:
            # For the validation set, the ground transform is the same as the satellite one (no augmentations).
            self.ground_transform = self.satellite_transform
            print(f"   -> Mode: VALIDATION (no augmentations)")


    def __len__(self):
        return len(self.file_pairs)

    def __getitem__(self, idx):
        satellite_path, ground_path = self.file_pairs[idx]
        
        try:
            # Load images
            satellite_img = Image.open(satellite_path).convert('RGB')
            ground_img = Image.open(ground_path).convert('RGB')

            sat_array = cv2.resize(np.array(satellite_img), (256, 256), interpolation=cv2.INTER_AREA)
            ground_array = cv2.resize(np.array(ground_img), (256, 256), interpolation=cv2.INTER_AREA)
            satellite_img = Image.fromarray(sat_array)
            ground_img = Image.fromarray(ground_array)
            
            # --- MODIFIED: Apply the correct transform to each image ---
            satellite_tensor = self.satellite_transform(satellite_img)
            ground_tensor = self.ground_transform(ground_img)
            
            return {
                "satellite": satellite_tensor,  # The clean, non-augmented target
                "ground": ground_tensor         # The potentially augmented input
            }
        
        except Exception as e:
            print(f"❌ Error loading images at index {idx}: {e}")
            dummy_tensor = torch.zeros(3, self.size, self.size)
            return {"satellite": dummy_tensor, "ground": dummy_tensor}


def create_dataloaders(data_root="/kaggle/input/cvusa-subset", batch_size=8, polar=True):
    train_csv = "/kaggle/working/train-19zl_fixed.csv"
    test_csv = "/kaggle/working/val-19zl_fixed.csv"
    
    print("\n🚀 Creating train and test datasets...")
    
    # --- MODIFIED: Pass the is_train flag to the constructor ---
    train_dataset = CVUSADataset(
        csv_path=train_csv,
        data_root=data_root,
        size=256,
        polar=polar,
        is_train=True  # Tell the dataset to use the training pipeline (with augmentations)
    )
    
    print("\n" + "="*50)
    
    test_dataset = CVUSADataset(
        csv_path=test_csv,
        data_root=data_root,
        size=256,
        polar=polar,
        is_train=False # Tell the dataset to use the validation pipeline (no augmentations)
    )
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    print(f"\n✅ Dataloaders created:")
    print(f"   📚 Training: {len(train_dataset)} samples, {len(train_dataloader)} batches")
    print(f"   🧪 Testing:  {len(test_dataset)} samples, {len(test_dataloader)} batches")
    
    return train_dataloader, test_dataloader


# Test the dataloaders
if __name__ == "__main__":
    train_loader, test_loader = create_dataloaders(
        data_root="/kaggle/input/cvusa-subset",
        batch_size=20,
        polar=True
    )
    
    print("\n🔍 Testing train dataloader...")
    train_batch = next(iter(train_loader))
    print(f"Train batch shapes:")
    print(f"  Ground: {train_batch['ground'].shape}")
    print(f"  Satellite: {train_batch['satellite'].shape}")
    
    print("\n🔍 Testing test dataloader...")
    test_batch = next(iter(test_loader))
    print(f"Test batch shapes:")
    print(f"  Ground: {test_batch['ground'].shape}")
    print(f"  Satellite: {test_batch['satellite'].shape}")
    
    print("\n✅ Both dataloaders working correctly!")


# SHAPE FIX

In [None]:
"""

with torch.no_grad():
    print("Check both encodings:")
    
    # Satellite encoding
    sat_quant_z, _, sat_info = model.first_stage_model.encode(batch['satellite'])
    sat_indices = sat_info[2]
    print(f"Satellite indices shape: {sat_indices.shape}")
    print(f"Satellite indices dim: {sat_indices.dim()}")
    
    # Ground encoding  
    ground_quant_c, _, ground_info = model.cond_stage_model.encode(batch['ground'])
    ground_indices = ground_info[2]
    print(f"Ground indices shape: {ground_indices.shape}")
    print(f"Ground indices dim: {ground_indices.dim()}")
    
    # Fix both to 2D
    batch_size = batch['ground'].shape[0]
    
    if sat_indices.dim() == 1:
        sat_tokens_per_image = sat_indices.shape[0] // batch_size
        sat_indices_2d = sat_indices.view(batch_size, sat_tokens_per_image)
        print(f"Fixed satellite indices: {sat_indices_2d.shape}")
    
    if ground_indices.dim() == 1:
        ground_tokens_per_image = ground_indices.shape[0] // batch_size  
        ground_indices_2d = ground_indices.view(batch_size, ground_tokens_per_image)
        print(f"Fixed ground indices: {ground_indices_2d.shape}")
        
    # Test concatenation
    cz_indices = torch.cat((ground_indices_2d, sat_indices_2d), dim=1)
    print(f"✅ Concatenation works! Shape: {cz_indices.shape}")

"""

# FORWARD PASS OVERWRITE

In [None]:


# IMPORTANT: Move model to GPU
model = model.to(device)
print("Model moved to",device)

def manual_forward_pass(model, satellite_imgs, ground_imgs):
    """Manual forward pass with tensor reshaping fix"""
    # Get raw encodings and fix reshaping
    _, z_indices_raw = model.encode_to_z(satellite_imgs)
    _, c_indices_raw = model.encode_to_c(ground_imgs)
    
    batch_size = satellite_imgs.shape[0]
    z_indices = z_indices_raw.view(batch_size, -1)
    c_indices = c_indices_raw.view(batch_size, -1)
    
    # Manual forward pass logic (from STEP-CHECKPOINT)
    cz_indices = torch.cat((c_indices, z_indices), dim=1)
    logits, _ = model.transformer(cz_indices[:, :-1])
    logits = logits[:, c_indices.shape[1]-1:]  # Cut off conditioning
    target = z_indices
    
    return logits, target

print("Manual Pass created")

# LOAD MODEL FUNCTIONS

In [None]:
import torch
import os
from datetime import datetime
from taming.models.cond_transformer import Net2NetTransformer
from omegaconf import OmegaConf
import glob

def load_saved_model(checkpoint_path, device='cuda', vqgan_checkpoint_path=None):
    """
    Load a model saved with your save_model_with_timestamp function
    
    Args:
        checkpoint_path: Path to the .pth checkpoint file
        device: Device to load the model on ('cuda' or 'cpu')
        vqgan_checkpoint_path: Optional path to VQ-GAN checkpoint (auto-detected if None)
    
    Returns:
        model: Loaded model ready for inference or continued training
        checkpoint_info: Dictionary with training information
    """
    
    # Check if checkpoint file exists
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
    
    # Auto-detect device if CUDA not available
    if device == 'cuda' and not torch.cuda.is_available():
        print("⚠️  CUDA not available, falling back to CPU")
        device = 'cpu'
    
    # Load checkpoint
    print(f"Loading checkpoint from: {checkpoint_path}")
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    except Exception as e:
        raise RuntimeError(f"Failed to load checkpoint: {e}")
    
    # Validate checkpoint structure
    required_keys = ['epoch', 'model_state_dict', 'model_config']
    missing_keys = [key for key in required_keys if key not in checkpoint]
    if missing_keys:
        raise ValueError(f"Checkpoint missing required keys: {missing_keys}")
    
    # Print checkpoint info
    print(f"Checkpoint info:")
    print(f"  - Epoch: {checkpoint['epoch']}")
    print(f"  - Loss: {checkpoint.get('loss', 'N/A')}")
    print(f"  - Timestamp: {checkpoint.get('timestamp', 'N/A')}")
    
    # Validate model config
    model_config = checkpoint['model_config']
    required_config_keys = ['transformer_vocab_size', 'transformer_block_size', 
                           'transformer_n_layer', 'transformer_n_head', 'transformer_n_embd']
    missing_config_keys = [key for key in required_config_keys if key not in model_config]
    if missing_config_keys:
        raise ValueError(f"Model config missing required keys: {missing_config_keys}")
    
    # Recreate the same configuration used in your notebook
    config = OmegaConf.create({
        "first_stage_config": {
            "target": "taming.models.vqgan.VQModel",
            "params": {
                "ckpt_path": None,  # We'll load this separately
                "embed_dim": 256,
                "n_embed": 16384,
                "ddconfig": {
                    "double_z": False,
                    "z_channels": 256,
                    "resolution": 256,
                    "in_channels": 3,
                    "out_ch": 3,
                    "ch": 128,
                    "ch_mult": [1, 1, 2, 2, 4],
                    "num_res_blocks": 2,
                    "attn_resolutions": [16],
                    "dropout": 0.0
                },
                "lossconfig": {
                    "target": "taming.modules.losses.DummyLoss"
                }
            }
        },
        "cond_stage_config": "__is_first_stage__",
        "transformer_config": {
            "target": "taming.modules.transformer.mingpt.GPT",
            "params": {
                "vocab_size": model_config['transformer_vocab_size'],
                "block_size": model_config['transformer_block_size'],
                "n_layer": model_config['transformer_n_layer'],
                "n_head": model_config['transformer_n_head'],
                "n_embd": model_config['transformer_n_embd']
            }
        }
    })
    
    # Create the model
    try:
        model = Net2NetTransformer(
            transformer_config=config.transformer_config,
            first_stage_config=config.first_stage_config,
            cond_stage_config=config.cond_stage_config,
            first_stage_key="satellite",
            cond_stage_key="ground",
            unconditional=False
        )
        print("✅ Model architecture created successfully")
    except Exception as e:
        raise RuntimeError(f"Failed to create model architecture: {e}")
    
    # Auto-detect VQ-GAN checkpoint path if not provided
    if vqgan_checkpoint_path is None:
        possible_paths = [
            "/kaggle/working/models/vqgan_imagenet_f16_16384.ckpt",
            "/kaggle/input/*/vqgan_imagenet_f16_16384.ckpt",
            "./models/vqgan_imagenet_f16_16384.ckpt",
            "/content/models/vqgan_imagenet_f16_16384.ckpt"  # For Colab
        ]
        
        # Use glob for wildcard paths
        for path_pattern in possible_paths:
            if '*' in path_pattern:
                matches = glob.glob(path_pattern)
                if matches:
                    vqgan_checkpoint_path = matches[0]
                    break
            elif os.path.exists(path_pattern):
                vqgan_checkpoint_path = path_pattern
                break
    
    # Load the VQ-GAN weights
    if vqgan_checkpoint_path and os.path.exists(vqgan_checkpoint_path):
        print(f"Loading VQ-GAN weights from: {vqgan_checkpoint_path}")
        try:
            vqgan_state = torch.load(vqgan_checkpoint_path, map_location=device, weights_only=False)
            
            # Check if the checkpoint has the expected structure
            if "state_dict" not in vqgan_state:
                raise ValueError("VQ-GAN checkpoint missing 'state_dict' key")
            
            model.first_stage_model.load_state_dict(vqgan_state["state_dict"], strict=False)
            
            # Freeze VQ-GAN parameters
            for param in model.first_stage_model.parameters():
                param.requires_grad = False
            print("✅ VQ-GAN weights loaded and frozen!")
            
        except Exception as e:
            print(f"⚠️  Failed to load VQ-GAN weights: {e}")
            print("Model will work but VQ-GAN may not be properly initialized")
    else:
        print("⚠️  VQ-GAN checkpoint not found. You may need to:")
        print("   1. Download it again")
        print("   2. Provide the correct path via vqgan_checkpoint_path parameter")
        print("   3. Ensure the file exists in one of the expected locations")
    
    # Load the trained transformer weights
    print("Loading trained transformer weights...")
    try:
        # Load with strict=False to handle potential key mismatches
        missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        
        if missing_keys:
            print(f"⚠️  Missing keys in checkpoint: {len(missing_keys)} keys")
            if len(missing_keys) <= 5:
                print(f"    {missing_keys}")
        if unexpected_keys:
            print(f"⚠️  Unexpected keys in checkpoint: {len(unexpected_keys)} keys")
            if len(unexpected_keys) <= 5:
                print(f"    {unexpected_keys}")
        
        print("✅ Transformer weights loaded!")
        
    except Exception as e:
        raise RuntimeError(f"Failed to load transformer weights: {e}")
    
    # Move to device
    try:
        model = model.to(device)
        print(f"✅ Model moved to {device}")
    except Exception as e:
        print(f"⚠️  Failed to move model to {device}: {e}")
        print("Falling back to CPU...")
        device = 'cpu'
        model = model.to(device)
    
    print(f"✅ Model loaded successfully on {device}")
    
    # Return model and checkpoint info
    checkpoint_info = {
        'epoch': checkpoint['epoch'],
        'loss': checkpoint.get('loss', None),
        'timestamp': checkpoint.get('timestamp', None),
        'model_config': checkpoint['model_config'],
        'device': device
    }
    
    return model, checkpoint_info

def load_with_optimizer(checkpoint_path, device='cuda', lr=5e-4, vqgan_checkpoint_path=None):
    """
    Load model with optimizer state for continued training
    
    Args:
        checkpoint_path: Path to the .pth checkpoint file
        device: Device to load the model on
        lr: Learning rate for optimizer (in case you want to change it)
        vqgan_checkpoint_path: Optional path to VQ-GAN checkpoint
    
    Returns:
        model: Loaded model
        optimizer: Optimizer with loaded state
        checkpoint_info: Dictionary with training information
    """
    # Load the model first
    model, checkpoint_info = load_saved_model(checkpoint_path, device, vqgan_checkpoint_path)
    
    # Create optimizer
    try:
        optimizer = torch.optim.AdamW(model.transformer.parameters(), lr=lr, betas=(0.9, 0.95))
        print("✅ Optimizer created")
    except Exception as e:
        raise RuntimeError(f"Failed to create optimizer: {e}")
    
    # Load checkpoint again to get optimizer state
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    except Exception as e:
        raise RuntimeError(f"Failed to reload checkpoint for optimizer: {e}")
    
    # Load optimizer state
    if 'optimizer_state_dict' in checkpoint:
        try:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("✅ Optimizer state loaded!")
        except Exception as e:
            print(f"⚠️  Failed to load optimizer state: {e}")
            print("Optimizer will use default initialization")
    else:
        print("⚠️  No optimizer state found in checkpoint")
    
    return model, optimizer, checkpoint_info

def find_latest_checkpoint(directory="/kaggle/working/", base_name="cvusa_ground2satellite"):
    """
    Find the latest checkpoint file in a directory
    
    Args:
        directory: Directory to search in
        base_name: Base name of the checkpoint files
    
    Returns:
        latest_checkpoint_path: Path to the latest checkpoint
    """
    if not os.path.exists(directory):
        raise FileNotFoundError(f"Directory not found: {directory}")
    
    try:
        files = os.listdir(directory)
    except PermissionError:
        raise PermissionError(f"Permission denied accessing directory: {directory}")
    
    checkpoint_files = []
    
    for filename in files:
        if filename.startswith(base_name) and filename.endswith('.pth'):
            filepath = os.path.join(directory, filename)
            try:
                # Get file modification time
                mtime = os.path.getmtime(filepath)
                checkpoint_files.append((mtime, filepath, filename))
            except OSError as e:
                print(f"⚠️  Could not access file {filename}: {e}")
                continue
    
    if not checkpoint_files:
        raise FileNotFoundError(f"No checkpoint files found with base name '{base_name}' in {directory}")
    
    # Sort by modification time (newest first)
    checkpoint_files.sort(reverse=True)
    latest_checkpoint = checkpoint_files[0][1]
    
    print(f"Found {len(checkpoint_files)} checkpoint files:")
    for i, (_, _, filename) in enumerate(checkpoint_files[:3]):  # Show top 3
        print(f"  {i+1}. {filename}")
    if len(checkpoint_files) > 3:
        print(f"  ... and {len(checkpoint_files) - 3} more")
    
    print(f"Latest checkpoint: {os.path.basename(latest_checkpoint)}")
    
    return latest_checkpoint

def setup_for_inference(model):
    """
    Prepare model for inference
    
    Args:
        model: Loaded model
    
    Returns:
        model: Model in eval mode
    """
    model.eval()
    # Optionally compile for better performance (if PyTorch 2.0+)
    try:
        if hasattr(torch, 'compile'):
            model = torch.compile(model)
            print("✅ Model compiled for better performance")
    except Exception as e:
        print(f"⚠️  Model compilation failed: {e}")
    
    return model

def get_model_info(model, checkpoint_info):
    """
    Print detailed model information
    
    Args:
        model: Loaded model
        checkpoint_info: Checkpoint information dictionary
    """
    print("\n" + "="*50)
    print("MODEL INFORMATION")
    print("="*50)
    
    # Checkpoint info
    print(f"Epoch: {checkpoint_info['epoch']}")
    print(f"Loss: {checkpoint_info.get('loss', 'N/A')}")
    print(f"Timestamp: {checkpoint_info.get('timestamp', 'N/A')}")
    print(f"Device: {checkpoint_info.get('device', 'Unknown')}")
    
    # Model config
    config = checkpoint_info['model_config']
    print(f"\nTransformer Configuration:")
    print(f"  - Vocab size: {config['transformer_vocab_size']}")
    print(f"  - Block size: {config['transformer_block_size']}")
    print(f"  - Layers: {config['transformer_n_layer']}")
    print(f"  - Heads: {config['transformer_n_head']}")
    print(f"  - Embedding dim: {config['transformer_n_embd']}")
    
    # Parameter counts
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nParameters:")
    print(f"  - Total: {total_params:,}")
    print(f"  - Trainable: {trainable_params:,}")
    print(f"  - Frozen: {total_params - trainable_params:,}")
    
    print("="*50)


# INFERENCE SET UP

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import os

def single_image_inference(model, ground_image_path, device=device, temperature=1.0, top_k=600, top_p=0.92, save_image=False, nameadd=""):

    
    # Load and preprocess the ground image
    ground_pil = Image.open(ground_image_path).convert('RGB')
    
    # Use same transform as your training
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
    ground_tensor = transform(ground_pil).unsqueeze(0).to(device)  # Add batch dimension
    
    model.eval()
    
    with torch.no_grad():
        # Encode ground image to conditioning tokens
        ground_quant_c, _, ground_info = model.cond_stage_model.encode(ground_tensor)
        ground_indices = ground_info[2]
        
        # Handle reshape only if needed (like your training code)
        batch_size = ground_tensor.shape[0]
        if ground_indices.dim() == 1:
            ground_tokens_per_image = ground_indices.shape[0] // batch_size  
            ground_indices = ground_indices.view(batch_size, ground_tokens_per_image)
        
        # Generate satellite tokens autoregressively
        sequence = ground_indices  # Start with conditioning
        satellite_seq_length = 256  # 16x16 tokens
        
        for i in range(satellite_seq_length):
            logits, _ = model.transformer(sequence)
            next_token_logits = logits[:, -1, :] / temperature
            
            # Use your top_k_top_p_filtering function
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            
            # Sample instead of argmax
            probs = torch.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            
            sequence = torch.cat([sequence, next_token], dim=1)
        
        # Extract generated satellite tokens
        generated_tokens = sequence[:, -satellite_seq_length:]
        
        # Decode using first_stage_model
        h = w = 16
        z_indices_spatial = generated_tokens.view(batch_size, h, w)
        
        # Get quantized features from codebook
        quant_z = model.first_stage_model.quantize.embedding(z_indices_spatial)
        quant_z = quant_z.permute(0, 3, 1, 2).contiguous()  # [batch, embed_dim, h, w]
        
        generated_satellite_tensor = model.first_stage_model.decode(quant_z)
    
    # Convert tensors back to PIL images for display
    def tensor_to_displayable(tensor):
        # Convert tensor to displayable format [0,1]
        img = ((tensor.squeeze(0) + 1) / 2).cpu()
        return img.permute(1, 2, 0).clamp(0, 1)
    
    # Convert to displayable format
    ground_display = tensor_to_displayable(ground_tensor)
    generated_display = tensor_to_displayable(generated_satellite_tensor)
    
    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Show INPUT
    axes[0].imshow(ground_display)
    axes[0].set_title("INPUT\n(Ground View)", fontsize=16, fontweight='bold')
    axes[0].axis('off')
    
    # Show OUTPUT
    axes[1].imshow(generated_display)
    axes[1].set_title("OUTPUT\n(Generated Satellite)", fontsize=16, fontweight='bold')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Convert back to PIL for return
    to_pil = transforms.ToPILImage()
    generated_satellite_pil = to_pil(generated_display.permute(2, 0, 1))

    # MODIFIED: Create unique filename based on input image and parameters
    if save_image:
        # Extract filename from path (without extension)
        input_filename = os.path.splitext(os.path.basename(ground_image_path))[0]
        
        # Create unique filename with parameters
        unique_filename = f"generated_{input_filename}_{nameadd}_temp{temperature}_k{top_k}_p{top_p}.png"
        
        generated_satellite_pil.save(unique_filename)
        print(f"✅ Generated image saved as '{unique_filename}'")
    
    return generated_satellite_pil, ground_pil

# Conservative: 
#single_image_inference(model, '/kaggle/input/ground-normal/0000008.jpg', temperature=0.0001, top_k=5000, top_p=1.0, save_image=True) #Deterministic
#single_image_inference(model, '/kaggle/input/ground-normal/0000008.jpg', temperature=1.0, top_k=300, top_p=0.92, save_image=True)    #Conservativeish



# SAVE AND TRAIN

In [None]:
import torch.nn.functional as F
import torch.optim as optim
import torch
import os
from datetime import datetime
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

def save_model_with_timestamp(model, optimizer, epoch, loss, base_name="cvusa_ground2satellite"):
    """Save model with automatic timestamp naming in Kaggle working directory"""
    
    # Kaggle working directory
    save_dir = "/kaggle/working/"
    
    # Create timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create filename
    filename = f"{base_name}_epoch{epoch}_loss{loss:.3f}_{timestamp}.pth"
    
    # Full path
    full_path = os.path.join(save_dir, filename)
    
    # Save model state
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'timestamp': timestamp,
        'model_config': {
            'transformer_vocab_size': model.transformer.config.vocab_size,
            'transformer_block_size': model.transformer.config.block_size,
            'transformer_n_layer': model.transformer.config.n_layer,
            'transformer_n_head': model.transformer.config.n_head,
            'transformer_n_embd': model.transformer.config.n_embd,
        }
    }
    
    torch.save(checkpoint, full_path)
    print(f"✅ Model saved as: {full_path}")
    print(f"📁 Location: Kaggle working directory")
    return full_path

def evaluate_model(model, test_dataloader, device):
    """
    Evaluate model on test set and return metrics
    """
    model.eval()
    total_loss = 0
    total_tokens = 0
    correct_predictions = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in test_dataloader:
            # Move to device
            ground_imgs = batch['ground'].to(device)
            satellite_imgs = batch['satellite'].to(device)
            
            # Forward pass
            logits, target = manual_forward_pass(model, satellite_imgs, ground_imgs)
            
            # Calculate loss
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
            
            # Calculate token accuracy
            predictions = logits.argmax(dim=-1)
            correct = (predictions == target).sum().item()
            
            # Accumulate metrics
            total_loss += loss.item()
            total_tokens += target.numel()
            correct_predictions += correct
            num_batches += 1
    
    # Calculate averages
    avg_loss = total_loss / num_batches
    token_accuracy = correct_predictions / total_tokens
    
    model.train()  # Switch back to training mode
    
    return {
        'loss': avg_loss,
        'token_accuracy': token_accuracy,
        'total_tokens': total_tokens,
        'total_batches': num_batches
    }

def run_visual_inference_samples(model, device):
    """Run visual inference on sample images"""
    print("🎨 VISUAL INFERENCE SAMPLES:")
    print("Conservative (temp=0.8) then Balanced (temp=1.0)")
    
    sample_images = [
        '/kaggle/input/cvusa-subset/streetview/0001149.jpg',
        '/kaggle/input/cvusa-subset/streetview/0003140.jpg', 
        '/kaggle/input/cvusa-subset/streetview/0003932.jpg',
        '/kaggle/input/cvusa-subset/streetview/0007602.jpg',
        '/kaggle/input/cvusa-subset/streetview/0008545.jpg',
    ]
    
    for img_path in sample_images:
        print(f"\n--- {os.path.basename(img_path)} ---")
        # Conservative - don't save during training to avoid clutter
        single_image_inference(model, img_path, device=device, temperature=0.8, top_k=100, top_p=0.9, save_image=False)
        # Balanced - don't save during training
        single_image_inference(model, img_path, device=device, temperature=1.0, top_k=300, top_p=0.92, save_image=False)

def train_one_epoch(model, train_dataloader, optimizer, scaler, device):
    """Train for one epoch and return average loss"""
    model.train()
    epoch_loss = 0
    num_batches = 0
    
    for batch_idx, batch in enumerate(train_dataloader):
        # Move to device
        ground_imgs = batch['ground'].to(device)
        satellite_imgs = batch['satellite'].to(device)
        
        # Forward pass with mixed precision
        with autocast():
            logits, target = manual_forward_pass(model, satellite_imgs, ground_imgs)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1),label_smoothing=0.1)
        
        # Backward pass with gradient scaling
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        scaler.step(optimizer)
        scaler.update()
        
        # Track loss
        epoch_loss += loss.item()
        num_batches += 1
        
        # Print progress every 50 batches
        if batch_idx % 50 == 0:
            print(f"  Batch {batch_idx}, Loss: {loss.item():.4f}")
    
    return epoch_loss / num_batches

def train_model_with_evaluation(model, train_dataloader, test_dataloader, num_epochs=50, lr=5e-4, inference_check=False):
    """
    Modified training with train/test split and overfitting detection
    """
    
    # Setup training parameters with weight decay
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model = torch.compile(model)  # Can give 10-20% speedup
    
    # Optimizer with weight decay for regularization
    optimizer = optim.AdamW(model.transformer.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1) #weight decay 0.01 was not enough
    scaler = GradScaler()
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
    
    print(f"🚀 Starting training for {num_epochs} epochs...")
    print(f"📊 Training set: {len(train_dataloader)} batches")
    print(f"📊 Test set: {len(test_dataloader)} batches")
    print(f"⚙️  Learning rate: {lr}, Weight decay: 0.01")
    print(f"🖥️  Training on device: {device}")
    
    # Tracking variables
    best_test_loss = float('inf')
    previous_gap = 0
    best_model_path = None  # Track the last saved best model to delete it
    
    for epoch in range(num_epochs):
        print(f"\n{'='*60}")
        print(f"📅 EPOCH {epoch + 1}/{num_epochs}")
        print(f"{'='*60}")
        
        # 1. TRAINING PHASE
        print("🏋️ Training phase...")
        train_loss = train_one_epoch(model, train_dataloader, optimizer, scaler, device)
        
        # 2. EVALUATION PHASE  
        print("🧪 Evaluation phase...")
        test_metrics = evaluate_model(model, test_dataloader, device)
        test_loss = test_metrics['loss']
        test_accuracy = test_metrics['token_accuracy']
        
        # 3. UPDATE LEARNING RATE
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # 4. OVERFITTING DETECTION (using absolute value)
        current_gap = abs(test_loss - train_loss)  # FIXED: Use absolute value
        gap_status = ""
        if epoch > 0:  # Skip first epoch comparison
            if current_gap > previous_gap:
                gap_status = "⚠️  WARNING: Gap widening (potential overfitting!)"
            else:
                gap_status = "✅ Gap stable/improving"
        previous_gap = current_gap
        
        # 5. PRINT EPOCH SUMMARY
        print(f"\n📊 EPOCH {epoch + 1} SUMMARY:")
        print(f"   🏋️  Train Loss:     {train_loss:.4f}")
        print(f"   🧪 Test Loss:      {test_loss:.4f}")
        print(f"   🎯 Test Accuracy:  {test_accuracy:.3f} ({test_accuracy*100:.1f}%)")
        print(f"   📈 Learning Rate:  {current_lr:.2e}")
        print(f"   📏 Loss Gap (abs): {current_gap:.4f}")
        if gap_status:
            print(f"   {gap_status}")
        
        # 6. BEST MODEL SAVING (when test loss improves)
        if test_loss < best_test_loss:
            # FIXED: Calculate improvement BEFORE updating best_test_loss
            improvement_amount = best_test_loss - test_loss
            improvement = f"improved by {improvement_amount:.4f}" if epoch > 0 else "first save"
            
            # Delete previous best model if it exists
            if best_model_path and os.path.exists(best_model_path):
                os.remove(best_model_path)
                print(f"   🗑️  Deleted previous best model: {os.path.basename(best_model_path)}")
            
            # Update best loss and save new best model
            best_test_loss = test_loss
            print(f"   🏆 New best test loss! Saving 'improve' model... ({improvement})")
            best_model_path = save_model_with_timestamp(model, optimizer, epoch+1, test_loss, 
                                    base_name="cvusa_ground2satellite_improve")
        
        # 7. ROUTINE SAVING + DETAILED METRICS (every 5 epochs)
        if (epoch + 1) % 5 == 0:
            print(f"\n📅 EPOCH {epoch + 1} - DETAILED EVALUATION:")
            
            # Calculate and display perplexity
            perplexity = torch.exp(torch.tensor(test_loss))
            print(f"   📈 Test Perplexity: {perplexity:.2f}")
            print(f"      (Model is as 'confused' as choosing randomly from ~{perplexity:.0f} options)")
            
            # Routine save
            print("   💾 Routine save...")
            save_model_with_timestamp(model, optimizer, epoch+1, test_loss, 
                                    base_name="cvusa_ground2satellite_routine")
            
            # Visual inference samples
            if inference_check:
                run_visual_inference_samples(model, device)
            else:
                print("Inference check not enabled\n")
            
            print(f"   ✅ Epoch {epoch + 1} detailed evaluation complete")
    
    print(f"\n🎉 TRAINING COMPLETED!")
    print(f"🏆 Best test loss achieved: {best_test_loss:.4f}")
    print(f"📁 All models saved in: /kaggle/working/")
    if best_model_path:
        print(f"🥇 Best model: {os.path.basename(best_model_path)}")


epochs = 75
learning_rate = 5e-4 # 5e-4 was maybe too much
    
train_model_with_evaluation(
        model=model,  # Your existing model
        train_dataloader=train_loader,  # New train dataloader
        test_dataloader=test_loader,    # New test dataloader
        num_epochs=epochs,
        lr=learning_rate,
        inference_check=False,
    )