In [1]:
import pandas as pd
import cv2
import numpy as np
import torch
import albumentations
import matplotlib.pyplot as plt
import glob
import math
from PIL import Image as Image
import torchvision
from torchvision import transforms
import sklearn.metrics
from sklearn.model_selection import StratifiedKFold
import torch.nn as nn
import torch.nn.functional as FT
import torch.optim as optim
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm as tqdm
from albumentations.core.transforms_interface import ImageOnlyTransform
from albumentations.augmentations import functional as F
import albumentations as A
device = torch.device("cuda")
import time
import random

In [6]:
from efficientnet_pytorch import EfficientNet
import torchvision.models as models

sigmoid = nn.Sigmoid()
class Swish(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * sigmoid(i)
        ctx.save_for_backward(i)
        return result
    
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = sigmoid(i)
        return grad_output * (sigmoid_i + i*(1 - sigmoid_i))
swish = Swish.apply
class Swish_module(nn.Module):
    def forward(self, x):
        return swish(x)
    
swish_layer = Swish_module()

def relu_fn(x):
    return swish_layer(x)

class GlobalAvgPool(nn.Module):
        def __init__(self):
            super(GlobalAvgPool, self).__init__()
        def forward(self, x):
            return x.view(*(x.shape[:-2]),-1).mean(-1)


class Seq_Ex_Block(nn.Module):
        def __init__(self, in_ch, r):
            super(Seq_Ex_Block, self).__init__()
            self.se = nn.Sequential(
                GlobalAvgPool(),
                nn.Linear(in_ch, in_ch//r),
                nn.ReLU(inplace=True),
                nn.Linear(in_ch//r, in_ch),
                nn.Sigmoid()
            )

        def forward(self, x):
            se_weight = self.se(x).unsqueeze(-1).unsqueeze(-1)
            #print(f'x:{x.sum()}, x_se:{x.mul(se_weight).sum()}')
            return x.mul(se_weight)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
      
class ClassifierNew(nn.Module):
    def __init__(self, inp = 2208, h1=1024, out = 102, d=0.35):
        super().__init__()
        self.ap = nn.AdaptiveAvgPool2d((1,1))
        self.mp = nn.AdaptiveMaxPool2d((1,1))
        self.fla = Flatten()
        self.bn0 = nn.BatchNorm1d(inp*2,eps=1e-05, momentum=0.1, affine=True)
        self.dropout0 = nn.Dropout(d)
        self.fc1 = nn.Linear(inp*2, h1)
        self.bn1 = nn.BatchNorm1d(h1,eps=1e-05, momentum=0.1, affine=True)
        self.dropout1 = nn.Dropout(d)
        self.fc2 = nn.Linear(h1, out)
        self.activation = nn.Softmax()
        
    def forward(self, x):
        ap = self.ap(x)
        mp = self.mp(x)
        x = torch.cat((ap,mp),dim=1)
        x = self.fla(x)
        x = self.bn0(x)
        x = self.dropout0(x)
        x = FT.relu(self.fc1(x))
        x = self.bn1(x)
        x = self.dropout1(x)         
        x = self.fc2(x)
        x = self.activation(x)
        return x
class EfficientNet_NeuralNet(nn.Module):
    def __init__(self, pretrained = True, Freeze_base = False, layers_freeze = None):
        super(EfficientNet_NeuralNet, self).__init__()
        
        self.cnn = EfficientNet.from_pretrained('efficientnet-b7')
        self.cnn._avg_pooling = nn.Identity()
        self.cnn._dropout = nn.Identity()
        self.cnn._swish = nn.Identity()
        if Freeze_base:
            if layers_freeze == None:
                for p in self.cnn.parameters():
                    p.requires_grad = False
            else:
                c = 0
                for p in self.cnn.parameters():
                    c+=1
                    if c < layers_freeze:
                        p.requires_grad = False
                    else:
                        p.requires_grad = True
        self.fc = ClassifierNew(2560, 1024, 4, 0.35)
        self.cnn._fc = nn.Identity()
    def forward(self, input):
        x = self.cnn.extract_features(input)
        x = self.fc(x)
        return x

In [15]:
model = EfficientNet_NeuralNet(True, True, 600).to(device)

Loaded pretrained weights for efficientnet-b7


In [16]:
from torchsummary import summary
summary(model, (3,224, 224))

MemoryEfficientSwish-645             [-1, 56, 1, 1]               0
        Identity-646             [-1, 56, 1, 1]               0
Conv2dStaticSamePadding-647           [-1, 1344, 1, 1]          76,608
        Identity-648           [-1, 1344, 7, 7]               0
Conv2dStaticSamePadding-649            [-1, 384, 7, 7]         516,096
     BatchNorm2d-650            [-1, 384, 7, 7]             768
     MBConvBlock-651            [-1, 384, 7, 7]               0
        Identity-652            [-1, 384, 7, 7]               0
Conv2dStaticSamePadding-653           [-1, 2304, 7, 7]         884,736
     BatchNorm2d-654           [-1, 2304, 7, 7]           4,608
MemoryEfficientSwish-655           [-1, 2304, 7, 7]               0
       ZeroPad2d-656         [-1, 2304, 11, 11]               0
Conv2dStaticSamePadding-657           [-1, 2304, 7, 7]          57,600
     BatchNorm2d-658           [-1, 2304, 7, 7]           4,608
MemoryEfficientSwish-659           [-1, 2304, 7, 7]               0


In [17]:
c =0 
for i in model.parameters():
    c+=1
    print(i.requires_grad)
print(c)

False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
Fals