In [None]:
import os

# Install packages
!pip install --upgrade pip
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
!pip install tqdm matplotlib pillow

# Clone the repo using Python logic instead of Bash
repo_url = "https://github.com/taesungp/contrastive-unpaired-translation.git"
repo_name = "contrastive-unpaired-translation"

if not os.path.exists(repo_name):
    print(f"Cloning {repo_name}...")
    !git clone {repo_url}
else:
    print(f"{repo_name} already exists. Skipping clone.")

In [None]:
# Block 1 - Settings & imports
from pathlib import Path
import shutil
import sys
import os
import json
import math
import random
from pprint import pprint

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset

# ---------- EDIT if different ----------
PROJECT_ROOT = Path.cwd().parents[1] / "Assignment 2" / "Esther"
DATA_ROOT  = PROJECT_ROOT / "AML_project_herbarium_dataset"
TRAIN_LIST = DATA_ROOT / "list/train.txt"
TEST_LIST  = DATA_ROOT / "list/test.txt"
REMAPPED_DIR = DATA_ROOT / "list/remapped"
TRAIN_REMAP = REMAPPED_DIR / "train_remapped.txt"
TRAIN_WITH_SYN = REMAPPED_DIR / "train_remapped_with_synth.txt"
SPECIES_LIST_PATH = DATA_ROOT / "list/species_list.txt"  # uploaded file path
CLASS_WITHOUT_PAIRS = DATA_ROOT / "list/class_without_pairs.txt"  # expects this file
FASTCUT_WORKDIR = PROJECT_ROOT / "models/fastcut_work"
FASTCUT_ROOT = FASTCUT_WORKDIR / "fastcut_data"
FASTCUT_RESULTS = PROJECT_ROOT / "models/fastcut_results/herb2field_cut_fast/test_latest/images"
FAKE_B_DIR = FASTCUT_RESULTS / "fake_B"
SYN_DEST_DIR = Path(DATA_ROOT) / "synthetic"

# GPU ID (used for shell commands)
GPU_ID = 0

# training / finetune params (you can tune later)
EPOCHS_HEAD = 20
EPOCHS_FINE = 20
LR_HEAD = 1e-4
LR_FINE = 1e-5
BATCH_SIZE = 4

# convenience prints
print("PROJECT_ROOT:", PROJECT_ROOT)
print("DATA_ROOT:", DATA_ROOT)
print("TRAIN_REMAP exists:", TRAIN_REMAP.exists())
print("SPECIES_LIST path (uploaded):", SPECIES_LIST_PATH.exists(), SPECIES_LIST_PATH)
print("FASTCUT fake_B exists:", FAKE_B_DIR.exists())
print("FAKE_B count (if exists):", len(list(FAKE_B_DIR.glob("*.png"))) if FAKE_B_DIR.exists() else 0)


In [None]:
# Block 2 - build mappings from species_list.txt
orig_to_new = {}
new_to_orig = {}

if not SPECIES_LIST_PATH.exists():
    raise FileNotFoundError(f"species_list.txt not found at {SPECIES_LIST_PATH}. Please upload it.")

with open(SPECIES_LIST_PATH, "r", encoding="utf-8") as f:
    for new_idx, line in enumerate(f):
        line = line.strip()
        if not line:
            continue
        parts = line.split(";")
        orig_str = parts[0].strip()
        try:
            orig_id = int(orig_str)
        except:
            # try to coerce
            orig_id = int(orig_str.split()[0])
        orig_to_new[orig_id] = new_idx
        new_to_orig[new_idx] = orig_id

print("Total species mapped:", len(orig_to_new))
# print a few examples
for k in list(orig_to_new.keys())[:10]:
    print(k, "->", orig_to_new[k])


In [None]:
# Block 3 - class_without_pairs -> unpaired_new_ids
if not CLASS_WITHOUT_PAIRS.exists():
    print("Warning: class_without_pairs.txt not found at", CLASS_WITHOUT_PAIRS)
    # If you don't have the file, you can paste the original IDs manually as a list here:
    # unpaired_orig_ids = [105951, 106387, ...]
    raise FileNotFoundError("class_without_pairs.txt missing; place it in dataset/list/remapped/")

unpaired_orig_ids = []
with open(CLASS_WITHOUT_PAIRS, "r", encoding="utf-8") as f:
    for ln in f:
        ln = ln.strip()
        if not ln: continue
        try:
            unpaired_orig_ids.append(int(ln.split()[0]))
        except:
            try:
                unpaired_orig_ids.append(int(ln))
            except:
                print("Could not parse line:", ln)

# map to new indices
unpaired_new_ids = set()
missing_orig = []
for oid in unpaired_orig_ids:
    if oid in orig_to_new:
        unpaired_new_ids.add(orig_to_new[oid])
    else:
        missing_orig.append(oid)

print("Unpaired original count:", len(unpaired_orig_ids))
print("Mapped to new indices count:", len(unpaired_new_ids))
if missing_orig:
    print("Missing orig IDs from species_list mapping (should be none):", missing_orig)


In [None]:
# Block 3.5 - Generate train_remapped.txt
# This reads TRAIN_LIST, applies the mapping from Block 2 (orig_to_new),
# and saves the result to TRAIN_REMAP.

print(f"Generating {TRAIN_REMAP} from {TRAIN_LIST}...")

# 1. Ensure the directory exists
TRAIN_REMAP.parent.mkdir(parents=True, exist_ok=True)

remapped_lines = []
skipped_count = 0

# 2. Open original train list and map IDs
if not TRAIN_LIST.exists():
    raise FileNotFoundError(f"Original train list not found at {TRAIN_LIST}")

