In [9]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
from torchvision import transforms, models
import torch.nn as nn
from torchinfo import summary
from torch.optim import Adam
from torchvision.ops import nms, RoIAlign

In [10]:
import sys
import os
notebook_dir= os.path.dirname(os.path.abspath("__file__"))
project_root= os.path.abspath(os.path.join(notebook_dir, '..'))

src_path= os.path.join(project_root, 'src')
if src_path not in sys.path:
    sys.path.append(src_path)
from config import data_dir, images_train_dir, images_val_dir, labels_train_dir, labels_val_dir, artifacts_dir
import config
from preprocessing import FaceDataset, generate_anchor_boxes, calculate_iou
from utils import draw_image_with_box, visualize_anchors_and_gt, decode_predictions, decode_deltas, smooth_l1_loss, bbox_transform
from models import RPN

In [11]:
NUM_CLASSES= 2
ROI_SIZE= 7
SPATIAL_SCALE= 1/8.0
ROI_PER_IMG= 128
POS_FRACTION= 0.5


In [12]:
class SecondStage(nn.Module):
    def __init__(self, in_channels= 512, roi_size= ROI_SIZE, num_classes= NUM_CLASSES):
        super().__init__()
        self.roi_align= RoIAlign(
            output_size= roi_size,
            spatial_scale= SPATIAL_SCALE,
            sampling_ratio= -1,
            aligned= True
        )

        self.fc= nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_channels * roi_size * roi_size, 1024),
            nn.ReLU(inplace= True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace= True)
        )
        self.cls_head= nn.Linear(1024, num_classes)
        self.reg_head= nn.Linear(1024, 4 * num_classes)
    
    def forward(self, feature_maps, all_proposals, targets= None):
        roi_batch = []
        for b, prop in enumerate(all_proposals):
            idx = torch.full((prop.shape[0], 1), b, device=prop.device, dtype=prop.dtype)
            roi_batch.append(torch.cat([idx, prop], dim=1))
        roi_batch = torch.cat(roi_batch, dim=0)          # (N,5)

        pooled = self.roi_align(feature_maps, roi_batch)   # (N,C,7,7)
        x = self.fc(pooled)                              # (N,1024)
        cls_logits = self.cls_head(x)                    # (N,2)
        reg_deltas = self.reg_head(x)                    # (N,8)

        return self.losses(cls_logits, reg_deltas, roi_batch, targets)
    
    def losses(self, cls_logits, reg_deltas, roi_batch, targets):
        labels_all, reg_targets_all = [], []

        for batch, target in enumerate(targets):
            mask = roi_batch[:, 0] == batch
            proposals = roi_batch[mask, 1:]
            iou = calculate_iou(proposals, target['boxes'])
            best_iou, best_idx = iou.max(dim=1)

            pos = best_iou >= 0.5
            labels_b = torch.zeros(proposals.shape[0], dtype=torch.long, device=proposals.device)
            labels_b[pos] = 1
            labels_all.append(labels_b)

            # regression targets for positives only
            reg_t = bbox_transform(proposals[pos], target['boxes'][best_idx[pos]])
            reg_targets_all.append(reg_t)

        labels_all = torch.cat(labels_all)          # (N,)
        reg_targets_all = torch.cat(reg_targets_all)  # (N_pos, 4)

        # subsample 128 RoIs
        pos_idx = torch.where(labels_all == 1)[0]
        neg_idx = torch.where(labels_all == 0)[0]
        num_pos = max(8, min(len(pos_idx), int(ROI_PER_IMG * POS_FRACTION)))
        num_neg = min(len(neg_idx), 128 - num_pos)
        keep = torch.cat([pos_idx[:num_pos], neg_idx[:num_neg]])

        cls_logits, labels = cls_logits[keep], labels_all[keep]
        reg_deltas = reg_deltas[keep]

        # slice regression targets for the kept positives
        pos_mask = labels == 1
        reg_targets = reg_targets_all[:num_pos]  # because we built only positives
        cls_loss = nn.functional.cross_entropy(cls_logits, labels)
        if pos_mask.sum():
            reg_loss = smooth_l1_loss(reg_deltas[pos_mask, 4:8], reg_targets).mean()
        else:
            torch.tensor(0., device= labels.device)
        return {'cls_loss': cls_loss, 'reg_loss': reg_loss}

In [13]:
resnet50_backbone= torch.load(artifacts_dir + "resnet50_backbone.pth")
state_dict= torch.load(artifacts_dir + "rpn_10epchs_1e-4lr_42anchgt_wghts_bigger_scales.pth")
rpn_model= RPN()
rpn_model.load_state_dict(state_dict)
for p in resnet50_backbone.parameters():
    p.requires_grad= False
for p in rpn_model.parameters():
    p.requires_grad= False

In [14]:
valid_img_extensions= ('.jpg', '.jpeg', '.png')
all_images= [
    img for img in os.listdir(images_train_dir)
    if img.lower().endswith(valid_img_extensions) and os.path.exists(os.path.join(labels_train_dir, img.rsplit('.', 1)[0] + ".txt"))
]
all_images[1:5]

['998faa48943fce6f.jpg',
 'a228f997057aa291.jpg',
 '49fe432784afea63.jpg',
 '0106e273d2de08be.jpg']

