In [54]:
import os
import PIL
from PIL import Image
from torch.utils.data import Dataset,DataLoader
import json
import torch
import matplotlib.pyplot as plt
from torch import nn
from torchvision import transforms

In [63]:
alphabet=[symb for symb in '_ABEKMHOPCTYX0123456789']

let2int={i:let for let,i in enumerate(alphabet)}
int2let={let:i for let,i in enumerate(alphabet)}

In [109]:
class NumberDataset(Dataset):
    def __init__(self,path,number_len):
        super(NumberDataset,self).__init__()
        self.number_len=number_len
        img_path=os.path.join(path,'img')
        label_path=os.path.join(path,'ann')

        #номера
        self.image_numbers=[img[:-4] for img in os.listdir(img_path)]
        self.label_numbers=[label[:-5] for label in os.listdir(label_path)]
        
        #изображения и лейблы 
        self.images=[os.path.join(img_path,img) for img in os.listdir(img_path) if img[:-4] in self.label_numbers]
        self.labels=[os.path.join(label_path,label) for label in os.listdir(label_path) if label[:-5] in self.image_numbers]
        
        self.images.sort(reverse=True)
        self.labels.sort(reverse=True)

        self.trans=transforms.Compose([
            transforms.Resize((64,128)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.images)
    def __getitem__(self,idx):
        idx_img=Image.open(self.images[idx]).convert('RGB')
        idx_label=self.labels[idx]
        with open(idx_label,'r') as file_option:
            jf=json.load(file_option)
            tensor_label=torch.tensor([let2int[let] for let in jf['name'][0:self.number_len]])
        tensor_img=self.trans(idx_img)

        return {
            'img':tensor_img,
            'label':tensor_label,
            'label_len':len(idx_label)
        }

In [126]:
def collate_fn(batch):
    imgs = torch.stack([x['img'] for x in batch])
    labels=[x['label'] for x in batch]
    label_lens=torch.tensor([x['label_len'] for x in batch])
    label=nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=0)
    return imgs,label,label_lens

In [119]:
number_data=NumberDataset('/home/artemybombastic/MyGit/VNR_Data/train',9)
number_dataloader=DataLoader(number_data,batch_size=16,shuffle=False,drop_last=True,collate_fn=collate_fn)

In [127]:
loss_func=nn.CTCLoss()