In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import ResNet50_Weights
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.image_list import ImageList
from tqdm import tqdm
import time

# Paths
image_dir = '/home/umang.shikarvar/data/lucknow_airshed/val/images'
label_dir = '/home/umang.shikarvar/data/lucknow_airshed/val/labels'
save_path = '/home/umang.shikarvar/AL/AL_detectors'
entropy_file = '/home/umang.shikarvar/AL/uncertain_images.txt'
model_path = '/home/umang.shikarvar/AL/source_detectors/model_epoch_30.pth'
os.makedirs(save_path, exist_ok=True)

# Image transform
transform = transforms.Compose([transforms.ToTensor()])

# Set device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# === Dataset ===
class TIFRCNNDataset(Dataset):
    def __init__(self, image_dir, label_dir, selected_filenames, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.image_filenames = sorted([
            f for f in os.listdir(image_dir)
            if f.endswith('.tif') and f in selected_filenames
        ])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        label_path = os.path.join(self.label_dir, os.path.splitext(img_name)[0] + ".txt")
        boxes, labels = [], []

        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    if line.strip():
                        try:
                            cls, x1, y1, x2, y2 = map(float, line.strip().split())
                            boxes.append([x1, y1, x2, y2])
                            labels.append(int(cls) + 1)  # background=0
                        except:
                            continue

        boxes = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4))
        labels = torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)

        target = {"boxes": boxes, "labels": labels}
        image = self.transforms(image) if self.transforms else image

        return image, target

def collate_fn(batch):
    return tuple(zip(*batch))

# === Model ===
backbone = resnet_fpn_backbone(
    backbone_name='resnet50',
    weights=ResNet50_Weights.IMAGENET1K_V1
)
model = FasterRCNN(backbone, num_classes=4)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

# === Entropy function ===
def compute_entropy(logits):
    probs = F.softmax(logits[:, 1:], dim=1)  # skip background
    entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
    return entropy

from torchmetrics.detection.mean_ap import MeanAveragePrecision

def load_txt_as_ground_truth(txt_path):
    boxes, labels = [], []
    if os.path.exists(txt_path):
        with open(txt_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 5:
                    continue
                cls, x1, y1, x2, y2 = map(float, parts)
                labels.append(int(cls))
                boxes.append([x1, y1, x2, y2])
    return {
        "boxes": torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4)),
        "labels": torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)
    }

def load_txt_as_prediction(txt_path):
    boxes, labels, scores = [], [], []
    if os.path.exists(txt_path):
        with open(txt_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 6:
                    continue
                cls, conf, x1, y1, x2, y2 = map(float, parts)
                labels.append(int(cls))
                scores.append(conf)
                boxes.append([x1, y1, x2, y2])
    return {
        "boxes": torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4)),
        "labels": torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64),
        "scores": torch.tensor(scores, dtype=torch.float32) if scores else torch.zeros((0,))
    }

def make_class_agnostic(detection_list):
    for d in detection_list:
        d["labels"] = torch.zeros_like(d["labels"])
    return detection_list

def compute_map(test_img_dir, test_label_dir, pred_dir):
    image_filenames = sorted([
        f for f in os.listdir(test_img_dir) if f.lower().endswith(".tif")
    ])

    gt_targets, predictions = [], []

    for fname in image_filenames:
        base = os.path.splitext(fname)[0]
        gt = load_txt_as_ground_truth(os.path.join(test_label_dir, base + ".txt"))
        pred = load_txt_as_prediction(os.path.join(pred_dir, base + ".txt"))
        gt_targets.append(gt)
        predictions.append(pred)

    gt_targets = make_class_agnostic(gt_targets)
    predictions = make_class_agnostic(predictions)

    metric = MeanAveragePrecision()
    metric.update(predictions, gt_targets)
    results = metric.compute()
    return results["map_50"].item()

# === Active Learning Loop ===
cumulative_filenames = set()
top_k = 10
al_rounds = 5
fine_tune_epochs = 6

