In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import datetime
import numpy as np
import torchvision


from sklearn.metrics import average_precision_score
from model import *
from tqdm.notebook import tqdm
from utils import *
from torch import optim
from torchvision.ops import roi_pool, nms
from pretrained import VGG_CNN_F, VGG_CNN_M_1024, VGG_VD_1024
from datasets import VOCDectectionDataset


In [7]:
class WSDDN_ALEX(nn.Module):
    def __init__(self):
        super(WSDDN_ALEX, self).__init__()        
        alexnet = torchvision.models.alexnet(pretrained=True)
        self.features = nn.Sequential(*list(alexnet.features._modules.values())[:-1])
        self.fc67 = nn.Sequential(*list(alexnet.classifier._modules.values())[:-1])
        
        self.roi_output_size = (6, 6)

        self.fc8c = nn.Linear(4096, 20)
        self.fc8d = nn.Linear(4096, 20)
        self.cls_softmax = nn.Softmax(dim=1)
        self.det_softmax = nn.Softmax(dim=0)
        
    def forward(self, x, regions, scores=None):
        #   x    : bs, c ,h, w
        # regions: bs, R, 4
        #  scores: bs, R
        regions = [regions[0]] # roi_pool require [Tensor(K, 4)]
        R = len(regions[0])
        features = self.features(x) # bs, 256， h/16, w/16
        pool_features = roi_pool(features, regions, self.roi_output_size, 1.0/16).view(R, -1) # R, 256, 6, 6
        
        if scores is not None:
            pool_features = pool_features * (10 * scores[0] + 1)

        fc7 = self.fc67(pool_features)
        # fc8x(out)   R, 20
        cls_score = self.cls_softmax(self.fc8c(fc7))
        det_score = self.det_softmax(self.fc8d(fc7)) * 2
        combined = cls_score * det_score

        return combined, fc7
    def spatial_regulariser(self, regions, fc7, combine_scores, labels):
        iou_th = 0.6
        K = 10 #  top 10 scores
        reg = 0
        cls_num = 0
        for c in range(20):
            # extract positive ones
            if labels[c].item() == 0:
                continue
            cls_num += 1
            topk_scores, topk_filter = combine_scores[:, c].topk(K, dim=0)
            topk_boxes = regions[topk_filter]
            topk_fc7 = fc7[topk_filter]

            # get box with the best box | iou > 0.6
            iou_mask = one2allbox_iou(topk_boxes[0:1, :], topk_boxes).view(K)
            iou_mask = (iou_mask > iou_th).float()

            fc7_diff = topk_fc7 - topk_fc7[0]
            score_diff = topk_scores.detach().view(K, 1)

            diff = fc7_diff * score_diff

            reg += 0.5 * (torch.pow(diff, 2).sum(1) * iou_mask).sum()            
        return reg/cls_num

In [8]:
propose_way = "edge_box"
voc_07_trainval = VOCDectectionDataset("~/data/", 2007, 'trainval', region_propose=propose_way)
train_loader = data.DataLoader(voc_07_trainval, 1, shuffle=True)

wsddn = WSDDN_ALEX().to(DEVICE)
wsddn.train()

optimizer = optim.Adam(wsddn.parameters(), lr=LR, weight_decay=WD)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)
bce_loss = nn.BCELoss(reduction="sum")
N = len(train_loader)


In [None]:
alpha = 1e-3
x = 0
for epoch in tqdm(range(EPOCHS), "Total"):
    epoch_loss = 0
    y_pred = []
    y_true = []
    
    for img, gt_box, gt_target, regions, scores in tqdm(train_loader, f"Epoch {epoch}"):
        optimizer.zero_grad()
        # img   : Tensor(1, 3, h, w)
        # gt_tar: Tensor(1, R_gt)
        # region: Tensor(1, R, 4)
        img = img.to(DEVICE)
        regions = regions.to(DEVICE)
        gt_target = gt_target.to(DEVICE)
        if propose_way != "edge_box":
            scores = None
        else:
            scores = scores.to(DEVICE)
        combined, fc7 = wsddn(img, regions, scores=scores)
        
        image_level_cls_score = torch.sum(combined, dim=0) # y
        image_level_cls_score = torch.clamp(image_level_cls_score, min=0.0, max=1.0)

        reg = alpha * wsddn.spatial_regulariser(regions[0], fc7, combined, gt_target[0])  
        if reg.item() > 1:
            print(reg)
        loss = bce_loss(image_level_cls_score, gt_target[0])
        out = loss + reg
                
        if torch.isnan(loss):
            print(image_level_cls_score)
        
        y_pred.append(image_level_cls_score.detach().cpu().numpy().tolist())
        y_true.append(gt_target[0].detach().cpu().numpy().tolist())
       
        epoch_loss += out.item()

        out.backward()
        optimizer.step()
    cls_ap = []
    y_pred = np.array(y_pred)
    y_true = np.array(y_true)
    for i in range(20):
        cls_ap.append(average_precision_score(y_true[:,i], y_pred[:,i]))
    
    print(f"Epoch {epoch} classify AP is {str(cls_ap)}")
    print(f"Epoch {epoch} classify mAP is {str(sum(cls_ap)/20)}")
    print(f"Epoch {epoch} Loss is {epoch_loss/N}")
    print("-" * 10)
    scheduler.step()
    break


HBox(children=(FloatProgress(value=0.0, description='Total', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=5011.0, style=ProgressStyle(description_wid…