In [15]:
train_images, val_images= train_test_split(all_images[:1000], test_size= 0.2, random_state= 42)
transforms= transforms.Compose([
    transforms.Resize((224,224)),
    transforms.Normalize(mean= [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset= FaceDataset(image_dir= images_train_dir, label_dir= labels_train_dir, image_list= train_images, transform= transforms)
val_dataset= FaceDataset(image_dir= images_train_dir, label_dir= labels_train_dir, image_list= val_images, transform= transforms)
train_loader= DataLoader(train_dataset, batch_size= config.BATCH_SIZE, shuffle= True)
val_loader= DataLoader(val_dataset, batch_size= config.BATCH_SIZE, shuffle= False)


In [None]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'
second_stage= SecondStage().to(device)

rpn_model.to(device)
resnet50_backbone.to(device)
all_anchors = generate_anchor_boxes(
    config.FEATURE_MAP_SHAPE,
    config.ANCHOR_SCALES_2,
    config.ANCHOR_RATIOS_2,
    config.ANCHOR_STRIDE,
    config.NUM_ANCHORS_PER_LOC
).to(device)
optimizer= Adam(second_stage.parameters(), lr= config.LEARNING_RATE)

for epoch in range(config.NUM_EPOCHS):
    second_stage.train()
    train_loss_history, val_loss_history= [], []
    for batch in train_loader:
        images = batch['image'].to(device)
        gt_boxes = [box.to(device) for box in batch['boxes']]

        with torch.no_grad():
            feat = resnet50_backbone(images)          # (B,C,H,W)
            cls_logits, reg_deltas = rpn_model(feat) # (B,N,1) & (B,N,4)

            # decode to list of proposals
            proposals_list = []
            for b in range(images.shape[0]):
                scores_b = cls_logits[b].squeeze(-1)
                deltas_b = reg_deltas[b]
                _, prop_b = decode_predictions(
                    all_anchors, scores_b, deltas_b,
                    pre_nms_topk=1000, post_nms_topk=300,
                    cls_score_threshold=0.8, nms_threshold=0.7
                )
                proposals_list.append(prop_b)

        losses = second_stage(feat, proposals_list,
                              [{'boxes': b} for b in gt_boxes])
        loss = losses['cls_loss'] + losses['reg_loss']
        train_loss_history.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch + 1}/{config.NUM_EPOCHS}: Train Objectness Loss: {losses['cls_loss'].item():.4f} | Train Reg Loss: {losses['reg_loss'].item():.4f}")
    second_stage.eval()
    val_cls_loss, val_reg_loss, val_samples= 0.0, 0.0, 0.0
    with torch.no_grad():
        for batch in val_loader:
            images= batch['image'].to(device)
            gt_boxes= [box.to(device) for box in batch['boxes']]

            features= resnet50_backbone(images)
            cls_logits, reg_deltas= rpn_model(features)

            proposals_list= []
            for b in range(images.shape[0]):
                scores_b= cls_logits[b].squeeze(-1)
                deltas_b= reg_deltas[b]
                _, proposals_b= decode_predictions(
                    all_anchors, scores_b, deltas_b,
                    pre_nms_topk= 1000, post_nms_topk=300,
                    cls_score_threshold= 0.8, nms_threshold= 0.7
                )
                proposals_list.append(proposals_b)
            losses= second_stage(features, proposals_list,
                                 [{'boxes': box} for box in gt_boxes])
            
            val_cls_loss+= losses['cls_loss'].item()
            val_reg_loss+= losses['reg_loss'].item()
            val_samples+= 1
        val_cls_loss/= val_samples
        val_reg_loss/= val_samples
        print(f"Val Objectness Loss: {val_cls_loss:.4f} | Val Regression Loss {val_reg_loss:.4f}")
    

Epoch 1/10: Train Objectness Loss: 0.6639 | Train Reg Loss: 0.0141
Val Objectness Loss: 0.6503 | Val Regression Loss 0.0133
Epoch 2/10: Train Objectness Loss: 0.5768 | Train Reg Loss: 0.0140
Val Objectness Loss: 0.6011 | Val Regression Loss 0.0125
Epoch 3/10: Train Objectness Loss: 0.5371 | Train Reg Loss: 0.0105
Val Objectness Loss: 0.5635 | Val Regression Loss 0.0111
Epoch 4/10: Train Objectness Loss: 0.5173 | Train Reg Loss: 0.0086
Val Objectness Loss: 0.5471 | Val Regression Loss 0.0102
Epoch 5/10: Train Objectness Loss: 0.5664 | Train Reg Loss: 0.0105
Val Objectness Loss: 0.5536 | Val Regression Loss 0.0095
Epoch 6/10: Train Objectness Loss: 0.5135 | Train Reg Loss: 0.0085
Val Objectness Loss: 0.5743 | Val Regression Loss 0.0091
Epoch 7/10: Train Objectness Loss: 0.5587 | Train Reg Loss: 0.0065
Val Objectness Loss: 0.5190 | Val Regression Loss 0.0086
Epoch 8/10: Train Objectness Loss: 0.4461 | Train Reg Loss: 0.0076
Val Objectness Loss: 0.5165 | Val Regression Loss 0.0084
Epoch 9/