In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder

Initializing  Hyper Parameters

In [53]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs = 3
batch_size = 16
learning_rate = 0.001
num_outputs = 4

Transforming the data to tensor and then loading it into test and train loaders

In [54]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
])


train_dataset = ImageFolder(root="./dataset/Training", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = ImageFolder(root="./dataset/Testing", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


This is the model architecture I had made, but this wasn't giving that high accuracy. It was giving around 65 - 70% accuracy. But then I was searching for something and found out about ResNet. So I have used ResNet as the model and not mine since that gave 90+ accuracy.


In [55]:
class CNN(nn.Module):
    def __init__(self, num_outputs):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv11 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.conv13 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv14 = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=3, padding=1)
        self.conv15 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(4096, 2000)
        self.fc2 = nn.Linear(2000, 1000)
        self.fc3 = nn.Linear(1000, num_outputs)

        self.batch_norm1 = nn.BatchNorm2d(64)
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.batch_norm3 = nn.BatchNorm2d(64)
        self.batch_norm4 = nn.BatchNorm2d(256)
        self.batch_norm5 = nn.BatchNorm2d(64)
        self.batch_norm6 = nn.BatchNorm2d(256)
        self.batch_norm7 = nn.BatchNorm2d(64)

    def forward(self, x):
        x = F.relu(self.batch_norm1(self.conv1(x)))
        x = self.max_pool(x)
        x = F.relu(self.batch_norm2(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = F.relu(self.batch_norm3(self.conv4(x)))
        x = F.relu(self.conv5(x))
        x = self.max_pool(x)
        x = F.relu(self.conv6(x))
        x = F.relu(self.batch_norm4(self.conv7(x)))
        x = F.relu(self.conv8(x))
        x = self.max_pool(x)
        x = F.relu(self.conv9(x))
        x = F.relu(self.batch_norm5(self.conv10(x)))
        x = F.relu(self.conv11(x))
        x = self.max_pool(x)
        x = F.relu(self.conv12(x))
        x = F.relu(self.batch_norm6(self.conv13(x)))
        x = self.max_pool(x)
        x = F.relu(self.conv14(x))
        x = F.relu(self.batch_norm7(self.conv15(x)))
        x = self.avg_pool(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# model = CNN(num_outputs).to(device)
model = torchvision.models.resnet18(pretrained=True)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 
# I converted the images to grayscale so this layer is for taking grayscale input and not RGB
model.fc = nn.Linear(512, 4) # Making last layer so that only 4 outputs are given

cost = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

Training the Model


In [None]:
loss_calc = []

model.train()
for epoch in range(epochs):
    epoch_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        if i % 10 == 0: print(i)
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = cost(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_calc.append(loss.item())
print("Finished")

This is done with no_grad because backward propogation and gradient calculations are not needed

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')

In [None]:
PATH = './romit.pth'
# torch.save(CNN.state_dict(), PATH)
torch.save(model.state_dict(), PATH)

net = CNN()
net.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))

Plotting Loss vs Iteration Graph

In [None]:
plt.plot(loss_calc, label='Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Loss Iteration')
plt.legend()
plt.show()