# CODE FOR THE FULL MULTITASKING PIPELINE - EXTRACTS ANTHOPOMETRIC MEASURES AND CLASSIFICATION OF IMAGES

In [None]:
# Necessary imports
import os
import random
from glob import glob
from typing import List, Dict

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

# %pip uninstall torch torchvision torchaudio -y
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.is_available())


from tqdm import tqdm

from sklearn.metrics import (
    mean_absolute_error, root_mean_squared_error, r2_score,
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
)

# Decide the Model

In [None]:

# Model Configurations : 
MODEL_CONFIGS = {
    "resnet50": {
        "img_size": 224,
        "weights": "IMAGENET1K_V1",  
        "target_layer" : "layer4"
    },
    "efficientnet_b3": {
        "img_size": 300,
        "weights": "IMAGENET1K_V1",
        "target_layer" : "features"
    }
}

# Choose which backbone you want to train
BACKBONE = "resnet50"   # change to "efficientnet_b3" to train EffNet
IMG_SIZE = MODEL_CONFIGS[BACKBONE]["img_size"]


# Setting Hyperparameters

In [None]:
# Hyperparameter Tuning

SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


print("CUDA Available:", torch.cuda.is_available())
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device.type == "cuda":
    torch.cuda.manual_seed_all(SEED)

# Paths 
MANIFEST = "Anthrovision/manifest.csv"  # your manifest with proc_path, child_id, label, split
IMG_ROOT = "."              # base dir for proc_path (if proc_path already absolute, that's fine)

# Model & training params
BATCH_SIZE = 8                     # number of children per batch (each child has num_imgs)
NUM_IMAGES_PER_CHILD = 5           # sample/pad to fixed number
LR = 1e-4
WEIGHT_DECAY = 1e-5
EPOCHS = 30
REGRESSION_LOSS_WEIGHT = 0.01       # weight for regression loss
CLASSIFICATION_LOSS_WEIGHT = 0.5   # weight for classification loss
GRAD_CLIP = 1.0
PATIENCE_LR_SCHED = 4


# For augmentation
def make_transforms(img_size):
    train_tf = T.Compose([
        T.Resize((img_size, img_size)),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomApply(
            [T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.15, hue=0.02)],
            p=0.5
        ),
        T.RandomRotation(degrees=10),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])

    val_tf = T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])
    return train_tf, val_tf

train_transforms, val_transforms = make_transforms(IMG_SIZE)

print(f"Using backbone: {BACKBONE} with image size {IMG_SIZE}x{IMG_SIZE}")

# Merging Manifest and Original Labels File

In [None]:
# If your original labels csv contains the numeric targets (e.g., head_circumference,height,muac),
# merge them into manifest.csv keyed by child_id. If your manifest already contains them, skip this.

def ensure_targets_in_manifest(manifest_path: str, original_labels_path: str = None):
    df = pd.read_csv(manifest_path)

    # Check for regression columns
    reg_cols = ['HC', 'Height', 'MUAC']
    
    # Build list of which of those names are missing from df.columns
    missing = [c for c in reg_cols if c not in df.columns]

    # Try to merge manifest and the original_labels_path(anthrovision_lables.csv)
    if missing and original_labels_path:
        print(f"Manifest missing {missing}. Attempting to merge from {original_labels_path}")
        orig = pd.read_csv(original_labels_path)

        print(orig.columns.tolist())

        # orig should contain child_id(tag) + regression columns
        df = df.merge(orig[['tag'] + reg_cols], left_on='child_id', right_on='tag', how='left')
        
        df.to_csv(manifest_path, index=False)
        print("Merged and saved manifest with regression targets.")
    elif missing:
        raise ValueError(f"Manifest is missing regression target columns {missing}. Provide original_labels_path to merge.")
    else:
        print("Manifest already contains regression target columns.")

    return

# Building Dataset - one item = one child with N images

