In [None]:
from google.colab import drive
import shutil
import torch
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm

In [None]:
import zipfile
import os

Dataset

In [None]:
drive.mount ('/content/gdrive')

Mounted at /content/gdrive


In [None]:
batch_size = 4

Loading the dataset

In [None]:
dataset_path = '/content/gdrive/MyDrive/Train-Test-Val/'

In [None]:
# Transformer to tensor
img_size = 256

transformer=transforms.Compose([
    transforms.Resize((img_size,img_size)),
    transforms.ToTensor(),  #0-255 to 0-1, numpy to tensors
])  

In [None]:
def load_dataset(d_path):
    train_dataset_manual = torchvision.datasets.ImageFolder(d_path, transform=transformer)
    train_loader_manual = torch.utils.data.DataLoader(train_dataset_manual)
    return train_loader_manual

In [None]:
train_dataset = load_dataset(str(dataset_path + 'train')).dataset
test_dataset = load_dataset(str(dataset_path + 'test')).dataset
valid_dataset = load_dataset(str(dataset_path + 'val')).dataset

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=2, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2, shuffle=True)

In [None]:
print('Train Set- ' + str(len(train_dataset)) + ' images in ' + str(len(train_loader)) +' batches')
print('Testing Set - ' + str(len(test_dataset)) + ' images in ' + str(len(test_loader)) + ' batches' )
print('Validation Set - ' + str(len(valid_dataset)) + ' images in ' + str(len(valid_loader)) + ' batches')

Train Set- 4168 images in 1042 batches
Testing Set - 1397 images in 350 batches
Validation Set - 1388 images in 347 batches


Network

In [None]:
class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=3, padding=1, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out
  

In [None]:
class channel_shuffle (nn.Module):
    def __init__(self, groups):
      super (channel_shuffle, self).__init__()
      self.groups = groups

    def forward (self, x):

      batchsize, num_channels, height, width = x.size()
      channels_per_group = num_channels // self.groups

      x = x.view(batchsize, self.groups, channels_per_group, height, width)
      x = torch.transpose(x, 1, 2).contiguous()
      x = x.view(batchsize, -1, height, width)
      
      return x


In [None]:
class GDSW (nn.Module):
  def __init__ (self, dim_in, dim_out):
    super(GDSW, self).__init__()

    self.gc1 = nn.Conv2d (dim_in, 6, kernel_size = (3,3), padding = 1,  groups = 3)
    self.cs = channel_shuffle (groups = 3)
    self.DSWC = depthwise_separable_conv (6, 12)
    self.gc2 = nn.Conv2d (12, dim_out, kernel_size = (3, 3), padding = 1, groups = 3)

  def forward (self, x):
    x = self.gc1 (x)
    x = self.cs(x)
    x = self.DSWC (x)
    x = self.gc2(x)

    return x

In [None]:
class FPN (nn.Module):
  def __init__ (self):
    super().__init__()
    self.enc_conv0 = nn.Conv2d(in_channels=18, out_channels=24, kernel_size=(3,3), padding=1)
    self.act0 = nn.ReLU()
    self.bn0 = nn.BatchNorm2d(24)
    self.pool0 = nn.MaxPool2d(kernel_size=(2,2))

    self.enc_conv1 = nn.Conv2d(in_channels=24, out_channels=32, kernel_size=(3,3), padding=1)
    self.act1 = nn.ReLU()
    self.bn1 = nn.BatchNorm2d(32)
    self.pool1 = nn.MaxPool2d(kernel_size=(2,2))

    self.enc_conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), padding=1)
    self.act2 = nn.ReLU()
    self.bn2 = nn.BatchNorm2d(64)
    self.pool2 =  nn.MaxPool2d(kernel_size=(2,2))

    self.bottleneck_conv = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), padding=1)

    self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2)
    self.dec_conv2 = nn.Conv2d(in_channels=192, out_channels=32, kernel_size=(3,3), padding=1)
    self.dec_act2 = nn.ReLU()
    self.dec_bn2 = nn.BatchNorm2d(32)

    self.upsample3 = nn.UpsamplingBilinear2d(scale_factor=2)
    self.dec_conv3 = nn.Conv2d(in_channels=64, out_channels=16, kernel_size=(3,3), padding=1)
    self.dec_act3 = nn.ReLU()
    self.dec_bn3 = nn.BatchNorm2d(16)

    self.upsample4 = nn.UpsamplingBilinear2d(scale_factor=2)
    self.dec_conv4 = nn.Conv2d(in_channels=40, out_channels=32, kernel_size=(1,1))
    self.dec_act4 = nn.ReLU()
    self.dec_bn4 = nn.BatchNorm2d(32)

    self.poolg = nn.MaxPool2d(kernel_size=(2,2))
    self.avgpool = nn.AdaptiveAvgPool2d (8)

  def forward (self, x):

    e0 = self.pool0(self.bn0(self.act0(self.enc_conv0(x))))   
    e1 = self.pool1(self.bn1(self.act1(self.enc_conv1(e0))))   
    e2 = self.pool2(self.bn2(self.act2(self.enc_conv2(e1))))   

    cat0 = self.bn0(self.act0(self.enc_conv0(x)))
    cat1 = self.bn1(self.act1(self.enc_conv1(e0)))      
    cat2 = self.bn2(self.act2(self.enc_conv2(e1)))

    b = self.bottleneck_conv(e2)

    d2 = self.dec_bn2(self.dec_act2(self.dec_conv2(torch.cat((self.upsample2(b), cat2), dim=1))))
    d3 = self.dec_bn3(self.dec_act3(self.dec_conv3(torch.cat((self.upsample3(d2), cat1), dim=1))))
    d4 = self.dec_bn4(self.dec_act4(self.dec_conv4(torch.cat((self.upsample4(d3), cat0), dim=1))))

    return d4

