<a href="https://colab.research.google.com/github/abhijith-07/food-detection-resnet/blob/main/food_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from dropblock import DropBlock2D

In [14]:
class self_attention(nn.Module):

    def __init__(self, in_channels, out_channels, dk, dq, dv, Nh):
        super(self_attention, self).__init__()
        self.Cin = in_channels
        self.Cout = out_channels
        self.dq = dq
        self.dk = dk
        self.dv = dv
        self.Nh = Nh

        self.k = int(self.dk * self.Cin)
        self.q = int(self.dq * self.Cin)
        self.v = int(self.dv * self.Cin)

        self.kqv_conv = nn.Sequential(
            nn.Conv2d(self.Cin, self.k+self.q+self.v, kernel_size=1, stride=1, padding=0),
            #nn.BatchNorm2d(self.k+self.q+self.v,self.k+self.q+self.v)
        )
        self.attn = nn.Conv2d(self.v, self.Cout, kernel_size=1, stride=1)

    def split_heads_2d(self, x, Nh):
        batch, channels, height, width = x.size()
        ret_shape = (batch, Nh, channels // Nh, height, width)
        split = torch.reshape(x, ret_shape)
        return split

    #shape of flat_q: (N, Nh, dq//Nh, H*W)
    #shape of q:      (N, Nh, dq//Nh, H, W)
    def compute_flat_qkv(self, x, dq, dk, dv, Nh):
        qkv = self.kqv_conv(x)
        N, _, H, W = qkv.size()
        q, k, v = torch.split(qkv, [dq, dk, dv], dim=1)
        q = self.split_heads_2d(q, Nh)
        k = self.split_heads_2d(k, Nh)
        v = self.split_heads_2d(v, Nh)

        dkh = dk // Nh
        q *= dkh ** -0.5
        flat_q = torch.reshape(q, (N, Nh, dq // Nh, H * W))
        flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W))
        flat_v = torch.reshape(v, (N, Nh, dv // Nh, H * W))

        return flat_q, flat_k, flat_v, q, k, v

    def forward(self, inputs):
        batch, N, H, W = inputs.shape
        #print(inputs.shape)
        flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(inputs, self.q, self.k,self.v,self.Nh)
        #print(flat_q.shape)
        logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
        weights = F.softmax(logits, dim=1)
        #print(weights.shape)
        #result = weights.cpu().detach().numpy()
        #np.save("visual/matrix"+str(H), result)
        #print(weights.shape)
        attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
        attn_out = torch.reshape(attn_out, (batch, self.Nh, self.v // self.Nh, H, W))
        #print(attn_out.shape)
        attn_out = torch.reshape(attn_out, (batch, self.Nh * (self.v // self.Nh), H, W))
        #print(attn_out.shape)
        attn_out = self.attn(attn_out)
        #print(attn_out.shape)

        return attn_out


In [15]:
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                              stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5,
                                 momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

In [16]:
class PRENet(nn.Module):
    def __init__(self, model, feature_size, classes_num):
        super(PRENet, self).__init__()

        self.features = model

        self.num_ftrs = 2048 * 1 * 1
        self.elu = nn.ELU(inplace=True)

        self.dk = 0.5
        self.dq = 0.5
        self.dv = 0.5
        self.Nh = 8


        self.classifier_concat = nn.Sequential(
            nn.BatchNorm1d(1024 * 5),
            nn.Linear(1024 * 5, feature_size),
            nn.BatchNorm1d(feature_size),
            nn.ELU(inplace=True),
            nn.Linear(feature_size, classes_num),
        )

        self.conv_block0 = nn.Sequential(
            BasicConv(self.num_ftrs // 8, feature_size, kernel_size=1, stride=1, padding=0, relu=True),
            BasicConv(feature_size, self.num_ftrs // 2, kernel_size=3, stride=1, padding=1, relu=True)
        )
        self.classifier0 = nn.Sequential(
            nn.BatchNorm1d(self.num_ftrs // 2),
            nn.Linear(self.num_ftrs // 2, feature_size),
            nn.BatchNorm1d(feature_size),
            nn.ELU(inplace=True),
            nn.Linear(feature_size, classes_num),
        )

        self.conv_block1 = nn.Sequential(
            BasicConv(self.num_ftrs//4, feature_size, kernel_size=1, stride=1, padding=0, relu=True),
            BasicConv(feature_size, self.num_ftrs//2, kernel_size=3, stride=1, padding=1, relu=True)
        )
        self.classifier1 = nn.Sequential(
            nn.BatchNorm1d(self.num_ftrs//2),
            nn.Linear(self.num_ftrs//2, feature_size),
            nn.BatchNorm1d(feature_size),
            nn.ELU(inplace=True),
            nn.Linear(feature_size, classes_num),
        )

        self.conv_block2 = nn.Sequential(
            BasicConv(self.num_ftrs//2, feature_size, kernel_size=1, stride=1, padding=0, relu=True),
            BasicConv(feature_size, self.num_ftrs//2, kernel_size=3, stride=1, padding=1, relu=True)
        )
        self.classifier2 = nn.Sequential(
            nn.BatchNorm1d(self.num_ftrs//2),
            nn.Linear(self.num_ftrs//2, feature_size),
            nn.BatchNorm1d(feature_size),
            nn.ELU(inplace=True),
            nn.Linear(feature_size, classes_num),
        )

        self.conv_block3 = nn.Sequential(
            BasicConv(self.num_ftrs, feature_size, kernel_size=1, stride=1, padding=0, relu=True),
            BasicConv(feature_size, self.num_ftrs//2, kernel_size=3, stride=1, padding=1, relu=True)
        )
        self.classifier3 = nn.Sequential(
            nn.BatchNorm1d(self.num_ftrs//2),
            nn.Linear(self.num_ftrs//2, feature_size),
            nn.BatchNorm1d(feature_size),
            nn.ELU(inplace=True),
            nn.Linear(feature_size, classes_num),
        )

        self.Avgmax = nn.AdaptiveMaxPool2d(output_size=(1,1))

        self.attn1_1 = self_attention(self.num_ftrs // 2,self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)
        self.attn2_2 = self_attention(self.num_ftrs // 2,self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)
        self.attn3_3 = self_attention(self.num_ftrs // 2,self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)

        '''
        self.attn1_2 = layer_self_attention(self.num_ftrs // 2,self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)
        self.attn1_3 = layer_self_attention(self.num_ftrs // 2,self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)
        self.attn2_3 = layer_self_attention(self.num_ftrs // 2,self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)

        self.attn2_1 = layer_self_attention(self.num_ftrs // 2, self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)
        self.attn3_1 = layer_self_attention(self.num_ftrs // 2, self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)
        self.attn3_2 = layer_self_attention(self.num_ftrs // 2, self.num_ftrs // 2, self.dk, self.dq, self.dv, self.Nh)
        '''

        self.sconv1 = nn.Conv2d((self.num_ftrs // 2), self.num_ftrs // 2, kernel_size= 3, padding= 1)
        self.sconv2 = nn.Conv2d((self.num_ftrs // 2), self.num_ftrs // 2, kernel_size= 3, padding= 1)
        self.sconv3 = nn.Conv2d((self.num_ftrs // 2), self.num_ftrs // 2, kernel_size= 3, padding= 1)
        self.drop_block = DropBlock2D(block_size=3, drop_prob=0.5)

    def forward(self, x, label):
        xf1, xf2, xf3, xf4, xf5, xn = self.features(x)
        batch_size, _, _, _ = x.shape

        #get feature pyramid
        xl1 = self.conv_block1(xf3)
        xl2 = self.conv_block2(xf4)
        xl3 = self.conv_block3(xf5)

        xk1 = self.Avgmax(xl1)
        xk1 = xk1.view(xk1.size(0), -1)
        xc1 = self.classifier1(xk1)

        xk2 = self.Avgmax(xl2)
        xk2 = xk2.view(xk2.size(0), -1)
        xc2 = self.classifier2(xk2)

        xk3 = self.Avgmax(xl3)
        xk3 = xk3.view(xk3.size(0), -1)
        xc3 = self.classifier3(xk3)


        if label:
            # xs1_2 means that using x2 to strength x1
            #(batch, 1024, 56, 56)
            xs1 = self.attn1_1(xl1)
            #xs1_2 = self.attn1_2(xl1, xl2)
            #xs1_3 = self.attn1_3(xl1, xl3)
            # (batch, 1024, 28, 28)
            xs2 = self.attn1_1(xl2)
            #xs2_3 = self.attn2_3(xl2, xl3)
            #xs2_1 = self.attn2_1(xl2, xl1)
            # (batch, 1024, 14, 14)
            xs3 = self.attn1_1(xl3)
            #xs3_1 = self.attn2_1(xl3, xl1)
            #xs3_2 = self.attn2_1(xl3, xl2)

            #xr1 = self.drop_block(self.sconv1(torch.cat([xs1,xs1_2,xs1_3], dim=1)))
            #xr2 = self.drop_block(self.sconv2(torch.cat([xs2,xs2_3,xs2_1], dim=1)))
            #xr3 = self.drop_block(self.sconv3(torch.cat([xs3,xs3_1,xs3_2], dim=1)))
            xr1 = self.drop_block(self.sconv1(xs1))
            xr2 = self.drop_block(self.sconv2(xs2))
            xr3 = self.drop_block(self.sconv3(xs3))

            xm1 = self.Avgmax(xr1)
            xm1 = xm1.view(xm1.size(0), -1)
            #print(np.argmax(F.softmax(xm1, dim=1).cpu().detach().numpy(),axis=1))
            #input()

            xm2 = self.Avgmax(xr2)
            xm2 = xm2.view(xm2.size(0), -1)
            #print(np.argmax(F.softmax(xm2, dim=1).cpu().detach().numpy(),axis=1))
            #input()

            xm3 = self.Avgmax(xr3)
            xm3 = xm3.view(xm3.size(0), -1)
            #print(np.argmax(F.softmax(xm3, dim=1).cpu().detach().numpy(),axis=1))
            #input()

            x_concat = torch.cat((xm1, xm2, xm3, xn), dim=1)
            x_concat = self.classifier_concat(x_concat)
        else:
            x_concat = torch.cat((xk1, xk2, xk3, xn), dim=1)
            x_concat = self.classifier_concat(x_concat)

        #get origal feature vector


        return xk1, xk2, xk3, x_concat, xc1, xc2, xc3


In [20]:
import cv2
import os

In [23]:
categories = os.listdir('dataset/train_set')
categories

['bisibelebath',
 'biriyani',
 'butternaan',
 'chaat',
 '.ipynb_checkpoints',
 'chappati',
 'dosa']

In [35]:
img_size = 256

def get_data(data_dir):
    print("Data Dir:", data_dir)
    features = []
    labels = []
    for category in categories:
        path = os.path.join(data_dir, category)
        class_num = categories.index(category)
        print("Class number: ", class_num)
        for img in os.listdir(path):
            try:
                img_arr = cv2.imread(os.path.join(path, img))
                resized_arr = cv2.resize(img_arr, (img_size, img_size))

                features.append(resized_arr)
                labels.append(class_num)
            except Exception as e:
                print(e)
    return np.array(features), np.array(labels)

In [41]:
(train_features, train_labels) = get_data('dataset/train_set')
train_features.shape, train_labels.shape

Data Dir: dataset/train_set
Class number:  0
Class number:  1
Class number:  2
Class number:  3
Class number:  4
Class number:  5
Class number:  6


((420, 256, 256, 3), (420,))

In [42]:
(test_features, test_labels) = get_data('dataset/train_set')
test_features.shape, test_labels.shape

Data Dir: dataset/train_set
Class number:  0
Class number:  1
Class number:  2
Class number:  3
Class number:  4
Class number:  5
Class number:  6


((420, 256, 256, 3), (420,))

In [45]:
try:
  resnet_model = models.resnet50(pretrained=True)
except Exception as e:
  print("Exception: ", e)



In [46]:
num_class = 6
prenet_model = PRENet(resnet_model, 512, num_class)