In [None]:
class ChildImageDataset(Dataset):
    """
    ChildImageDataset is a custom PyTorch Dataset whose one item corresponds to one child (not one image).
    Each item returns:
        images: a tensor containing num_images_per_child images for that child, shape (num_images, C, H, W).
        reg_targets: tensor of 3 regression values [head_circumference, height, muac] (may contain nan when missing).
        class_label: tensor 0.0 or 1.0 for malnutrition binary target.
        child_id: identifier string for that child
    """
    def __init__(self, manifest_csv: str, split: str, img_root: str = ".", transforms=None, num_images_per_child: int = 4):
        self.df = pd.read_csv(manifest_csv)
        self.split = split                                        # split -> test , val , train
        self.img_root = img_root
        self.transforms = transforms
        self.num_images_per_child = num_images_per_child

        # Filter split to keep only rows where the split column (case-insensitive) matches the requested split
        self.split_df = self.df[self.df['split'].str.lower() == split.lower()].copy()
        if self.split_df.empty:
            raise ValueError(f"No rows found for split={split} in {manifest_csv}")

        # Groups rows that belong to the same child. Each group g is a small DataFrame containing all rows/images for one child
        self.children = []
        child_groups = self.split_df.groupby('child_id')

        for child_id, g in child_groups:
            # List of all image file paths for this child
            paths = g['proc_path'].tolist()      
            # List of the binary labels for the images
            labels = g['label'].astype(float).tolist()
            
            # For child-level targets (head_circumference, height, muac), each image row may include the same value
            # Guarantees one child-level numeric value per child
            def first_nonnull(col):
                if col in g.columns : 
                    vals = g[col].dropna().unique().tolist()  
                else: 
                    vals = []
                return float(vals[0]) if vals else np.nan

            head_circ = first_nonnull('HC')
            height = first_nonnull('Height')
            muac = first_nonnull('MUAC')

            # Handles the rare case when labels differ per-image; majority vote produces a single child-level label
            class_label = float(round(np.mean(labels)))

            self.children.append({
                'child_id': child_id,
                'paths': paths,          # List of image paths
                'class_label': class_label,
                'head_circumference': head_circ,
                'height': height,
                'muac': muac
            })

    # Returns number of unique children in the chosen split
    def __len__(self):
        return len(self.children)

    def _load_and_transform(self, path):
        # Allow proc_path be absolute or relative
        if not os.path.isabs(path):
            p = os.path.join(self.img_root, path)
        else:
            p = path

        # Ensures consistent 3-channel input(RGB) even if some images are grayscale    
        img = Image.open(p).convert('RGB')

        if self.transforms:
            img = self.transforms(img)
        else:
            img = T.ToTensor()(img)

        # Return tensor : (Channel = 3 , Height , Width)
        return img


    # Fetch the idx-th item
    def __getitem__(self, idx):
        # dictionary we built for the child at that index
        info = self.children[idx]
        imgs = []

        # Randomly sample num_images_per_child unique images from the child's available images - in case of more images
        if len(info['paths']) >= self.num_images_per_child:
            # sample randomly, but keep deterministic seed per idx/epoch could be added
            sampled = random.sample(info['paths'], self.num_images_per_child)
        # In case of less images
        else:
            # repeat first image to pad
            sampled = list(info['paths'])
            while len(sampled) < self.num_images_per_child:
                sampled.append(info['paths'][0])  # repeat first
                
        # For each selected path p, load the image and transform it
        for p in sampled:
            imgs.append(self._load_and_transform(p))
        imgs_tensor = torch.stack(imgs, dim=0)  # (num_images, C, H, W)

        # building 1D tensor of regression targets: could be NaN if missing; we will mask them during loss
        reg_targets = torch.tensor([
            float(info['head_circumference']) if not pd.isna(info['head_circumference']) else np.nan,
            float(info['height']) if not pd.isna(info['height']) else np.nan,
            float(info['muac']) if not pd.isna(info['muac']) else np.nan,
        ], dtype=torch.float32)

        return {
            'images': imgs_tensor,
            'reg_targets': reg_targets,
            'class_label': torch.tensor(info['class_label'], dtype=torch.float32),
            'child_id': info['child_id']
        }

# collate_fn is default since each item has same shapes (we fixed num_images_per_child)
# collate_fn will stack items into batch tensors
def get_dataloaders(manifest_path: str, img_root: str, batch_size: int, num_images_per_child: int, train_transforms, val_transforms):
    
    # Create dataset instances
    train_ds = ChildImageDataset(
        manifest_path, split='train', img_root=img_root,
        transforms=train_transforms, num_images_per_child=num_images_per_child
    )
    val_ds = ChildImageDataset(
        manifest_path, split='val', img_root=img_root,
        transforms=val_transforms, num_images_per_child=num_images_per_child
    )
    test_ds = ChildImageDataset(
        manifest_path, split='test', img_root=img_root,
        transforms=val_transforms, num_images_per_child=num_images_per_child
    )

    # Create Dataloaders
    # batch_size: number of children per batch (B). The effective number of images processed per GPU step will be B * num_images_per_child
    # shuffle randomizes order of children each epoch
    # num_workers spawns 4 worker processes to load and preprocess images in parallel
    # pin_memory places tensors in page-locked memory to speed transfer to GPU
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    
    return train_loader, val_loader, test_loader

