In [1]:
import cv2
import numpy as np
import os
import pandas as pd
import pickle
from tqdm import tqdm

import torch
from torch.utils.data import Dataset
from torch import nn
from torch.nn import functional as F

from torchvision.models import resnet

import pytorch_lightning as pl

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
UCODE_DICT = '../NomDataset/HWDB1.1-bitmap64-ucode-hannom-v2-tst_seen-label-set-ucode.pkl'

# Dataset

## NomImageDataset - For loading raw-cropped images

In [43]:
# Dataset class for inputting YoloV5
class NomImageDataset(Dataset):
    def __init__(self, image_dir, annotation_file, unicode_dict_path, transform=None):
        self.root_dir = image_dir
        self.label_list = list()
        self.image_list = list()
        self.unicode_dict = dict()
        self.transform = transform
        self.n_crop = 0
        
        with open(unicode_dict_path, 'rb') as f:
            tmp = pickle.load(f)
            tmp = sorted(list(tmp.keys()))
        for idx, k in enumerate(tmp):
            self.unicode_dict[k] = idx
        print(self.unicode_dict)

        with open(annotation_file, 'r') as f:
            for line in f:
                line = line.strip().split(',')
                image_name, label = line
                label = label.strip()
                image_path = os.path.join(self.root_dir, image_name)
                
                self.image_list.append(image_path)
                try:
                    self.label_list.append(self.unicode_dict[label])
                except:
                    self.label_list.append(self.unicode_dict['UNK'])
                    # print(f'Unknown label: {label}')

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        x_image = cv2.imread(self.image_list[idx])
        y_label = self.label_list[idx]
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            x_image = self.transform(x_image)
        else:
            x_image = x_image *  1.0 / 255
            x_image = torch.from_numpy(x_image).permute(2, 0, 1).float()
        y_label = torch.tensor(y_label, dtype=torch.long)
        return x_image, y_label


# opt = dict(
#     image_dir = '../TempResources/ToK1871/Tok1871_raw_crops',
#     annotation_file = '../TempResources/ToK1871/ToK1871_crops.txt',
#     unicode_dict_path = '../NomDataset/HWDB1.1-bitmap64-ucode-hannom-v2-tst_seen-label-set-ucode.pkl',
#     transform = None,
# )
# dataset = NomImageDataset(**opt)

# label_dict = dict()
# with open(opt['unicode_dict_path'], 'rb') as f:
#     tmp = pickle.load(f)
# for idx, (k, v) in enumerate(tmp.items()):
#     label_dict[idx] = k
    


# from matplotlib import pyplot as plt
# img = dataset[2][0].permute(1, 2, 0).numpy()
# label = dataset[2][1]

# plt.imshow(img)
# plt.show()
# print(label.item())
# print(label_dict[label.item()])
# print(chr(int(label_dict[label.item()], 16)))



## NomYoloImageDataset
Yolo inference creates new crops that doesn't have labels. This class is exclusively for finding labels of such crops

In [3]:
import pybboxes as pybbx