In [None]:
class CNN_Branch(nn.Module):
    def __init__(self):
        super().__init__()

        self.GDSW1 = GDSW(dim_in = 3, dim_out = 6)
        self.GDSW2 = GDSW(dim_in = 6, dim_out = 9)
        self.GDSW3 = GDSW(dim_in = 9, dim_out = 12)

        self.bnn1 = nn.BatchNorm2d (6)
        self.bnn2 = nn.BatchNorm2d (9)
        self.bnn3 = nn.BatchNorm2d (12)

        self.FPN = FPN ()

        self.poolg = nn.MaxPool2d(kernel_size=(2,2))
        self.poolg4 = nn.MaxPool2d(kernel_size=(4,4))
        self.avgpool = nn.AdaptiveAvgPool2d (8)

        self.fc = nn.Linear(128 * img_size, 7)        

    def forward(self, x):

        g0 = self.poolg(self.bnn1(self.GDSW1(x)))        
        g1 = self.poolg(self.bnn2(self.GDSW2(g0)))        
        g2 = self.poolg(self.bnn3(self.GDSW3(g1)))        

        g3 = torch.cat ((self.poolg4(g0), g2), dim = 1)
        d4 = self.FPN (g3)
        d4 = self.poolg4 (d4)

        return d4

In [None]:
class Transformer_Branch (nn.Module):
    def __init__(self):
        super().__init__()
        self.SWIN = torchvision.models.swin_b()
        self.SWIN.avgpool = nn.Identity()
        self.SWIN.flatten = nn.Identity()
        self.SWIN.head = nn.Identity()

    def forward (self, x):

      x = self.SWIN (x)

      return x

In [None]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

class sa_layer(nn.Module):
    def __init__(self, channel, groups=66):
        super(sa_layer, self).__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))

        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))

    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape

        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.shape

        x = x.reshape(b * self.groups, -1, h, w)
        x_0, x_1 = x.chunk(2, dim=1)

        # channel attention
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0 * self.sigmoid(xn)

        # spatial attention
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1 * self.sigmoid(xs)

        # concatenate along channel axis
        out = torch.cat([xn, xs], dim=1)
        out = out.reshape(b, -1, h, w)

        out = self.channel_shuffle(out, 2)
        return out
    


In [None]:
class Overall_Arch (nn.Module):
    def __init__(self):
        super().__init__()

        self.CNN_Branch = CNN_Branch()
        self.Transformer_Branch = Transformer_Branch()
        self.SA_Block = sa_layer(channel = 1056)

        self.gap = nn.AvgPool2d (kernel_size = (8,8))
        self.fc = nn.Sequential (nn.Linear (1056, 7), nn.Softmax(dim=1))
 
    
    def forward (self, x):

      local_f = self.CNN_Branch (x)
      global_f = self.Transformer_Branch(x)
      fused_f = torch.cat ([local_f, global_f], dim = 1)

      fused_f = self.SA_Block (fused_f)
      f = self.gap (fused_f)

      f = torch.flatten (f, 1)
      f = self.fc (f)

      return f

Model

In [None]:
if torch.cuda.is_available():
  torch.backends.cudnn.deterministic = True

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [None]:
model = Overall_Arch().to(device)

In [None]:
#Load latest model
save_path = '/content/gdrive/MyDrive/Train-Test-Val/Epochn35'
model = torch.load(save_path)

In [None]:
model.eval()

Testing

In [None]:
from sklearn.metrics import classification_report, accuracy_score
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
y_pred = []
y_true = []

# iterate over test data
for i, (images, labels) in enumerate(test_loader):
    images, labels = images.to(device), labels.to(device)

    output = model(images) # Feed Network

    output = (torch.max(torch.exp(output), 1)[1])
    output = output.data.cpu().numpy()

    y_pred.extend(output) # Save Prediction
        
    labels1 = labels.data.cpu().numpy()
    y_true.extend(labels1) # Save Truth

In [None]:
r = classification_report(y_true, y_pred,zero_division=0,output_dict=True)

In [None]:
print ('Accuracy - ', r.get ('accuracy'))
print ('Weighted Average - ', r.get('weighted avg'))
print ('Macro Average - ', r.get('macro avg'))

Accuracy -  0.7430207587687903
Weighted Average -  {'precision': 0.6072968325168648, 'recall': 0.7430207587687903, 'f1-score': 0.6671792962913379, 'support': 1397}
Macro Average -  {'precision': 0.3188943398935498, 'recall': 0.3911214377737827, 'f1-score': 0.3507525565623589, 'support': 1397}


In [None]:
### References:

### [1] Adeel H, “Focal Loss,” GitHub, https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py 
### [2] wofmanaf, “SA-Net,” GitHub, https://github.com/wofmanaf/SA-Net/blob/main/models/sa_resnet.py  
### [3] Microsoft, “Semantic Segmentation Pytorch,” GitHub, https://github.com/microsoft/AI-For-Beginners/blob/main/lessons/4-ComputerVision/12-Segmentation/SemanticSegmentationPytorch.ipynb 
### [4] Beijing Technology and Business University, “Fe-net,” GitHub, https://github.com/btbuIntelliSense/Fe-net/blob/main/Model/FENet.py
### [5] Shicai, “How to modify a conv2d to depthwise separable convolution?,” PyTorch Forums, https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7