# Building a MultiTask Model : ResNet50 + Mean Pooling

In [None]:
# MultiTaskModel is a PyTorch neural network that extracts features from images using a ResNet50 backbone
# It has two heads - classification + regression

class ResNet_MultiTaskModel(nn.Module):
    # Three regression outputs - muac , hc , height
    def __init__(self, backbone_name='resnet50', pretrained=True, regression_out=3):
        super().__init__()                                    # Initialize the pre-trained NN model
        # Load backbone
        if backbone_name == 'resnet50':
            base = models.resnet50(pretrained=pretrained)     # pretrained weights trained on ImageNet are loaded
            feat_dim = base.fc.in_features                    # usually 2048 - fully connected layer
            # nn.Identity() simply returns its input unchanged. We do this because we want to attach our own heads
            base.fc = nn.Identity()                           # we will use the pooled features
            self.backbone = base
            self.last_conv_name = 'layer4'                    # for grad-cam if needed - records last convolution layer
        else:
            raise NotImplementedError("Backbone not implemented in this script. Only ResNet50 allowed.")

        # Store feature dimensionality (e.g., 2048) for building heads.
        self.feat_dim = feat_dim

        # Building the Regression Heads- feed forward network
        self.regression_head = nn.Sequential(
            nn.Linear(feat_dim, 512),                  # Fully connected layer that maps the feature vector to 512 units 
            nn.ReLU(),
            nn.Dropout(0.3),                           # Dropout prob = 30%
            nn.Linear(512, regression_out)             # Final linear layer which has 3 output heads
        )

        # The final output is a single logit per example in the classification head
        self.classification_head = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)                          # BCEWithLogitsLoss expects logits hence no sigmoid
        )
        # Note :  We do not put a sigmoid here. Instead we return logits; 
        # training should use BCEWithLogitsLoss, which combines a numerically stable sigmoid + binary cross-entropy internally

    # Defines how the model computes outputs from inputs
    def forward(self, images):
        """
        images: Tensor (batch, num_imgs, C, H, W)
        returns:
          reg_preds (batch, 3), class_logits (batch, 1)
        """
        
        # Unpack the tensor into variables
        b, n, c, h, w = images.shape

        # flatten images into batch*n - reshapes (without copying) the 5D tensor into a 4D tensor of shape
        # This step is needed as the backbone expects a batch of single images
        images_flat = images.view(b * n, c, h, w)

        # Pass every image through the ResNet backbone
        feats_flat = self.backbone(images_flat)  # feats_flat shape -> (b*n, feat_dim)

        # Reshape back to group embeddings by child
        feats = feats_flat.view(b, n, -1)        # feats shape -> (b, n, feat_dim)
        # Now feats[i, j] is the embedding for the j-th image of the i-th child in the batch

        # Average embeddings across images per child
        pooled = feats.mean(dim=1)               # pooled shape -> (b, feat_dim)

        # Pass the averaged embedding through the regression head
        reg_out = self.regression_head(pooled)   # reg_out shape -> (b, 3)

        # Pass through classification head
        class_logits = self.classification_head(pooled).squeeze(1)  # shape(due to .squeeze) -> (b,)

        return reg_out, class_logits, pooled 

# Building a Multitask Model - EfficientNet-B3 + Attention Pooling

In [None]:
# This module extracts per-image features with EfficientNet-B3, learns per-image attention weights 
# to fuse multiple views into a single child embedding, and from that embedding predicts 
# the three regression targets (head circumference, height, MUAC) and one classification logit (malnutrition). 
# It also returns the attention weights so you can inspect which images the model relied on.

# Attention Pooling Module
class AttentionPooling(nn.Module):
    """
    Learns to weigh each image embedding per child.
    Input: feats -> (batch, num_imgs, feat_dim) , feat_dim is 1536 for EfficientNet-B3
    Output: pooled -> (batch, feat_dim)
    """
    def __init__(self, feat_dim):
        super().__init__()
        # Creating a small MLP applied independently to each embedding vector
        self.attn = nn.Sequential(
            nn.Linear(feat_dim, 256),   # maps D(feature dimension) → 256, projects features into a small “attention space”.
            nn.ReLU(),
            nn.Linear(256, 1)           # maps the 256-vector to a single scalar score - unnormalized attention logit
        )

    def forward(self, feats):
        # feats: (B, N, D) , D is feature dimension
        # Treats feats as a batch of size B*N vectors of length D and outputs B*N scalars, then reshapes them to (B, N, 1)
        weights = self.attn(feats)                   # attention logit of image 'i' of child 'b'     
        weights = torch.softmax(weights, dim=1)      # normalize weights across N images , logits -> weights
        pooled = torch.sum(weights * feats, dim=1)   # weighted sum of features
        return pooled, weights                       
