# **Download Libraries**

In [None]:
!pip install tensorboardX
!pip install -q evaluate seqeval

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorboardX
  Downloading tensorboardX-2.6-py2.py3-none-any.whl (114 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/21

# **Import libraries**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as utils
import torchvision
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from sklearn.metrics import accuracy_score
from sklearn import metrics
from datetime import datetime
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import cv2
import math
import os
import numpy as np
from tqdm.auto import tqdm
import evaluate
import pandas as pd

In [None]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


# **Build Attention ResNet**

## **Attention layers**

In [None]:
class ProjectorBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ProjectorBlock, self).__init__()
        self.op = nn.Conv2d(in_channels=in_features, out_channels=out_features,
            kernel_size=1, padding=0, bias=False)

    def forward(self, x):
        return self.op(x)

class SpatialAttn(nn.Module):
    def __init__(self, in_features, normalize_attn=True):
        super(SpatialAttn, self).__init__()
        self.normalize_attn = normalize_attn
        self.op = nn.Conv2d(in_channels=in_features, out_channels=1,
            kernel_size=1, padding=0, bias=False)

    def forward(self, l, g):
        N, C, H, W = l.size()
        c = self.op(l+g) # (batch_size,1,H,W)
        if self.normalize_attn:
            a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,H,W)
        else:
            a = torch.sigmoid(c)
        g = torch.mul(a.expand_as(l), l)
        if self.normalize_attn:
            g = g.view(N,C,-1).sum(dim=2) # (batch_size,C)
        else:
            g = F.adaptive_avg_pool2d(g, (1,1)).view(N,C)
        return c.view(N,1,H,W), g

## **Residual block**

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(out_channels))
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

## **Build the network**

In [None]:
class AttnResNet(nn.Module):
    def __init__(self, sample_size, block, layers, num_classes, attention=True, normalize_attn=True):
        super(AttnResNet, self).__init__()
        # conv blocks
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
                        nn.BatchNorm2d(64),
                        nn.ReLU())
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
        self.layer2 = self._make_layer(block, 512, layers[2], stride = 2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
        self.avgpool = nn.AvgPool2d(5, stride=2)
        self.dense = nn.Conv2d(in_channels=512, out_channels=512, kernel_size= 1, padding=0, bias=True)
        # attention blocks
        self.attention = attention
        if self.attention:
            self.projector = ProjectorBlock(128, 512)
            self.attn1 = SpatialAttn(in_features=512, normalize_attn=normalize_attn)
            self.attn2 = SpatialAttn(in_features=512, normalize_attn=normalize_attn)
            self.attn3 = SpatialAttn(in_features=512, normalize_attn=normalize_attn)
        # final classification layer
        if self.attention:
            self.classify = nn.Linear(in_features=512*3, out_features=num_classes, bias=True)
        else:
            self.classify = nn.Linear(in_features=512, out_features=num_classes, bias=True)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)

        x = self.layer0(x)

        l1 = self.layer1(x)

        l2 = self.layer2(l1)

        l3 = self.layer3(l2)

        x = self.avgpool(l3)

        #x = x.view(x.size(0), -1)

        g = self.dense(x) # batch_sizex512x1x1
        # attention
        if self.attention:
            c1, g1 = self.attn1(self.projector(l1), g)
            c2, g2 = self.attn2(l2, g)
            c3, g3 = self.attn3(l3, g)
            g = torch.cat((g1,g2,g3), dim=1) # batch_sizex3C
            # classification layer
            x = self.classify(g) # batch_sizexnum_classes
        else:
            c1, c2, c3 = None, None, None
            x = self.classify(torch.squeeze(g))
        return [x, c1, c2, c3]

## **Train / Val epochs**