class YoloCropDataset(Dataset):
    def __init__(self, image_file_path : str, annotation_file_path : str, label_file_path : str, unicode_dict_path : str, image_size : int | int, transform = None, scale = 1.0):
        self.image_file_path = image_file_path
        self.annotation_file_path = annotation_file_path
        self.label_file_path = label_file_path
        self.unicode_dict_path = unicode_dict_path
        self.image_size = image_size    # Target crop image size
        self.scale = scale
        
        self.image_files = []
        self.annotation_files = []
        self.label_files = []
        
        self.transform = transform
        self.load_files_list()
        
        self.crop_dict = {'crops': [], 'original_images_name': [], 'labels': [], 'unicode_labels': []}
        self.load_crops()        
        
    def load_files_list(self) -> None:
        for file in os.listdir(self.image_file_path):
            if file.endswith('.jpg'):
                self.image_files.append(file)
        for file in os.listdir(self.annotation_file_path):
            if file.endswith('.txt'):
                self.annotation_files.append(file)
        assert len(self.image_files) == len(self.annotation_files), "Number of image files and annotation files do not match"
        
        for file in os.listdir(self.label_file_path):
            if file.endswith('.xlsx'):
                self.label_files.append(file)
        assert len(self.image_files) == len(self.label_files), f"Number of image files and label files do not match. {len(self.image_files)} != {len(self.label_files)}"


    def load_crops(self) -> None:
        def find_best_IOU(ref_box, boxes) -> float | tuple | int:
            def calculate_IOU(box1, box2):
                x1, y1, x2, y2 = box1
                x3, y3, x4, y4 = box2
                x5, y5 = max(x1, x3), max(y1, y3)
                x6, y6 = min(x2, x4), min(y2, y4)
                intersection = max(0, x6 - x5) * max(0, y6 - y5)
                area1 = (x2 - x1) * (y2 - y1)
                area2 = (x4 - x3) * (y4 - y3)
                union = area1 + area2 - intersection
                return intersection / union
            
            best_iou = 0
            best_box = None
            best_index = -1
            for index, box in enumerate(boxes, 0):
                iou = calculate_IOU(ref_box, box)
                if iou > best_iou:
                    best_iou = iou
                    best_box = box
                    best_index = index
            return best_iou, best_box, best_index
        
        
        # Label dictionary
        with open(self.unicode_dict_path, 'rb') as f:
            unicode_labels = pickle.load(f)
        for i, (k, v) in enumerate(unicode_labels.items()):
            unicode_labels[k] = i
        
        # For reading yolo txt files
        total_n = len(self.image_files)
        for image_file, txt_file, excel_file in tqdm(zip(self.image_files, self.annotation_files, self.label_files)):
            image = cv2.cvtColor(cv2.imread(os.path.join(self.image_file_path, image_file)), cv2.COLOR_BGR2RGB)    # Grayscale, so I can stack 3 channels later
            h, w, _ = image.shape
            df = pd.read_excel(os.path.join(self.label_file_path, excel_file))

            label_dict = {'boxes': [], 'labels': []}
            for _, row in df.iterrows():
                x1, y1, x2, y2 = row['LEFT'], row['TOP'], row['RIGHT'], row['BOTTOM']
                label = row['UNICODE']
                
                x1, y1, x2, y2 = x1 // self.scale, y1 // self.scale, x2 // self.scale, y2 // self.scale
                
                label_dict['boxes'].append((x1, y1, x2, y2))
                label_dict['labels'].append(label)

            with open(os.path.join(self.annotation_file_path, txt_file), 'r') as f:
                lines = f.readlines()
                for line in lines:
                    _, x, y, b_w, b_h = map(float, line.split(' '))
                    bbox = pybbx.YoloBoundingBox(x, y, b_w, b_h, image_size=(w, h)).to_voc(return_values=True)
                    x1, y1, x2, y2 = bbox
                    
                    # Find the best IOU to label the cropped image
                    iou, box, idx = find_best_IOU(bbox, label_dict['boxes'])
                    
                    crop_img = image[int(y1):int(y2), int(x1):int(x2)]
                    
                    self.crop_dict['crops'].append(crop_img)
                    
                    try:
                        label = unicode_labels[label_dict['labels'][idx]]
                    except:
                        label = unicode_labels['UNK'] 
                    self.crop_dict['labels'].append(label)

        assert len(self.crop_dict['crops']) == len(self.crop_dict['labels']), "Number of crops and labels do not match"

    def __len__(self) -> int:
        return len(self.crop_dict['crops'])
        
    def __getitem__(self, index: int) -> torch.Tensor | torch.Tensor:
        assert index <= len(self), "Index out of range"
                
        image = self.crop_dict['crops'][index]
        label = self.crop_dict['labels'][index]

        
        if self.transform:
            image = self.transform(image)
        else:
            # Resize the image to 224x224
            image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_LANCZOS4)
            image = image *  1.0 / 255
            
            # TODO: This is the mean and std of ImageNet dataset, need to change to the mean and std of the dataset
            mean = [0.485, 0.456, 0.406]
            std = [0.229, 0.224, 0.225]

            # mean = [0.799, 0.818, 0.829]
            # std = [0.183, 0.179, 0.179]

            image = (image - mean) / std
            image = torch.from_numpy(image).permute(2, 0, 1).float()
        label = torch.tensor(label, dtype=torch.long)
        
        
        return image, label
    
