# GMMM Submission AI Hackathon Lockeed Martin

 - Author: Cesar Ruiz, Edyan Cruz, Angel Morales, Yahid Diaz
 - Date: September 7 , 2025

## Download dependencies:

In [None]:
!pip install -r requirements.txt

## Download Dataset

In [None]:
import kagglehub
path = kagglehub.dataset_download("olebro/nasa-geographical-objects-multilabel-dataset")

In [None]:
path

### Importing Libraries

In [None]:
import argparse
import ast
import datetime
import json
import os
import time
from pathlib import Path
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score
from transformers import AutoImageProcessor,AutoModelForImageClassification,get_cosine_schedule_with_warmup
import ast, numpy as np, pandas as pd, torch
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, f1_score, confusion_matrix
import seaborn as sns
from collections import OrderedDict
from typing import List

## Configs

In [None]:
torch.manual_seed(0)
np.random.seed(0)
cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_csv = "data/train.csv"
val_csv = "data/val.csv"
images_path = os.path.join(path, "images")
image_size = 224
model_id = "facebook/dinov2-base"
batch_size = 32
workers = 15
num_classes = 10
weight_decay = 0.05
warmup_ratio = 0.05
epochs = 3
lr = 4e-5
output_dir = "output"
resume = False
resume_path = "/home/sagemaker-user/satmae_pp-1/final_submission/output/best_f1_0.9469.pt"

## Dataset Definition

