# Making Predictions On A Single Image

In [None]:
import torch
import torchvision
import os
import cv2
import numpy as np
from torch import nn
from glob import glob
import timm
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
device='cpu'

In [None]:
image_paths = './Data/images/*/*'
image_paths = glob(image_paths)

In [None]:
labels = []
images = []

for image_path in image_paths:
    label = image_path.split(os.path.sep)[3]
    image = image_path.split(os.path.sep)[4]
    labels.append(label)
    images.append(image)

In [None]:
np.unique(labels)

In [None]:
labels = np.array(labels, dtype='str')
label2pred = dict(zip(np.unique(labels), range(0, 3)))
pred2label = dict(zip(range(0, 3), np.unique(labels)))
n_classes = len(np.unique(labels))

In [None]:
class ClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        out = self(images)
        loss = F.cross_entropy(out, labels)
        acc = accuracy(out, labels)          
        return loss, acc

    def validation_step(self, batch):
        images, labels = batch 
        images = images.to(device)
        labels = labels.to(device)
        out = self(images)                    
        loss = F.cross_entropy(out, labels) 
        acc = accuracy(out, labels)          
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()    
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}\n".format(epoch, result['val_loss'], result['val_acc']))
        

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

In [None]:
class EfficientNetB4(ClassificationBase):
    
    def __init__(self):
        super().__init__()
        
        self.network = torchvision.models.efficientnet_b4(weights=True)
        num_ftrs = self.network.classifier
        self.network.classifier = nn.Sequential(nn.Dropout(0.3), nn.Linear(1792, n_classes))
        
    def forward(self, batch):
        batch = batch.to(device)
        return self.network(batch)
        
        
        
model = EfficientNetB4()     

In [None]:
model = model.to(device)

In [None]:
model.load_state_dict(torch.load('EfficientNet-0.97.pth', map_location=torch.device('cpu')))
model.eval()

In [None]:
image = cv2.imread('./Data/images/point-flaky/359.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)

In [None]:
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((400, 400)),
    torchvision.transforms.CenterCrop((380))
])

image = test_transforms(image)
image = torch.reshape(image, (1, 3, 380, 380))

In [None]:
model.eval()
pred2label[np.argmax(model(image).cpu().detach().numpy())]