In [None]:
def train_epoch(model, criterion, optimizer, dataloader, device, epoch, log_interval, writer):
    model.train()
    losses = []
    all_label = []
    all_pred = []

    for batch_idx, (inputs, labels) in enumerate(dataloader):
        # get the inputs and labels
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        # forward
        outputs = model(inputs)
        if isinstance(outputs, list):
            outputs = outputs[0]

        # compute the loss
        loss = criterion(outputs, labels.squeeze())
        losses.append(loss.item())

        # compute the accuracy
        prediction = torch.max(outputs, 1)[1]
        all_label.extend(labels.squeeze())
        all_pred.extend(prediction)
        score = accuracy_score(labels.squeeze().cpu().data.squeeze().numpy(), prediction.cpu().data.squeeze().numpy())

        # backward & optimize
        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % log_interval == 0:
            print("epoch {:3d} | iteration {:5d} | Loss {:.6f} | Acc {:.2f}%".format(epoch+1, batch_idx+1, loss.item(), score*100))
    # Compute the average loss & accuracy
    training_loss = sum(losses)/len(losses)
    all_label = torch.stack(all_label, dim=0)
    all_pred = torch.stack(all_pred, dim=0)
    training_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())
    # Log
    writer.add_scalars('Loss', {'train': training_loss}, epoch+1)
    writer.add_scalars('Accuracy', {'train': training_acc}, epoch+1)
    print("Average Training Loss of Epoch {}: {:.6f} | Acc: {:.2f}%".format(epoch+1, training_loss, training_acc*100))


def val_epoch(model, criterion, dataloader, device, epoch, writer):
    model.eval()
    losses = []
    all_label = []
    all_pred = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(dataloader):
            # get the inputs and labels
            inputs, labels = inputs.to(device), labels.to(device)
            # forward
            outputs = model(inputs)
            if isinstance(outputs, list):
                outputs = outputs[0]
            # compute the loss
            loss = criterion(outputs, labels.squeeze())
            losses.append(loss.item())
            # collect labels & prediction
            prediction = torch.max(outputs, 1)[1]
            all_label.extend(labels.squeeze())
            all_pred.extend(prediction)

    # Compute the average loss & accuracy
    val_loss = sum(losses)/len(losses)
    all_label = torch.stack(all_label, dim=0)
    all_pred = torch.stack(all_pred, dim=0)
    val_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())
    # Log
    writer.add_scalars('Loss', {'val': val_loss}, epoch+1)
    writer.add_scalars('Accuracy', {'val': val_acc}, epoch+1)
    print("Average Validation Loss: {:.6f} | Acc: {:.2f}%".format(val_loss, val_acc*100))

# **Import dataset**

In [None]:
#train and test data directory
data_dir = "/gdrive/MyDrive/data/train_resnet/train/"
test_data_dir = "/gdrive/MyDrive/data/train_resnet/test/"

#load the train and test data with augmentation
dataset = ImageFolder(data_dir,transform = transforms.Compose([
    transforms.Resize((150,150)),
    transforms.ToTensor(),

]))
test_dataset = ImageFolder(test_data_dir,transforms.Compose([
    transforms.Resize((150,150)),
    transforms.ToTensor()
]))

In [None]:
batch_size = 7
val_size = 500
train_size = len(dataset) - val_size

train_data,val_data = random_split(dataset,[train_size,val_size])
print(f"Length of Train Data : {len(train_data)}")
print(f"Length of Validation Data : {len(val_data)}")

Length of Train Data : 5063
Length of Validation Data : 500


In [None]:
#load the train and validation into batches.
train_loader = DataLoader(train_data, batch_size, shuffle = True, num_workers = 4, pin_memory = True)
val_loader = DataLoader(val_data, batch_size*2, num_workers = 4, pin_memory = True)
test_loader = DataLoader(test_dataset, batch_size*2, num_workers = 4, pin_memory = True)



In [None]:
# switch to cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# **Let's train !**

In [None]:
# Create model
model = AttnResNet(sample_size=32, block = ResidualBlock, layers = [3, 4, 6, 3], num_classes=3).to(device)
if torch.cuda.device_count() > 1:
    print("Using {} GPUs".format(torch.cuda.device_count()))
    model = nn.DataParallel(model)

In [None]:
writer = SummaryWriter("runs/cnn_attention_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now()))

In [None]:
num_epochs = 20
lr = 1e-4
no_save = False
log_interval = 100
weight_decay = 1e-4
save_path = "/gdrive/MyDrive/saved_models/resNet/"

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)