In [None]:
class KaggleGeographicalDataset(Dataset):
    """
    Reads `FileName` and `Label Vector`.
    Builds a case-insensitive map of files and tolerates extension differences.
    Returns (pixel_values, label_tensor).
    """
    def __init__(self, csv_file, images_dir, processor_name, image_size=224, augment=False):
        self.df = pd.read_csv(csv_file)
        self.images_dir = images_dir
        self.image_size = image_size
        self.augment = augment
        self.augment_prob = 0.05

        self.processor = AutoImageProcessor.from_pretrained(
            processor_name,
            do_resize=True,
            size={"height": image_size, "width": image_size},
            do_center_crop=False,
        )
    def __len__(self):
        return len(self.df)

    def _hflip(self, img):
        return np.ascontiguousarray(img[:, ::-1, :])

    def _vflip(self, img):
        return np.ascontiguousarray(img[::-1, :, :])

    def _rot90k(self, img, k):
        if k % 4 == 0:
            return img
        return np.ascontiguousarray(np.rot90(img, k).copy())

    def _small_rotate(self, img, max_deg=360):
        ang = (np.random.rand() * 2 - 1) * max_deg
        h, w = img.shape[:2]
        M = cv2.getRotationMatrix2D((w/2, h/2), ang, 1.0)
        return cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)

    def _random_resized_crop(self, img, scale=(0.6, 1.0), ratio=(0.75, 1.33)):
        h, w = img.shape[:2]
        area = h * w
        for _ in range(10):
            target_area = np.random.uniform(*scale) * area
            log_ratio = (np.log(ratio[0]), np.log(ratio[1]))
            aspect = np.exp(np.random.uniform(*log_ratio))
            new_w = int(round(np.sqrt(target_area * aspect)))
            new_h = int(round(np.sqrt(target_area / aspect)))
            if 0 < new_w <= w and 0 < new_h <= h:
                x1 = np.random.randint(0, w - new_w + 1)
                y1 = np.random.randint(0, h - new_h + 1)
                crop = img[y1:y1+new_h, x1:x1+new_w]
                return cv2.resize(crop, (w, h), interpolation=cv2.INTER_LINEAR)
        min_side = min(h, w)
        y1 = (h - min_side) // 2; x1 = (w - min_side) // 2
        crop = img[y1:y1+min_side, x1:x1+min_side]
        return cv2.resize(crop, (w, h), interpolation=cv2.INTER_LINEAR)

    def _color_jitter(self, img, br=0.2, ct=0.2, sat=0.2):
        img_f = img.astype(np.float32)
        if br > 0:
            factor = 1.0 + np.random.uniform(-br, br)
            img_f = img_f * factor
        if ct > 0:
            mean = img_f.mean(axis=(0,1), keepdims=True)
            factor = 1.0 + np.random.uniform(-ct, ct)
            img_f = (img_f - mean) * factor + mean
        img_f = np.clip(img_f, 0, 255)

        if sat > 0:
            hsv = cv2.cvtColor(img_f.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
            s_factor = 1.0 + np.random.uniform(-sat, sat)
            hsv[...,1] = np.clip(hsv[...,1] * s_factor, 0, 255)
            img_f = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32)

        return np.clip(img_f, 0, 255).astype(np.uint8)

    def _gaussian_blur(self, img):
        k = np.random.choice([3, 5])
        return cv2.GaussianBlur(img, (k, k), 0)
  

    def _gaussian_noise(self, img, sigma=5.0):
        noise = np.random.randn(*img.shape).astype(np.float32) * sigma
        out = img.astype(np.float32) + noise
        return np.clip(out, 0, 255).astype(np.uint8)
  

    def _random_erasing(self, img, area_ratio=(0.02, 0.12), min_aspect=0.3):

        h, w = img.shape[:2]
        area = h * w
        for _ in range(10):
            target = np.random.uniform(*area_ratio) * area
            aspect = np.random.uniform(min_aspect, 1/min_aspect)
            er_w = int(round(np.sqrt(target * aspect)))
            er_h = int(round(np.sqrt(target / aspect)))
            if er_w < w and er_h < h:
                x1 = np.random.randint(0, w - er_w + 1)
                y1 = np.random.randint(0, h - er_h + 1)
                fill = np.random.randint(0, 256, (er_h, er_w, 3), dtype=np.uint8)
                img[y1:y1+er_h, x1:x1+er_w] = fill
                return img
        return img

    def _maybe_augment(self, img_rgb):
        if not self.augment:
            return img_rgb

        # Order of ops: geo -> crop -> color -> blur/noise -> erase
        if np.random.rand() < self.augment_prob:
            img_rgb = self._hflip(img_rgb)
        if np.random.rand() < self.augment_prob:
            img_rgb = self._vflip(img_rgb)
        if np.random.rand() < self.augment_prob:
            img_rgb = self._rot90k(img_rgb, np.random.randint(0, 4))
        if np.random.rand() < self.augment_prob:
            img_rgb = self._small_rotate(img_rgb, max_deg=90)
        if np.random.rand() < self.augment_prob:
            img_rgb = self._random_resized_crop(img_rgb, scale=(0.6, 1.0), ratio=(0.75, 1.33))
        if np.random.rand() < self.augment_prob:
            img_rgb = self._color_jitter(img_rgb, br=0.15, ct=0.15, sat=0.15)
        if np.random.rand() < self.augment_prob:
            img_rgb = self._gaussian_blur(img_rgb)
        if np.random.rand() < self.augment_prob:
            img_rgb = self._gaussian_noise(img_rgb, sigma=5.0)
        if np.random.rand() < self.augment_prob:
            img_rgb = self._random_erasing(img_rgb, area_ratio=(0.02, 0.12), min_aspect=0.3)

        return img_rgb


    def __getitem__(self, idx):
        img_path = os.path.join(images_path,self.df["FileName"][idx])
        image = cv2.imread(img_path)  # BGR
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self._maybe_augment(image)
        label_vec = ast.literal_eval(self.df["Label Vector"][idx])
        label = torch.tensor(label_vec, dtype=torch.float32)
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        return pixel_values, label

## Balance Helper Function

In [None]:
def compute_pos_weight_from_csv(train_csv_path: str, device: torch.device, eps: float = 1e-6):
    df = pd.read_csv(train_csv_path)
    Y = np.stack([np.array(ast.literal_eval(s), dtype=np.float32) for s in df["Label Vector"]])  # (N, C)
    pos = Y.sum(axis=0)                     # positives per class
    neg = Y.shape[0] - pos                  # negatives per class
    pw = neg / (pos + eps)                  # ratio -> larger weight for rare classes
    # (Optional) clamp huge values if you have classes with 0 positives
    pw = np.clip(pw, 1.0, 100.0)
    return torch.tensor(pw, dtype=torch.float32, device=device)