with open(TRAIN_LIST, "r", encoding="utf-8") as f_in:
    for line in f_in:
        line = line.strip()
        if not line:
            continue
        
        parts = line.split()
        if len(parts) < 2:
            continue
            
        rel_path = parts[0]
        try:
            # Assumes the original ID is the second element
            orig_id = int(parts[1])
        except ValueError:
            continue

        # 3. Apply Mapping (orig_id -> new_id)
        if orig_id in orig_to_new:
            new_id = orig_to_new[orig_id]
            # Write: relative/path/to/image.jpg new_id
            remapped_lines.append(f"{rel_path} {new_id}")
        else:
            skipped_count += 1

# 4. Save the new file
with open(TRAIN_REMAP, "w", encoding="utf-8") as f_out:
    for line in remapped_lines:
        f_out.write(line + "\n")

print(f"Done! Created {TRAIN_REMAP}")
print(f"Total lines written: {len(remapped_lines)}")
print(f"Skipped (ID not in species_list): {skipped_count}")

In [None]:
# Block 4 - check train_remapped.txt
if not TRAIN_REMAP.exists():
    raise FileNotFoundError("train_remapped.txt not found at expected location: " + str(TRAIN_REMAP))

lines = []
with open(TRAIN_REMAP, "r", encoding="utf-8") as f:
    for ln in f:
        ln = ln.strip()
        if not ln: continue
        parts = ln.split()
        if len(parts) < 2:
            continue
        lines.append((parts[0], int(parts[1])))

print("First 10 lines of train_remapped.txt:")
for item in lines[:10]:
    print(item)
print("Total remapped train lines:", len(lines))

# Count how many herbarium images belong to unpaired_new_ids
count_herb_unpaired = sum(1 for rel, cls in lines if cls in unpaired_new_ids and "herbarium" in rel.lower())
print("Herbarium images for unpaired classes in train_remapped.txt:", count_herb_unpaired)


In [None]:
# Clean folders
import shutil
from pathlib import Path

CHECK_DIR = Path("contrastive-unpaired-translation/checkpoints/herb2field_cut_fast")
RESULT_DIR = Path("/content/drive/MyDrive/COS30082_Cross_Domain/models/fastcut_results/herb2field_cut_fast")

shutil.rmtree(CHECK_DIR, ignore_errors=True)
shutil.rmtree(RESULT_DIR, ignore_errors=True)

print("Cleaned old checkpoints + results.")

shutil.rmtree(FASTCUT_ROOT, ignore_errors=True)
FASTCUT_ROOT.mkdir(parents=True, exist_ok=True)
print("Cleaned FASTCUT data directories.")

In [None]:
# Block 5 - prepare FASTCUT directories (trainA/trainB/testA)
from pathlib import Path
import shutil

FASTCUT_ROOT = Path(FASTCUT_WORKDIR) / "fastcut_data"
trainA = FASTCUT_ROOT / "trainA"
trainB = FASTCUT_ROOT / "trainB"
valA = FASTCUT_ROOT / "valA"
valB = FASTCUT_ROOT / "valB"
testA = FASTCUT_ROOT / "testA"

# reset directories
for p in (trainA, trainB, valA, valB, testA):
    if p.exists():
        shutil.rmtree(p)
    p.mkdir(parents=True, exist_ok=True)

# Build trainA/trainB from your remapped train list
addedA = 0
addedB = 0

with open(TRAIN_REMAP, "r", encoding="utf-8") as f:
    for ln in f:
        ln = ln.strip()
        if not ln: continue
        rel, cls_str = ln.split()
        cls = int(cls_str)
        src = Path(DATA_ROOT) / rel
        if not src.exists():
            continue
        if "herbarium" in rel.lower():
            shutil.copy(src, trainA / src.name)
            addedA += 1
        elif "photo" in rel.lower() or "field" in rel.lower():
            shutil.copy(src, trainB / src.name)
            addedB += 1

# Build testA using herbarium TRAIN images for UNPAIRED classes (index-based mapping needs these)
added_testA = 0
with open(TRAIN_REMAP, "r", encoding="utf-8") as f:
    for ln in f:
        ln = ln.strip()
        if not ln: continue
        rel, cls_str = ln.split()
        cls = int(cls_str)
        if cls not in unpaired_new_ids:
            continue
        if "herbarium" not in rel.lower():
            continue
        src = Path(DATA_ROOT) / rel
        if src.exists():
            shutil.copy(src, testA / src.name)
            added_testA += 1

print("TrainA (herbarium) count:", addedA)
print("TrainB (photo/field) count:", addedB)
print("TestA (herbarium - unpaired candidate) count:", added_testA)

In [None]:
# Block 6 - Train FastCUT (Windows Fixed - Corrected Arguments)

# 1. Enter the directory
%cd contrastive-unpaired-translation

# 2. Install dependencies (if not already done)
%pip install dominate visdom

# 3. Run training
# Changes made:
# - Removed backslashes (\) to fit on one line for Windows stability
# - Changed --n_threads 0 to --num_threads 0 (this was the error)
# - Kept -u for unbuffered output so you can see the progress bar
!python -u train.py --dataroot "{FASTCUT_ROOT}" --name herb2field_cut_fast --model cut --no_dropout --gpu_ids {GPU_ID} --n_epochs 10 --n_epochs_decay 10 --batch_size 4 --save_epoch_freq 10 --print_freq 200 --load_size 192 --crop_size 128 --nce_layers 4,8 --num_patches 128 --no_html --display_id -1 --num_threads 0

# 4. Go back to original directory
%cd -

In [None]:
print(os.listdir("contrastive-unpaired-translation/checkpoints/herb2field_cut_fast"))

In [None]:
import os