# Note : Averaging treats each view equally. Attention lets the network learn which views are useful 
# (e.g., frontal posture may be more informative for MUAC; lateral might help height). 
# This typically improves performance when image quality or pose varies.


class EfficientNet_MultiTaskModel(nn.Module):
    def __init__(self, backbone_name='efficientnet_b3', pretrained=True, regression_out=3):
        super().__init__()

        # Load EfficientNet backbone
        if backbone_name == 'efficientnet_b3':
            base = models.efficientnet_b3(weights='IMAGENET1K_V1' if pretrained else None)    # loads EfficientNet-B3
            feat_dim = base.classifier[1].in_features   # usually 1536
            base.classifier = nn.Identity()             # remove classification head - we define our own output heads
            self.backbone = base
            self.last_conv_name = 'features'            # for Grad-CAM if needed
        else:
            raise NotImplementedError(f"Backbone {backbone_name} not implemented")

        # Store feat_dim (1536) and create attention pooling instance
        self.feat_dim = feat_dim
        self.attn_pool = AttentionPooling(feat_dim)

        # Regression head - no activation on final layer (raw continuous outputs)
        self.regression_head = nn.Sequential(
            nn.Linear(feat_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, regression_out)
        )

        # Classification head - last layer returns logits since BCEWithLogitsLoss expects logits
        self.classification_head = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)   
        )

    def forward(self, images):
        """
        images: Tensor (batch, num_imgs, C, H, W)
        Returns:
            reg_out: (batch, 3)
            class_logits: (batch,)
            pooled_feats: (batch, feat_dim)
            attn_weights: (batch, num_imgs, 1)
        """
        b, n, c, h, w = images.shape

        # Reshape so backbone processes each image separately like a standard batch of single images
        images_flat = images.view(b * n, c, h, w)

        feats_flat = self.backbone(images_flat)        # image run through EfficientNet backbone, feat_flat -> (b*n, feat_dim)
        feats = feats_flat.view(b, n, -1)              # Reshape back to group features per child -> (b, n, feat_dim)

        # Attention pooling across images
        pooled, attn_weights = self.attn_pool(feats)   # pooled -> (b, feat_dim)  , attn_weight -> (b, n, 1)

        # Multi-task heads
        reg_out = self.regression_head(pooled)                      # reg_out -> (b, 3)
        class_logits = self.classification_head(pooled).squeeze(1)  # class_logits -> (b,) -> removes singleton channel dim

        return reg_out, class_logits, pooled, attn_weights

# Model Utility Functions

In [None]:
# Compute mean squared error (MSE) only for regression targets that actually exist. 
# Some children may not have HC/Height/MUAC (NaN). We ignore(mask) those entries in loss computation.
def masked_regression_loss(preds, targets, mask=None, loss_fn=nn.SmoothL1Loss(reduction='none')):
    """
    preds: (B, 3)  , our y_cap var which have been predicted by model
    targets: (B, 3) with possible np.nan  ,  our y original variable
    mask: optional boolean (B,3) where True indicates valid
    """
    # If the caller doesn’t provide a mask, create one from targets: valid entries are where targets is not NaN
    if mask is None:
        mask = ~torch.isnan(targets)            # False indicates missing target(NaN)

    # Contains squared error for each element with shape (B,3)
    # loss_per_elem = (pred - targ)^2
    loss_per_elem = loss_fn(preds, torch.nan_to_num(targets, nan=0.0))  

    # Zero out invalid entries by multiplying with mask
    loss_per_elem = loss_per_elem * mask.float()

    # Average only over valid elements
    valid_counts = mask.sum()         # Total number of valid target entries across the batch
    
    # In case of zero valid entries, avoids dividing by zero          
    if valid_counts.item() == 0:
        return torch.tensor(0.0, device=preds.device, requires_grad=True)
    
    # Otherwise average the sum of squared errors over the valid count → this is the masked MSE
    return loss_per_elem.sum() / valid_counts.float()
    # The returned scalar is the mean element-wise MSE across the whole batch for all available regression entries