## F1 Calculator

In [None]:

@torch.no_grad()
def evaluate(model, dataloader, device, criterion):
    model.eval()

    total_loss = 0.0
    all_pred, all_true = [], []

    for pixel_values, labels in dataloader:
        pixel_values = pixel_values.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(pixel_values=pixel_values).logits
        loss = criterion(logits, labels)
        total_loss += loss.item() * pixel_values.size(0)

        preds = (logits.sigmoid() > 0.5).int().cpu().numpy()
        all_pred.append(preds)
        all_true.append(labels.int().cpu().numpy())

    all_pred = np.vstack(all_pred) if all_pred else np.zeros((0, 0))
    all_true = np.vstack(all_true) if all_true else np.zeros((0, 0))

    f1_micro = f1_score(all_true, all_pred, average="micro", zero_division=0) if len(all_true) else 0.0
    f1_macro = f1_score(all_true, all_pred, average="macro", zero_division=0) if len(all_true) else 0.0
    avg_loss = total_loss / max(len(dataloader.dataset), 1)

    return {"loss": avg_loss, "f1_micro": f1_micro, "f1_macro": f1_macro, "logits": logits}

## Training Helper

In [None]:
def train_one_epoch(model, dataloader, optimizer, scheduler, device, scaler, criterion):
    model.train()
    seen = 0
    running = 0.0
    labeles =[]
    for pixel_values, labels in dataloader:
        pixel_values = pixel_values.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        labeles.append(labels)

        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=torch.cuda.is_available()):
            logits = model(pixel_values=pixel_values).logits
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running += loss.item() * pixel_values.size(0)
        seen += pixel_values.size(0)

    return labeles

## Main Training Loop

In [None]:
train_ds = KaggleGeographicalDataset(
        train_csv, images_path, processor_name=model_id,
        image_size=image_size, augment=True,
    )
val_ds = KaggleGeographicalDataset(
        val_csv, images_path, processor_name=model_id,
        image_size=image_size, augment=False
    )

train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=workers, pin_memory=True, drop_last=True
    )
val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True
    )


model = AutoModelForImageClassification.from_pretrained(
        model_id,
        num_labels=num_classes,
        problem_type="multi_label_classification",
        ignore_mismatched_sizes=True,
    ).to(device)


no_decay = ["bias", "LayerNorm.weight"]
param_groups = [
        {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay},
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0},
    ]
optimizer = torch.optim.AdamW(param_groups, lr=lr)
total_steps = max(1, len(train_loader)) * epochs
warmup_steps = int(total_steps * warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
pos_weight = compute_pos_weight_from_csv(train_csv, device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
# Logging mkdir intialization
os.makedirs(output_dir, exist_ok=True)

# Train
best_f1 = 0.0
history_f1_val = []
history_f1_train = []
history_loss_val = []
history_loss_train = []
       # Store batch results
actual_labels = []
predicted_labels = []
t0 = time.time()
print(f"Start training for {epochs} epochs")
for epoch in range(epochs):
    actual_label = train_one_epoch(model, train_loader, optimizer, scheduler, device, scaler, criterion)
    val_stats  = evaluate(model, val_loader, device, criterion)
    train_stats = evaluate(model, train_loader, device, criterion)
    print(f"Epochs: {epoch}\n")
    print(f"[Eval] loss={val_stats['loss']:.4f} | f1_micro={val_stats['f1_micro']:.4f} | f1_macro={val_stats['f1_macro']:.4f}")
    print(f"[Train] loss={train_stats['loss']:.4f} | f1_micro={train_stats['f1_micro']:.4f} | f1_macro={train_stats['f1_macro']:.4f}\n")
    history_f1_val.append(val_stats['f1_micro'])
    history_f1_train.append(train_stats['f1_micro'])
    history_loss_val.append(val_stats["loss"])
    history_loss_train.append(train_stats["loss"])

    
    actual_labels.append(actual_label)
    predicted_labels.append(val_stats["logits"])

    if val_stats["f1_micro"] > best_f1:
        best_f1 = val_stats["f1_micro"]
        best_path = os.path.join(output_dir, f"best_f1_{best_f1:.4f}.pt")
        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "pos_weight": pos_weight.detach().cpu().tolist(), 
        }, best_path)

        print(f"New best F1_micro={best_f1:.4f} -> {best_path}")