print(f"Listing all files in: {FASTCUT_ROOT}")

# os.walk allows us to look recursively into all subfolders (like ls -R)
for root, dirs, files in os.walk(FASTCUT_ROOT):
    # Print the current directory
    print(f"\nüìÇ {root}")
    # Print files in this directory
    for file in files:
        print(f"   ‚îî‚îÄ‚îÄ {file}")

In [None]:
import os
os.environ["FASTCUT_ROOT"] = str(FASTCUT_ROOT)
FASTCUT_ROOT

In [None]:
# Block - Setup testB with a dummy image (Python Version)
import shutil
from pathlib import Path

# Ensure FASTCUT_ROOT is a Path object
root = Path(FASTCUT_ROOT)
trainB = root / "trainB"
testB = root / "testB"

print(f"FASTCUT_ROOT = {root}")

# 1. Create testB directory
testB.mkdir(parents=True, exist_ok=True)

# 2. Pick one field image from trainB
# We use glob to find files, equivalent to 'ls'
try:
    # Get the first file found in trainB
    first_field_img = next(trainB.glob("*")) 
    print(f"Dummy image source: {first_field_img.name}")

    # 3. Copy dummy image
    dest = testB / "dummy.jpg"
    shutil.copy(first_field_img, dest)
    print(f"Successfully copied to: {dest}")

    # 4. Verify (List files in testB)
    print("Contents of testB:", list(testB.glob("*")))

except StopIteration:
    print("‚ùå Error: No images found in trainB! Make sure Block 5 ran correctly.")
except FileNotFoundError:
    print(f"‚ùå Error: Directory not found: {trainB}")

In [None]:
# Block 7 - FastCUT inference (Windows Fixed)

# 1. Enter the directory
%cd contrastive-unpaired-translation

# 2. Run inference on a SINGLE LINE
# Changes:
# - Combined into one line to fix the {GPU_ID} error
# - Changed {DRIVE_ROOT} to {PROJECT_ROOT} so it saves to your local hard drive correctly
!python test.py --dataroot "{FASTCUT_ROOT}" --name herb2field_cut_fast --model cut --no_dropout --phase test --serial_batches --results_dir "{PROJECT_ROOT}/models/fastcut_results" --num_test 999999 --gpu_ids {GPU_ID}

# 3. Return to project root
%cd -

In [None]:
# verify
from pathlib import Path
fakeB = Path(PROJECT_ROOT) / "models/fastcut_results/herb2field_cut_fast/test_latest/images/fake_B"
print("fake_B exists:", fakeB.exists())
print("png count:", len(list(fakeB.glob("*.png"))))


In [None]:
# Block 9 - Integrate synthetic images via SAFE NUMERIC MATCHING

from pathlib import Path
import shutil
import re

# Paths
SYN_SRC_FAKEB = Path(PROJECT_ROOT) / "models/fastcut_results/herb2field_cut_fast/test_latest/images/fake_B"
FASTCUT_TESTA = FASTCUT_ROOT / "testA"
OUTPUT_LIST = TRAIN_WITH_SYN
SYN_DEST_DIR = Path(DATA_ROOT) / "synthetic"
SYN_DEST_DIR.mkdir(parents=True, exist_ok=True)

# ------- Helper: extract numeric ID from filenames --------
def extract_id(path_obj):
    """extract numeric ID from basename, removes extension"""
    s = path_obj.stem              # e.g. '10028' or '10028_fake'
    nums = re.findall(r'\d+', s)
    return int(nums[-1]) if nums else -1

# ------- Load and sort testA and fakeB numerically --------
testA_files = sorted(
    [p for p in FASTCUT_TESTA.iterdir() if p.is_file()],
    key=extract_id
)

fakeB_files = sorted(
    [p for p in SYN_SRC_FAKEB.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]],
    key=extract_id
)

print("testA count:", len(testA_files))
print("fakeB count:", len(fakeB_files))

if len(testA_files) != len(fakeB_files):
    print("‚ö† WARNING: testA and fakeB counts differ. Using min() to match.")
n = min(len(testA_files), len(fakeB_files))
print("Mapping count:", n)

# ------- Load original remap list --------
with open(TRAIN_REMAP, "r", encoding="utf-8") as f:
    orig_train_lines = [ln.strip() for ln in f if ln.strip()]

# Build lookup from train_remap filename to label
testA_name_to_label = {}
for ln in orig_train_lines:
    parts = ln.split()
    if len(parts) < 2:
        continue
    rel = parts[0]
    cls = int(parts[1])
    fname = Path(rel).name
    if fname not in testA_name_to_label:
        testA_name_to_label[fname] = cls

print("Indexed train_remap entries:", len(testA_name_to_label))

# ------- NEW list for training with synthetic --------
new_lines = orig_train_lines.copy()
added = 0
missing_match = 0
examples = []

for i in range(n):
    testA_path = testA_files[i]
    fakeB_path = fakeB_files[i]

    # match by numeric ID
    id_A = extract_id(testA_path)
    id_B = extract_id(fakeB_path)

    if id_A != id_B:
        print(f"‚ö† ID mismatch: {testA_path.name} != {fakeB_path.name}, skipping.")
        missing_match += 1
        continue

    # lookup label from original train remap
    testA_name = testA_path.name
    if testA_name not in testA_name_to_label:
        missing_match += 1
        print("‚ö† No label found for", testA_name)
        continue

    cls = testA_name_to_label[testA_name]

    # Add only if class is unpaired
    if cls not in unpaired_new_ids:
        continue

    # Copy fake image to synthetic folder
    dest = SYN_DEST_DIR / fakeB_path.name
    if not dest.exists():
        shutil.copy(fakeB_path, dest)

    # Append synthetic entry
    syn_rel = f"synthetic/{fakeB_path.name}"
    new_lines.append(f"{syn_rel} {cls}")
    added += 1

    if len(examples) < 5:
        examples.append((testA_path.name, fakeB_path.name, cls))