# opt = dict(
#     image_file_path = '../NomDataset/datasets/mono-domain-datasets/tale-of-kieu/1871/1871-raw-images',
#     annotation_file_path = YOLO_ANNOTATION,
#     label_file_path = '../NomDataset/datasets/mono-domain-datasets/tale-of-kieu/1871/1871-annotation/annotation-mynom',
#     unicode_dict_path = '../NomDataset/HWDB1.1-bitmap64-ucode-hannom-v2-tst-label-set-ucode.pkl',
#     image_size = (224, 224),
#     transform = None,
# )

# dataset = YoloCropDataset(**opt)

In [3]:
# img = dataset[3][0].permute(1, 2, 0).numpy()
# label = dataset[3][1].item()
# from matplotlib import pyplot as plt
# cv2.imwrite('test.jpg', img * 255)

# new_unicode_dict = dict()
# with open('../NomDataset/HWDB1.1-bitmap64-ucode-hannom-v2-tst-label-set-ucode.pkl', 'rb') as f:
#     unicode_dict = pickle.load(f)
# for idx, (k, v) in enumerate(unicode_dict.items()):
#     new_unicode_dict[idx] = k
# print(new_unicode_dict[label])
# print(chr(int(new_unicode_dict[label], 16)))


# Architectures

## Detector : YoloV5

In [3]:
# from yolov5.models.common import DetectMultiBackend
# from yolov5.utils.general import non_max_suppression, scale_coords, check_img_size, Profile, increment_path
# from yolov5.utils.dataloaders import LoadImages

# from pathlib import Path

# args = {
#     'weights': '../Backup/pretrained_model/yolov5_Nom.pt',
#     'source': '../NomDataset/datasets/mono-domain-datasets/tale-of-kieu/1871/1871-raw-images',
#     'project': 'runs/detect',
#     'name': 'exp',
#     'imgsz': (640, 640),
#     'conf_thres': 0.5,
#     'iou_thres': 0.5,   
#     'device': '',       # Let YOLO decide
#     'save_txt': True,
#     'save_crop': True,  # Save cropped prediction boxes, for debugging
#     'exist_ok': True,
#     'hide_labels': True,    # Hide labels from output images
#     'hide_conf': True,      # Hide confidence, these two ommited for better visualization
# }

# # Directories
# save_dir = increment_path(Path(args['project']) / args['name'], exist_ok=args['exist_ok'])  # increment run
# (save_dir / 'labels' if args['save_txt'] else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

# # Load Model
# model = DetectMultiBackend(weights=args['weights'], device=DEVICE, dnn=False, data=None, fp16=False)
# strides, names, pt = model.strides, model.names, model.pt
# imgsz = check_img_size((640, 640), s=strides)  # check img_size

# # Dataloader
# bs = 1
# dataset = LoadImages(
#     source = '../NomDataset/datasets/mono-domain-datasets/tale-of-kieu/1871/1871-raw-images',
#     img_size = imgsz,
#     stride = strides,
#     auto = pt,
#     vid_stride=1,
# )

# # Run inference
# model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
# seen, windows, dt = 0, [], (Profile(device=DEVICE), Profile(device=DEVICE), Profile(device=DEVICE))