for epoch in range(num_epochs):
    train_epoch(model, criterion, optimizer, train_loader, device, epoch, log_interval, writer)
    val_epoch(model, criterion, val_loader, device, epoch, writer)
    # adjust learning rate
    # scheduler.step()
    if not no_save:
        torch.save(model.state_dict(), os.path.join(save_path, "cnn_epoch{:03d}.pth".format(epoch+1)))
        print("Saving Model of Epoch {}".format(epoch+1))



epoch   1 | iteration   100 | Loss 0.771114 | Acc 71.43%
epoch   1 | iteration   200 | Loss 0.864530 | Acc 42.86%
epoch   1 | iteration   300 | Loss 0.089553 | Acc 100.00%
epoch   1 | iteration   400 | Loss 0.107847 | Acc 100.00%
epoch   1 | iteration   500 | Loss 0.584193 | Acc 71.43%
epoch   1 | iteration   600 | Loss 0.305859 | Acc 71.43%
epoch   1 | iteration   700 | Loss 0.055511 | Acc 100.00%
Average Training Loss of Epoch 1: 0.526812 | Acc: 79.08%




Average Validation Loss: 0.272510 | Acc: 88.60%
Saving Model of Epoch 1




epoch   2 | iteration   100 | Loss 0.120884 | Acc 100.00%
epoch   2 | iteration   200 | Loss 0.509735 | Acc 71.43%
epoch   2 | iteration   300 | Loss 0.229015 | Acc 85.71%
epoch   2 | iteration   400 | Loss 0.240785 | Acc 85.71%
epoch   2 | iteration   500 | Loss 0.760544 | Acc 71.43%
epoch   2 | iteration   600 | Loss 0.346308 | Acc 85.71%
epoch   2 | iteration   700 | Loss 0.010384 | Acc 100.00%
Average Training Loss of Epoch 2: 0.260325 | Acc: 89.75%




Average Validation Loss: 0.213171 | Acc: 89.80%
Saving Model of Epoch 2




epoch   3 | iteration   100 | Loss 0.035274 | Acc 100.00%
epoch   3 | iteration   200 | Loss 0.133436 | Acc 100.00%
epoch   3 | iteration   300 | Loss 0.109736 | Acc 100.00%
epoch   3 | iteration   400 | Loss 0.085499 | Acc 100.00%
epoch   3 | iteration   500 | Loss 0.261006 | Acc 85.71%
epoch   3 | iteration   600 | Loss 0.018720 | Acc 100.00%
epoch   3 | iteration   700 | Loss 0.078668 | Acc 100.00%
Average Training Loss of Epoch 3: 0.198429 | Acc: 92.22%




Average Validation Loss: 0.212705 | Acc: 93.80%
Saving Model of Epoch 3




epoch   4 | iteration   100 | Loss 0.025829 | Acc 100.00%
epoch   4 | iteration   200 | Loss 0.319995 | Acc 85.71%
epoch   4 | iteration   300 | Loss 0.165207 | Acc 85.71%
epoch   4 | iteration   400 | Loss 0.154185 | Acc 85.71%
epoch   4 | iteration   500 | Loss 0.130651 | Acc 85.71%
epoch   4 | iteration   600 | Loss 0.019386 | Acc 100.00%
epoch   4 | iteration   700 | Loss 0.018602 | Acc 100.00%
Average Training Loss of Epoch 4: 0.172361 | Acc: 93.66%




Average Validation Loss: 0.362194 | Acc: 85.20%
Saving Model of Epoch 4




epoch   5 | iteration   100 | Loss 0.014172 | Acc 100.00%
epoch   5 | iteration   200 | Loss 0.024898 | Acc 100.00%
epoch   5 | iteration   300 | Loss 0.042838 | Acc 100.00%
epoch   5 | iteration   400 | Loss 0.057089 | Acc 100.00%
epoch   5 | iteration   500 | Loss 0.108253 | Acc 100.00%
epoch   5 | iteration   600 | Loss 0.010850 | Acc 100.00%
epoch   5 | iteration   700 | Loss 0.015313 | Acc 100.00%
Average Training Loss of Epoch 5: 0.143764 | Acc: 94.84%




Average Validation Loss: 0.154902 | Acc: 94.20%
Saving Model of Epoch 5




