In [6]:
# multimodal_damage_improved.py
# HARD-CODED VERSION (Option C)
# -------------------------------------------------------------------------
# This version removes ALL argparse usage and works exactly like your original
# script: just run it in a notebook or Python file, no CLI arguments required.
# All config values and file paths are defined below. Nothing else.
# -------------------------------------------------------------------------

import os, re, random, json
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report

# ======================================================================
# ------------------------------ CONFIG --------------------------------
# ======================================================================
SEED = 42
MAX_LEN = 64
BATCH_SIZE = 16
EPOCHS = 25
LR = 2e-4
IMG_SIZE = 128
MIN_FREQ = 2
NUM_WORKERS = 4
MODEL_DIR = "./checkpoints_multimodal_improved"
os.makedirs(MODEL_DIR, exist_ok=True)

TRAIN_FILES = [
    "/Volumes/Extreme SSD/DL_Proj/CrisisMMD_v2.0/crisismmd_datasplit_all/task_damage_text_img_dev.tsv",
    "/Volumes/Extreme SSD/DL_Proj/CrisisMMD_v2.0/crisismmd_datasplit_all/task_damage_text_img_train.tsv",
]
TEST_FILES = [
    "/Volumes/Extreme SSD/DL_Proj/CrisisMMD_v2.0/crisismmd_datasplit_all/task_damage_text_img_test.tsv",
]

DEVICE = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

# ======================================================================
# --------------------------- REPRO SETUP -------------------------------
# ======================================================================

def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

# ======================================================================
# ----------------------------- LOADING ---------------------------------
# ======================================================================

def clean_text(text):
    text = str(text)
    text = re.sub(r"http\\S+", "", text)
    text = re.sub(r"@\\w+", "", text)
    text = re.sub(r"#\\w+", "", text)
    text = re.sub(r"[^A-Za-z0-9\\s.,!?']", " ", text)
    text = re.sub(r"\\s+", " ", text).strip().lower()
    return text

def load_tsv(path):
    df = pd.read_csv(path, sep="\t")
    df["tweet_text"] = df["tweet_text"].astype(str).apply(clean_text)
    return df

train_df = pd.concat([load_tsv(p) for p in TRAIN_FILES], ignore_index=True)
test_df = pd.concat([load_tsv(p) for p in TEST_FILES], ignore_index=True)

label_le = LabelEncoder()
train_df["label_id"] = label_le.fit_transform(train_df["label"])
test_df["label_id"] = label_le.transform(test_df["label"])

n_classes = len(label_le.classes_)
print("Classes:", label_le.classes_.tolist())

# ======================================================================
# ------------------------------- TOKENIZER -----------------------------
# ======================================================================

def basic_tokenizer(text):
    return re.findall(r"\\b[\\w']+\\b", text.lower())

def build_vocab(texts, min_freq=2):
    counter = Counter()
    for t in texts:
        counter.update(basic_tokenizer(t))
    vocab = {"<unk>":0, "<pad>":1, "<cls>":2}
    for w,f in counter.items():
        if f >= min_freq:
            vocab[w] = len(vocab)
    return vocab

vocab = build_vocab(train_df["tweet_text"], MIN_FREQ)
vocab_size = len(vocab)
print("Vocab size:", vocab_size)

with open(os.path.join(MODEL_DIR, "vocab.json"), "w") as f:
    json.dump(vocab, f, indent=2)


def encode_text(text):
    tokens = basic_tokenizer(text)[:MAX_LEN-1]
    ids = [vocab["<cls>"]] + [vocab.get(t, vocab["<unk>"]) for t in tokens]
    ids += [vocab["<pad>"]] * (MAX_LEN - len(ids))
    return torch.tensor(ids, dtype=torch.long)

# ======================================================================
# ------------------------------- DATASET -------------------------------
# ======================================================================

img_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

class CrisisDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = encode_text(row.tweet_text)
        try:
            img = Image.open(row.image).convert("RGB")
        except:
            img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (0,0,0))
        img = img_transform(img)
        label = torch.tensor(row.label_id, dtype=torch.long)
        return {"text":text, "image":img, "label":label}


def collate_fn(batch):
    return {
        "text": torch.stack([b["text"] for b in batch]),
        "image": torch.stack([b["image"] for b in batch]),
        "label": torch.stack([b["label"] for b in batch])
    }

train_dataset = CrisisDataset(train_df)
test_dataset = CrisisDataset(test_df)

counts = train_df["label_id"].value_counts().sort_index()
weights = 1.0 / torch.tensor(counts.values, dtype=torch.float)
sample_weights = [weights[l] for l in train_df["label_id"]]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler,
                          num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)

# ======================================================================
# -------------------------------- MODEL -------------------------------
# ======================================================================

class TransformerBlock(nn.Module):
    def __init__(self, d, heads, ff):
        super().__init__()
        self.attn = nn.MultiheadAttention(d, heads, dropout=0.1)
        self.ff = nn.Sequential(
            nn.Linear(d, ff), nn.ReLU(), nn.Dropout(0.1), nn.Linear(ff, d)
        )
        self.n1 = nn.LayerNorm(d)
        self.n2 = nn.LayerNorm(d)
    def forward(self, x):
        a,_ = self.attn(x, x, x)
        x = self.n1(x + a)
        f = self.ff(x)
        x = self.n2(x + f)
        return x

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, d=128, heads=4, layers=3, ff=256):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d, padding_idx=1)
        self.pos = nn.Parameter(torch.zeros(1, MAX_LEN, d))
        self.blocks = nn.ModuleList([TransformerBlock(d, heads, ff) for _ in range(layers)])
    def forward(self, x):
        b, s = x.size()
        h = self.emb(x) + self.pos[:, :s]
        h = h.transpose(0,1)
        for blk in self.blocks:
            h = blk(h)
        return h[0]

class ImageEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.AdaptiveAvgPool2d(1)
        )
    def forward(self, x):
        x = self.cnn(x)
        return x.view(x.size(0), -1)

class DamageNetMM(nn.Module):
    def __init__(self, vocab_size, n_classes):
        super().__init__()
        self.text = TextEncoder(vocab_size)
        self.img = ImageEncoder()
        self.t_proj = nn.Sequential(nn.Linear(128,128), nn.Re

usage: ipykernel_launcher.py [-h] --train_tsvs TRAIN_TSVS [TRAIN_TSVS ...]
                             --test_tsvs TEST_TSVS [TEST_TSVS ...]
                             [--model_dir MODEL_DIR] [--seed SEED]
                             [--max_len MAX_LEN] [--batch_size BATCH_SIZE]
                             [--epochs EPOCHS] [--lr LR] [--img_size IMG_SIZE]
                             [--min_freq MIN_FREQ] [--num_workers NUM_WORKERS]
                             [--save_every SAVE_EVERY]
ipykernel_launcher.py: error: the following arguments are required: --train_tsvs, --test_tsvs


SystemExit: 2

In [13]:
# damagenet_text_final_improved_minority.py
"""
Final patched script focused on improving minority-class performance.
Key changes:
 - Hard-example retries only for MINORITY misclassified samples (1 pass)
 - Minority token dropout increased to 0.20
 - Soft targets applied only to minority samples
 - Stronger logit adjustment scale
 - On-the-fly cheap augmentations for minority samples (random deletion / swap)
 - BT augmentation still available, augmented rows are tagged and excluded from hard mining
 - Safe BalancedBatchSampler (terminates)
"""

import os, re, json, random, time
from collections import Counter
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report

# -------------------- CONFIG --------------------
SEED = 42
MAX_LEN = 128
BATCH_SIZE = 30            # try to keep divisible by n_classes=3
EPOCHS = 20
LR = 2e-4

EMB_DIM = 256
N_HEADS = 8
N_LAYERS = 4
FF_DIM = 512
DROPOUT = 0.1

MODEL_DIR = "./damagenet_text_minority_ckpt"
os.makedirs(MODEL_DIR, exist_ok=True)

TRAIN_FILES = [
    "/Volumes/Extreme SSD/DL_Proj/CrisisMMD_v2.0/crisismmd_datasplit_all/task_damage_text_img_train.tsv",
    "/Volumes/Extreme SSD/DL_Proj/CrisisMMD_v2.0/crisismmd_datasplit_all/task_damage_text_img_dev.tsv",
]
TEST_FILES = [
    "/Volumes/Extreme SSD/DL_Proj/CrisisMMD_v2.0/crisismmd_datasplit_all/task_damage_text_img_test.tsv",
]

# Back-translation remains available
ENABLE_BACKTRANSLATION = True
BT_LANGS = ["fr", "de", "es", "it"]
BT_CYCLES = 2

# Minority-focused hyperparams (tuned for better minority recall)
MINORITY_TOKEN_DROPOUT = 0.20   # increased
USE_SOFT_LABELS = True
SOFT_CONF = 0.88                # slightly higher confidence for minority soft targets

# Hard-example mining
HARD_RETRY_PASSES = 1           # reduced to 1
HARD_SAMPLE_LIMIT = 2000

# Cheap augmentations for minority samples per epoch
MINORITY_AUG_USES_PER_EPOCH = 1  # how many augmented passes of each minority sample inserted per epoch
MINORITY_AUG_PROB = 0.6          # probability to apply augmentation in dataset __getitem__ for minority

# Data loader
NUM_WORKERS = 0
PIN_MEMORY = True

# Device
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
    DEVICE = torch.device("cpu")
print("Device:", DEVICE)

# -------------------- REPRO --------------------
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed()

# -------------------- optional translator --------------------
TRANSLATOR = None
TRANSLATOR_NAME = None
if ENABLE_BACKTRANSLATION:
    try:
        from googletrans import Translator as GoogleTranslator
        TRANSLATOR = GoogleTranslator()
        TRANSLATOR_NAME = "googletrans"
        print("BT: using googletrans")
    except Exception:
        try:
            from deep_translator import GoogleTranslator as DeepTranslator
            TRANSLATOR = DeepTranslator(source="auto", target="en")
            TRANSLATOR_NAME = "deep_translator"
            print("BT: using deep_translator")
        except Exception:
            TRANSLATOR = None
            TRANSLATOR_NAME = None
            ENABLE_BACKTRANSLATION = False
            print("BT: translator not available -> disabling back-translation")

# -------------------- TEXT UTIL --------------------
def clean_text(text):
    text = str(text)
    text = re.sub(r"http\S+", "", text)
    text = re.sub(r"@\w+", "", text)
    text = re.sub(r"#\w+", "", text)
    text = re.sub(r"[^A-Za-z0-9\s.,!?']", " ", text)
    text = re.sub(r"\s+", " ", text).strip().lower()
    return text

def basic_tokenizer(text):
    return re.findall(r"\b[\w']+\b", text.lower())

def build_vocab(texts, min_freq=2):
    counter = Counter()
    for t in texts:
        counter.update(basic_tokenizer(t))
    vocab = {"<pad>":0, "<unk>":1, "<cls>":2}
    for w,c in counter.items():
        if c >= min_freq:
            vocab[w] = len(vocab)
    return vocab

def encode_text_ids_from_tokens(tokens, vocab):
    ids = [vocab.get(t, vocab["<unk>"]) for t in tokens[:MAX_LEN-1]]
    ids = [vocab["<cls>"]] + ids
    if len(ids) < MAX_LEN:
        ids += [vocab["<pad>"]] * (MAX_LEN - len(ids))
    return ids[:MAX_LEN]

def encode_text(text, vocab):
    toks = basic_tokenizer(text)
    return torch.tensor(encode_text_ids_from_tokens(toks, vocab), dtype=torch.long)

# -------------------- LOAD DATA --------------------
def load_tsv(path):
    df = pd.read_csv(path, sep="\t")
    df["tweet_text"] = df["tweet_text"].astype(str).apply(clean_text)
    return df

train_df = pd.concat([load_tsv(p) for p in TRAIN_FILES], ignore_index=True)
test_df  = pd.concat([load_tsv(p) for p in TEST_FILES], ignore_index=True)

label_le = LabelEncoder()
train_df["label_id"] = label_le.fit_transform(train_df["label"])
test_df["label_id"]  = label_le.transform(test_df["label"])
n_classes = len(label_le.classes_)
print("Classes:", label_le.classes_.tolist())

# -------------------- targeted back-translation helpers --------------------
def should_backtranslate(text):
    toks = basic_tokenizer(text)
    if len(toks) <= 7:
        return True
    if len(set(toks)) <= max(2, len(toks)//2):
        return True
    for tok in ["collapsed","destroyed","flooded","burned","fire","injured","killed","trapped"]:
        if tok in text:
            return False
    return False

def back_translate_googletrans(text, lang_chain):
    try:
        cur = text
        for lang in lang_chain:
            cur = TRANSLATOR.translate(cur, dest=lang).text
        cur = TRANSLATOR.translate(cur, dest="en").text
        return clean_text(cur)
    except Exception:
        return text

def back_translate_deeptrans(text, lang_chain):
    try:
        from deep_translator import GoogleTranslator as DeepTranslator
        cur = text
        for lang in lang_chain:
            cur = DeepTranslator(source='auto', target=lang).translate(cur)
        cur = DeepTranslator(source='auto', target='en').translate(cur)
        return clean_text(cur)
    except Exception:
        return text

def back_translate(text, lang_chain):
    if TRANSLATOR is None:
        return text
    if TRANSLATOR_NAME == "googletrans":
        return back_translate_googletrans(text, lang_chain)
    elif TRANSLATOR_NAME == "deep_translator":
        return back_translate_deeptrans(text, lang_chain)
    return text

# -------------------- BT augmentation (minority-targeted) --------------------
class_counts = train_df["label_id"].value_counts().sort_index()
mean_count = class_counts.mean()
minority_classes = class_counts[class_counts < mean_count].index.tolist()
print("Minority class ids:", minority_classes)

# we'll tag rows that were BT-augmented so they can be excluded from hard mining
train_df["_aug_bt"] = False

if ENABLE_BACKTRANSLATION and TRANSLATOR is not None:
    aug_rows = []
    subset = train_df[train_df["label_id"].isin(minority_classes)].reset_index(drop=True)
    print("BT: selecting minority samples for augmentation...")
    for _, row in tqdm(subset.iterrows(), total=len(subset), desc="BT select"):
        txt = row["tweet_text"]
        if not should_backtranslate(txt):
            continue
        for c in range(BT_CYCLES):
            lang_chain = BT_LANGS[:]
            random.shuffle(lang_chain)
            aug_text = back_translate(txt, lang_chain)
            new_row = row.copy()
            new_row["tweet_text"] = aug_text
            new_row["_aug_bt"] = True
            aug_rows.append(new_row)
    if aug_rows:
        aug_df = pd.DataFrame(aug_rows)
        train_df = pd.concat([train_df, aug_df], ignore_index=True).sample(frac=1.0, random_state=SEED).reset_index(drop=True)
        print(f"BT: added {len(aug_df)} augmented minority samples.")

# -------------------- physical oversampling --------------------
def oversample_to_max(df):
    counts = df["label_id"].value_counts()
    max_c = counts.max()
    parts = []
    for cls, cnt in counts.items():
        subset = df[df["label_id"] == cls]
        times = int(np.ceil(max_c / max(1, cnt)))
        parts.extend([subset] * times)
    out = pd.concat(parts, ignore_index=True).sample(frac=1.0, random_state=SEED).reset_index(drop=True)
    print("Oversampled dataset size:", len(out))
    return out

train_df = oversample_to_max(train_df)

# -------------------- cheap on-the-fly augmentations (minority) --------------------
def random_deletion(tokens, p=0.2):
    if len(tokens) <= 3: 
        return tokens
    keep = [t for t in tokens if random.random() > p]
    if len(keep) == 0:
        return tokens[:max(1, len(tokens)//2)]
    return keep

def random_swap(tokens, n_swaps=1):
    toks = tokens[:]
    for _ in range(n_swaps):
        if len(toks) < 2: break
        i,j = random.sample(range(len(toks)), 2)
        toks[i], toks[j] = toks[j], toks[i]
    return toks

# -------------------- vocab & dataset --------------------
vocab = build_vocab(train_df["tweet_text"].tolist(), min_freq=2)
vocab_size = len(vocab)
print("Vocab size:", vocab_size)
with open(os.path.join(MODEL_DIR, "vocab.json"), "w") as f:
    json.dump(vocab, f, indent=2)

minority_set = set(minority_classes)

class BalancedTextDataset(Dataset):
    def __init__(self, df, vocab, minority_set=None, token_dropout=0.0, aug_prob=0.0):
        self.df = df.reset_index(drop=True)
        self.vocab = vocab
        self.minority_set = minority_set or set()
        self.token_dropout = token_dropout
        self.aug_prob = aug_prob
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = row["tweet_text"]
        label = int(row["label_id"])
        toks = basic_tokenizer(text)
        # Minority cheap augmentation on-the-fly (random deletion / swap)
        if label in self.minority_set and random.random() < self.aug_prob:
            if random.random() < 0.5:
                toks = random_deletion(toks, p=0.2)
            else:
                toks = random_swap(toks, n_swaps=1)
        ids = encode_text_ids_from_tokens(toks, self.vocab)
        ids = torch.tensor(ids, dtype=torch.long)
        # minority token dropout
        if self.token_dropout > 0 and label in self.minority_set:
            prob = self.token_dropout
            keep_mask = torch.rand(ids.size(0)) > prob
            keep_mask[0] = True  # keep CLS
            kept = ids[keep_mask].tolist()
            kept = kept[:MAX_LEN]
            if len(kept) < MAX_LEN:
                kept += [self.vocab["<pad>"]] * (MAX_LEN - len(kept))
            ids = torch.tensor(kept, dtype=torch.long)
        return {"input_ids": ids, "label": torch.tensor(label, dtype=torch.long), "_aug_bt": row.get("_aug_bt", False)}

train_dataset = BalancedTextDataset(train_df, vocab, minority_set=minority_set, token_dropout=MINORITY_TOKEN_DROPOUT, aug_prob=MINORITY_AUG_PROB)
test_dataset  = BalancedTextDataset(test_df, vocab, minority_set=minority_set, token_dropout=0.0, aug_prob=0.0)

# -------------------- SAFE BalancedBatchSampler --------------------
class BalancedBatchSampler(Sampler):
    def __init__(self, labels, batch_size):
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.classes = np.unique(self.labels)
        self.num_classes = len(self.classes)
        self.samples_per_class = max(1, batch_size // self.num_classes)
        self.idx_by_class = {c: np.where(self.labels == c)[0].tolist() for c in self.classes}
        for c in self.classes:
            if len(self.idx_by_class[c]) == 0:
                raise ValueError(f"No samples found for class {c}.")
        smallest = min(len(v) for v in self.idx_by_class.values())
        self.num_batches = max(1, smallest // self.samples_per_class)
    def __len__(self):
        return self.num_batches
    def __iter__(self):
        pools = {c: np.random.permutation(v).tolist() for c, v in self.idx_by_class.items()}
        ptrs = {c: 0 for c in self.classes}
        for _ in range(self.num_batches):
            batch = []
            for c in self.classes:
                start = ptrs[c]; end = start + self.samples_per_class
                if end > len(pools[c]):
                    pools[c] = np.random.permutation(pools[c]).tolist()
                    ptrs[c] = 0
                    start = 0; end = self.samples_per_class
                batch.extend(pools[c][start:end])
                ptrs[c] += self.samples_per_class
            if len(batch) < self.batch_size:
                all_idx = np.arange(len(self.labels))
                extra = self.batch_size - len(batch)
                batch.extend(np.random.choice(all_idx, extra, replace=False).tolist())
            random.shuffle(batch)
            yield batch

train_sampler = BalancedBatchSampler(train_df["label_id"].values, batch_size=BATCH_SIZE)
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# -------------------- model --------------------
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
        self.ff = nn.Sequential(nn.Linear(dim, ff_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim))
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
    def forward(self, x, mask=None):
        attn_out = self.attn(x, x, x, key_padding_mask=mask)[0]
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.ff(x))
        return x

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=EMB_DIM, n_heads=N_HEADS, n_layers=N_LAYERS, ff_dim=FF_DIM):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.pos = nn.Parameter(torch.zeros(1, MAX_LEN, emb_dim))
        self.layers = nn.ModuleList([TransformerBlock(emb_dim, n_heads, ff_dim, dropout=DROPOUT) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(emb_dim)
    def forward(self, ids):
        mask = ids.eq(0)
        x = self.emb(ids) + self.pos[:, :ids.size(1)]
        for l in self.layers:
            x = l(x, mask)
        return self.norm(x[:, 0])

class DamageNetTextFromScratch(nn.Module):
    def __init__(self, vocab_size, n_classes, emb_dim=EMB_DIM):
        super().__init__()
        self.encoder = TextEncoder(vocab_size)
        self.classifier = nn.Sequential(nn.Linear(emb_dim, emb_dim//2), nn.ReLU(), nn.Dropout(0.3), nn.Linear(emb_dim//2, n_classes))
        self._init_weights()
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
    def forward(self, ids, logit_adjust=None):
        z = self.encoder(ids)
        logits = self.classifier(z)
        if logit_adjust is not None:
            logits = logits + logit_adjust
        return logits

# -------------------- Class-balanced loss --------------------
class CBLoss(nn.Module):
    def __init__(self, samples_per_class, beta=0.9999, gamma=2.0, device=DEVICE):
        super().__init__()
        self.beta = beta
        self.gamma = gamma
        effective_num = 1.0 - np.power(beta, samples_per_class)
        weights = (1.0 - beta) / (effective_num + 1e-12)
        weights = weights / np.sum(weights) * len(samples_per_class)
        self.weights = torch.tensor(weights, dtype=torch.float32, device=device)
    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, reduction="none")
        pt = torch.exp(-ce)
        focal = ((1 - pt) ** self.gamma) * ce
        w = self.weights[targets]
        return (w * focal).mean()

# -------------------- utilities --------------------
def build_soft_targets(y_batch, n_classes, soft_conf=SOFT_CONF):
    bs = y_batch.size(0)
    soft = torch.full((bs, n_classes), (1.0 - soft_conf) / (n_classes - 1), device=y_batch.device)
    for i, lab in enumerate(y_batch):
        soft[i, lab] = soft_conf
    return soft

def collect_hard_examples_minority(model, dataset, minority_set, limit=HARD_SAMPLE_LIMIT):
    """
    Collect misclassified training samples that are MINORITY and NOT BT-augmented.
    Returns (xh, yh) or (None, None).
    """
    model.eval()
    xs, ys = [], []
    loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=NUM_WORKERS)
    with torch.no_grad():
        for b in loader:
            x = b["input_ids"].to(DEVICE)
            y = b["label"].to(DEVICE)
            aug_bt_flags = b.get("_aug_bt", None)
            logits = model(x)
            preds = logits.argmax(1)
            mask = preds != y
            # filter only minority and non-BT augmented
            mask_indices = []
            if aug_bt_flags is None:
                aug_mask = [False] * len(mask)
            else:
                aug_mask = aug_bt_flags
            for i_val, m in enumerate(mask.cpu().numpy()):
                if not m:
                    continue
                yi = int(y.cpu()[i_val].item())
                if yi in minority_set and not bool(aug_mask[i_val]):
                    mask_indices.append(i_val)
            if len(mask_indices) > 0:
                xs.append(x[mask_indices].cpu())
                ys.append(y[mask_indices].cpu())
    if not xs:
        return None, None
    xh = torch.cat(xs)[:limit]
    yh = torch.cat(ys)[:limit]
    return xh, yh

def evaluate(model, loader, logit_adjust=None):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for b in loader:
            x = b["input_ids"].to(DEVICE)
            y = b["label"].to(DEVICE)
            logits = model(x, logit_adjust)
            preds.extend(logits.argmax(1).cpu().tolist())
            trues.extend(y.cpu().tolist())
    acc = accuracy_score(trues, preds) if len(trues) > 0 else 0.0
    return acc, trues, preds

# -------------------- training loop --------------------
def train_loop():
    model = DamageNetTextFromScratch(vocab_size=vocab_size, n_classes=n_classes).to(DEVICE)
    print("Model params:", sum(p.numel() for p in model.parameters()))

    samples = train_df["label_id"].value_counts().sort_index().values.astype(np.int64)
    criterion_cb = CBLoss(samples_per_class=samples, beta=0.9999, gamma=2.0, device=DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    # initial logit adjust: stronger boost to minority
    priors = samples / samples.sum()
    logit_adjust = torch.log(1.0 / (torch.tensor(priors, dtype=torch.float32) + 1e-12)).to(DEVICE)
    logit_adjust = (logit_adjust - logit_adjust.mean()) * 1.0  # stronger scaling than before

    best = 0.0
    patience = 0

    for epoch in range(1, EPOCHS+1):
        model.train()
        total_loss = 0.0
        it = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        for batch in pbar:
            it += 1
            x = batch["input_ids"].to(DEVICE)
            y = batch["label"].to(DEVICE)

            optimizer.zero_grad()
            if scaler:
                with torch.cuda.amp.autocast():
                    logits = model(x, logit_adjust)
                    # apply soft targets only to minority items in batch
                    if USE_SOFT_LABELS:
                        mask_min = torch.tensor([int(int(li) in minority_set) for li in y.cpu().tolist()], device=DEVICE).bool()
                        if mask_min.any():
                            # create soft for minority rows and CE for rest
                            y_min_idx = mask_min.nonzero(as_tuple=False).squeeze(1)
                            y_min = y[y_min_idx]
                            y_soft = build_soft_targets(y_min, n_classes)
                            logits_min = logits[y_min_idx]
                            loss_min = F.kl_div(F.log_softmax(logits_min, dim=1), y_soft, reduction="batchmean")
                            if (~mask_min).any():
                                y_maj = y[~mask_min]
                                logits_maj = logits[~mask_min]
                                loss_maj = criterion_cb(logits_maj, y_maj)
                                loss = 0.6 * loss_min + 0.4 * loss_maj
                            else:
                                loss = loss_min
                        else:
                            loss = criterion_cb(logits, y)
                    else:
                        loss = criterion_cb(logits, y)
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer); scaler.update()
            else:
                logits = model(x, logit_adjust)
                if USE_SOFT_LABELS:
                    mask_min = torch.tensor([int(int(li) in minority_set) for li in y.cpu().tolist()], device=DEVICE).bool()
                    if mask_min.any():
                        y_min_idx = mask_min.nonzero(as_tuple=False).squeeze(1)
                        y_min = y[y_min_idx]
                        y_soft = build_soft_targets(y_min, n_classes)
                        logits_min = logits[y_min_idx]
                        loss_min = F.kl_div(F.log_softmax(logits_min, dim=1), y_soft, reduction="batchmean")
                        if (~mask_min).any():
                            y_maj = y[~mask_min]
                            logits_maj = logits[~mask_min]
                            loss_maj = criterion_cb(logits_maj, y_maj)
                            loss = 0.6 * loss_min + 0.4 * loss_maj
                        else:
                            loss = loss_min
                    else:
                        loss = criterion_cb(logits, y)
                else:
                    loss = criterion_cb(logits, y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            total_loss += float(loss.item())
            pbar.set_postfix(loss=total_loss/it)

        scheduler.step()

        val_acc, val_trues, val_preds = evaluate(model, test_loader, logit_adjust)
        print(f"\nEpoch {epoch} | Train Loss: {total_loss/it:.4f} | Val Acc: {val_acc:.4f}")
        print(classification_report(val_trues, val_preds, target_names=label_le.classes_, zero_division=0))

        # small logit_adjust update based on per-class recall (mild)
        report = classification_report(val_trues, val_preds, output_dict=True, zero_division=0)
        recs = []
        for i, cls in enumerate(label_le.classes_):
            rec = report.get(cls, {}).get("recall", 0.0)
            recs.append(rec)
        recs = np.array(recs)
        adjust = (1.0 - recs)
        adjust_t = torch.tensor(adjust, dtype=torch.float32, device=DEVICE)
        logit_adjust = logit_adjust + (adjust_t - adjust_t.mean()) * 0.25  # slightly stronger update
        logit_adjust = torch.clamp(logit_adjust, -4.0, 4.0)

        # checkpoint
        ckpt = {"epoch": epoch, "model_state": model.state_dict(), "optim_state": optimizer.state_dict(), "logit_adjust": logit_adjust.cpu().numpy()}
        torch.save(ckpt, os.path.join(MODEL_DIR, f"epoch_{epoch}.pth"))

        # hard example mining - only minority misclassified AND not BT-augmented
        hard_x, hard_y = collect_hard_examples_minority(model, train_dataset, minority_set, limit=HARD_SAMPLE_LIMIT)
        if hard_x is not None and HARD_RETRY_PASSES > 0:
            hard_x = hard_x.to(DEVICE)
            hard_y = hard_y.to(DEVICE)
            print(f"Retrying {len(hard_x)} minority hard samples for {HARD_RETRY_PASSES} pass(es)")
            for r in range(HARD_RETRY_PASSES):
                model.train()
                CH = 256
                idx = 0
                total_hloss = 0.0
                it_h = 0
                while idx < len(hard_x):
                    xb = hard_x[idx: idx+CH]
                    yb = hard_y[idx: idx+CH]
                    optimizer.zero_grad()
                    if scaler:
                        with torch.cuda.amp.autocast():
                            logits = model(xb, logit_adjust)
                            # for hard minority, use soft targets (makes them learn distribution)
                            y_soft = build_soft_targets(yb, n_classes, soft_conf=SOFT_CONF)
                            hloss = F.kl_div(F.log_softmax(logits, dim=1), y_soft, reduction="batchmean")
                        scaler.scale(hloss).backward()
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                        scaler.step(optimizer); scaler.update()
                    else:
                        logits = model(xb, logit_adjust)
                        y_soft = build_soft_targets(yb, n_classes, soft_conf=SOFT_CONF)
                        hloss = F.kl_div(F.log_softmax(logits, dim=1), y_soft, reduction="batchmean")
                        hloss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                        optimizer.step()
                    total_hloss += float(hloss.item())
                    it_h += 1
                    idx += CH
                print(f"  Hard pass {r+1} loss: {total_hloss/it_h:.4f}")

        # best model
        if val_acc > best:
            best = val_acc
            patience = 0
            torch.save(model.state_dict(), os.path.join(MODEL_DIR, "best_minority.pt"))
            print("Saved new best model.")
        else:
            patience += 1
            if patience >= 6:
                print("Early stopping.")
                break

    print("Training finished. Best val acc:", best)

# -------------------- ENTRY --------------------
if __name__ == "__main__":
    train_loop()


Device: mps
BT: translator not available -> disabling back-translation
Classes: ['little_or_no_damage', 'mild_damage', 'severe_damage']
Minority class ids: [0, 1]
Oversampled dataset size: 6039
Vocab size: 3455
Model params: 3059459


Epoch 1/20: 100%|██████████| 188/188 [00:39<00:00,  4.73it/s, loss=0.49] 



Epoch 1 | Train Loss: 0.4902 | Val Acc: 0.6049
                     precision    recall  f1-score   support

little_or_no_damage       0.25      0.21      0.23        71
        mild_damage       0.36      0.10      0.15       126
      severe_damage       0.67      0.88      0.76       332

           accuracy                           0.60       529
          macro avg       0.43      0.40      0.38       529
       weighted avg       0.54      0.60      0.55       529

Retrying 1856 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.5660
Saved new best model.


Epoch 2/20: 100%|██████████| 188/188 [00:40<00:00,  4.60it/s, loss=0.428]



Epoch 2 | Train Loss: 0.4284 | Val Acc: 0.6030
                     precision    recall  f1-score   support

little_or_no_damage       0.22      0.18      0.20        71
        mild_damage       0.45      0.16      0.24       126
      severe_damage       0.67      0.86      0.75       332

           accuracy                           0.60       529
          macro avg       0.45      0.40      0.40       529
       weighted avg       0.56      0.60      0.56       529

Retrying 1368 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.6188


Epoch 3/20: 100%|██████████| 188/188 [00:40<00:00,  4.66it/s, loss=0.362]



Epoch 3 | Train Loss: 0.3622 | Val Acc: 0.6144
                     precision    recall  f1-score   support

little_or_no_damage       0.25      0.17      0.20        71
        mild_damage       0.42      0.20      0.27       126
      severe_damage       0.68      0.87      0.76       332

           accuracy                           0.61       529
          macro avg       0.45      0.41      0.41       529
       weighted avg       0.56      0.61      0.57       529

Retrying 974 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.7510
Saved new best model.


Epoch 4/20: 100%|██████████| 188/188 [00:40<00:00,  4.59it/s, loss=0.275]



Epoch 4 | Train Loss: 0.2754 | Val Acc: 0.5936
                     precision    recall  f1-score   support

little_or_no_damage       0.30      0.23      0.26        71
        mild_damage       0.35      0.19      0.25       126
      severe_damage       0.67      0.83      0.74       332

           accuracy                           0.59       529
          macro avg       0.44      0.41      0.42       529
       weighted avg       0.55      0.59      0.56       529

Retrying 651 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.7747


Epoch 5/20: 100%|██████████| 188/188 [00:39<00:00,  4.74it/s, loss=0.214]



Epoch 5 | Train Loss: 0.2143 | Val Acc: 0.6200
                     precision    recall  f1-score   support

little_or_no_damage       0.36      0.23      0.28        71
        mild_damage       0.43      0.23      0.30       126
      severe_damage       0.68      0.85      0.75       332

           accuracy                           0.62       529
          macro avg       0.49      0.44      0.44       529
       weighted avg       0.58      0.62      0.58       529

Retrying 443 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.9017
Saved new best model.


Epoch 6/20: 100%|██████████| 188/188 [00:39<00:00,  4.74it/s, loss=0.174]



Epoch 6 | Train Loss: 0.1744 | Val Acc: 0.6049
                     precision    recall  f1-score   support

little_or_no_damage       0.29      0.24      0.26        71
        mild_damage       0.39      0.15      0.22       126
      severe_damage       0.67      0.86      0.75       332

           accuracy                           0.60       529
          macro avg       0.45      0.42      0.41       529
       weighted avg       0.55      0.60      0.56       529

Retrying 529 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 1.1515


Epoch 7/20: 100%|██████████| 188/188 [00:39<00:00,  4.78it/s, loss=0.14] 



Epoch 7 | Train Loss: 0.1396 | Val Acc: 0.6333
                     precision    recall  f1-score   support

little_or_no_damage       0.38      0.15      0.22        71
        mild_damage       0.44      0.27      0.33       126
      severe_damage       0.69      0.87      0.77       332

           accuracy                           0.63       529
          macro avg       0.50      0.43      0.44       529
       weighted avg       0.59      0.63      0.59       529

Retrying 303 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 1.0291
Saved new best model.


Epoch 8/20: 100%|██████████| 188/188 [00:39<00:00,  4.76it/s, loss=0.124]



Epoch 8 | Train Loss: 0.1244 | Val Acc: 0.6219
                     precision    recall  f1-score   support

little_or_no_damage       0.33      0.14      0.20        71
        mild_damage       0.41      0.17      0.24       126
      severe_damage       0.67      0.90      0.76       332

           accuracy                           0.62       529
          macro avg       0.47      0.40      0.40       529
       weighted avg       0.56      0.62      0.56       529

Retrying 369 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 1.0566


Epoch 9/20: 100%|██████████| 188/188 [00:39<00:00,  4.79it/s, loss=0.102] 



Epoch 9 | Train Loss: 0.1022 | Val Acc: 0.6219
                     precision    recall  f1-score   support

little_or_no_damage       0.33      0.20      0.25        71
        mild_damage       0.42      0.17      0.25       126
      severe_damage       0.67      0.88      0.76       332

           accuracy                           0.62       529
          macro avg       0.48      0.42      0.42       529
       weighted avg       0.57      0.62      0.57       529

Retrying 306 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.9123


Epoch 10/20: 100%|██████████| 188/188 [00:40<00:00,  4.59it/s, loss=0.0943]



Epoch 10 | Train Loss: 0.0943 | Val Acc: 0.6276
                     precision    recall  f1-score   support

little_or_no_damage       0.34      0.17      0.23        71
        mild_damage       0.44      0.21      0.29       126
      severe_damage       0.68      0.88      0.77       332

           accuracy                           0.63       529
          macro avg       0.49      0.42      0.43       529
       weighted avg       0.58      0.63      0.58       529

Retrying 255 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.8625


Epoch 11/20: 100%|██████████| 188/188 [00:39<00:00,  4.71it/s, loss=0.0819]



Epoch 11 | Train Loss: 0.0819 | Val Acc: 0.6144
                     precision    recall  f1-score   support

little_or_no_damage       0.29      0.21      0.25        71
        mild_damage       0.42      0.17      0.24       126
      severe_damage       0.68      0.87      0.76       332

           accuracy                           0.61       529
          macro avg       0.46      0.42      0.42       529
       weighted avg       0.56      0.61      0.57       529

Retrying 262 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.8075


Epoch 12/20: 100%|██████████| 188/188 [00:39<00:00,  4.82it/s, loss=0.0763]



Epoch 12 | Train Loss: 0.0763 | Val Acc: 0.6163
                     precision    recall  f1-score   support

little_or_no_damage       0.30      0.18      0.23        71
        mild_damage       0.41      0.19      0.26       126
      severe_damage       0.68      0.87      0.76       332

           accuracy                           0.62       529
          macro avg       0.46      0.41      0.42       529
       weighted avg       0.56      0.62      0.57       529

Retrying 218 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 1.0528


Epoch 13/20: 100%|██████████| 188/188 [00:39<00:00,  4.74it/s, loss=0.0642]



Epoch 13 | Train Loss: 0.0642 | Val Acc: 0.6181
                     precision    recall  f1-score   support

little_or_no_damage       0.29      0.21      0.25        71
        mild_damage       0.46      0.19      0.27       126
      severe_damage       0.68      0.87      0.76       332

           accuracy                           0.62       529
          macro avg       0.48      0.42      0.43       529
       weighted avg       0.57      0.62      0.57       529

Retrying 199 minority hard samples for 1 pass(es)
  Hard pass 1 loss: 0.9218
Early stopping.
Training finished. Best val acc: 0.6332703213610587


In [15]:
# ============================================================
#  test_damagenet_text_final.py
#  Standalone testing script for DamageNetTextFromScratch
# ============================================================

import os, re, json, torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter

from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix
)
from sklearn.preprocessing import LabelEncoder

# ------------------------------------------------------------
# CONFIG (EDIT THESE)
# ------------------------------------------------------------
VOCAB_PATH = "/Volumes/Extreme SSD/DL_Proj/damagenet_text_minority_ckpt/vocab.json"
CHECKPOINT_PATH = "/Volumes/Extreme SSD/DL_Proj/damagenet_text_minority_ckpt/best_minority.pt"

TEST_FILE = "/Volumes/Extreme SSD/DL_Proj/CrisisMMD_v2.0/crisismmd_datasplit_all/task_damage_text_img_test.tsv"

MAX_LEN = 128
EMB_DIM = 256
N_HEADS = 8
N_LAYERS = 4
FF_DIM = 512
DROPOUT = 0.1
BATCH_SIZE = 32

DEVICE = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print("DEVICE:", DEVICE)


# ------------------------------------------------------------
# UTILITIES
# ------------------------------------------------------------

def clean_text(text):
    text = str(text)
    text = re.sub(r"http\S+", "", text)
    text = re.sub(r"@\w+", "", text)
    text = re.sub(r"#\w+", "", text)
    text = re.sub(r"[^A-Za-z0-9\s.,!?']", " ", text)
    return re.sub(r"\s+", " ", text).strip().lower()

def basic_tokenizer(text):
    return re.findall(r"\b[\w']+\b", text.lower())

def encode_text_ids_from_tokens(tokens, vocab, max_len=MAX_LEN):
    ids = [vocab.get(t, vocab["<unk>"]) for t in tokens[:max_len - 1]]
    ids = [vocab["<cls>"]] + ids
    if len(ids) < max_len:
        ids += [vocab["<pad>"]] * (max_len - len(ids))
    return ids[:max_len]

# ------------------------------------------------------------
# LOAD VOCAB
# ------------------------------------------------------------

print("Loading vocab:", VOCAB_PATH)
with open(VOCAB_PATH, "r") as f:
    vocab = json.load(f)

vocab_size = len(vocab)
print("Vocab size:", vocab_size)


# ------------------------------------------------------------
# LOAD TEST TSV
# ------------------------------------------------------------

print("Loading test file:", TEST_FILE)
df = pd.read_csv(TEST_FILE, sep="\t")
df["tweet_text"] = df["tweet_text"].astype(str).apply(clean_text)

label_le = LabelEncoder()
df["label_id"] = label_le.fit_transform(df["label"].astype(str))
n_classes = len(label_le.classes_)
print("CLASSES:", label_le.classes_.tolist())


# ------------------------------------------------------------
# DATASET
# ------------------------------------------------------------

from torch.utils.data import Dataset, DataLoader

class TestDataset(Dataset):
    def __init__(self, df, vocab):
        self.df = df.reset_index(drop=True)
        self.vocab = vocab

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        toks = basic_tokenizer(row["tweet_text"])
        ids = encode_text_ids_from_tokens(toks, self.vocab)
        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "label": torch.tensor(int(row["label_id"]), dtype=torch.long)
        }

test_dataset = TestDataset(df, vocab)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


# ------------------------------------------------------------
# MODEL ARCHITECTURE (MATCHES TRAINING)
# ------------------------------------------------------------

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x, mask=None):
        out = self.attn(x, x, x, key_padding_mask=mask)[0]
        x = self.norm1(x + out)
        x = self.norm2(x + self.ff(x))
        return x

class TextEncoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=0)
        self.pos = nn.Parameter(torch.zeros(1, MAX_LEN, EMB_DIM))
        self.layers = nn.ModuleList([
            TransformerBlock(EMB_DIM, N_HEADS, FF_DIM, DROPOUT)
            for _ in range(N_LAYERS)
        ])
        self.norm = nn.LayerNorm(EMB_DIM)

    def forward(self, ids):
        mask = ids.eq(0)
        x = self.emb(ids) + self.pos[:, :ids.size(1)]
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x[:, 0])

class DamageNetTextFromScratch(nn.Module):
    def __init__(self, vocab_size, n_classes):
        super().__init__()
        self.encoder = TextEncoder(vocab_size)
        self.classifier = nn.Sequential(
            nn.Linear(EMB_DIM, EMB_DIM // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(EMB_DIM // 2, n_classes)
        )

    def forward(self, ids, logit_adjust=None):
        z = self.encoder(ids)
        logits = self.classifier(z)
        return logits + logit_adjust if logit_adjust is not None else logits


# ------------------------------------------------------------
# LOAD CHECKPOINT (HANDLES ALL FORMATS)
# ------------------------------------------------------------

print("Loading checkpoint:", CHECKPOINT_PATH)
ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)

model = DamageNetTextFromScratch(vocab_size, n_classes).to(DEVICE)

if isinstance(ckpt, dict) and "model_state" in ckpt:
    print("Checkpoint format: {model_state: ...}")
    model.load_state_dict(ckpt["model_state"])
    logit_adjust = ckpt.get("logit_adjust", None)

elif isinstance(ckpt, dict) and any(k.startswith("encoder") or k.startswith("classifier") for k in ckpt.keys()):
    print("Checkpoint format: raw state_dict dict")
    model.load_state_dict(ckpt)
    logit_adjust = None

else:
    print("Checkpoint format: raw tensor-only state_dict")
    model.load_state_dict(ckpt)
    logit_adjust = None

if logit_adjust is not None:
    logit_adjust = torch.tensor(logit_adjust, dtype=torch.float32, device=DEVICE)
    print("Loaded logit adjustment.")
else:
    print("No logit adjustment found.")


# ------------------------------------------------------------
# EVALUATION LOOP
# ------------------------------------------------------------

print("\nRunning inference...")
model.eval()
preds = []
trues = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        ids = batch["input_ids"].to(DEVICE)
        y = batch["label"].to(DEVICE)

        logits = model(ids, logit_adjust)
        pred = logits.argmax(1)

        preds.extend(pred.cpu().tolist())
        trues.extend(y.cpu().tolist())


# ------------------------------------------------------------
# METRICS
# ------------------------------------------------------------

acc = accuracy_score(trues, preds)

print("\n================= TEST RESULTS =================")
print(f"Accuracy: {acc:.4f}\n")

print(classification_report(
    trues,
    preds,
    target_names=label_le.classes_,
    digits=4,
    zero_division=0
))

print("\nConfusion Matrix:")
print(confusion_matrix(trues, preds))

print("\n================================================\n")


DEVICE: mps
Loading vocab: /Volumes/Extreme SSD/DL_Proj/damagenet_text_minority_ckpt/vocab.json
Vocab size: 3455
Loading test file: /Volumes/Extreme SSD/DL_Proj/CrisisMMD_v2.0/crisismmd_datasplit_all/task_damage_text_img_test.tsv
CLASSES: ['little_or_no_damage', 'mild_damage', 'severe_damage']
Loading checkpoint: /Volumes/Extreme SSD/DL_Proj/damagenet_text_minority_ckpt/best_minority.pt


  ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)


Checkpoint format: raw state_dict dict
No logit adjustment found.

Running inference...


100%|██████████| 17/17 [00:00<00:00, 20.01it/s]


Accuracy: 0.6276

                     precision    recall  f1-score   support

little_or_no_damage     0.3714    0.1831    0.2453        71
        mild_damage     0.4430    0.2778    0.3415       126
      severe_damage     0.6843    0.8554    0.7604       332

           accuracy                         0.6276       529
          macro avg     0.4996    0.4388    0.4490       529
       weighted avg     0.5849    0.6276    0.5915       529


Confusion Matrix:
[[ 13   9  49]
 [  9  35  82]
 [ 13  35 284]]