# for path, im, im0s, vid_cap, s in dataset:
#     with dt[0]:
#         img = torch.from_numpy(im).to(DEVICE)
#         im = im.half() if model.fp16 else im.float()
#         im /= 255.0
#         if len(im.shape) == 3:
#             im = im[None]
#         if model.xml and im.shape[0] > 1:
#                 ims = torch.chunk(im, im.shape[0], 0)

#         # Inference
#         with dt[1]:
#             visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
#             if model.xml and im.shape[0] > 1:
#                 pred = None
#                 for image in ims:
#                     if pred is None:
#                         pred = model(image, augment=False, visualize=visualize).unsqueeze(0)
#                     else:
#                         pred = torch.cat((pred, model(image, augment=False, visualize=visualize).unsqueeze(0)), dim=0)
#                 pred = [pred, None]
#             else:
#                 pred = model(im, augment=False, visualize=visualize)
#         # NMS
#         with dt[2]:
#             pred = non_max_suppression(pred, args['conf_thres'], args['iou_thres'], None, False, max_det=1000)



PROJECT = './detect'
EXP_NAME = 'yolo_tok1871'
YOLO_RESULTS = os.path.join(PROJECT, EXP_NAME)
YOLO_WEIGHTS = '../Backup/pretrained_model/yolov5_Nom.pt'
SOURCE_DIR = '../TempResources/ToK1871/ToK1871_LRcubicx4'
YOLO_CROPS = YOLO_RESULTS + '/crops/nom_char' # Yolo will save the cropped images here
YOLO_ANNOTATION = YOLO_RESULTS + '/labels' # Yolo will save the labels here

# #%%
from yolov5.detect import run as YoloInference
args = {
    'weights': YOLO_WEIGHTS,
    'source': SOURCE_DIR,
    'imgsz': (640, 640),
    'conf_thres': 0.5,
    'iou_thres': 0.5,   
    'device': '',       # Let YOLO decide
    'save_txt': True,
    'save_crop': True,  # Save cropped prediction boxes, for debugging
    'project': PROJECT,
    'name': EXP_NAME,
    'exist_ok': False,
    'hide_labels': True,    # Hide labels from output images
    'hide_conf': True,      # Hide confidence, these two ommited for better visualization
}
YoloInference(**args)

YOLOv5  v7.0-326-gec331cbd Python-3.11.5 torch-2.2.1 CUDA:0 (NVIDIA GeForce GTX 1650, 4096MiB)

Fusing layers... 
Model summary: 224 layers, 7053910 parameters, 0 gradients
image 1/10 C:\Users\Soppo\Documents\GitHub\Thesis\TempResources\SR_Nom_Text\hd\che (copy)\1.JPG: 224x640 106 nom_chars, 97.9ms
image 2/10 C:\Users\Soppo\Documents\GitHub\Thesis\TempResources\SR_Nom_Text\hd\che (copy)\10.JPG: 256x640 39 nom_chars, 58.7ms
image 3/10 C:\Users\Soppo\Documents\GitHub\Thesis\TempResources\SR_Nom_Text\hd\che (copy)\2.JPG: 224x640 156 nom_chars, 14.6ms
image 4/10 C:\Users\Soppo\Documents\GitHub\Thesis\TempResources\SR_Nom_Text\hd\che (copy)\3.JPG: 256x640 193 nom_chars, 12.4ms
image 5/10 C:\Users\Soppo\Documents\GitHub\Thesis\TempResources\SR_Nom_Text\hd\che (copy)\4.JPG: 224x640 170 nom_chars, 20.9ms
image 6/10 C:\Users\Soppo\Documents\GitHub\Thesis\TempResources\SR_Nom_Text\hd\che (copy)\5.JPG: 224x640 168 nom_chars, 14.1ms
image 7/10 C:\Users\Soppo\Documents\GitHub\Thesis\TempResources\S

