## Import all dependencies 😊👌

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import torch
import torchvision
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
import torch.nn.functional as F # All functions that don't have any parameters
from torch.utils.data import DataLoader, Dataset # Gives easier dataset managment and creates mini batches
from torchvision.datasets import ImageFolder
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
from PIL import Image
import math

## Set device 🏎

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use gpu or cpu
print(device)

## Preparing data set ✔

**Define data set class 🤷‍♂️**

In [None]:
from sklearn.model_selection import train_test_split
dataset = ImageFolder("../input/cat-and-dog/training_set/training_set/")
train_data, test_data, train_label, test_label = train_test_split(dataset.imgs, dataset.targets, test_size=0.2, random_state=42)

# ImageLoader Class

class ImageLoader(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = self.checkChannel(dataset) # some images are CMYK, Grayscale, check only RGB 
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, item):
        image = Image.open(self.dataset[item][0])
        classCategory = self.dataset[item][1]
        if self.transform:
            image = self.transform(image)
        return image, classCategory
        
    
    def checkChannel(self, dataset):
        datasetRGB = []
        for index in range(len(dataset)):
            if (Image.open(dataset[index][0]).getbands() == ("R", "G", "B")): # Check Channels
                datasetRGB.append(dataset[index])
        return datasetRGB

**Image transformation: Resizing normalizing 😎**

In [None]:
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0]*1, [1]*1)
]) # train transform

test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0]*1, [1]*1)
]) # test transform

train_dataset = ImageLoader(train_data, train_transform)
test_dataset = ImageLoader(test_data, test_transform)

**Define data set loader⚙**

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# **Loading resnet pre train network 🐱‍🏍**

In [None]:
from tqdm import tqdm
from torchvision import models

# load pretrain model and modify...
model = models.resnet50(pretrained=True)

model.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

model.to(device)

# Training the greatest neural network ever by arman👨‍🔬

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# Train and test

def train(num_epoch, model):
    for epoch in range(0, num_epoch):

        losses = []
        model.train()
        loop = tqdm(enumerate(train_loader), total=len(train_loader)) # create a progress bar
        for batch_idx, (data, targets) in loop:
            data = data.to(device=device)
            targets = targets.to(device=device)
            scores = model(data)
            
            loss = criterion(scores, targets)
            optimizer.zero_grad()
            losses.append(loss)
            loss.backward()
            optimizer.step()
            _, preds = torch.max(scores, 1)
            
            loop.set_description(f"Epoch {epoch+1}/{num_epoch} process: {int((batch_idx / len(train_loader)) * 100)}")
            loop.set_postfix(loss=loss.data.item())
        
        # save model
        torch.save({ 
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    }, 'checpoint_epoch_'+str(epoch)+'.pt')

In [None]:
def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            _, predictions = torch.max(output, 1)
            correct += (predictions == y).sum().item()
            test_loss = criterion(output, y)
            
    test_loss /= len(test_loader.dataset)
    print("Average Loss: ", test_loss, "  Accuracy: ", correct, " / ",
    len(test_loader.dataset), "  ", int(correct / len(test_loader.dataset) * 100), "%")

In [None]:
if __name__ == "__main__":
    train(5, model) # train
    #test() # test

In [None]:
print("----> Loading checkpoint")
checkpoint = torch.load("./checpoint_epoch_4.pt") # Try to load last checkpoint
model.load_state_dict(checkpoint["model_state_dict"]) 
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

# Attention Selector Network

**Patching image 👀**

In [None]:
!pip install patchify
from patchify import patchify, unpatchify
#patches = patchify(image, (16, 16), step=1)

**positional encoding🎈**

In [None]:
def positionalencoding1d(d_model, length):
    """
    :param d_model: dimension of the model
    :param length: length of positions
    :return: length*d_model position matrix
    """
    if d_model % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(d_model))
    pe = torch.zeros(length, d_model)
    position = torch.arange(0, length).unsqueeze(1)
    div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
                         -(math.log(10000.0) / d_model)))
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)

    return pe

