In [15]:
import os
import torch


In [16]:
import torch.nn as nn
import torch.nn.functional as F
from model import Stage2Model, FaceModel, SelectNet_resnet, SelectNet
from helper_funcs import affine_crop, stage2_pred_softmax, calc_centroid, affine_mapback
import os
import torch

class ModelEnd2End(nn.Module):
    def __init__(self):
        super(ModelEnd2End, self).__init__()
        self.modelA = FaceModel()
        self.stn_model = SelectNet_resnet()
        self.modelC = Stage2Model()

    def forward(self, x, orig, inference=True, orig_label=None):
        N, L, H, W = orig.shape
        if inference:
            self.modelA.eval()
            self.stn_model.eval()
            self.modelC.eval()
            stage1_pred = F.softmax(self.modelA(x), dim=1)
            assert stage1_pred.shape == (N, 9, 128, 128)
            theta = self.stn_model(stage1_pred)
            parts, _ = affine_crop(orig, label=None, theta_in=theta, map_location=x.device)
            stage2_pred = self.modelC(parts)
            softmax_stage2 = stage2_pred_softmax(stage2_pred) 
            final_pred = affine_mapback(softmax_stage2, theta, x.device)
            return final_pred, parts, softmax_stage2
        
        elif orig_label is not None:
            stage1_pred = F.softmax(self.modelA(x), dim=1)
            assert stage1_pred.shape == (N, 9, 128, 128)
            theta = self.stn_model(F.relu(stage1_pred))
            cens = calc_centroid(orig_label)
            parts, parts_labels, theta = affine_crop(orig, orig_label, points=cens, map_location=x.device, floor=False)
            stage2_pred = self.modelC(parts)
            softmax_stage2 = stage2_pred_softmax(stage2_pred)
            final_pred = affine_mapback(softmax_stage2, theta, x.device)
            return final_pred, stage2_pred, parts, parts_labels

    def load_pretrain(self, path, device):
        if len(path) == 0:
            print("ERROR! No state path!")
            raise RuntimeError
        elif len(path) == 1:
            path = path[0]
            print("load from" + path)
            state = torch.load(path, map_location=device)
            self.modelA.load_state_dict(state['model1'])
            self.stn_model.load_state_dict(state['select_net'])
            self.modelC.load_state_dict(state['model2'])
        elif len(path) == 2:
            # AB, C
            pathAB, pathC = path
            print("load from" + pathAB)
            print("load from" + pathC)
            stateAB = torch.load(pathAB, map_location=device)
            stateC = torch.load(pathC, map_location=device)
            self.modelA.load_state_dict(stateAB['model1'])
            self.stn_model.load_state_dict(stateAB['select_net'])
            self.modelC.load_state_dict(stateC['model2'])
        elif len(path) == 3:
            # A, B, C
            pathA, pathB, pathC = path
            print("load from" + pathA)
            print("load from" + pathB)
            print("load from" + pathC)
            stateA = torch.load(pathA, map_location=device)
            stateB = torch.load(pathB, map_location=device)
            stateC = torch.load(pathC, map_location=device)
            self.modelA.load_state_dict(stateA['model1'])
            self.stn_model.load_state_dict(stateB['select_net'])
            self.modelC.load_state_dict(stateC['model2'])


In [17]:
model= ModelEnd2End()
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [18]:
# path_AB = os.path.join("/home/yinzi/data4/STN-iCNN/checkpoints_AB_res/23903dfc", 'best.pth.tar')
path_AB = os.path.join("/home/yinzi/data4/STN-iCNN/checkpoints_AB_res/e0de5954", 'best.pth.tar')

path_C = os.path.join("/home/yinzi/data4/STN-iCNN/checkpoints_C/c1f2ab1a", 'best.pth.tar')
# path_C = os.path.join("/home/yinzi/data3/stn-new/checkpoints_C/9b41a676", 'best.pth.tar')


# c1f2ab1a best_error_all 0.245
# python3 train_stage2.py  --batch_size 64 --cuda 6 --lr0 0.0008 --lr1 0.0008 --lr2 0.0008 --lr3 0.0008 --epochs 3000
# 1 不增 + 4增
model.eval()
model.load_pretrain(path=[path_AB, path_C], device=device)

load from/home/yinzi/data4/STN-iCNN/checkpoints_AB_res/e0de5954/best.pth.tar
load from/home/yinzi/data4/STN-iCNN/checkpoints_C/c1f2ab1a/best.pth.tar


In [19]:
# 把两个 checkpoint存在一起
state = {}
fname = os.path.join("/home/yinzi/data4/STN-iCNN/", 'before_end2end.pth.tar')
state['model'] = model.state_dict()
torch.save(state, fname)

In [20]:
from dataset import HelenDataset
from torchvision import transforms
from preprocess import ToPILImage, ToTensor, OrigPad, Resize
from torch.utils.data import DataLoader

testDataset = HelenDataset(txt_file='testing.txt',
                           root_dir="/data1/yinzi/datas",
                           parts_root_dir="/home/yinzi/data3/recroped_parts",
                           transform=  transforms.Compose([
                                    ToPILImage(),
                                    Resize((128, 128)),
                                    ToTensor(),
                                    OrigPad()
                                ])
                           )
dataloader = DataLoader(testDataset, batch_size=1,
                            shuffle=False, num_workers=4)

## 不考虑 AB 模型时(即使用groundtruth 指导切割)的F1 overall为0.9104

In [27]:
import matplotlib.pyplot as plt
%matplotlib inline
import torch.nn.functional as F
from helper_funcs import F1Accuracy
f1 = F1Accuracy(num=9)
model.eval()
for iter,batch in enumerate(dataloader):
    image, labels, orig, orig_label = batch['image'], batch['labels'], batch['orig'], batch['orig_label']
    pred, stage2_pred, parts, parts_labels =model(image, orig=orig, inference=False, orig_label=orig_label)
    pred_arg = pred.argmax(dim=1, keepdim=False).detach()
#     plt.imshow(pred_arg[0])
#     plt.pause(0.01)
#     plt.imshow(orig_label.argmax(dim=1, keepdim=False)[0])
#     plt.pause(0.01)
    f1.collect(pred_arg, orig_label.argmax(dim=1, keepdim=False))
    
f1_accu = f1.calc()
print(f1_accu)

0.9104196111000484


## 加上 AB 模型时(即采用预测信息指导切割)的F1 overall 为0.893

In [22]:
import matplotlib.pyplot as plt
%matplotlib inline
import torch.nn.functional as F
from helper_funcs import F1Accuracy
f1 = F1Accuracy(num=9)
model.eval()
for iter,batch in enumerate(dataloader):
    image, labels, orig, orig_label = batch['image'], batch['labels'], batch['orig'], batch['orig_label']
    pred, _, _=model(image, orig=orig, inference=True)
    pred_arg = pred.argmax(dim=1, keepdim=False).detach()
#     plt.imshow(pred_arg[0])
#     plt.pause(0.01)
#     plt.imshow(orig_label.argmax(dim=1, keepdim=False)[0])
#     plt.pause(0.01)
    f1.collect(pred_arg, orig_label.argmax(dim=1, keepdim=False))
    
f1_accu = f1.calc()
print(f1_accu)

0.8928423572029238