total = str(datetime.timedelta(seconds=int(time.time() - t0)))
print(f"\nDone. Best F1_micro={best_f1:.4f}. Total time: {total}")

## Data Visualization

In [None]:
x_axis = np.arange(epochs) + 1
fig = plt.figure(figsize=(12,4))
ax = fig.add_subplot(1, 2, 1)
ax.plot(x_axis, history_loss_train, '-o', label = 'Train loss')
ax.plot(x_axis, history_loss_val, '--<', label = 'Validation loss')
ax.legend(fontsize=15)
ax.grid(True)
ax = fig.add_subplot(1, 2, 2)
ax.plot(x_axis, history_f1_train, '-o', label='Train F1')
ax.plot(x_axis, history_f1_val, '--<', label='Validation F1')        
ax.legend(fontsize=15)
ax.grid(True)
ax.set_xlabel('Epoch', size=15)
ax.set_ylabel('Epoch', size=15)
plt.show()

## Grade model from path

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, criterion, incorrect_predictions):
    model.eval()

    total_loss = 0.0
    all_pred, all_true = [], []

    for pixel_values, labels in dataloader:
        pixel_values = pixel_values.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(pixel_values=pixel_values).logits
        loss = criterion(logits, labels)
        total_loss += loss.item() * pixel_values.size(0)

        preds = (logits.sigmoid() > 0.5).int().cpu().numpy()
        all_pred.append(preds)
        all_true.append(labels.int().cpu().numpy())




    all_pred = np.vstack(all_pred) if all_pred else np.zeros((0, 0))
    all_true = np.vstack(all_true) if all_true else np.zeros((0, 0))

    f1_micro = f1_score(all_true, all_pred, average="micro", zero_division=0) if len(all_true) else 0.0
    f1_macro = f1_score(all_true, all_pred, average="macro", zero_division=0) if len(all_true) else 0.0
    avg_loss = total_loss / max(len(dataloader.dataset), 1)




    return {"loss": avg_loss, "f1_micro": f1_micro, "f1_macro": f1_macro}