epoch   6 | iteration   100 | Loss 0.009737 | Acc 100.00%
epoch   6 | iteration   200 | Loss 0.027866 | Acc 100.00%
epoch   6 | iteration   300 | Loss 0.027989 | Acc 100.00%
epoch   6 | iteration   400 | Loss 0.022718 | Acc 100.00%
epoch   6 | iteration   500 | Loss 0.298004 | Acc 85.71%
epoch   6 | iteration   600 | Loss 0.007175 | Acc 100.00%
epoch   6 | iteration   700 | Loss 0.001348 | Acc 100.00%
Average Training Loss of Epoch 6: 0.120098 | Acc: 95.56%




Average Validation Loss: 1.518639 | Acc: 65.80%
Saving Model of Epoch 6




epoch   7 | iteration   100 | Loss 0.052919 | Acc 100.00%
epoch   7 | iteration   200 | Loss 0.041425 | Acc 100.00%
epoch   7 | iteration   300 | Loss 0.013586 | Acc 100.00%
epoch   7 | iteration   400 | Loss 0.705341 | Acc 85.71%
epoch   7 | iteration   500 | Loss 0.050600 | Acc 100.00%
epoch   7 | iteration   600 | Loss 0.005113 | Acc 100.00%
epoch   7 | iteration   700 | Loss 0.078827 | Acc 100.00%
Average Training Loss of Epoch 7: 0.127966 | Acc: 95.34%




Average Validation Loss: 0.122462 | Acc: 95.80%
Saving Model of Epoch 7




epoch   8 | iteration   100 | Loss 0.009550 | Acc 100.00%
epoch   8 | iteration   200 | Loss 0.004395 | Acc 100.00%
epoch   8 | iteration   300 | Loss 0.166238 | Acc 85.71%
epoch   8 | iteration   400 | Loss 0.022646 | Acc 100.00%
epoch   8 | iteration   500 | Loss 0.004125 | Acc 100.00%
epoch   8 | iteration   600 | Loss 0.225101 | Acc 85.71%
epoch   8 | iteration   700 | Loss 0.338227 | Acc 85.71%
Average Training Loss of Epoch 8: 0.097778 | Acc: 96.31%




Average Validation Loss: 0.138703 | Acc: 96.20%
Saving Model of Epoch 8




epoch   9 | iteration   100 | Loss 0.008712 | Acc 100.00%
epoch   9 | iteration   200 | Loss 0.003710 | Acc 100.00%
epoch   9 | iteration   300 | Loss 0.028333 | Acc 100.00%
epoch   9 | iteration   400 | Loss 0.003344 | Acc 100.00%
epoch   9 | iteration   500 | Loss 0.135930 | Acc 85.71%
epoch   9 | iteration   600 | Loss 0.045928 | Acc 100.00%
epoch   9 | iteration   700 | Loss 0.118588 | Acc 100.00%
Average Training Loss of Epoch 9: 0.098993 | Acc: 96.17%




Average Validation Loss: 0.084155 | Acc: 97.40%
Saving Model of Epoch 9




epoch  10 | iteration   100 | Loss 0.465709 | Acc 85.71%
epoch  10 | iteration   200 | Loss 0.020143 | Acc 100.00%
epoch  10 | iteration   300 | Loss 0.007004 | Acc 100.00%
epoch  10 | iteration   400 | Loss 0.000446 | Acc 100.00%
epoch  10 | iteration   500 | Loss 0.003053 | Acc 100.00%
epoch  10 | iteration   600 | Loss 0.006609 | Acc 100.00%
epoch  10 | iteration   700 | Loss 0.011141 | Acc 100.00%
Average Training Loss of Epoch 10: 0.076892 | Acc: 97.29%




Average Validation Loss: 0.351149 | Acc: 89.60%
Saving Model of Epoch 10


## **Save the model**

In [None]:
torch.save(model.state_dict(), "/gdrive/MyDrive/saved_models/resNet/resNet_model_state.pth")

In [None]:
modelB = AttnResNet(sample_size=32, block = ResidualBlock, layers = [3, 4, 6, 3], num_classes=3).to(device)
modelB.load_state_dict(torch.load("/gdrive/MyDrive/saved_models/resNet/resNet_model_state.pth"), strict=False)