print("‚ú® Synthetic integration complete!")
print("Added synthetic images:", added)
print("Missing matches:", missing_match)
print("Examples:", examples)

# ------- Save new training list --------
with open(OUTPUT_LIST, "w", encoding="utf-8") as f:
    for ln in new_lines:
        f.write(ln + "\n")

print("üìÅ Saved new train-with-synth file:", OUTPUT_LIST)
print("Total training lines:", len(new_lines))

In [None]:
# === Block X: Build validation list using test.txt + groundtruth.txt ===
from pathlib import Path

TEST_LIST = Path(DATA_ROOT) / "list/test.txt"
GROUNDTRUTH = Path(DATA_ROOT) / "list/groundtruth.txt"
VAL_OUT = REMAPPED_DIR / "test_remapped_fixed.txt"

# sanity checks
if not TEST_LIST.exists():
    raise FileNotFoundError("Missing test.txt at: " + str(TEST_LIST))
if not GROUNDTRUTH.exists():
    raise FileNotFoundError("Missing groundtruth.txt at: " + str(GROUNDTRUTH))

# Step 1: Load test list (paths only)
test_paths = []
with open(TEST_LIST, "r", encoding="utf-8") as f:
    for ln in f:
        ln = ln.strip()
        if ln:
            test_paths.append(ln)

print("Found test images:", len(test_paths))

# Step 2: Load groundtruth mapping: path ‚Üí original class ID
gt_map = {}
with open(GROUNDTRUTH, "r", encoding="utf-8") as f:
    for ln in f:
        ln = ln.strip()
        if not ln: continue
        parts = ln.split()
        if len(parts) != 2: continue
        rel, orig_lab = parts[0], int(parts[1])
        gt_map[rel] = orig_lab

print("Groundtruth entries:", len(gt_map))

# Step 3: Build final validation list with new indices
missing = 0
written = 0
with open(VAL_OUT, "w", encoding="utf-8") as f_out:
    for rel in test_paths:
        if rel not in gt_map:
            missing += 1
            continue

        orig_lab = gt_map[rel]

        # map using orig_to_new
        if orig_lab not in orig_to_new:
            missing += 1
            continue

        new_lab = orig_to_new[orig_lab]
        f_out.write(f"{rel} {new_lab}\n")
        written += 1

print("\n=== Validation file built ===")
print("Written lines:", written)
print("Missing gt or unmapped:", missing)
print("Saved to:", VAL_OUT)

# Step 4: Sanity check ‚Äî ensure all labels are 0‚Äì99
bad = []
with open(VAL_OUT, "r") as f:
    for i, ln in enumerate(f):
        parts = ln.strip().split()
        if len(parts) != 2: continue
        lab = int(parts[1])
        if lab < 0 or lab >= 100:
            bad.append((i, lab, ln.strip()))

print("Bad labels remaining:", len(bad))
print(bad[:10])


In [None]:
# Block 10 - Dataset class and dataloaders
from PIL import Image
import torchvision.transforms as transforms
import torch

class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, root, list_path, transform=None, expect_domain=False):
        self.root = Path(root)
        self.transform = transform
        self.expect_domain = expect_domain
        self.samples = []
        with open(list_path, "r", encoding="utf-8") as f:
            for ln in f:
                ln = ln.strip()
                if not ln: continue
                parts = ln.split()
                rel = parts[0]
                cls = int(parts[1]) if len(parts) > 1 else -1
                dom = int(parts[2]) if len(parts) > 2 else 0
                self.samples.append((self.root / rel, cls, dom))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, cls, dom = self.samples[idx]
        try:
            img = Image.open(path).convert("RGB")
        except Exception as e:
            # return None to be cleaned by collate
            # also print once for debugging
            print(f"[WARN] failed to open {path}: {e}")
            return None
        if self.transform:
            img = self.transform(img)
        return img, cls, dom