for al_epoch in range(al_rounds):
    print(f"\n🔁 Active Learning Epoch {al_epoch + 1}/{al_rounds}")

    if al_epoch==0:
        top_k=20
    else:
        top_k=5

    # 1. Evaluate entropy over all images
    model.eval()
    image_uncertainties = []

    with torch.no_grad():
        for filename in tqdm(os.listdir(image_dir)):
            if not filename.lower().endswith((".jpg", ".png", ".tif")):
                continue

            if filename in cumulative_filenames:
                continue  # Skip already added images

            img_path = os.path.join(image_dir, filename)
            image = Image.open(img_path).convert("RGB")
            image_tensor = transform(image).unsqueeze(0).to(device)
            image_size = [tuple(image_tensor.shape[-2:])]

            # Backbone feature extraction
            features = model.backbone(image_tensor)
            images = ImageList(image_tensor, image_size)
            proposals, _ = model.rpn(images, features)

            if len(proposals[0]) == 0:
                image_uncertainties.append((filename, 0.0))
                continue

            box_features = model.roi_heads.box_roi_pool(features, proposals, image_size)
            box_features = model.roi_heads.box_head(box_features)
            class_logits = model.roi_heads.box_predictor.cls_score(box_features)

            entropy = compute_entropy(class_logits)
            avg_entropy = entropy.mean().item()
            image_uncertainties.append((filename, avg_entropy))

    # 2. Select top-k uncertain images
    image_uncertainties.sort(key=lambda x: x[1], reverse=True)
    new_filenames = [fname for fname, _ in image_uncertainties[:top_k]]
    cumulative_filenames.update(new_filenames)

    # Save cumulative set
    with open(entropy_file, "w") as f:
        for fname in sorted(cumulative_filenames):
            f.write(f"{fname}\n")

    # 3. Build Dataset and Dataloader with updated set
    dataset = TIFRCNNDataset(
        image_dir=image_dir,
        label_dir=label_dir,
        selected_filenames=cumulative_filenames,
        transforms=transform
    )
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True, pin_memory=True, collate_fn=collate_fn)

    # 4. Fine-tune the model
    model.train()
    optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=1e-4)

    for epoch in range(fine_tune_epochs):
        epoch_loss = 0.0
        start_time = time.time()

        for images, targets in dataloader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            epoch_loss += losses.item()

        end_time = time.time()
        print(f"[AL {al_epoch + 1} | Epoch {epoch + 1}/{fine_tune_epochs}] Loss: {epoch_loss:.4f} | Time: {end_time - start_time:.2f}s")

    # Save model after each AL round
    torch.save(model.state_dict(), os.path.join(save_path, f"model_al_round_{al_epoch + 1}.pth"))

    # === Step 5: Run inference on test set after fine-tuning ===
    print(f"\n📤 Running inference on test set after AL round {al_epoch + 1}")

    test_image_dir = "/home/umang.shikarvar/data/lucknow_airshed/test/images"
    prediction_output_dir = "/home/umang.shikarvar/data/lucknow_predictions"
    
    # Clear previous predictions
    if os.path.exists(prediction_output_dir):
        for f in os.listdir(prediction_output_dir):
            os.remove(os.path.join(prediction_output_dir, f))
    else:
        os.makedirs(prediction_output_dir)

    model.eval()
    with torch.no_grad():
        for img_file in tqdm(os.listdir(test_image_dir)):
            if not img_file.lower().endswith(".tif"):
                continue

            image_path = os.path.join(test_image_dir, img_file)
            image = Image.open(image_path).convert("RGB")
            image_tensor = transform(image).unsqueeze(0).to(device)

            outputs = model(image_tensor)[0]  # Get single-image output

            base_name = os.path.splitext(img_file)[0]
            pred_txt_path = os.path.join(prediction_output_dir, base_name + ".txt")

            with open(pred_txt_path, "w") as f:
                for box, label, score in zip(outputs["boxes"], outputs["labels"], outputs["scores"]):
                    x1, y1, x2, y2 = box.tolist()
                    f.write(f"{label.item()} {score.item():.4f} {x1:.1f} {y1:.1f} {x2:.1f} {y2:.1f}\n")

    map50 = compute_map(
    test_img_dir="/home/umang.shikarvar/data/lucknow_airshed/test/images",
    test_label_dir="/home/umang.shikarvar/data/lucknow_airshed/test/labels",
    pred_dir="/home/umang.shikarvar/AL/lucknow_predictions"
)

    print(f"✅ Class-Agnostic mAP@0.50 after AL round {al_epoch + 1}: {map50:.4f}")


