In [1]:
import os
import torch


In [2]:
import torch.nn as nn
import torch.nn.functional as F
from model import Stage2Model, FaceModel, SelectNet_resnet, SelectNet
from helper_funcs import affine_crop, stage2_pred_softmax, calc_centroid, affine_mapback
import os
import torch

class ModelEnd2End(nn.Module):
    def __init__(self):
        super(ModelEnd2End, self).__init__()
        self.modelA = FaceModel()
        self.stn_model = SelectNet_resnet()
        self.modelC = Stage2Model()
        self.modelA.eval()
        self.stn_model.eval()
        self.modelC.eval()

    def forward(self, x, orig, orig_label=None):
        N, L, H, W = orig.shape
        stage1_pred = F.softmax(self.modelA(x), dim=1)
        assert stage1_pred.shape == (N, 9, 128, 128)
        theta = self.stn_model(stage1_pred)
        # List: [5x[torch.size(N, 2, 81, 81)], 1x [torch.size(N, 4, 81, 81)]]
        if orig_label is not None:
            parts, parts_labels, _ = affine_crop(orig, orig_label, theta_in=theta, map_location=x.device)
            # 如果没切到眉毛，就将切块位置换成对应眼睛
            # lbrow
            if(parts_labels[0].argmax(dim=1).sum() == 0):
                parts[:, 0] = parts[:, 2]
                parts_labels[0] = parts_labels[2]
            # rbrow
            if(parts_labels[1].argmax(dim=1).sum() == 0):
                parts[:, 1] = parts[:, 3]
                parts_labels[1] = parts_labels[3]
        else:  
            parts, _ = affine_crop(orig, label=None, theta_in=theta, map_location=x.device)
        stage2_pred = self.modelC(parts)
        softmax_stage2 = stage2_pred_softmax(stage2_pred) 
        final_pred = affine_mapback(softmax_stage2, theta, x.device)
        return final_pred, parts, softmax_stage2
                

    def load_pretrain(self, path, device):
        if len(path) == 0:
            print("ERROR! No state path!")
            raise RuntimeError
        elif len(path) == 1:
            path = path[0]
            print("load from" + path)
            state = torch.load(path, map_location=device)
            self.modelA.load_state_dict(state['model1'])
            self.stn_model.load_state_dict(state['select_net'])
            self.modelC.load_state_dict(state['model2'])
            self.modelA.eval()
            self.stn_model.eval()
            self.modelC.eval()
        elif len(path) == 2:
            # AB, C
            pathAB, pathC = path
            print("load from" + pathAB)
            print("load from" + pathC)
            stateAB = torch.load(pathAB, map_location=device)
            stateC = torch.load(pathC, map_location=device)
            self.modelA.load_state_dict(stateAB['model1'])
            self.stn_model.load_state_dict(stateAB['select_net'])
            self.modelC.load_state_dict(stateC['model2'])
            self.modelA.eval()
            self.stn_model.eval()
            self.modelC.eval()
        elif len(path) == 3:
            # A, B, C
            pathA, pathB, pathC = path
            print("load from" + pathA)
            print("load from" + pathB)
            print("load from" + pathC)
            stateA = torch.load(pathA, map_location=device)
            stateB = torch.load(pathB, map_location=device)
            stateC = torch.load(pathC, map_location=device)
            self.modelA.load_state_dict(stateA['model1'])
            self.stn_model.load_state_dict(stateB['select_net'])
            self.modelC.load_state_dict(stateC['model2'])
            self.modelA.eval()
            self.stn_model.eval()
            self.modelC.eval()


In [3]:
model= ModelEnd2End()
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [4]:
path_ABC = os.path.join("/home/yinzi/data4/STN-iCNN/checkpoints_ABC/ea0ac45c-0", 'best.pth.tar')
model.eval()
model.load_pretrain(path=[path_ABC], device=device)

load from/home/yinzi/data4/STN-iCNN/checkpoints_ABC/ea0ac45c-0/best.pth.tar


In [5]:
#  python3 3_end2end_tunning_all.py --batch_size 32 --cuda 8 --select_net 1 --pretrainA 1 --pretrainB 1 --pretrainC 0 --lr 0 --lr2 0.0025 --lr_s 0 --epoch 3000 --f1_eval 1
# ea0ac45c-0

In [6]:
from dataset import HelenDataset
from torchvision import transforms
from preprocess import ToPILImage, ToTensor, OrigPad, Resize
from torch.utils.data import DataLoader

testDataset = HelenDataset(txt_file='testing.txt',
                           root_dir="/data1/yinzi/datas",
                           parts_root_dir="/home/yinzi/data3/recroped_parts",
                           transform=  transforms.Compose([
                                    ToPILImage(),
                                    Resize((128, 128)),
                                    ToTensor(),
                                    OrigPad()
                                ])
                           )
dataloader = DataLoader(testDataset, batch_size=1,
                            shuffle=False, num_workers=4)

# F1 overall 为

In [7]:
import matplotlib.pyplot as plt
%matplotlib inline
import torch.nn.functional as F
from helper_funcs import F1Accuracy
f1 = F1Accuracy(num=9)
# f1_local = F1Accuracy(num=9)
model.eval()
for iter,batch in enumerate(dataloader):
    image, labels, orig, orig_label = batch['image'], batch['labels'], batch['orig'], batch['orig_label']
    pred, parts, stage2_pred =model(image, orig=orig, orig_label=None)
    pred_arg = pred.argmax(dim=1, keepdim=False).detach()
#     f1_local.collect(pred_arg, orig_label.argmax(dim=1, keepdim=False))
#     f1_now = f1_local.calc()
#     print(f"f1 score of {iter} is: {f1_now}")
#         for i in range(6):
#             plt.imshow(parts[0][i].permute(1,2,0).detach().cpu())
#             plt.pause(0.01)
#             plt.imshow(stage2_pred[i].argmax(dim=1, keepdim=False)[0].detach().cpu())
#             plt.pause(0.01)
#         plt.imshow(pred_arg[0])
#         plt.pause(0.01)
#         plt.imshow(orig_label.argmax(dim=1, keepdim=False)[0])
#         plt.pause(0.01)
    f1.collect(pred_arg, orig_label.argmax(dim=1, keepdim=False))
    
f1_accu = f1.calc()
print(f1_accu)

0.9000639554324783
