In [150]:
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision 
from torchvision import transforms

In [151]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [152]:
#get data
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder

#get current directory
base_dir = os.getcwd()

# add crop to the directory
base_dir = os.path.join(base_dir, 'crop')

dataset = ImageFolder(root=base_dir, transform=transform)


In [153]:
#get training and testing data
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=4)

In [154]:
image, label = train_dataset[0]
print(image.size(), label)

torch.Size([3, 256, 256]) 25


In [155]:
# these are the classes for the aircraft
class_names = ["A-10", "A-400M", "AG-600", "AH-64", "AV-8B", "An-124", "An-22", "An-225", "An-72", "B-1", "B-2", "B-21", "B-52", "Be-200", "C-130", "C-17", "C-2", "C-390", "C-5", "CH-47", "CL-415", "E-2", "E-7", "EF-2000", "F-117", "F-14", "F-15", "F-16", "F-22", "F-35", "F-4", "F/A-18", "H-6", "J-10", "J-20", "JAS-39", "JF-17", "JH-7", "KC-135", "KF-21", "KJ-600", "Ka-27", "Ka-52", "MQ-9", "Mi-24", "Mi-26", "Mi-28", "Mig-29", "Mig-31", "Mirage2000", "P-3", "RQ-4", "Rafale", "SR-71", "Su-24", "Su-25", "Su-34", "Su-57", "TB-001", "TB-2", "Tornado", "Tu-160", "Tu-22M", "Tu-95", "U-2", "UH-60", "US-2", "V-22", "Vulcan", "WZ-7", "XB-70", "Y-20", "YF-23", "Z-19"]

In [156]:
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) # new shape is 6 * 252 * 252
        self.pool = nn.MaxPool2d(2, 2) # new shape is 6 * 126 * 126
        self.conv2 = nn.Conv2d(6, 16, 5) # new shape is 16 * 122 * 122
        self.fc1 = nn.Linear(16 * 61 * 61, 120) # flatten the image to 16 * 61 * 61
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 74) # change the output features to 74

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 61 * 61)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [157]:
net = NeuralNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)  # Increased learning rate from 0.001 to 0.01

In [158]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

for epoch in range(10):
    print(f'Training epoch {epoch + 1}')
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
    print(f'Loss: {running_loss / len(train_loader): .4f}')

Training epoch 1
Loss:  4.0135
Training epoch 2
Loss:  3.7932
Training epoch 3
Loss:  3.4933
Training epoch 4
Loss:  3.0291
Training epoch 5
Loss:  2.4722
Training epoch 6
Loss:  1.8325
Training epoch 7
Loss:  1.2630
Training epoch 8
Loss:  0.8843
Training epoch 9
Loss:  0.6821
Training epoch 10
Loss:  0.5454


In [159]:
correct = 0
total = 0

net.eval()

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accruacy = correct / total
print(f'Accuracy: {accruacy * 100:.2f}%')

Accuracy: 25.68%