if resume:
    ckpt = torch.load(resume_path, map_location="cpu", weights_only=False)
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])


    if "pos_weight" in ckpt:
        pos_weight = torch.tensor(ckpt["pos_weight"], dtype=torch.float32, device=device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    stats = evaluate(model, val_loader, device, criterion, incorrect_predictions)
    print(f"[Eval] loss={stats['loss']:.4f} | f1_micro={stats['f1_micro']:.4f} | f1_macro={stats['f1_macro']:.4f}")
    

## Grade Model (on validation set)

In [None]:
def to_numpy(x):
    return x.detach().cpu().numpy()

test_ds = KaggleGeographicalDataset("data/val.csv", images_path, processor_name=model_id)
test_dl = DataLoader(
    test_ds, batch_size=batch_size, shuffle=False, num_workers=8, prefetch_factor=8,
    pin_memory=True, persistent_workers=True
)

val_data_raw = pd.read_csv(val_csv)

criterion = nn.BCEWithLogitsLoss()
actual_labels = []
predicted_labels = []
image_paths = []  # To store paths of processed images
incorrect_predictions = []  # To store (image_path, actual, predicted) for incorrect predictions

f1_test = 0.0
image_idx = 0  # Keep track of position in dataset

with torch.no_grad():
    test_loss = 0.0
    for x, y in tqdm(test_dl):
        bs = batch_size
        x, y = x.to(device), y.to(device)
        logits = model(pixel_values=x).logits
        loss = criterion(logits, y.float())
        test_loss += loss.item()
        pred = torch.sigmoid(logits)
        # Convert to numpy for easier handling
        y_np = to_numpy(y)
        pred_np = to_numpy((pred > 0.5).int())

        # Calculate F1 score for this batch
        is_correct = f1_score(y_np, pred_np, average='micro')
        f1_test += is_correct

        # Store batch results
        actual_labels.append(y_np)
        predicted_labels.append(pred_np)

        # Check each image in the batch for correctness
        for i in range(bs):
            # Get the current image's index in the full dataset
            curr_idx = image_idx + i
            if curr_idx < len(test_ds):  # Ensure we don't go out of bounds
                # Get image path from test_data_raw
                img_path = val_data_raw.iloc[curr_idx]['FileName']
                
                # Compare prediction with actual label
                if not np.array_equal(pred_np[i], y_np[i]):
                    # This is an incorrect prediction
                    incorrect_predictions.append({
                        'image_path': img_path,
                        'actual': y_np[i],
                        'predicted': pred_np[i]
                    })

        # Update image index for next batch
        image_idx += bs

    f1_test /= np.ceil(len(test_dl.dataset) / bs)

print(f'Validation F1: {f1_test}')

## Confusion Matrix

In [None]:
def create_multilabel_confusion_matrix(y_true_arrays: List[list], y_pred_arrays: list, test_data_raw: pd.DataFrame):
    valid_combinations = []
    valid_label_names = {}
    
    # unique label vectors from test data
    for _, row in test_data_raw.drop_duplicates(subset=['Label Vector']).iterrows():
        label_vector = np.array(row['Label Vector'].strip('[]').split(', '), dtype=int)
        if isinstance(label_vector, np.ndarray):
            vector_tuple = tuple(label_vector.flatten())
        else:
            for arr in label_vector:
                if isinstance(arr, np.ndarray):
                    vector_tuple = tuple(arr.flatten())
                    break
                    
        valid_combinations.append(vector_tuple)
        valid_label_names[vector_tuple] = row['Label String']
    
    # add "Other" category
    other_idx = len(valid_combinations)
    
    # convert true and predicted arrays to tuples
    y_true_tuples = []
    y_pred_tuples = []
    
    for array in y_true_arrays:
        for row in array:
            y_true_tuples.append(tuple(row))
            
    for array in y_pred_arrays:
        for row in array:
            y_pred_tuples.append(tuple(row))
    
    # Map each tuple to the index of its class or to "Other"
    def get_class_index(tup):
        if tup in valid_combinations:
            return valid_combinations.index(tup)
        else:
            return other_idx
    
    # convert tuples to class indices
    y_true_indices = [get_class_index(t) for t in y_true_tuples]
    y_pred_indices = [get_class_index(t) for t in y_pred_tuples]
    
    # create labels for the confusion matrix
    labels = [valid_label_names[combo] for combo in valid_combinations] + ["Other"]
    
    # create the confusion matrix
    cm = confusion_matrix(y_true_indices, y_pred_indices, 
                          labels=range(len(valid_combinations) + 1))
    
    # create a DataFrame for better visualization
    cm_df = pd.DataFrame(cm, index=labels, columns=labels)
    
    return cm, cm_df, labels

def plot_confusion_matrix(cm_df, labels, figsize=(15, 15)):
    # create the plot
    plt.figure(figsize=figsize)
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix for Multilabel Classification')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    
    # rotate x-axis labels for better readability
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    plt.tight_layout()
    return plt.gcf()
val_df = pd.read_csv(val_csv)
cm, cm_df, labels = create_multilabel_confusion_matrix(
    actual_labels, predicted_labels,val_df )
fig = plot_confusion_matrix(cm_df, labels)
plt.show()

## Where the model is underperforming

In [None]:
# Now plot the incorrect images
def plot_incorrect_predictions(incorrect_preds, label_names=None, max_images=20):
    """
    Plot images with incorrect predictions
    
    Parameters:
    - incorrect_preds: List of dicts with keys 'image_path', 'actual', 'predicted'
    - label_names: List of label names corresponding to positions in one-hot vector
    - max_images: Maximum number of images to plot
    """
    # Limit number of images to display
    num_to_show = min(max_images, len(incorrect_preds))
    
    # Calculate grid size
    grid_size = int(np.ceil(np.sqrt(num_to_show)))
    
    plt.figure(figsize=(20, 20))
    
    for i in range(num_to_show):
        plt.subplot(grid_size, grid_size, i + 1)
        
        # Load and display image
        img_path = os.path.join(images_path, incorrect_preds[i]['image_path'])
        img = plt.imread(img_path)
        plt.imshow(img)
        
        # Format actual and predicted labels
        actual = incorrect_preds[i]['actual']
        predicted = incorrect_preds[i]['predicted']
        
        if label_names:
            # Convert one-hot encoded vectors to label names
            actual_labels_text = ', '.join([label_names[j] for j, val in enumerate(actual) if val == 1])
            pred_labels_text = ', '.join([label_names[j] for j, val in enumerate(predicted) if val == 1])
            if not actual_labels_text:
                actual_labels_text = "None"
            if not pred_labels_text:
                pred_labels_text = "None"
        else:
            # Display raw vectors
            actual_indices = [j for j, val in enumerate(actual) if val == 1]
            pred_indices = [j for j, val in enumerate(predicted) if val == 1]
            actual_labels_text = f"[{', '.join(map(str, actual_indices))}]"
            pred_labels_text = f"[{', '.join(map(str, pred_indices))}]"
        
        # set title
        plt.title(f"Actual: {actual_labels_text}\nPredicted: {pred_labels_text}", fontsize=10)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Plot incorrect predictions
plot_incorrect_predictions(incorrect_predictions)

## Model Insights

### Attention Visualization

In [None]:
n_params = sum(p.numel() for p in model.parameters())
print(f"Model params: {n_params/1e6:.2f}M")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model parameters: {total_params/1e6:.2f}M total")
print(f"Trainable parameters: {trainable_params/1e6:.2f}M")

# Show the architecture
print("\nModel architecture:\n")
print(model)

# Optional: per-layer parameter counts
print("\nParameter breakdown per layer:\n")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name:<60} {param.numel()/1e6:.3f}M")


