In [5]:
# coding=utf-8
import os
import time
import math

import torch
from tqdm import tqdm
from torchvision.models import resnet34, resnet50, vgg11, resnet152, vgg19, vgg16, resnext101_32x8d, resnet18, \
    densenet201
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
from visdom import Visdom
import numpy as np
from torch.nn.parameter import Parameter, UninitializedParameter
import torch.nn.functional as F

In [8]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        self.shape = 0

    def forward(self, x):
        self.shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, self.shape)

In [9]:
class SelfDefineModel(nn.Module):
    def __init__(self):
        super(SelfDefineModel, self).__init__()
        self.trained_model = vgg16(pretrained=True)
        self.modelA = nn.Sequential(*list(self.trained_model.children())[0][0:17],
                                    nn.Conv2d(256, 64, kernel_size=(1, 1)),
                                    nn.Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2),
                                    nn.Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
                                    nn.Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
                                    nn.Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
                                    nn.AdaptiveAvgPool2d(output_size=(7, 7)),
                                    Flatten(),
                                    nn.Dropout(p=0.5),
                                    nn.Linear(in_features=3136, out_features=512, bias=True),
                                    # ChannelWise(3136, 512, device=torch.device('cuda'), input_split_size=4),
                                    nn.Dropout(p=0.4),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(512, 3),
                                    nn.ReLU(inplace=True)
                                    )
        self.modelS = nn.Sequential(*list(self.trained_model.children())[0][0:23],
                                    nn.MaxPool2d(kernel_size=2, stride=1, padding=1, dilation=1, ceil_mode=False),
                                    nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2),
                                    nn.ReLU(inplace=True),
                                    nn.MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False),
                                    nn.ConvTranspose2d(in_channels=512, out_channels=256, padding=2, kernel_size=8,
                                                       stride=2),
                                    nn.AdaptiveAvgPool2d(output_size=(7, 7)),
                                    Flatten(),
                                    nn.Linear(in_features=12544, out_features=512, bias=True),
                                    # ChannelWise(12544, 512, device=torch.device('cuda'), input_split_size=8),
                                    nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.5),
                                    nn.Linear(512, 128),
                                    nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.4),
                                    nn.Linear(128, 16),
                                    nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.3),
                                    nn.Linear(16, 3)
                                    )
        self.vgg = nn.Sequential(*list(self.trained_model.children())[:-1],  # 测试一下输出维度[b, 512, 1, 1]
                                 Flatten(),
                                 nn.Linear(25088, 512),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.4),
                                 nn.Linear(512, 128),
                                 nn.Dropout(p=0.3),
                                 nn.ReLU(),
                                 nn.Linear(128, 3)
                                 )
        self.model3 = nn.Sequential(
            nn.Linear(18, 3)
        )
        self.trained_model2 = resnet34(pretrained=True)  # .to(device)
        self.res34 = nn.Sequential(*list(self.trained_model2.children())[:-1],  # 测试一下输出维度[b, 512, 1, 1]
                                   Flatten(),
                                   nn.Dropout(p=0.5),
                                   nn.Linear(512, 128),
                                   nn.Dropout(p=0.4),
                                   nn.ReLU(),
                                   nn.Linear(128, 9)
                                   )

    def forward(self, input1):
        x1 = self.modelA(input1)
        x2 = self.modelS(input1)
        x3 = self.vgg(input1)
        x4 = self.res34(input1)
        output1 = torch.cat([x1, x2, x3, x4], dim=1)
        output1 = self.model3(output1)
        return output1

In [10]:
a = SelfDefineModel()

In [17]:
a.modelA[-2]

Linear(in_features=512, out_features=3, bias=True)

In [13]:
for f in a.modelA[0].parameters():
    print(f)

Parameter containing:
tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],
          [-5.8312e-01,  3.5655e-01,  7.6566e-01],
          [-6.9022e-01, -4.8019e-02,  4.8409e-01]],

         [[ 1.7548e-01,  9.8630e-03, -8.1413e-02],
          [ 4.4089e-02, -7.0323e-02, -2.6035e-01],
          [ 1.3239e-01, -1.7279e-01, -1.3226e-01]],

         [[ 3.1303e-01, -1.6591e-01, -4.2752e-01],
          [ 4.7519e-01, -8.2677e-02, -4.8700e-01],
          [ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],


        [[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],
          [-4.2805e-01, -2.4349e-01,  2.4628e-01],
          [-2.5066e-01,  1.4177e-01, -5.4864e-03]],

         [[-1.4076e-01, -2.1903e-01,  1.5041e-01],
          [-8.4127e-01, -3.5176e-01,  5.6398e-01],
          [-2.4194e-01,  5.1928e-01,  5.3915e-01]],

         [[-3.1432e-01, -3.7048e-01, -1.3094e-01],
          [-4.7144e-01, -1.5503e-01,  3.4589e-01],
          [ 5.4384e-02,  5.8683e-01,  4.9580e-01]]],


        [[[ 1.7715e-01,  5.2149e-01,  9.8740