In [1]:
import os
import time
import math
import random
import shutil
import contextlib
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torch.utils.data as data
import torchvision.transforms as transforms

from utils.converter import LabelConverter, IndexConverter
from datasets.dataset import InMemoryDigitsDataset, DigitsDataset, collate_train, collate_dev, inmemory_train, inmemory_dev
from generate import gen_text_img

import arguments
from models.densenet_ import DenseNet

In [2]:
torch.cuda.current_device()

0

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.argv = ['main.py','--dataset-root','alphabet','--arch','densenet121','--alphabet','alphabet/alphabet_decode_5990.txt',
            '--lr','5e-5','--max-epoch','20','--optimizer','rmsprop','--gpu-id','0']
args = arguments.parse_args()
if args.gpu_id < 0:
    device = torch.device("cpu")
else:
#     os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    device = torch.device("cuda")
#     torch.backends.cudnn.benchmark = True
    
if os.path.isfile(args.alphabet):
    alphabet = ''
    with open(args.alphabet, mode='r', encoding='utf-8') as f:
        for line in f.readlines():
            alphabet += line.strip()
    args.alphabet = alphabet

num = 2
dev_num = num
use_file = 1
text = "嘤嘤嘤"
text_length = 10
font_size = -1
font_id = -1
space_width = 1
text_color = '#282828'
thread_count = 8
channel = 3

random_skew = True
skew_angle = 2
random_blur = True
blur = 0.5

orientation = 0
distorsion = -1
distorsion_orientation = 2
background = 1

random_process = True
noise = 20
erode = 2
dilate = 2
incline = 10
transform = transforms.Compose([
    transforms.Resize((32, 280)),
    transforms.ToTensor(),
])

model = DenseNet(num_classes=len(args.alphabet) + 1).to(device)
text_meta, text_img = gen_text_img(num, use_file, text, text_length, font_size, font_id, space_width,
                                       background, text_color,
                                       orientation, blur, random_blur, distorsion, distorsion_orientation,
                                       skew_angle, random_skew,
                                       random_process, noise, erode, dilate, incline,
                                                                thread_count)
dev_meta, dev_img = text_meta, text_img

index_converter = IndexConverter(args.alphabet, ignore_case=True)

train_dataset = InMemoryDigitsDataset(mode='train', text=text_meta, img=text_img, total=num,
                                     transform=transform, converter = index_converter)
dev_dataset = InMemoryDigitsDataset(mode='dev', text=dev_meta, img=dev_img, total=num,
                                   transform=transform, converter = index_converter)

train_loader = data.DataLoader(dataset=train_dataset,batch_size=args.batch_size, num_workers=4, shuffle=True,
                               collate_fn=collate_train, pin_memory=True)
dev_loader = data.DataLoader(dataset=dev_dataset,batch_size=args.batch_size, num_workers=4, shuffle=False,
                             collate_fn=collate_dev, pin_memory=False)

for i, sample in enumerate(train_loader):
        images = sample.images.to(device)
        targets = sample.targets.type(torch.long).to(device)
        target_lengths = sample.target_lengths.type(torch.long).to(device)
        
        log_probs = model(images).to(device)
        input_lengths = torch.full((images.size(0),),log_probs.size(0), dtype=torch.long).to(device)
        print(log_probs.dtype)
        print(targets.dtype)
        print(input_lengths.dtype)
        print(target_lengths.dtype)
        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
        loss.backward()

torch.float32
torch.int64
torch.int32
torch.int64


In [35]:
T = 35      # Input sequence length
C = 5990      # Number of classes (excluding blank)
N = 64      # Batch size
S = 30      # Target sequence length of longest target in batch
S_min = 10  # Minimum target length, for demonstration purposes

input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()

target = torch.randint(low=1, high=C+1, size=(N, S))
input_lengths = torch.full(size=(N,), fill_value=T)
target_lengths = torch.randint(low=S_min, high=S, size=(N,))

device = torch.device("cuda")

input = input.to(device)
target = target.to(device)
input_lengths = input_lengths.to(device)
target_lengths = target_lengths.to(device)

loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

In [36]:
print(input.dtype)
print(target.dtype)
print(input_lengths.dtype)
print(target_lengths.dtype)

torch.float32
torch.int64
torch.float32
torch.int64


In [4]:
print(input.shape)

torch.Size([255, 16, 20])


In [5]:
print(target.shape)

torch.Size([16, 30])


In [6]:
input_lengths

tensor([255., 255., 255., 255., 255., 255., 255., 255., 255., 255., 255., 255.,
        255., 255., 255., 255.], device='cuda:0')

In [7]:
target_lengths

tensor([19, 19, 17, 11, 28, 11, 28, 10, 25, 10, 10, 27, 14, 21, 11, 14],
       device='cuda:0')

In [8]:
ctc_loss = nn.CTCLoss().to(device)
loss = ctc_loss(input, target, input_lengths, target_lengths)

In [46]:
import os
format(os.curdir.join("checkpoint"))


'c.h.e.c.k.p.o.i.n.t'

In [53]:
os.path.join(os.curdir,"checkpoint","new_prarams2.pth'")

".\\checkpoint\\new_prarams2.pth'"

In [51]:
index = "1234567890qwertyuiopasdfghjklzxcvbnm一二三四五六七八九十"
out=""
import random
for i in range(5000):
    out+=index[random.randint(0,len(index)-1)]

In [52]:
out

'rkxcga7kquflwl1wzyone8zviy30t9t1tx8lw5ovea7x7rtj5dzncwspgckrfl36cdgj6p7rp6mfqds13rcey0p6dtzt532v9ch1h1twni9rf0l0l2gpw1frxrxbe5n4olu19m8h83araj436wc4s2ryzyff4lz2w32rn7nnpynydrlk5sgcahk808q19takejbsf0033t8we6dzgaueybktdp44yvicfvzzpdxvwk6n0br5kni17d2y9duh1ulsrtw5fe1wsltkoe718oi68exl7lbl3dqb857k12qnkrlz29m5jdi1vjin2pzg90e05nz1a2gk5rilfy0b3fmt9vejp3bvsxspkyjvucjg3ylon4p2rornyff9yjzw3un7997jd3r8qdek1t39rmkx2ne4pcfu2j7tnvsbk0nxm1vljkdz7f5mafr8g6vm8eeqzgccyg0rp34gf16co15wym7p29u0rqpf6hxnmf2wqnrxekc1775x0ygvs1tmchswmuefoyzyymc1uu3oilsce80g80kbjymg7fnxxs4ezbx03s5t0ukoeitr8tk1rextz0kiqmncq12fetxl9jb2y6g67it6gv9zhxgfacc90uebs9w9ue37lq4ufvy1d4ocfk4uhlyc461e9cr639zrr5hczmik1sqlhy2iruyop6rcka8z3er49vridjw9rszcw3kejc4z7zzndr6pdgglj8hbnitnoa05m0zc7xdqyr5fz8nsdzv824xjj9u99vs7fky4b96g892pzrsy0jtaw2o1x39cmab3tgtbp1y6pm7xuw1vmicb4oiloeiwmpzdfqdquuy22dz6qlsjc2u8tms4kddmdobyevy2429kjugpqkhjieqcmtvbg0u5rnke1drro1dp1b5p2nen8od569sercuqdf0rt59pi7dolecbang0kf5n2qm88xv5qesi6f68qcu108u1ag9hak7vfyfmymiflkiph6ovf49