In [None]:
import math
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import Optional, List, Tuple, Union

def _extract_pixel_values(batch, device):
    pv = None
    if isinstance(batch, dict):
        for k in ("pixel_values", "images", "inputs"):
            if k in batch:
                pv = batch[k]; break
        if pv is None:
            raise KeyError("Batch dict missing 'pixel_values'/'images'/'inputs'.")
    elif isinstance(batch, (list, tuple)):
        if len(batch) > 0 and isinstance(batch[0], dict):
            if "pixel_values" in batch[0]:
                pv = torch.stack([b["pixel_values"] for b in batch], dim=0)
            elif "images" in batch[0]:
                pv = torch.stack([b["images"] for b in batch], dim=0)
            else:
                raise KeyError("List of dicts lacks 'pixel_values'/'images'.")
        else:
            pv = batch[0]
    else:
        raise TypeError(f"Unsupported batch type: {type(batch)}")

    if pv.ndim == 3:
        pv = pv.unsqueeze(0)
    elif pv.ndim != 4:
        raise ValueError(f"Unexpected pixel_values shape: {tuple(pv.shape)}")
    if pv.dtype == torch.uint8:
        pv = pv.float() / 255.0
    return pv.to(device)


def _denorm_rgb(pixel_values: torch.Tensor,
                mean: List[float] = [0.485, 0.456, 0.406],
                std:  List[float] = [0.229, 0.224, 0.225]) -> List:
    mean_t = torch.tensor(mean, device=pixel_values.device).view(1,3,1,1)
    std_t  = torch.tensor(std,  device=pixel_values.device).view(1,3,1,1)
    rgb = (pixel_values * std_t + mean_t).clamp(0,1).detach().cpu()
    return [rgb[i].permute(1,2,0).numpy() for i in range(rgb.shape[0])]