def clean_collate(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    return torch.utils.data.default_collate(batch)

# transforms (match backbone input; adjust sizes as needed)
train_tf = transforms.Compose([
    transforms.Resize((518,518)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), 
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((518,518)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# Build datasets using the new remapped list
train_ds = SimpleDataset(DATA_ROOT, str(OUTPUT_LIST), transform=train_tf, expect_domain=False)
val_ds = SimpleDataset(DATA_ROOT, str(REMAPPED_DIR / "test_remapped_fixed.txt"), transform=val_tf, expect_domain=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=clean_collate)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=clean_collate)

print("Train size:", len(train_ds), "Val size:", len(val_ds))


In [None]:
# Block 11 - Load Dinov2 backbone checkpoint and build a wrapper
import timm
import torch.nn as nn

BACKBONE_CHECKPOINT = f"{PROJECT_ROOT}/model_best.pth.tar"  # edit if different
if not Path(BACKBONE_CHECKPOINT).exists():
    print("Warning: backbone checkpoint not found at", BACKBONE_CHECKPOINT)
print("Attempting to create backbone model...")

backbone = timm.create_model("vit_base_patch14_reg4_dinov2.lvd142m", pretrained=False, num_classes=0)
feat_dim = getattr(backbone, "num_features", None) or getattr(backbone, "embed_dim", 768)
print("Feature dim:", feat_dim)

class DINOWrapper(nn.Module):
    def __init__(self, backbone, feat_dim, num_classes):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Linear(feat_dim, num_classes)
    def forward(self, x, return_feat=False):
        feats = self.backbone(x)  # (B, feat_dim)
        if return_feat:
            return feats
        logits = self.head(feats)
        return logits

NUM_CLASSES = 100
model = DINOWrapper(backbone, feat_dim, NUM_CLASSES)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(DEVICE)
print("Model created. Device:", DEVICE)


In [None]:
model_cpu = model.to("cpu")
x = torch.randn(1, 3, 518, 518)
out = model_cpu(x)
print(out.shape)

In [None]:
# Block 12 - permissive checkpoint load (to backbone)
import torch
import argparse
from pathlib import Path

ckpt_path = Path(BACKBONE_CHECKPOINT)
if ckpt_path.exists():
    # allow argparse.Namespace types if checkpoint contains them
    torch.serialization.add_safe_globals([argparse.Namespace])
    ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
    # find state dict
    if "state_dict" in ckpt:
        state = ckpt["state_dict"]
    else:
        state = ckpt
    # clean keys
    new_state = {}
    for k,v in state.items():
        nk = k.replace("module.", "")
        # drop leading "backbone." if present
        if nk.startswith("backbone."):
            nk2 = nk.replace("backbone.","")
            new_state[nk2] = v
        else:
            new_state[nk] = v
    # try load into backbone
    missing, unexpected = model.backbone.load_state_dict(new_state, strict=False)
    print("Backbone load missing:", len(missing), "unexpected:", len(unexpected))
else:
    print("Checkpoint not found, continuing with random init for backbone.")


In [None]:
# ===== Block X : Build paired_set and unpaired_set =====
from pathlib import Path

print("=== Building paired and unpaired sets ===")

# Path to unpaired species (original IDs, as provided)
UNPAIRED_PATH = PROJECT_ROOT / "AML_project_herbarium_dataset/list/class_without_pairs.txt"

# Load original unpaired IDs
unpaired_orig_ids = []
with open(UNPAIRED_PATH, "r") as f:
    for ln in f:
        ln = ln.strip()
        if ln.isdigit():
            unpaired_orig_ids.append(int(ln))

print("Unpaired ORIGINAL species count:", len(unpaired_orig_ids))

# Convert original IDs ‚Üí new 0..99 mapped IDs
unpaired_set = set()
for oid in unpaired_orig_ids:
    if oid in orig_to_new:
        unpaired_set.add(orig_to_new[oid])
    else:
        print("WARNING: original ID not in mapping:", oid)

print("Unpaired mapped count:", len(unpaired_set))

# Paired set = all classes except unpaired ones
all_classes = set(range(100))
paired_set = all_classes - unpaired_set

print("Paired count:", len(paired_set))

print("Sample unpaired mapped:", sorted(list(unpaired_set))[:10])
print("Sample paired mapped:", sorted(list(paired_set))[:10])


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"]="1"


In [None]:
# ===== Block 13 : Head-only + Fine-tune training + dual checkpoint saving =====
import torch, torch.optim as optim, torch.nn as nn
from tqdm import tqdm
# %pip install matplotlib # <--- Commented out to save time if already installed
import matplotlib.pyplot as plt
from pathlib import Path

# --- Config (tune these) ---
EPOCHS_HEAD = 10       # head only
EPOCHS_FINE = 10       # fine-tune
LR_HEAD = 1e-4
LR_FINE = 1e-5
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 4
SAVE_DIR = PROJECT_ROOT / "models"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

best_overall_path = SAVE_DIR / "best_overall.pth"
best_unpaired_path = SAVE_DIR / "best_unpaired.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ==============================================================================
# üî¥ FIX: Move model to GPU explicitly before creating optimizer
# This fixes the "Input type (cuda) and weight type (cpu) should be the same" error
model = model.to(device)
# ==============================================================================

criterion = nn.CrossEntropyLoss()

# -- helpers: compute paired/unpaired accuracy on a loader
@torch.no_grad()
def eval_paired_unpaired(model, loader):
    model.eval()
    paired_correct = paired_total = 0
    unpaired_correct = unpaired_total = 0
    for batch in loader:
        # allow val loader that returns (imgs, labels) or (imgs, labels, domain)
        if isinstance(batch, (list, tuple)) and len(batch) >= 2:
            imgs, labels = batch[0], batch[1]
        else:
            raise RuntimeError("Val loader must return at least (imgs, labels)")
        imgs = imgs.to(device); labels = labels.to(device)
        logits = model(imgs)
        preds = logits.argmax(dim=1)
        for lab, pred in zip(labels, preds):
            lab_i = int(lab.item())
            if lab_i in paired_set:
                paired_total += 1
                if int(pred.item()) == lab_i:
                    paired_correct += 1
            else:
                unpaired_total += 1
                if int(pred.item()) == lab_i:
                    unpaired_correct += 1

    paired_acc = 100.0 * paired_correct / max(1, paired_total)
    unpaired_acc = 100.0 * unpaired_correct / max(1, unpaired_total)
    return paired_acc, unpaired_acc

# -- one epoch train (standard)
def train_one_epoch(model, loader, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc="train", leave=False)
    for batch in pbar:
        if isinstance(batch, (list,tuple)) and len(batch) >= 2:
            imgs, labels = batch[0], batch[1]
        else:
            raise RuntimeError("Train loader must return (imgs, labels[,domain])")
        imgs = imgs.to(device); labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        bs = imgs.size(0)
        running_loss += loss.item() * bs
        correct += (logits.argmax(1) == labels).sum().item()
        total += bs
        pbar.set_postfix({"loss": running_loss / total})
    return running_loss / max(1,total), 100.0 * correct / max(1,total)

# -- validation that returns loss + overall Top1% (for monitoring)
@torch.no_grad()
def validate_overall(model, loader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch in loader:
        if isinstance(batch, (list,tuple)) and len(batch) >= 2:
            imgs, labels = batch[0], batch[1]
        else:
            raise RuntimeError("Val loader must return (imgs, labels[,domain])")
        imgs = imgs.to(device); labels = labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        bs = imgs.size(0)
        running_loss += loss.item() * bs
        correct += (logits.argmax(1) == labels).sum().item()
        total += bs
    return running_loss / max(1,total), 100.0 * correct / max(1,total)

# --- Storage for curves
train_losses = []; train_accs = []
val_losses = []; val_accs = []
paired_curve = []; unpaired_curve = []

# --- Freeze backbone (head-only) ---
for p in model.backbone.parameters():
    p.requires_grad = False

optim_head = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                         lr=LR_HEAD, weight_decay=WEIGHT_DECAY)

best_overall = 0.0
best_unpaired = 0.0

print("==== HEAD-ONLY TRAINING ====")
for ep in range(1, EPOCHS_HEAD+1):
    t_loss, t_acc = train_one_epoch(model, train_loader, optim_head)
    v_loss, v_acc = validate_overall(model, val_loader)
    p_acc, u_acc = eval_paired_unpaired(model, val_loader)

    train_losses.append(t_loss); train_accs.append(t_acc)
    val_losses.append(v_loss); val_accs.append(v_acc)
    paired_curve.append(p_acc); unpaired_curve.append(u_acc)

    print(f"[HEAD] Epoch {ep}/{EPOCHS_HEAD} | Train loss {t_loss:.4f} Train Top1 {t_acc:.2f}%")
    print(f"       Val loss {v_loss:.4f} Val Top1 {v_acc:.2f}% | Paired {p_acc:.2f}% Unpaired {u_acc:.2f}%")

    # save best overall
    if v_acc > best_overall:
        best_overall = v_acc
        torch.save({
            "epoch": ep,
            "model_state_dict": model.state_dict(),
            "val_top1": v_acc,
            "unpaired_top1": u_acc
        }, best_overall_path)
        print("üíæ Saved best_overall ->", best_overall_path)

    # save best unpaired
    if u_acc > best_unpaired:
        best_unpaired = u_acc
        torch.save({
            "epoch": ep,
            "model_state_dict": model.state_dict(),
            "val_top1": v_acc,
            "unpaired_top1": u_acc
        }, best_unpaired_path)
        print("üíæ Saved best_unpaired ->", best_unpaired_path)

# --- Unfreeze backbone and fine-tune all weights ---
for p in model.backbone.parameters():
    p.requires_grad = True

optim_fine = optim.AdamW(model.parameters(), lr=LR_FINE, weight_decay=WEIGHT_DECAY)

print("\n==== FINE-TUNE FULL MODEL ====")
for ep in range(1, EPOCHS_FINE+1):
    t_loss, t_acc = train_one_epoch(model, train_loader, optim_fine)
    v_loss, v_acc = validate_overall(model, val_loader)
    p_acc, u_acc = eval_paired_unpaired(model, val_loader)

    train_losses.append(t_loss); train_accs.append(t_acc)
    val_losses.append(v_loss); val_accs.append(v_acc)
    paired_curve.append(p_acc); unpaired_curve.append(u_acc)

    epoch_index = EPOCHS_HEAD + ep
    print(f"[FINE] Epoch {epoch_index}/{EPOCHS_HEAD+EPOCHS_FINE} | Train loss {t_loss:.4f} Train Top1 {t_acc:.2f}%")
    print(f"       Val loss {v_loss:.4f} Val Top1 {v_acc:.2f}% | Paired {p_acc:.2f}% Unpaired {u_acc:.2f}%")

    if v_acc > best_overall:
        best_overall = v_acc
        torch.save({
            "epoch": epoch_index,
            "model_state_dict": model.state_dict(),
            "val_top1": v_acc,
            "unpaired_top1": u_acc
        }, best_overall_path)
        print("üíæ Saved best_overall ->", best_overall_path)

    if u_acc > best_unpaired:
        best_unpaired = u_acc
        torch.save({
            "epoch": epoch_index,
            "model_state_dict": model.state_dict(),
            "val_top1": v_acc,
            "unpaired_top1": u_acc
        }, best_unpaired_path)
        print("üíæ Saved best_unpaired ->", best_unpaired_path)

# --- Final plots ---
plt.figure(figsize=(14,5))
plt.subplot(1,3,1)
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.title("Loss"); plt.legend(); plt.grid()

plt.subplot(1,3,2)
plt.plot(train_accs, label="Train Top1")
plt.plot(val_accs, label="Val Top1")
plt.title("Top-1 Acc"); plt.legend(); plt.grid()

plt.subplot(1,3,3)
plt.plot(paired_curve, label="Paired Top1")
plt.plot(unpaired_curve, label="Unpaired Top1")
plt.title("Paired / Unpaired"); plt.legend(); plt.grid()

plt.tight_layout()
plt.show()

print("Done. Best overall Top1:", best_overall, " Best unpaired Top1:", best_unpaired)

In [None]:
# ===== Block 13 : RESUME MODE (Skip Head-Only if checkpoint exists) =====
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path

# --- Config ---
EPOCHS_HEAD = 10       
EPOCHS_FINE = 10       
LR_HEAD = 1e-4
LR_FINE = 5e-5  # Increased slightly to help DINOv2 adapt faster
WEIGHT_DECAY = 1e-4
SAVE_DIR = PROJECT_ROOT / "models"
BATCH_SIZE = 4
SAVE_DIR.mkdir(parents=True, exist_ok=True)

best_overall_path = SAVE_DIR / "best_overall.pth"
best_unpaired_path = SAVE_DIR / "best_unpaired.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# üî¥ CRITICAL FIX: Move model to GPU explicitly before optimizing
model = model.to(device)

# ============================================================
# üü¢ RESUME LOGIC
# ============================================================
run_head_training = True
start_epoch_fine = 1
best_overall = 0.0
best_unpaired = 0.0

# Arrays to store metrics (will be empty if resuming, but that's okay)
train_losses, train_accs = [], []
val_losses, val_accs = [], []
paired_curve, unpaired_curve = [], []

if best_overall_path.exists():
    print(f"Found saved checkpoint: {best_overall_path}")
    print("Loading weights...")
    checkpoint = torch.load(best_overall_path, map_location=device)
    
    # Load the state
    model.load_state_dict(checkpoint['model_state_dict'])
    best_overall = checkpoint.get('val_top1', 0.0)
    best_unpaired = checkpoint.get('unpaired_top1', 0.0)
    
    print(f"‚úÖ Weights loaded! Best Overall: {best_overall:.2f}%, Best Unpaired: {best_unpaired:.2f}%")
    print("Skipping Head-Only training phase.")
    run_head_training = False
else:
    print("No checkpoint found. Starting from scratch.")

criterion = nn.CrossEntropyLoss()

# --- Helper Functions (Required inside the block for standalone running) ---

@torch.no_grad()
def eval_paired_unpaired(model, loader):
    model.eval()
    paired_correct = paired_total = 0
    unpaired_correct = unpaired_total = 0
    for batch in loader:
        if batch is None: continue
        imgs, labels, dom = batch
        imgs = imgs.to(device); labels = labels.to(device)
        logits = model(imgs)
        preds = logits.argmax(dim=1)
        for lab, pred in zip(labels, preds):
            lab_i = int(lab.item())
            if lab_i in unpaired_set:
                unpaired_total += 1
                if int(pred.item()) == lab_i:
                    unpaired_correct += 1
            else:
                paired_total += 1
                if int(pred.item()) == lab_i:
                    paired_correct += 1
    paired_acc = 100.0 * paired_correct / max(1, paired_total)
    unpaired_acc = 100.0 * unpaired_correct / max(1, unpaired_total)
    return paired_acc, unpaired_acc

def train_one_epoch(model, loader, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc="train", leave=False)
    for batch in pbar:
        if batch is None: continue
        imgs, labels, dom = batch
        imgs = imgs.to(device); labels = labels.to(device)
        
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        bs = imgs.size(0)
        running_loss += loss.item() * bs
        correct += (logits.argmax(1) == labels).sum().item()
        total += bs
        pbar.set_postfix({"loss": running_loss / max(1, total)})
        
    return running_loss / max(1,total), 100.0 * correct / max(1,total)

@torch.no_grad()
def validate_overall(model, loader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch in loader:
        if batch is None: continue
        imgs, labels, dom = batch
        imgs = imgs.to(device); labels = labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        bs = imgs.size(0)
        running_loss += loss.item() * bs
        correct += (logits.argmax(1) == labels).sum().item()
        total += bs
    return running_loss / max(1,total), 100.0 * correct / max(1,total)

# =========================================
# PHASE 1: HEAD ONLY (Skipped if Resuming)
# =========================================
if run_head_training:
    # Freeze backbone
    for p in model.backbone.parameters():
        p.requires_grad = False

    optim_head = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                             lr=LR_HEAD, weight_decay=WEIGHT_DECAY)

    print("\n==== HEAD-ONLY TRAINING ====")
    for ep in range(1, EPOCHS_HEAD+1):
        t_loss, t_acc = train_one_epoch(model, train_loader, optim_head)
        v_loss, v_acc = validate_overall(model, val_loader)
        p_acc, u_acc = eval_paired_unpaired(model, val_loader)

        train_losses.append(t_loss); train_accs.append(t_acc)
        val_losses.append(v_loss); val_accs.append(v_acc)
        paired_curve.append(p_acc); unpaired_curve.append(u_acc)

        print(f"[HEAD] Epoch {ep}/{EPOCHS_HEAD} | Val Top1 {v_acc:.2f}% | Paired {p_acc:.2f}% Unpaired {u_acc:.2f}%")

        if v_acc > best_overall:
            best_overall = v_acc
            torch.save({
                "epoch": ep,
                "model_state_dict": model.state_dict(),
                "val_top1": v_acc,
                "unpaired_top1": u_acc
            }, best_overall_path)
            print("üíæ Saved best_overall")

# =========================================
# PHASE 2: FINE-TUNING (Always Runs)
# =========================================
# Unfreeze backbone
for p in model.backbone.parameters():
    p.requires_grad = True

optim_fine = optim.AdamW(model.parameters(), lr=LR_FINE, weight_decay=WEIGHT_DECAY)

print("\n==== FINE-TUNE FULL MODEL ====")
for ep in range(1, EPOCHS_FINE+1):
    t_loss, t_acc = train_one_epoch(model, train_loader, optim_fine)
    v_loss, v_acc = validate_overall(model, val_loader)
    p_acc, u_acc = eval_paired_unpaired(model, val_loader)

    train_losses.append(t_loss); train_accs.append(t_acc)
    val_losses.append(v_loss); val_accs.append(v_acc)
    paired_curve.append(p_acc); unpaired_curve.append(u_acc)

    epoch_index = EPOCHS_HEAD + ep
    print(f"[FINE] Epoch {epoch_index}/{EPOCHS_HEAD+EPOCHS_FINE} | Loss {t_loss:.3f} | Val Acc {v_acc:.2f}% | Unpaired {u_acc:.2f}%")

    if v_acc > best_overall:
        best_overall = v_acc
        torch.save({
            "epoch": epoch_index,
            "model_state_dict": model.state_dict(),
            "val_top1": v_acc,
            "unpaired_top1": u_acc
        }, best_overall_path)
        print("üíæ Saved best_overall")
        
    if u_acc > best_unpaired:
        best_unpaired = u_acc
        torch.save({
            "epoch": epoch_index,
            "model_state_dict": model.state_dict(),
            "val_top1": v_acc,
            "unpaired_top1": u_acc
        }, best_unpaired_path)
        print("üíæ Saved best_unpaired")

# --- Final plots ---
if len(train_losses) > 0:
    plt.figure(figsize=(14,5))
    plt.subplot(1,3,1)
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.title("Loss"); plt.legend(); plt.grid()

    plt.subplot(1,3,2)
    plt.plot(train_accs, label="Train Top1")
    plt.plot(val_accs, label="Val Top1")
    plt.title("Top-1 Acc"); plt.legend(); plt.grid()

    plt.subplot(1,3,3)
    plt.plot(paired_curve, label="Paired Top1")
    plt.plot(unpaired_curve, label="Unpaired Top1")
    plt.title("Paired / Unpaired"); plt.legend(); plt.grid()

    plt.tight_layout()
    plt.show()

print("Done. Best overall:", best_overall, " Best unpaired:", best_unpaired)

In [None]:
# ===== Block 14: FINAL EVALUATION=====
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path


MODEL_PATH = SAVE_DIR / "best_unpaired.pth"

print(f"‚öñÔ∏è  Loading Model:     {MODEL_PATH}")

if MODEL_PATH.exists():
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    # Load the state dictionary
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    print("‚úÖ Model weights loaded successfully.")
else:
    raise FileNotFoundError(f"Could not find model at {MODEL_PATH}")

model.eval()

total = 0
top1 = 0
top5 = 0

paired_correct_1 = paired_correct_5 = 0
paired_total = 0

unpaired_correct_1 = unpaired_correct_5 = 0
unpaired_total = 0

print("Running evaluation...")
with torch.no_grad():
    for batch in tqdm(val_loader, desc="Eval"):
        if batch is None:
            continue

        # val_loader returns imgs, labels, domain
        imgs, labels, dom = batch
        imgs = imgs.to(device); labels = labels.to(device)

        logits = model(imgs)              # (B, 100)
        preds_top1 = logits.argmax(dim=1) # (B)

        # top-5 indices
        _, preds_top5 = logits.topk(5, dim=1)  # (B, 5)

        bs = labels.size(0)
        total += bs

        # ----- OVERALL TOP-1 -----
        top1 += (preds_top1 == labels).sum().item()

        # ----- OVERALL TOP-5 -----
        for i in range(bs):
            if labels[i] in preds_top5[i]:
                top5 += 1

        # ----- PAIRED / UNPAIRED TOP-1 / TOP-5 -----
        for i in range(bs):
            lab = int(labels[i].item())
            pred1 = preds_top1[i]
            pred5 = preds_top5[i]

            if lab in unpaired_new_ids:
                unpaired_total += 1

                # Top-1
                if pred1 == labels[i]:
                    unpaired_correct_1 += 1

                # Top-5
                if labels[i] in pred5:
                    unpaired_correct_5 += 1

            else:
                paired_total += 1

                # Top-1
                if pred1 == labels[i]:
                    paired_correct_1 += 1

                # Top-5
                if labels[i] in pred5:
                    paired_correct_5 += 1

# --- PRINT RESULTS ---
print("\n===== FINAL EVALUATION =====")
print(f"Overall Top-1 : {100.0 * top1 / max(1, total):.2f}%")
print(f"Overall Top-5 : {100.0 * top5 / max(1, total):.2f}%")
print("-" * 30)
print(f"Paired Top-1  : {100.0 * paired_correct_1 / max(1, paired_total):.2f}%")
print(f"Paired Top-5  : {100.0 * paired_correct_5 / max(1, paired_total):.2f}%")
print("-" * 30)
print(f"Unpaired Top-1: {100.0 * unpaired_correct_1 / max(1, unpaired_total):.2f}%")
print(f"Unpaired Top-5: {100.0 * unpaired_correct_5 / max(1, unpaired_total):.2f}%")
print("=" * 30)

‚öñÔ∏è  Loading Model:     c:\Users\William\School\Swinburne\Computer Science\2025 Semester 2\COS30082 Applied Machine Learning\Assignment 2\Esther\models\best_unpaired.pth
‚úÖ Model weights loaded successfully.
Running evaluation...


Eval: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 52/52 [00:11<00:00,  4.46it/s]


===== FINAL EVALUATION =====
Overall Top-1 : 72.95%
Overall Top-5 : 84.54%
------------------------------
Paired Top-1  : 86.27%
Paired Top-5  : 98.04%
------------------------------
Unpaired Top-1: 35.19%
Unpaired Top-5: 46.30%





In [33]:
# Block 15 - save final weights & metadata
out_path = Path(PROJECT_ROOT) / "models" / "final_finetuned_with_synth.pth.tar"
torch.save({
    "model_state_dict": model.state_dict(),
    "orig_to_new": orig_to_new,
    "new_to_orig": new_to_orig,
    "unpaired_orig_ids": unpaired_orig_ids,
    "unpaired_new_ids": sorted(list(unpaired_new_ids)),
}, out_path)
print("Saved final checkpoint:", out_path)


Saved final checkpoint: c:\Users\William\School\Swinburne\Computer Science\2025 Semester 2\COS30082 Applied Machine Learning\Assignment 2\Esther\models\final_finetuned_with_synth.pth.tar