# Runs the model forward, computes the masked regression loss and a binary classification loss, 
# sums them (with configurable weights), backpropagates, and updates the optimizer
def train_epoch(model, loader, optimizer, epoch, scheduler=None):
    model.train()           # Sets the model in training mode       
    total_loss = 0.0        # Accumulate the (weighted) loss across the epoch
    total_samples = 0       # Counts the number of children processed

    pbar = tqdm(loader, desc=f"Training Epoch {epoch+1}", leave=False)    # Progress bar

    
    scaler = torch.amp.GradScaler()
    cls_loss_fn = nn.BCEWithLogitsLoss()            # sigmoid + binary cross entropy

    for batch in loader:
        images = batch['images'].to(device)                # (B, n, C, H, W)
        reg_targets = batch['reg_targets'].to(device)     # (B, 3) may have NaN
        class_labels = batch['class_label'].to(device)    # (B,)

        # Clears old gradients
        optimizer.zero_grad()

        with torch.amp.autocast(device_type='cuda'):    # allows GPU use
            # The models accepts images and returns regression and classfication heads along with 
            # optionally pooled embeddings or attention weights that are ignored here(_)
            outputs = model(images)
            if len(outputs) >= 3:
                if isinstance(outputs, (list, tuple)):
                    reg_preds = outputs[0]
                    class_logits = outputs[1]
                else:
                    reg_preds, class_logits = outputs
            else:
                reg_preds, class_logits = outputs
                
            # Check for NaNs/infs before computing losses
            if torch.isnan(reg_preds).any() or torch.isinf(reg_preds).any():
                print("NaN or inf in regression predictions!")
                print("Batch mean:", reg_preds.mean().item(), "std:", reg_preds.std().item())
                continue  # skip this batch

            if torch.isnan(class_logits).any() or torch.isinf(class_logits).any():
                print("NaN or inf in classification logits!")
                continue

            # Regression loss (masked)
            reg_loss = masked_regression_loss(reg_preds, reg_targets, None, nn.SmoothL1Loss(reduction='none'))
            # Classification loss
            cls_loss = cls_loss_fn(class_logits, class_labels)
            # Weighted sum of regression and classification losses
            loss = REGRESSION_LOSS_WEIGHT * reg_loss + CLASSIFICATION_LOSS_WEIGHT * cls_loss

        # Backward propagation
        scaler.scale(loss).backward()

        # Gradient clipping prevents explosion of gradients. clip_grad_norm_ rescales gradients if their global norm exceeds GRAD_CLIP
        if GRAD_CLIP:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        # Updates parameters
        scaler.step(optimizer)
        scaler.update() 

        # Accumulate total loss proportional to number of children
        total_loss += loss.item() * images.shape[0]
        # Total number of children trained
        total_samples += images.shape[0]

        # Updates the progress bar
        pbar.set_postfix({"cls": f"{cls_loss.item():.4f}", 
                          "reg": f"{reg_loss.item():.4f}", 
                          "tot": f"{loss.item():.4f}"})

    # Average loss per child is returned
    avg_loss = total_loss / total_samples

    # if scheduler is not None and isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
    #     scheduler.step(avg_loss)
    
    return avg_loss