🔁 Active Learning Epoch 1/5


100%|██████████| 65/65 [00:03<00:00, 20.32it/s]


[AL 1 | Epoch 1/6] Loss: 0.9822 | Time: 1.37s
[AL 1 | Epoch 2/6] Loss: 0.5722 | Time: 1.38s
[AL 1 | Epoch 3/6] Loss: 0.6498 | Time: 1.39s
[AL 1 | Epoch 4/6] Loss: 0.4468 | Time: 1.41s
[AL 1 | Epoch 5/6] Loss: 0.4391 | Time: 1.40s
[AL 1 | Epoch 6/6] Loss: 0.3414 | Time: 1.39s

📤 Running inference on test set after AL round 1


100%|██████████| 533/533 [00:28<00:00, 18.49it/s]


✅ Class-Agnostic mAP@0.50 after AL round 1: 0.7204

🔁 Active Learning Epoch 2/5


100%|██████████| 65/65 [00:02<00:00, 28.49it/s]


[AL 2 | Epoch 1/6] Loss: 1.5966 | Time: 1.71s
[AL 2 | Epoch 2/6] Loss: 1.1208 | Time: 1.73s
[AL 2 | Epoch 3/6] Loss: 0.9528 | Time: 1.72s
[AL 2 | Epoch 4/6] Loss: 0.7816 | Time: 1.70s
[AL 2 | Epoch 5/6] Loss: 0.6952 | Time: 1.74s
[AL 2 | Epoch 6/6] Loss: 0.6557 | Time: 1.73s

📤 Running inference on test set after AL round 2


100%|██████████| 533/533 [00:29<00:00, 18.26it/s]


✅ Class-Agnostic mAP@0.50 after AL round 2: 0.7935

🔁 Active Learning Epoch 3/5


100%|██████████| 65/65 [00:02<00:00, 31.72it/s]


[AL 3 | Epoch 1/6] Loss: 0.6525 | Time: 2.07s
[AL 3 | Epoch 2/6] Loss: 0.6590 | Time: 2.03s
[AL 3 | Epoch 3/6] Loss: 0.5611 | Time: 2.04s
[AL 3 | Epoch 4/6] Loss: 0.5164 | Time: 2.01s
[AL 3 | Epoch 5/6] Loss: 0.4558 | Time: 2.08s
[AL 3 | Epoch 6/6] Loss: 0.4079 | Time: 2.14s

📤 Running inference on test set after AL round 3


100%|██████████| 533/533 [00:28<00:00, 18.42it/s]


✅ Class-Agnostic mAP@0.50 after AL round 3: 0.7397

🔁 Active Learning Epoch 4/5


100%|██████████| 65/65 [00:01<00:00, 35.75it/s]


[AL 4 | Epoch 1/6] Loss: 0.7298 | Time: 2.26s
[AL 4 | Epoch 2/6] Loss: 0.6932 | Time: 2.40s
[AL 4 | Epoch 3/6] Loss: 0.6262 | Time: 2.38s
[AL 4 | Epoch 4/6] Loss: 0.6388 | Time: 2.53s
[AL 4 | Epoch 5/6] Loss: 0.5021 | Time: 2.46s
[AL 4 | Epoch 6/6] Loss: 0.4910 | Time: 2.49s

📤 Running inference on test set after AL round 4


100%|██████████| 533/533 [00:28<00:00, 18.79it/s]


✅ Class-Agnostic mAP@0.50 after AL round 4: 0.7820

🔁 Active Learning Epoch 5/5


100%|██████████| 65/65 [00:01<00:00, 43.13it/s]


[AL 5 | Epoch 1/6] Loss: 0.5710 | Time: 2.65s
[AL 5 | Epoch 2/6] Loss: 0.4591 | Time: 2.82s
[AL 5 | Epoch 3/6] Loss: 0.3707 | Time: 2.70s
[AL 5 | Epoch 4/6] Loss: 0.3074 | Time: 2.76s
[AL 5 | Epoch 5/6] Loss: 0.2626 | Time: 2.70s
[AL 5 | Epoch 6/6] Loss: 0.2279 | Time: 2.73s

📤 Running inference on test set after AL round 5


100%|██████████| 533/533 [00:28<00:00, 18.53it/s]


✅ Class-Agnostic mAP@0.50 after AL round 5: 0.8272