def _infer_grid_hw_from_tokens(T_no_cls: int) -> Tuple[int,int]:
    s = int(round(math.sqrt(T_no_cls)))
    if s * s == T_no_cls:
        return s, s
    best = (1, T_no_cls); best_diff = T_no_cls - 1
    for h in range(1, int(math.sqrt(T_no_cls)) + 1):
        if T_no_cls % h == 0:
            w = T_no_cls // h
            if abs(h - w) < best_diff:
                best_diff = abs(h - w); best = (h, w)
    return best


def _grid_hw_from_model_and_inputs(model, pixel_values, tokens_no_cls: int) -> Tuple[int,int]:
    cfg = getattr(model, "config", None)
    ph = pw = None
    image_size = getattr(cfg, "image_size", None) if cfg else None
    patch_size = getattr(cfg, "patch_size", None) if cfg else None
    if isinstance(patch_size, (list, tuple)) and len(patch_size) == 2:
        ph, pw = patch_size
    elif isinstance(patch_size, int):
        ph = pw = patch_size
    B, C, H, W = pixel_values.shape

    if image_size is not None and ph is not None and pw is not None:
        if isinstance(image_size, (list, tuple)):
            ih, iw = image_size
        else:
            ih = iw = int(image_size)
        gh, gw = max(1, ih // ph), max(1, iw // pw)
        if gh * gw == tokens_no_cls:
            return gh, gw
    if ph is not None and pw is not None and (H % ph == 0) and (W % pw == 0):
        gh, gw = H // ph, W // pw
        if gh * gw == tokens_no_cls:
            return gh, gw
    return _infer_grid_hw_from_tokens(tokens_no_cls)


@torch.no_grad()
def _with_eager_attn_collect_attentions(model, pixel_values):
    """
    Temporarily set attn_implementation='eager' so output_attentions=True is supported.
    Restores the previous implementation afterwards.
    """
    was_training = model.training
    model.eval()

    cfg = getattr(model, "config", None)
    orig_impl = getattr(cfg, "attn_implementation", None) if cfg else None

    # Prefer official setter if present; else set on config directly.
    set_ok = False
    if hasattr(model, "set_attn_implementation"):
        try:
            model.set_attn_implementation("eager")
            set_ok = True
        except Exception:
            pass
    if not set_ok and cfg is not None:
        # Some configs gate via property; set private field directly if needed.
        try:
            cfg.attn_implementation = "eager"
            set_ok = True
        except Exception:
            # last resort for older versions
            setattr(cfg, "_attn_implementation", "eager")
            set_ok = True

    if not set_ok:
        raise RuntimeError("Could not switch attention implementation to 'eager'.")

    # Now we can safely request attentions
    outputs = model(pixel_values=pixel_values, output_attentions=True, return_dict=True)
    atts = outputs.attentions

    # restore previous impl
    try:
        if hasattr(model, "set_attn_implementation") and orig_impl is not None:
            model.set_attn_implementation(orig_impl)
        elif cfg is not None and orig_impl is not None:
            cfg.attn_implementation = orig_impl
    except Exception:
        pass

    if was_training:
        model.train()

    return atts


@torch.no_grad()
def get_attention_heatmap(
    model,
    pixel_values: torch.Tensor,
    layer_idx: int = -1,
    head_idx: Optional[int] = None,
    rollout: bool = False,
    cls_to_patch: bool = True
) -> torch.Tensor:
    """Return (B, H', W') heatmaps in [0,1]."""
    atts = _with_eager_attn_collect_attentions(model, pixel_values)  # tuple (L) of (B, heads, T, T)
    B, Hh, T, _ = atts[-1].shape
    tokens_no_cls = T - 1
    grid_h, grid_w = _grid_hw_from_model_and_inputs(model, pixel_values, tokens_no_cls)

    if rollout:
        eye = torch.eye(T, device=pixel_values.device).unsqueeze(0).expand(B, T, T)
        A = eye.clone()
        for A_l in atts:
            A_mean = A_l.mean(dim=1)
            A_hat = A_mean.clamp_min(0) + eye
            A_hat = A_hat / A_hat.sum(dim=-1, keepdim=True)
            A = torch.bmm(A, A_hat)
        cls_row = A[:, 0, 1:] if cls_to_patch else A[:, 1:, 0]
        heat = cls_row.reshape(B, grid_h, grid_w)
    else:
        A_l = atts[layer_idx]
        A_sel = A_l.mean(dim=1) if head_idx is None else A_l[:, head_idx]
        cls_row = A_sel[:, 0, 1:] if cls_to_patch else A_sel[:, 1:, 0]
        heat = cls_row.reshape(B, grid_h, grid_w)

    heat = heat - heat.amin(dim=(1,2), keepdim=True)
    heat = heat / heat.amax(dim=(1,2), keepdim=True).clamp_min(1e-8)
    return heat


@torch.no_grad()
def visualize_attention_overlays(
    model,
    pixel_values: torch.Tensor,
    rgb_images: Optional[List] = None,
    layer_idx: int = -1,
    head_idx: Optional[int] = None,
    rollout: bool = False,
    cmap: str = "jet",
    alpha: float = 0.45,
    figsize: Union[int, float] = 4,
    title_prefix: str = "Attn",
    save_dir: Optional[str] = None
):
    device = next(model.parameters()).device
    pixel_values = pixel_values.to(device)

    heat = get_attention_heatmap(model, pixel_values, layer_idx, head_idx, rollout)
    B, _, H, W = pixel_values.shape
    heat_up = F.interpolate(heat.unsqueeze(1), size=(H, W), mode="bilinear", align_corners=False).squeeze(1)

    for i in range(B):
        plt.figure(figsize=(figsize, figsize))
        if rgb_images is not None:
            plt.imshow(rgb_images[i])
        else:
            img = pixel_values[i].detach().cpu()
            img0 = (img[0] - img[0].min()) / (img[0].max() - img[0].min() + 1e-8)
            plt.imshow(img0, cmap="gray")
        plt.imshow(heat_up[i].detach().cpu(), cmap=cmap, alpha=alpha)
        plt.title(f"{title_prefix} | layer={layer_idx} | head={'avg' if head_idx is None else head_idx} | rollout={rollout}")
        plt.axis('off'); plt.tight_layout()

        if save_dir is None:
            plt.show()
        else:
            import os
            os.makedirs(save_dir, exist_ok=True)
            fname = f"attn_layer{layer_idx}_head{'avg' if head_idx is None else head_idx}_rollout{rollout}_idx{i}.png"
            plt.savefig(os.path.join(save_dir, fname), dpi=200, bbox_inches="tight"); plt.close()


@torch.no_grad()
def visualize_batch_from_loader(
    model,
    data_loader,
    device,
    n: int = 4,
    layer_idx: int = -1,
    head_idx: Optional[int] = None,
    rollout: bool = False,
    mean: List[float] = [0.485, 0.456, 0.406],
    std:  List[float] = [0.229, 0.224, 0.225],
    save_dir: Optional[str] = None
):
    model.eval()
    batch = next(iter(data_loader))
    pixel_values = _extract_pixel_values(batch, device)
    if pixel_values.shape[0] > n:
        pixel_values = pixel_values[:n]
    rgb_images = _denorm_rgb(pixel_values, mean=mean, std=std)
    visualize_attention_overlays(
        model, pixel_values, rgb_images=rgb_images,
        layer_idx=layer_idx, head_idx=head_idx, rollout=rollout,
        title_prefix="DINOv2 Attention", save_dir=save_dir
    )
model.to(device).eval()
visualize_batch_from_loader(model, val_loader, device, n=4, layer_idx=-1, head_idx=None, rollout=False)
visualize_batch_from_loader(model, val_loader, device, n=4, rollout=True)
visualize_batch_from_loader(model, val_loader, device, n=4, layer_idx=-1, head_idx=3, rollout=False)