In [11]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import glob
import PIL
from PIL import Image
from torch.utils import data as D
from torch.utils.data.sampler import SubsetRandomSampler
import random

In [12]:
batch_size = 32
validation_ratio = 0.1
random_seed = 10

In [13]:
transform_train = transforms.Compose([
        transforms.Resize(299),
        transforms.RandomCrop(299, padding=38),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

transform_validation = transforms.Compose([
        transforms.Resize(299),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


transform_test = transforms.Compose([
        transforms.Resize(299),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

validset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_validation)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

#trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
#                                          shuffle=True, num_workers=0)

num_train = len(trainset)
indices = list(range(num_train))
split = int(np.floor(validation_ratio * num_train))

np.random.seed(random_seed)
np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, sampler=train_sampler, num_workers=0
)

valid_loader = torch.utils.data.DataLoader(
    validset, batch_size=batch_size, sampler=valid_sampler, num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=0
)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

initial_lr = 0.045

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## Reduced Dataset

In [14]:
import torchvision.transforms as transforms
import torch

# included_classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog']
included_classes = ['plane', 'car', 'bird', 'cat', 'deer']


transform_train = transforms.Compose([
        transforms.Resize(299),
        transforms.RandomCrop(299, padding=38),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

transform_validation = transforms.Compose([
        transforms.Resize(299),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


transform_test = transforms.Compose([
        transforms.Resize(299),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainset = torch.utils.data.Subset(trainset, [idx for idx in range(len(trainset)) if trainset.targets[idx] < len(included_classes)])

validset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_validation)
validset = torch.utils.data.Subset(validset, [idx for idx in range(len(validset)) if validset.targets[idx] < len(included_classes)])

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testset = torch.utils.data.Subset(testset, [idx for idx in range(len(testset)) if testset.targets[idx] < len(included_classes)])

num_train = len(trainset)
indices = list(range(num_train))
split = int(np.floor(validation_ratio * num_train))

np.random.seed(random_seed)
np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, sampler=train_sampler, num_workers=0
)

valid_loader = torch.utils.data.DataLoader(
    validset, batch_size=batch_size, sampler=valid_sampler, num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=0
)

classes = ('plane', 'car')

initial_lr = 0.045

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout, kernel_size, padding, bias=False):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, padding=padding, groups=nin, bias=bias)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

class Xception(nn.Module):
    def __init__(self, input_channel, num_classes=10):
        super(Xception, self).__init__()
        
        # Entry flow with explicit skip connections
        self.entry_flow_1 = nn.Sequential(
            nn.Conv2d(input_channel, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Entry flow 2 with residual connection
        self.entry_flow_2 = nn.Sequential(
            depthwise_separable_conv(64, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            depthwise_separable_conv(128, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.entry_flow_2_residual = nn.Conv2d(64, 128, kernel_size=1, stride=2, padding=0)
        
        # Entry flow 3 with residual connection
        self.entry_flow_3 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(128, 256, 3, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            depthwise_separable_conv(256, 256, 3, 1),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.entry_flow_3_residual = nn.Conv2d(128, 256, kernel_size=1, stride=2, padding=0)
        
        # Entry flow 4 with residual connection
        self.entry_flow_4 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(256, 728, 3, 1),
            nn.BatchNorm2d(728),
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.entry_flow_4_residual = nn.Conv2d(256, 728, kernel_size=1, stride=2, padding=0)
        
        # Middle flow repeated 8 times with residual connections
        self.middle_flow = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728)
        )
        
        # Exit flow with explicit skip connections
        self.exit_flow_1 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            nn.ReLU(True),
            depthwise_separable_conv(728, 1024, 3, 1),
            nn.BatchNorm2d(1024),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.exit_flow_1_residual = nn.Conv2d(728, 1024, kernel_size=1, stride=2, padding=0)
        self.exit_flow_2 = nn.Sequential(
            depthwise_separable_conv(1024, 1536, 3, 1),
            nn.BatchNorm2d(1536),
            nn.ReLU(True),
            depthwise_separable_conv(1536, 2048, 3, 1),
            nn.BatchNorm2d(2048),
            nn.ReLU(True)
        )
        
        self.linear = nn.Linear(2048, num_classes)
        
    def forward(self, x):
        # Apply each block, adding skip connections as defined
        x = self.entry_flow_1(x)
        x = self.entry_flow_2(x) + self.entry_flow_2_residual(x)
        x = self.entry_flow_3(x) + self.entry_flow_3_residual(x)
        x = self.entry_flow_4(x) + self.entry_flow_4_residual(x)
        
        # Middle flow iterations
        middle_out = x
        for _ in range(8):
            middle_out = self.middle_flow(middle_out) + middle_out

        x = self.exit_flow_1(middle_out) + self.exit_flow_1_residual(middle_out)
        x = self.exit_flow_2(x)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        output = self.linear(x)
        
        return output


In [16]:
class Xception(nn.Module):
    def __init__(self, input_channel, num_classes=10):
        super(Xception, self).__init__()
        
        # Entry Flow
        self.entry_flow_1 = nn.Sequential(
            nn.Conv2d(input_channel, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        
        self.entry_flow_2 = nn.Sequential(
            depthwise_separable_conv(64, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            depthwise_separable_conv(128, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.entry_flow_2_residual = nn.Conv2d(64, 128, kernel_size=1, stride=2, padding=0)
        
        self.entry_flow_3 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(128, 256, 3, 1),
            nn.BatchNorm2d(256),
            
            nn.ReLU(True),
            depthwise_separable_conv(256, 256, 3, 1),
            nn.BatchNorm2d(256),
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.entry_flow_3_residual = nn.Conv2d(128, 256, kernel_size=1, stride=2, padding=0)
        
        self.entry_flow_4 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(256, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.entry_flow_4_residual = nn.Conv2d(256, 728, kernel_size=1, stride=2, padding=0)
        
        # Middle Flow
        self.middle_flow = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728)
        )
        
        # Exit Flow
        self.exit_flow_1 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.ReLU(True),
            depthwise_separable_conv(728, 1024, 3, 1),
            nn.BatchNorm2d(1024),
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.exit_flow_1_residual = nn.Conv2d(728, 1024, kernel_size=1, stride=2, padding=0)
        self.exit_flow_2 = nn.Sequential(
            depthwise_separable_conv(1024, 1536, 3, 1),
            nn.BatchNorm2d(1536),
            nn.ReLU(True),
            
            depthwise_separable_conv(1536, 2048, 3, 1),
            nn.BatchNorm2d(2048),
            nn.ReLU(True)
        )
        
        self.linear = nn.Linear(2048, num_classes)
        
    def forward(self, x):
        entry_out1 = self.entry_flow_1(x)
        entry_out2 = self.entry_flow_2(entry_out1) + self.entry_flow_2_residual(entry_out1)
        entry_out3 = self.entry_flow_3(entry_out2) + self.entry_flow_3_residual(entry_out2)
        entry_out = self.entry_flow_4(entry_out3) + self.entry_flow_4_residual(entry_out3)
        
        middle_out = self.middle_flow(entry_out) + entry_out
        
        for i in range(7):
          middle_out = self.middle_flow(middle_out) + middle_out

        exit_out1 = self.exit_flow_1(middle_out) + self.exit_flow_1_residual(middle_out)
        exit_out2 = self.exit_flow_2(exit_out1)

        exit_avg_pool = F.adaptive_avg_pool2d(exit_out2, (1, 1))                
        exit_avg_pool_flat = exit_avg_pool.view(exit_avg_pool.size(0), -1)

        output = self.linear(exit_avg_pool_flat)
        
        return output

In [17]:
net = Xception(3, 5) #ResNet-18

## Original Xception

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [19]:
net.to(device)

Xception(
  (entry_flow_1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (entry_flow_2): Sequential(
    (0): depthwise_separable_conv(
      (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
      (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): depthwise_separable_conv(
      (depthwise): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
      (pointwise): Conv2d(128, 128, 

In [20]:
from tqdm import tqdm

# Define the number of epochs
num_epochs = 20

# Outer loop over epochs
for epoch in range(num_epochs):
    # Logic for adjusting learning rate every 2 epochs
    if epoch == 0:
        lr = initial_lr
    elif epoch % 2 == 0 and epoch != 0:
        lr *= 0.94
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)

    # Initialize the running loss
    running_loss = 0.0

    # Wrap train_loader with tqdm to create a progress bar
    with tqdm(train_loader, unit='batch') as t:
        # Inner loop over batches
        for i, data in enumerate(t):
            # Get the inputs and labels
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = net(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update the running loss
            running_loss += loss.item()

            # Update the progress bar description
            t.set_description(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / (i + 1):.4f}')

    # Validation part
    correct = 0
    total = 0
    with torch.no_grad():
        for i, data in enumerate(valid_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Print accuracy after each epoch
    print('[%d epoch] Accuracy of the network on the validation images: %d %%' %
          (epoch, 100 * correct / total))

# Print a message when finished training
print('Finished Training')


Epoch 1/20, Loss: 1.6352: 100%|██████████| 704/704 [06:55<00:00,  1.70batch/s]


[0 epoch] Accuracy of the network on the validation images: 20 %


Epoch 2/20, Loss: 1.6352: 100%|██████████| 704/704 [06:54<00:00,  1.70batch/s]


[1 epoch] Accuracy of the network on the validation images: 20 %


Epoch 3/20, Loss: 1.2508: 100%|██████████| 704/704 [06:54<00:00,  1.70batch/s]


[2 epoch] Accuracy of the network on the validation images: 48 %


Epoch 4/20, Loss: 0.9558: 100%|██████████| 704/704 [06:54<00:00,  1.70batch/s]


[3 epoch] Accuracy of the network on the validation images: 60 %


Epoch 5/20, Loss: 0.7844: 100%|██████████| 704/704 [06:54<00:00,  1.70batch/s]


[4 epoch] Accuracy of the network on the validation images: 70 %


Epoch 6/20, Loss: 0.6533: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[5 epoch] Accuracy of the network on the validation images: 74 %


Epoch 7/20, Loss: 0.5405: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[6 epoch] Accuracy of the network on the validation images: 77 %


Epoch 8/20, Loss: 0.4786: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[7 epoch] Accuracy of the network on the validation images: 80 %


Epoch 9/20, Loss: 0.3990: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[8 epoch] Accuracy of the network on the validation images: 84 %


Epoch 10/20, Loss: 0.3641: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[9 epoch] Accuracy of the network on the validation images: 81 %


Epoch 11/20, Loss: 0.3157: 100%|██████████| 704/704 [06:54<00:00,  1.70batch/s]


[10 epoch] Accuracy of the network on the validation images: 84 %


Epoch 12/20, Loss: 0.2919: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[11 epoch] Accuracy of the network on the validation images: 85 %


Epoch 13/20, Loss: 0.2589: 100%|██████████| 704/704 [06:54<00:00,  1.70batch/s]


[12 epoch] Accuracy of the network on the validation images: 85 %


Epoch 14/20, Loss: 0.2397: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[13 epoch] Accuracy of the network on the validation images: 85 %


Epoch 15/20, Loss: 0.2144: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[14 epoch] Accuracy of the network on the validation images: 84 %


Epoch 16/20, Loss: 0.2017: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[15 epoch] Accuracy of the network on the validation images: 86 %


Epoch 17/20, Loss: 0.1798: 100%|██████████| 704/704 [06:54<00:00,  1.70batch/s]


[16 epoch] Accuracy of the network on the validation images: 88 %


Epoch 18/20, Loss: 0.1696: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[17 epoch] Accuracy of the network on the validation images: 88 %


Epoch 19/20, Loss: 0.1561: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[18 epoch] Accuracy of the network on the validation images: 88 %


Epoch 20/20, Loss: 0.1474: 100%|██████████| 704/704 [06:53<00:00,  1.70batch/s]


[19 epoch] Accuracy of the network on the validation images: 89 %
Finished Training


In [21]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the test images: 89 %


In [23]:
import torch
import time
from sklearn.metrics import precision_score, recall_score, f1_score
correct = 0
total = 0
all_labels = []
all_predictions = []

# To measure inference time
start_time = time.time()

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass to get outputs
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)

        # Collect all labels and predictions for precision, recall, F1 score calculation
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

        # Calculating correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate inference time
inference_time = time.time() - start_time

# Convert lists to numpy arrays for metric calculation
all_labels = np.array(all_labels)
all_predictions = np.array(all_predictions)

# Calculate precision, recall, and F1 score
precision = precision_score(all_labels, all_predictions, average='macro')
recall = recall_score(all_labels, all_predictions, average='macro')
f1 = f1_score(all_labels, all_predictions, average='macro')

print('Accuracy of the network on the test images: {:.2f} %'.format(100 * correct / total))
print('Precision: {:.4f}'.format(precision))
print('Recall: {:.4f}'.format(recall))
print('F1 Score: {:.4f}'.format(f1))
print('Inference Time: {:.2f} seconds'.format(inference_time))

Accuracy of the network on the test images: 89.42 %
Precision: 0.8953
Recall: 0.8942
F1 Score: 0.8943
Inference Time: 33.75 seconds


In [22]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
                
        for i in range(labels.shape[0]):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of plane : 90 %
Accuracy of   car : 96 %


IndexError: tuple index out of range