In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import datetime


from tqdm.notebook import tqdm
from utils import *
from torch import optim
from torchvision.ops import roi_pool, nms
from pretrained import VGG_CNN_F, VGG_CNN_M_1024, VGG_VD_1024
from VOCdatasets import VOCDectectionDataset


In [2]:
class WSDDN_S(nn.Module):
    def __init__(self):
        super(WSDDN_S, self).__init__()
        self.pretrain_net = VGG_CNN_F()
        self.pretrain_net.load_mat()
        self.roi_output_size = (6, 6)
        
        self.fc6 = nn.Linear(6*6*256, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8c = nn.Linear(4096, 20)
        self.fc8d = nn.Linear(4096, 20)
        
    def forward(self, x, regions):
        #   x    : bs, c ,h, w
        # regions: bs, R, 4
        regions = [regions[0]] # roi_pool require [Tensor(K, 4)]
        R = len(regions[0])
        out = self.pretrain_net(x) # bs, 256， h/16, w/16
        out = roi_pool(out, regions, self.roi_output_size, 1.0/16)  # R, 256, 6, 6
        out = out.view(R, -1)
        out = F.relu(self.fc6(out))
        out = F.relu(self.fc7(out))
        # fc8x(out)   R, 20
        cls_score = F.softmax(self.fc8c(out), dim = 1)
        det_score = F.softmax(self.fc8d(out), dim = 0)
        combined = cls_score * det_score
        return combined
    

In [3]:
voc_07_trainval = VOCDectectionDataset("~/data/", 2007, 'trainval')


In [3]:
voc_07_trainval = VOCDectectionDataset("~/data/", 2007, 'trainval')
voc07_train_loader = data.DataLoader(voc_07_trainval, 1, shuffle=True)
wsddn_s = WSDDN_S().to(DEVICE)
wsddn_s.train()
optimizer = optim.SGD(wsddn_s.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)
bce_loss = nn.BCELoss(reduction="sum")
N = len(voc07_train_loader)


In [None]:
with open(LOG_PATH + "ssw_wsddn_s.txt", 'a') as fp:
    fp.writelines(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    
    for epoch in tqdm(range(EPOCHS), "Total"):
        epoch_loss = 0
        for img, gt_box, gt_target, regions in tqdm(voc07_train_loader, f"Epoch {epoch}"):
            optimizer.zero_grad()
            # img   : Tensor(1, 3, h, w)
            # gt_tar: Tensor(1, R_gt)
            # region: Tensor(1, R, 4)
            img = img.to(DEVICE)
            regions = regions.to(DEVICE)
            gt_target = gt_target.to(DEVICE)

            combined = wsddn_s(img, regions)
            image_level_cls_score = torch.sum(combined, dim=0) # y
            out = bce_loss(image_level_cls_score, gt_target[0])

            epoch_loss += out.item()
            out.backward()
            optimizer.step()

        print(f"Epoch {epoch} Loss is {epoch_loss/N}")
        fp.writelines(f"Epoch {epoch} Loss is {epoch_loss/N}")
        scheduler.step()

    torch.save(wsddn_s.state_dict(), SAVE_PATH+"ssw_wsddn_s.pt")

HBox(children=(FloatProgress(value=0.0, description='Total', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=5011.0, style=ProgressStyle(description_wid…


Epoch 0 Loss is 4.736465042090849


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=5011.0, style=ProgressStyle(description_wid…


Epoch 1 Loss is 4.651713859080983


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=5011.0, style=ProgressStyle(description_wid…


Epoch 2 Loss is 4.501806497918801


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=5011.0, style=ProgressStyle(description_wid…


Epoch 3 Loss is 4.363589323416272


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=5011.0, style=ProgressStyle(description_wid…


Epoch 4 Loss is 4.261662473104786


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=5011.0, style=ProgressStyle(description_wid…


Epoch 5 Loss is 4.179469488738088


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=5011.0, style=ProgressStyle(description_wid…

In [11]:
bcd_loss = nn.BCELoss(reduction="sum")

In [7]:
save_path = "../models/"
torch.save(wsddn_s.state_dict(), save_path+"ssw_wsddn_s.pt")

In [None]:
    for epoch in tqdm(range(EPOCHS), "Total"):
        epoch_loss = 0
        for img, gt_box, gt_target, regions, scores in tqdm(train_loader, f"Epoch {epoch}"):
        optimizer.zero_grad()
        # img   : Tensor(1, 3, h, w)
        # gt_tar: Tensor(1, R_gt)
        # region: Tensor(1, R, 4)
        img = img.to(DEVICE)
        regions = regions.to(DEVICE)
        gt_target = gt_target.to(DEVICE)
        if scores:
        scores = scores.to(DEVICE)

        combined = wsddn_s(img, regions, scores=scores)
        image_level_cls_score = torch.sum(combined, dim=0) # y
        out = bce_loss(image_level_cls_score, gt_target[0])

        epoch_loss += out.item()
        out.backward()
        optimizer.step()

        print(f"Epoch {epoch} Loss is {epoch_loss/N}")
    scheduler.step()
    torch.save(wsddn_s.state_dict(),
               SAVE_PATH + get_model_name(propose_way, args.year, "wsddn_s") + ".pt")

In [7]:
torch.Tensor([None])

TypeError: must be real number, not NoneType