In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets,transforms,models
import matplotlib.pyplot as plt
import time
import os
import copy
import torchmetrics
from torchmetrics.classification import Accuracy, Precision, Recall, F1Score

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [11]:
mean = np.array([0.485,0.456,0.406])
std = np.array([0.229,0.224,0.225])

data_transforms = {
    'train':transforms.Compose([
        transforms.Resize((30, 30)),
        transforms.ToTensor(),
        # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Repeat the single channel grayscale image three times to create an RGB image
        transforms.Normalize(mean, std)  # Normalize all three channels
    ]),
    'val':transforms.Compose([
        transforms.Resize((30, 30)),
        transforms.ToTensor(),
        # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Repeat the single channel grayscale image three times to create an RGB image
        transforms.Normalize(mean, std)  # Normalize all three channels
    ])
}

In [12]:
data_dir = 'train_dataset'
sets = ['train','val']
image_datasets = {x:datasets.ImageFolder(os.path.join(data_dir,x),
                                         data_transforms[x]) 
                  for x in ['train','val']}
dataloaders = {x:torch.utils.data.DataLoader(image_datasets[x],batch_size=4,
                                             shuffle=True,num_workers=4)
                for x in ['train','val']}


In [13]:
dataset_sizes = {x:len(image_datasets[x]) for x in ['train','val']}
class_names = image_datasets['train'].classes
print(class_names)

['1', '2', '3']


In [14]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # initialize metric
    metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=3)
    predicted_labels = []
    ground_truth_labels = []
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-'*10)
        
        #Training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            #Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                #forward
                #track history only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _,preds = torch.max(outputs,1)
                    loss = criterion(outputs,labels)
                    
                    #backward + optimize only in train
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                
                #statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            predicted_labels.append(preds.cpu())
            ground_truth_labels.append(labels.cpu())

            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            #deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        print()
    
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    #calculate accuracy
    predicted_labels = torch.cat(predicted_labels)
    ground_truth_labels = torch.cat(ground_truth_labels)
    accuracy = Accuracy(task="multiclass", num_classes=3)
    accuracy(predicted_labels, ground_truth_labels)
    print(f'Accuracy: {accuracy.compute():.4f}')
    
    #calculate precision
    precision = Precision(task="multiclass", average='macro', num_classes=3)
    precision(predicted_labels, ground_truth_labels)
    print(f'Precision: {precision.compute():.4f}')
    
    #calculate recall
    recall = Recall(task="multiclass", average='macro', num_classes=3)
    recall(predicted_labels, ground_truth_labels)
    print(f'Recall: {recall.compute():.4f}')
    
    #calculate f1 score
    f1 = F1Score(task="multiclass", average='macro', num_classes=3)
    f1(predicted_labels, ground_truth_labels)
    print(f'F1: {f1.compute():.4f}')
    
    #calculate confusion matrix
    cm = torchmetrics.functional.confusion_matrix(predicted_labels, ground_truth_labels, num_classes=3, task="multiclass")
    print(f'Confusion Matrix: \n{cm}')    
    #load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [15]:
class ConvNet(nn.Module):
    def __init__(self, num_classes=3):
        super(ConvNet, self).__init__()
        
        # Input shape = (64, 3, 30, 30)
        # Output size after convolutional layer = (w-f+2p)/s + 1 = (30-3+2)/1 + 1 = 30
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
        #Input shape = (64, 12, 30, 30)
        self.bn1 = nn.BatchNorm2d(num_features=12)
        self.relu1 = nn.ReLU()
        
        # Input shape = (64, 12, 30, 30)
        # Output size after max pooling = 30/2 = 15
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        
        # Input shape = (64, 12, 15, 15)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=20, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(in_channels=20, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=32)
        self.relu3 = nn.ReLU()
        #shape = (64, 32, 15, 15)
        
        self.fc = nn.Linear(in_features=15*15*32, out_features=num_classes)
    
    def forward(self, input):
        output = self.conv1(input)
        output = self.bn1(output)
        output = self.relu1(output)
        
        output = self.maxpool1(output)
        
        output = self.conv2(output)
        output = self.relu2(output)
        
        output = self.conv3(output)
        output = self.bn3(output)
        output = self.relu3(output)
        
        #reshaping the output to feed into the fully connected layer
        output = output.view(-1, 15*15*32)
        
        output = self.fc(output)
        
        return output

In [16]:
#load the best_checkpoint.model
checkpoint = torch.load('best_checkpoint.model')
pretrained_model = ConvNet(num_classes=3)
pretrained_model.load_state_dict(checkpoint)

for param in pretrained_model.parameters():
    param.requires_grad = False

#exchange the last layer
num_ftrs = pretrained_model.fc.in_features

pretrained_model.fc = nn.Linear(num_ftrs, 3)
pretrained_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pretrained_model.parameters(),lr=0.001)

#scheduler
step_lr_scheduler = lr_scheduler.StepLR(optimizer,step_size=7,gamma=0.1)

model_ft = train_model(pretrained_model,criterion,optimizer,step_lr_scheduler,num_epochs=25)

Epoch 0/24
----------


train Loss: 1.0004 Acc: 0.5556
val Loss: 1.1024 Acc: 0.6667

Epoch 1/24
----------
train Loss: 0.5684 Acc: 0.8222
val Loss: 0.5087 Acc: 0.8000

Epoch 2/24
----------
train Loss: 0.3413 Acc: 0.9111
val Loss: 0.4808 Acc: 0.7333

Epoch 3/24
----------
train Loss: 0.2801 Acc: 0.9556
val Loss: 0.2670 Acc: 1.0000

Epoch 4/24
----------
train Loss: 0.1867 Acc: 0.9778
val Loss: 0.2505 Acc: 0.9333

Epoch 5/24
----------
train Loss: 0.1961 Acc: 1.0000
val Loss: 0.1937 Acc: 1.0000

Epoch 6/24
----------
train Loss: 0.1555 Acc: 0.9556
val Loss: 0.1962 Acc: 1.0000

Epoch 7/24
----------
train Loss: 0.1542 Acc: 0.9778
val Loss: 0.1896 Acc: 1.0000

Epoch 8/24
----------
train Loss: 0.1621 Acc: 0.9778
val Loss: 0.1756 Acc: 1.0000

Epoch 9/24
----------
train Loss: 0.1601 Acc: 0.9778
val Loss: 0.1595 Acc: 1.0000

Epoch 10/24
----------
train Loss: 0.1267 Acc: 1.0000
val Loss: 0.1611 Acc: 1.0000

Epoch 11/24
----------
train Loss: 0.1350 Acc: 0.9778
val Loss: 0.1558 Acc: 1.0000

Epoch 12/24
----------
t