**Main Class🧨**

In [None]:
class attention_sellector(nn.Module):
    
    def __init__(self, patch_size = 16, d_model=256, attention_head=8):
        super(attention_sellector, self).__init__()
        self.device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.patch_size    = patch_size
        self.d_model       = d_model
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=attention_head)
        self.transformer   = nn.TransformerEncoder(self.encoder_layer,6)
        self.fc            = nn.Linear(self.patch_size ** 2, self.d_model)
        self.softmax       = nn.Softmax(dim=2)
    
    def add_positionalencoding(self,x, input_shape):
        # Assuming all the images are Square and height and width are equal
        # x -> (batch, tokens, d_model)
        # pe -> (tokens , d_model)
        pe  = positionalencoding1d(self.d_model, (input_shape // self.patch_size)**2)
        pe  = pe.to(self.device)
        return torch.add(x,pe)
    
    def patchify_image(self,imgs, input_shape=224):
        imgs       = torch.reshape(imgs,(-1, input_shape, input_shape))
        img_copy   = imgs.cpu().numpy()
        
        patches    = patchify(img_copy, (1, self.patch_size, self.patch_size), step=(1,self.patch_size, self.patch_size))
        batch_size = imgs.size()[0]
        patches    = np.reshape(patches, (batch_size,-1,self.patch_size,self.patch_size))
        out        = torch.from_numpy(patches) # out == (batch_size, 196, 16 , 16) 
        out        = out.to(device=self.device)
        return  out
     
    def unpatchify(self, input_tensor, input_shape=224): #out == (batch_size, 196, 256)
        out = torch.reshape(input_tensor, (input_tensor.size()[0], input_shape // self.patch_size, input_shape // self.patch_size, self.patch_size, self.patch_size))
        # out == (batch_size,14,14,16,16)
        out_copy = out.numpy()
        out_copy = unpatchify(patches, (input_tensor.size()[0], input_shape, input_shape))
        out      = torch.from_numpy(out_copy) #out == (batch_size, 224, 224)
        return out
    
    def torch_delete(self,tensor, indices):
        mask = torch.ones(tensor.numel(), dtype=torch.bool)
        mask[indices] = False
        return tensor[mask]
    
    
    def select_attention(self,img, attention, attention_num = 20000):  #attention -> (batch_size, 196, 256)
                                                                       #img       -> (batch_size, 224, 224)
        attention_copy = torch.reshape(attention, (attention.size()[0], 1, -1)) #Attention_copy -> (batch_size, 1, 196*256)
        img_copy       = torch.reshape(img,       (img.size()[0],1,-1))
        for i in range(attention_copy.size()[0]):
            idx                      = torch.multinomial(attention_copy[i,:,:], attention_num)
            idx, _                   = torch.sort(idx)
            idx_delete               = self.torch_delete(torch.arange(196*256),idx)   #correct these numbers
            img_copy[i,:,idx_delete] = 0
            temp                     = torch.reshape(img_copy[i,:,:],(1,img.size()[2],img.size()[3])) #temp -> (1,224, 224) !!!
            if i == 0:
                out                  = temp
            else:
                out                  = torch.cat((out,temp),0)
        return out
     
     
    def apply_linear(self, input_tensor): # input == (batch_size, 196, 16 , 16)
        #add resnet layer later
        out = torch.flatten(input_tensor, start_dim=2) # (batch_size, 196, 256)
        out = self.fc(out) # (batch_size, 196, 256)
        return out
    
    def forward(self, x): #x -> (batch_size, 1, 224, 224)
        
        out = self.patchify_image(x)  #out == (batch_size, 196, 16 , 16) 
        out = self.apply_linear(out)  #out == (batch_size, 196, 256)
        out = self.add_positionalencoding(out, 224) #out == (batch_size, 196, 256)
        out = self.transformer(out)   #out == (batch_size, 196, 256)
        out = self.softmax(out)       #out == (batch_size, 196, 256)
        out = self.select_attention(x,out)  #out == (batch_size, 224, 224)
        out = torch.reshape(out, (-1,1,224 ,224))     #correct numbers
        return out
        

**Training this amazing model🕶**

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
attention_model = attention_sellector()
attention_model.to(attention_model.device)

def train_attention(num_epoch, attention_model, detection_model):
    for epoch in range(0, num_epoch):

        losses = []
        attention_model.train()
        detection_model.eval()
        optimizer = optim.Adam(attention_model.parameters(), lr=0.001)
        loop = tqdm(enumerate(train_loader), total=len(train_loader)) # create a progress bar
        for batch_idx, (data, targets) in loop:
            data = data.to(device=device)
            targets = targets.to(device=device)
            attention_output = attention_model(data)
            scores           = detection_model(attention_output)
            
            loss = criterion(scores, targets)
            optimizer.zero_grad()
            losses.append(loss)
            loss.backward()
            optimizer.step()
            _, preds = torch.max(scores, 1)
            
            loop.set_description(f"Epoch {epoch+1}/{num_epoch} process: {int((batch_idx / len(train_loader)) * 100)}")
            loop.set_postfix(loss=loss.data.item())
        
        # save model
        torch.save({ 
                    'model_state_dict': attention_model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    }, 'checpoint_epoch_'+str(epoch)+'.pt')
        


In [None]:
def train_detection(num_epoch, attention_model, detection_model):
    for epoch in range(0, num_epoch):

        losses = []
        attention_model.eval()
        detection_model.train()
        optimizer = optim.Adam(detection_model.parameters(), lr=0.001)
        loop = tqdm(enumerate(train_loader), total=len(train_loader)) # create a progress bar
        for batch_idx, (data, targets) in loop:
            data = data.to(device=device)
            targets = targets.to(device=device)
            attention_output = attention_model(data)
            scores           = detection_model(attention_output)
            
            loss = criterion(scores, targets)
            optimizer.zero_grad()
            losses.append(loss)
            loss.backward()
            optimizer.step()
            _, preds = torch.max(scores, 1)
            
            loop.set_description(f"Epoch {epoch+1}/{num_epoch} process: {int((batch_idx / len(train_loader)) * 100)}")
            loop.set_postfix(loss=loss.data.item())
        
        # save model
        torch.save({ 
                    'model_state_dict': detection_model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    }, 'checpoint_epoch_'+str(epoch)+'.pt')

In [None]:
if __name__ == "__main__":
    torch.cuda.empty_cache()
    train_detection(5,attention_model,model)
    train_attention(5, attention_model, model) # train

In [None]:
def test(attention_model, detection_model):
    attention_model.eval()
    detection_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            x = attention_model(x)
            output = detection_model(x)
            _, predictions = torch.max(output, 1)
            correct += (predictions == y).sum().item()
            test_loss = criterion(output, y)
            
    test_loss /= len(test_loader.dataset)
    print("Average Loss: ", test_loss, "  Accuracy: ", correct, " / ",
    len(test_loader.dataset), "  ", int(correct / len(test_loader.dataset) * 100), "%")

In [None]:
test(attention_model,model)

In [None]:
# Check the test set
test_dataset = ImageFolder("../input/cat-and-dog/test_set/test_set/", 
                     transform=transforms.Compose([
                         transforms.Grayscale(num_output_channels=1),
                         transforms.Resize((224, 224)), 
                         transforms.ToTensor()]))
                         #transforms.Normalize([0]*1, [1]*1)
                     
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle = False)

In [None]:
import matplotlib.pyplot as plt
with torch.no_grad():
    attention_model.eval()
    i = 0
    for data, target in test_dataloader:
        if i > 3: break
        i +=1
        data, target = data.to(device), target.to(device)
        output = attention_model(data)
        data = torch.reshape(data,(224,224))
        output = torch.reshape(output,(224,224))
        plt.imshow(output.cpu().numpy(), cmap='gray')
        plt.show()
        plt.imshow(data.cpu().numpy(), cmap='gray')
        plt.show()