# Runs the model over a loader without gradients and computes per-target regression metrics (MAE, RMSE, R²) and 
# classification metrics (accuracy, precision, recall, f1, ROC AUC)
@torch.no_grad()            # disables gradient tracking to save memory and compute
def evaluate(model, loader):
    model.eval()            # Sets dropout off and BatchNorm into evaluation mode
    # Collect true/pred for metrics, per-batch numpy arrays
    all_reg_true = []
    all_reg_pred = []
    all_cls_true = []
    all_cls_prob = []
    all_child_ids = []

    pbar = tqdm(loader, desc=f"Evaluating", leave=False)

    for batch in loader:
        images = batch['images'].to(device, non_blocking=True)
        reg_targets = batch['reg_targets'].to(device, non_blocking=True)
        class_labels = batch['class_label'].to(device, non_blocking=True)
        child_ids = batch['child_id']

        outputs = model(images)
        if len(outputs) >= 3:
            if isinstance(outputs, (list, tuple)):
                reg_preds = outputs[0]
                class_logits = outputs[1]
            else:
                reg_preds, class_logits = outputs
        else:
            reg_preds, class_logits = outputs
        # Probabilities in [0,1] computed from logits by sigmoid
        class_probs = torch.sigmoid(class_logits)

        # Move to CPU numpy
        all_reg_true.append(reg_targets.cpu().numpy())
        all_reg_pred.append(reg_preds.cpu().numpy())
        all_cls_true.append(class_labels.cpu().numpy())
        all_cls_prob.append(class_probs.cpu().numpy())
        all_child_ids.extend(child_ids)

    # N is total children in the loader (sum of batch sizes), we now have the entire dataset predictions
    all_reg_true = np.vstack(all_reg_true)       # (N,3)
    all_reg_pred = np.vstack(all_reg_pred)
    all_cls_true = np.concatenate(all_cls_true)  # (N,)
    all_cls_prob = np.concatenate(all_cls_prob)

    # Regression metrics: compute per target where targets not NaN
    reg_metrics = {}
    for i, name in enumerate(['head_circumference', 'height', 'muac']):
        # Mark the valid targets
        mask = ~np.isnan(all_reg_true[:, i])
        # If no valid targets
        if mask.sum() == 0:
            reg_metrics[name] = {'MAE': None, 'RMSE': None, 'R2': None}
            continue
        # Compute accuracy metrics
        y_true = all_reg_true[mask, i]
        y_pred = all_reg_pred[mask, i]
        mae = mean_absolute_error(y_true, y_pred)
        rmse = root_mean_squared_error(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        reg_metrics[name] = {'MAE': float(mae), 'RMSE': float(rmse), 'R2': float(r2)}

    # Classification metrics 
    y_true_cls = all_cls_true
    y_prob_cls = all_cls_prob

    # Automatically find best threshold for F1 score - overcomes class imbalance
    best_thresh, best_f1 = 0.5, 0.0
    for t in np.linspace(0.05, 0.95, 91):   # test thresholds from 0.05 to 0.95
        preds = (y_prob_cls >= t).astype(int)
        f1 = f1_score(y_true_cls, preds, zero_division=0)
        if f1 > best_f1:
            best_f1, best_thresh = f1, t

    # Use best threshold for final predictions
    y_pred_cls = (y_prob_cls >= best_thresh).astype(int)    # If prob > threshold , predict child is malnutritioned
    
    try:
        roc_auc = float(roc_auc_score(y_true_cls, y_prob_cls))
    except ValueError:
        roc_auc = None  # if only one class present, since both are required
    cls_metrics = {
        'accuracy': float(accuracy_score(y_true_cls, y_pred_cls)),
        'precision': float(precision_score(y_true_cls, y_pred_cls, zero_division=0)),
        'recall': float(recall_score(y_true_cls, y_pred_cls, zero_division=0)),
        'f1': float(f1_score(y_true_cls, y_pred_cls, zero_division=0)),
        'roc_auc': roc_auc
    }

    return reg_metrics, cls_metrics

# GRAD-CAM Utility

In [None]:
# GradCAM produces a visual explanation (a heatmap) that highlights 
# which parts of an input image the network used to make a particular decision
class GradCAM:
    """
    Simple Grad-CAM for ResNet-like backbones. Usage:
      cam = GradCAM(model, target_layer=model.layer4)
      mask = cam.generate_cam(input_tensor, class_idx=None)
    returns: mask (H_orig, W_orig) float between 0..1
    """
    def __init__(self, model: nn.Module):
        self.model = model
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        # Use model.last_conv_name to find correct layer
        if not hasattr(self.model, "last_conv_name"):
            raise AttributeError("Model must define `last_conv_name` (e.g., 'layer4' or 'features').")
        
        target_layer_name = self.model.last_conv_name
        # Builds a dictionary mapping strings like 'layer4' → module objects inside the model’s backbone
        named_modules = dict(self.model.backbone.named_modules())       # Returns (name, model)

        # Checks the requested name exists in the dict — otherwise it raises an error
        if target_layer_name not in named_modules:
            raise ValueError(f"Target layer '{target_layer_name}' not found in backbone modules: {list(named_modules.keys())[:10]}...")

        # find target layer module
        target_module = dict(self.model.backbone.named_modules())[target_layer_name]

        # forward hook to capture activations , output is the activation tensor produced by the model
        def forward_hook(module, input, output):
            self.activations = output.detach()      # .detach() allows to store the raw activation values without further auto-compututaion

        # backward hook to capture gradients during nackpropagation
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()   # grad_out contains gradients of outputs wrt to the model's shape(B,C,H,W)

        # We save the handles returned by the hooks in self.hook_handles so we can remove hooks later via handle.remove()
        self.hook_handles.append(target_module.register_forward_hook(forward_hook))
        self.hook_handles.append(target_module.register_full_backward_hook(backward_hook))

    def remove_hooks(self):
        for h in self.hook_handles:
            # Remove hooks so that they don't interfere with fwd/back prop and also to prevent memory leakage
            h.remove()

    def generate_cam(self, input_tensor: torch.Tensor, class_idx: int = None):
        """
        input_tensor: single sample (1, num_imgs, C, H, W) - chose the first image ;  or (1, C, H, W) - we accept one image only
        For simplicity, we produce CAM for the first image in the set (frontal). If input has multiple images,
        pass a single image to this function or modify logic accordingly.
        """
        # Reset parameters , gradients , activations
        self.model.eval()
        self.model.zero_grad()
        self.gradients = None
        self.activations = None

        # If input has shape (1, n, C, H, W) take first image only
        if input_tensor.ndim == 5:
            x = input_tensor[:, 0, ...]  # (1, C, H, W)
        else:
            x = input_tensor

        # Get the device(CPU or GPU) where the model params exist
        x = x.to(next(self.model.parameters()).device)

        # Forward through backbone until logits: we need to call backbone and heads manually
        # We'll run through model.backbone and get feature pooled vector
        # but to trigger hooks we call the full forward with a crafted forward pass
        
        # Build a small forward pass:
        # We will compute classification logits and pick a class_idx (for binary, positive class)
        feats = self.model.backbone(x)          # triggers forward hook on last convolution layer
        pooled = feats.view(feats.size(0), -1)  # flattens features to (B,D)
        # Run the heads using pooled features
        reg_out = self.model.regression_head(pooled)
        cls_logits = self.model.classification_head(pooled).squeeze(1)

        if class_idx is None:
            # for binary, set class_idx based on positive logit
            target = cls_logits
        else:
            # not typical for binary; using direct scalar
            target = cls_logits * 0 + cls_logits  # fallback

        # Backprop gradients from target scalar
        target = target[0]
        target.backward(retain_graph=True)

        # grads and activations captured
        if self.gradients is None or self.activations is None:
            raise RuntimeError("Gradients or activations not captured — check hooks.")

        # Standard Grad-CAM formula: channel-wise average of gradients used to weight the feature maps, then sum
        # grads: (B=1, C, H', W'), activations: same shape
        grads = self.gradients[0]       # (C, H', W') , selected the first batch
        acts = self.activations[0]      # (C, H', W') , selected the first batch
        # Average gradients spatially over (H',W')
        weights = grads.mean(dim=(1, 2), keepdim=True)  # (C,1,1)
        # Multiply each channel map with its weight
        cam = (weights * acts).sum(dim=0)  # (H', W')
        # Grad-CAM uses ReLU because negative regions are typically not helpful for "positive class" explanation
        cam = F.relu(cam)
        cam = cam.cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)  # normalize 0..1

        # Convert NumPy cam back to a PyTorch tensor and add batch & channel dims
        cam_tensor = torch.tensor(cam).unsqueeze(0).unsqueeze(0)  # (1,1,H',W')
        # Upsamples the coarse map to the same size as the input image
        cam_up = F.interpolate(cam_tensor, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
        # Remove extra dimensions added two lines above
        cam_up = cam_up.squeeze().numpy()
        return cam_up

def overlay_cam_on_image(img_tensor, cam_mask, alpha=0.4):
    """
    img_tensor: (C,H,W) un-normalized or normalized (we'll denormalize using ImageNet params)
    cam_mask: (H,W) in 0..1
    Returns matplotlib figure
    """
    # Denormalize
    inv_norm = T.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )

    # Move to CPU, denormalize, permute from (C,H,W) to (H,W,C) for plotting, convert to NumPy
    img = inv_norm(img_tensor.cpu()).permute(1, 2, 0).numpy()
    # ensures RGB stays valid
    img = np.clip(img, 0, 1)

    plt.figure(figsize=(4,4))
    plt.imshow(img)
    plt.imshow(cam_mask, cmap='jet', alpha=alpha)
    plt.axis('off')
    plt.tight_layout()

