In [None]:
import torch
import pandas as pd
import numpy as np
import os
from torch import nn
from pathlib import Path
from torchvision import transforms
import torchvision
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
BATCH_SIZE = 64
NUM_WORKERS = os.cpu_count()
SEED = 42
EPOCHS = 100
IMG_SIZE = 128
LR = 3e-4
PATH = ""

In [None]:
class TinyVGG(nn.Module):
    """
    Model architecture copying TinyVGG from: 
    https://poloclub.github.io/cnn-explainer/
    """
    def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape, 
                      out_channels=hidden_units, 
                      kernel_size=3, 
                      stride=1, 
                      padding=1), 
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units, 
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=2) 
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential( 
            nn.Flatten(),
            nn.Linear(in_features=hidden_units*IMG_SIZE*8,
                      out_features=output_shape)
        )
    
    def forward(self, x: torch.Tensor):
        x = self.conv_block_1(x)
        # print(x.shape)
        x = self.conv_block_2(x)
        # print(x.shape)
        x = self.classifier(x)
        # print(x.shape)
        return x

In [None]:
model = TinyVGG()

model.to(device)

model.load_state_dict(torch.load(PATH, weights_only=True))

model.eval()

data_transforms = transforms.Compose([
    transforms.Resize([IMG_SIZE,IMG_SIZE]),
    transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomRotation(degrees=(0, 180)),
    transforms.TrivialAugmentWide(num_magnitude_bins=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
   
])

img_path = Path("data/anger/IR_IR_anger_cuonga1_1730.jpg")

img_tensor = torchvision.io.read_image(path=img_path)

img_after_transforms = data_transforms(img_tensor)

with torch.inference_mode():

    img_tensor.to(device)

    pred_logits = model(img_after_transforms)

    pred_labels = pred_logits.argmax(dim=1)

    print(pred_labels)

plt.imshow(img_tensor.permute(1, 2, 0))