In [7]:
# YoloV5 test and metrics
from yolov5.val import run as YoloVal

YAML_DATA = 'config.yaml'
PROJECT = './test'
EXP_NAME = 'yolo_LRcubicx4'

args = {
    'weights': YOLO_WEIGHTS,
    'data': YAML_DATA,
    'imgsz': 640,
    'task': 'test',
    'batch_size': 4,
    'device': '',
    'project': PROJECT,
    'name': EXP_NAME,
    'exist_ok': True,
}
YoloVal(**args)


YOLOv5  v7.0-326-gec331cbd Python-3.11.5 torch-2.2.1 CUDA:0 (NVIDIA GeForce GTX 1650, 4096MiB)

Fusing layers... 
Model summary: 224 layers, 7053910 parameters, 0 gradients
[34m[1mtest: [0mScanning C:\Users\Soppo\Documents\GitHub\Thesis\Yolo_SR_Resnet\data\LRcubicx4\test\labels... 6 images, 0 backgrounds, 0 corrupt: 100%|██████████| 6/6 [00:10<00:00,  1.79s/it]
[34m[1mtest: [0mNew cache created: C:\Users\Soppo\Documents\GitHub\Thesis\Yolo_SR_Resnet\data\LRcubicx4\test\labels.cache
                 Class     Images  Instances          P          R      mAP50   mAP50-95: 100%|██████████| 2/2 [00:10<00:00,  5.03s/it]
                   all          6        618      0.995      0.694      0.779      0.505
Speed: 0.2ms pre-process, 159.7ms inference, 4.6ms NMS per image at shape (4, 3, 640, 640)
Results saved to [1mtest\yolo_LRcubicx4[0m


((0.9953615700307287,
  0.6944671902600705,
  0.7792180158933966,
  0.5046449782428649,
  0.0,
  0.0,
  0.0),
 array([    0.50464]),
 (0.17305215199788412, 159.7349246342977, 4.605650901794434))

## Recognizer : AlexNet

## Recognizer : Nom_Resnet101

In [None]:
class Nom_Resnet101(nn.Module):
    def __init__(self, n_classes, pretrained=True):
        super(Nom_Resnet101, self).__init__()
        self.model = resnet.resnet101(weights=resnet.ResNet101_Weights.DEFAULT)
        
        # Modify the last layer
        self.model.fc = nn.Linear(self.model.fc.in_features, n_classes)
        
    def forward(self, x):
        return self.model(x)

## Super-Resolution Generator: RRDB

In [None]:
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm

@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.

    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1.
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function.
    """
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
                
def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.

    Args:
        basic_block (nn.module): nn.module class for basic block.
        num_basic_block (int): number of blocks.

    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)

def pixel_unshuffle(x, scale):
    """ Pixel unshuffle.

    Args:
        x (Tensor): Input feature with shape (b, c, hh, hw).
        scale (int): Downsample ratio.

    Returns:
        Tensor: the pixel unshuffled feature.
    """
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)



class ResidualDenseBlock(nn.Module):
    """Residual Dense Block.

    Used in RRDB block in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat=64, num_grow_ch=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        # Empirically, we use 0.2 to scale the residual for better performance
        return x5 * 0.2 + x


class RRDB(nn.Module):
    """Residual in Residual Dense Block.

    Used in RRDB-Net in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat, num_grow_ch=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        # Empirically, we use 0.2 to scale the residual for better performance
        return out * 0.2 + x


class RRDBNet(nn.Module):
    """Networks consisting of Residual in Residual Dense Block, which is used
    in ESRGAN.

    ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

    We extend ESRGAN for scale x2 and scale x1.
    Note: This is one option for scale 1, scale 2 in RRDBNet.
    We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
    and enlarge the channel size before feeding inputs into the main ESRGAN architecture.

    Args:
        num_in_ch (int): Channel number of inputs.
        num_out_ch (int): Channel number of outputs.
        num_feat (int): Channel number of intermediate features.
            Default: 64
        num_block (int): Block number in the trunk network. Defaults: 23
        num_grow_ch (int): Channels for each growth. Default: 32.
    """

    def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
        super(RRDBNet, self).__init__()
        self.scale = scale
        if scale == 2:
            num_in_ch = num_in_ch * 4
        elif scale == 1:
            num_in_ch = num_in_ch * 16
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # upsample
        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        if self.scale == 2:
            feat = pixel_unshuffle(x, scale=2)
        elif self.scale == 1:
            feat = pixel_unshuffle(x, scale=4)
        else:
            feat = x
        feat = self.conv_first(feat)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        # upsample
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return out

# Training Loop

In [None]:
# Blank

# Testing

In [None]:
# Seperate cell because Dataset loading is slow
dataset = YoloCropDataset(
    image_file_path = '../NomDataset/datasets/mono-domain-datasets/tale-of-kieu/1871/1871-raw-images',
    annotation_file_path = YOLO_ANNOTATION,
    label_file_path = '../NomDataset/datasets/mono-domain-datasets/tale-of-kieu/1871/1871-annotation/annotation-mynom',
    unicode_dict_path = UCODE_DICT,
    image_size = (56, 56),
    transform = None,
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
unicode_dict = dict()
with open(UCODE_DICT, 'rb') as f:
    temp = pickle.load(f)
for idx, (k, v) in enumerate(temp.items()):
    unicode_dict[idx] = k

# Load the SR model
sr_model = RRDBNet(num_in_ch=3, num_out_ch=3, scale=4, num_feat=64, num_block=23, num_grow_ch=32)
sr_model.load_state_dict(torch.load('../Backup/pretrained_model/RealESRGAN_x4plus.pth')['params_ema'])
sr_model.eval()

# Load the recognizer model
recognizer_model = Nom_Resnet101(n_classes=len(unicode_dict.keys()))
recognizer_model.model.load_state_dict(torch.load('../Backup/pretrained_model/NomResnet101.pth'))
recognizer_model.eval()

from torchsummary import summary
summary(sr_model, (3, 56, 56))
summary(recognizer_model, (3, 224, 224))

In [None]:
from matplotlib import pyplot as plt
batch = next(iter(dataloader))
imgs, labels = batch
plt.figure()
for idx, i in enumerate(imgs, 1):
    if idx == 17:
        break
    img = i.permute(1, 2, 0).numpy()
    img = img * 255
    img = img.clip(0, 255).astype('uint8')
    plt.subplot(4, 4, idx)
    plt.imshow(img)
plt.show()

labels = labels.tolist()
print("Labels:", [unicode_dict[i] for i in labels][:16])
print("Labels:", end=' ')
for idx, i in enumerate(labels[:16]):
    if unicode_dict[i] == 'UNK':
        print("UNK", end=', ')
    else:
        print(chr(int(unicode_dict[i], 16)), end=', ')


In [None]:
sr_model.to(DEVICE)
recognizer_model.to(DEVICE)

pbar = tqdm(total=len(dataloader), desc='Testing')

correct_pred = 0
incorrect_pred = []
for idx, (imgs, labels) in enumerate(dataloader, 1):
    imgs = imgs.to(DEVICE)
    labels = labels.to(DEVICE)
    
    with torch.no_grad():
        sr_imgs = sr_model(imgs)
        preds = recognizer_model(sr_imgs)
        preds = F.softmax(preds, dim=1)
        preds = torch.argmax(preds, dim=1)
        
        correct_pred += torch.sum(preds == labels).item()
        # Record failure cases
        for i, (pred, label) in enumerate(zip(preds, labels)):
            if pred != label:
                incorrect_pred.append((f'{idx}_{i}', pred, label))
        pbar.update(1)
        
pbar.close()
print("Accuracy:", correct_pred / len(dataset))