# Main Driver

In [None]:
def get_model(backbone_name: str, pretrained=True, regression_out=3):
    if backbone_name == "resnet50":
        model = ResNet_MultiTaskModel(backbone_name='resnet50', pretrained=pretrained, regression_out=regression_out)
    elif backbone_name == "efficientnet_b3":
        model = EfficientNet_MultiTaskModel(backbone_name='efficientnet_b3', pretrained=pretrained, regression_out=regression_out)
    else:
        raise ValueError(f"Unsupported backbone: {backbone_name}")
    return model


def main(BACKBONE : str):
    cfg = MODEL_CONFIGS[BACKBONE]
    IMG_SIZE = cfg["img_size"]
    TARGET_LAYER = cfg["target_layer"]

    # Ensure regression targets present (pass original labels path if needed)
    ensure_targets_in_manifest(MANIFEST, original_labels_path="Anthrovision/anthrovision_labels.csv")

    # Load transforms with correct size
    train_tf, val_tf = make_transforms(IMG_SIZE)

    train_loader, val_loader, test_loader = get_dataloaders(MANIFEST, IMG_ROOT, BATCH_SIZE, NUM_IMAGES_PER_CHILD, train_tf, val_tf)

    # Initialize model
    model = get_model(BACKBONE, pretrained=True, regression_out=3)
    model = model.to(device)

    # Optionally freeze early layers for faster convergence
    # Intent: fine-tune only the final block(s) (the ones after TARGET_LAYER) and the heads. 
    # This can speed up training and reduce overfitting for small datasets
    for name, param in model.backbone.named_parameters():
        if TARGET_LAYER not in name:  # keep last block trainable
            param.requires_grad = False

    # Pass only trainable parameters to the optimizer, rest are freezed in previous line
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=PATIENCE_LR_SCHED)

    # Save the best validation proxy metric
    best_val_loss = float('inf')
    # Place where best model for this backbone is saved
    ckpt_path = f"best_{BACKBONE}.pth"

    for epoch in range(1, EPOCHS + 1):
        train_loss = train_epoch(model, train_loader, optimizer, epoch, scheduler=None)
        # Runs validation and returns two dictionaries with classification and regression metric
        reg_metrics, cls_metrics = evaluate(model, val_loader)

        # Compose a val_loss for scheduler
        # Use average of normalized regression RMSEs + (1 - roc_auc) as proxy
        val_loss_proxy = 0.0
        cnt = 0
        # Averages the available RMSEs across the three regression targets (ignoring missing targets)
        for m in reg_metrics.values():
            if m['RMSE'] is not None:
                val_loss_proxy += m['RMSE']
                cnt += 1
        if cnt > 0:
            val_loss_proxy = val_loss_proxy / cnt
        else:
            val_loss_proxy = 0.0

        # incorporate classification metric
        if cls_metrics['roc_auc'] is not None:
            # penalize poor classification AUC
            val_loss_proxy += (1.0 - cls_metrics['roc_auc'])

        # Instructs the LR scheduler to see the current val_loss_proxy and adjust the LR if plateau detected, stops otherwise
        scheduler.step(val_loss_proxy)

        # Prints the metrics
        print(f"Epoch {epoch}/{EPOCHS} | train_loss: {train_loss:.4f}")
        print(" Val regression metrics:")
        for k, v in reg_metrics.items():
            print(f"  {k} => MAE: {v['MAE']}, RMSE: {v['RMSE']}, R2: {v['R2']}")
        print(" Val classification metrics:", cls_metrics)

        # Save best model by val_loss_proxy
        if val_loss_proxy < best_val_loss:
            best_val_loss = val_loss_proxy
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'config': cfg
            }, ckpt_path)
            print(" Saved best model.")

    # Final evaluation on test set
    checkpoint = torch.load(ckpt_path, map_location=device)        # loads the checkpoint of the tensor
    model.load_state_dict(checkpoint['model_state_dict'])          # restore the trained weights
    reg_metrics, cls_metrics = evaluate(model, test_loader)        # get final test metrics
    print("Final test regression metrics:")
    for k, v in reg_metrics.items():
        print(f"  {k} => MAE: {v['MAE']}, RMSE: {v['RMSE']}, R2: {v['R2']}")
    print("Final test classification metrics:", cls_metrics)

    # Example: produce Grad-CAM for 5 test children
    print("Generating Grad-CAM examples for a few test samples...")
    grad_cam = GradCAM(model)
    model.eval()
    it = iter(test_loader)
    samples = next(it)
    images = samples['images']  # (B, n, C, H, W)
    # Picks the first batch from the test loader (samples) and the first child in that batch (images[0:1])
    first_child_images = images[0:1].to(device)           # (1,n,C,H,W)
    cam_mask = grad_cam.generate_cam(first_child_images)  # (IMG_SIZE, IMG_SIZE)
    # overlay on first image (denormalize internally)
    overlay_cam_on_image(first_child_images[0, 0], cam_mask, alpha=0.5)   # draws the heatmap over the first image from the child
    plt.show()
    grad_cam.remove_hooks()

# Train the ResNet50 Model

In [None]:
# Train ResNet50 
main(BACKBONE = "resnet50")

In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()


# Train the EfficientNet-B3 Model

In [None]:
# Train EfficientNet-B3 Model
main(BACKBONE = "efficientnet_b3")