<All keys matched successfully>

# **Evaluation**

In [None]:
criterion = nn.CrossEntropyLoss()
num_epochs = 10
lr = 1e-4
no_save = False
log_interval = 100
weight_decay = 1e-4
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
writer = SummaryWriter("runs/cnn_attention_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now()))

In [None]:
def evaluate_model(model, criterion, dataloader, device,  writer):
    model.eval()
    losses = []

    all_label = []
    all_pred = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(dataloader):
            # get the inputs and labels
            inputs, labels = inputs.to(device), labels.to(device)
            # forward
            outputs = model(inputs)
            if isinstance(outputs, list):
                outputs = outputs[0]
            # compute the loss
            try :

              loss = criterion(outputs, labels.squeeze())
              losses.append(loss.item())
              # collect labels & prediction
              prediction = torch.max(outputs, 1)[1]
              all_label.extend(labels.squeeze())
              all_pred.extend(prediction)
            except :
              pass

    # Compute the average loss & accuracy
    val_loss = sum(losses)/len(losses)
    all_label = torch.stack(all_label, dim=0)
    all_pred = torch.stack(all_pred, dim=0)
    #val_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())
    # Log
    #writer.add_scalars('Loss', {'val': val_loss})
    #writer.add_scalars('Accuracy', {'val': val_acc})
    #print("Average Validation Loss: {:.6f} | Acc: {:.2f}%".format(val_loss, val_acc*100))

    # Classification report
    print('\n\n\t\tCLASSIFICATIION METRICS\n')
    print(metrics.classification_report(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy(),
                                        target_names = ["handwiriting", "other", "table"]))
    return metrics.classification_report(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy(),
                                        target_names = ["handwiriting", "other", "table"], output_dict=True)

In [None]:
report = evaluate_model(model, criterion, test_loader, device, writer)
df = pd.DataFrame(report).transpose()





		CLASSIFICATIION METRICS

              precision    recall  f1-score   support

handwiriting       1.00      0.95      0.97       424
       other       0.80      1.00      0.89       465
       table       0.99      0.81      0.89       502

    accuracy                           0.91      1391
   macro avg       0.93      0.92      0.92      1391
weighted avg       0.93      0.91      0.91      1391



In [None]:
print('\t\tCLASSIFICATIION METRICS\n\tCNN with attention mechanism for pipeline ')
print('______________________________________________________')
df

		CLASSIFICATIION METRICS
	CNN with attention mechanism for pipeline 
______________________________________________________


Unnamed: 0,precision,recall,f1-score,support
handwiriting,1.0,0.990566,0.995261,424.0
other,0.965293,0.956989,0.961123,465.0
table,0.954813,0.97006,0.962376,501.0
accuracy,0.971942,0.971942,0.971942,0.971942
macro avg,0.973369,0.972538,0.97292,1390.0
weighted avg,0.972103,0.971942,0.971988,1390.0


# **Inference**

In [None]:
def to_device(data, device):
    "Move data to the device"
    if isinstance(data,(list,tuple)):
        return [to_device(x,device) for x in data]
    return data.to(device,non_blocking = True)

In [None]:
def predict_img_class(img,model):
    """ Predict the class of image and Return Predicted Class"""
    img = to_device(img.unsqueeze(0), device)
    outputs =  model(img)[0]
    prediction = torch.max(outputs, 1)[1]
    """print(prediction)
    _, preds = torch.max(prediction, dim = 1)"""
    return dataset.classes[prediction[0].item()]

In [None]:
from PIL import Image
import numpy as np
#open image file
img = Image.open("/gdrive/MyDrive/test_images/handwriting5.jpg")

#resize image
img = img.resize((150,150))

channels = transforms.ToTensor()(img)
if channels.shape[0]>3 or channels.shape[0]<3:
  img = img.convert(mode='RGB')

#convert image to tensor
img = transforms.ToTensor()(img)

#print image
plt.imshow(img.permute(1,2,0))

#prdict image label
print(f"Predicted Class : {predict_img_class(img,modelB)}")

FileNotFoundError: ignored