In [1]:
import os
import time
import math
import random
import shutil
import contextlib
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torch.utils.data as data
import torchvision.transforms as transforms

from utils.converter import LabelConverter, IndexConverter
from datasets.dataset import InMemoryDigitsDataset, DigitsDataset, collate_train, collate_dev, inmemory_train, inmemory_dev
from generate import gen_text_img

import arguments
from models.densenet_ import DenseNet

In [2]:
torch.cuda.current_device()

0

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.argv = ['main.py','--dataset-root','alphabet','--arch','densenet121','--alphabet','alphabet/alphabet_decode_5990.txt',
            '--lr','5e-5','--max-epoch','20','--optimizer','rmsprop','--gpu-id','0']
args = arguments.parse_args()
if args.gpu_id < 0:
    device = torch.device("cpu")
else:
#     os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    device = torch.device("cuda")
#     torch.backends.cudnn.benchmark = True
    
if os.path.isfile(args.alphabet):
    alphabet = ''
    with open(args.alphabet, mode='r', encoding='utf-8') as f:
        for line in f.readlines():
            alphabet += line.strip()
    args.alphabet = alphabet

num = 2
dev_num = num
use_file = 1
text = "嘤嘤嘤"
text_length = 10
font_size = -1
font_id = -1
space_width = 1
text_color = '#282828'
thread_count = 8
channel = 3

random_skew = True
skew_angle = 2
random_blur = True
blur = 0.5

orientation = 0
distorsion = -1
distorsion_orientation = 2
background = 1

random_process = True
noise = 20
erode = 2
dilate = 2
incline = 10
transform = transforms.Compose([
    transforms.Resize((32, 280)),
    transforms.ToTensor(),
])

model = DenseNet(num_classes=len(args.alphabet) + 1).to(device)
text_meta, text_img = gen_text_img(num, use_file, text, text_length, font_size, font_id, space_width,
                                       background, text_color,
                                       orientation, blur, random_blur, distorsion, distorsion_orientation,
                                       skew_angle, random_skew,
                                       random_process, noise, erode, dilate, incline,
                                                                thread_count)
dev_meta, dev_img = text_meta, text_img

index_converter = IndexConverter(args.alphabet, ignore_case=True)

train_dataset = InMemoryDigitsDataset(mode='train', text=text_meta, img=text_img, total=num,
                                     transform=transform, converter = index_converter)
dev_dataset = InMemoryDigitsDataset(mode='dev', text=dev_meta, img=dev_img, total=num,
                                   transform=transform, converter = index_converter)

train_loader = data.DataLoader(dataset=train_dataset,batch_size=args.batch_size, num_workers=4, shuffle=True,
                               collate_fn=collate_train, pin_memory=True)
dev_loader = data.DataLoader(dataset=dev_dataset,batch_size=args.batch_size, num_workers=4, shuffle=False,
                             collate_fn=collate_dev, pin_memory=False)

for i, sample in enumerate(train_loader):
        images = sample.images.to(device)
        targets = sample.targets.type(torch.long).to(device)
        target_lengths = sample.target_lengths.type(torch.long).to(device)
        
        log_probs = model(images).to(device)
        input_lengths = torch.full((images.size(0),),log_probs.size(0), dtype=torch.long).to(device)
        print(log_probs.dtype)
        print(targets.dtype)
        print(input_lengths.dtype)
        print(target_lengths.dtype)
        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
        loss.backward()

torch.float32
torch.int64
torch.int32
torch.int64


In [35]:
T = 35      # Input sequence length
C = 5990      # Number of classes (excluding blank)
N = 64      # Batch size
S = 30      # Target sequence length of longest target in batch
S_min = 10  # Minimum target length, for demonstration purposes

input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()

target = torch.randint(low=1, high=C+1, size=(N, S))
input_lengths = torch.full(size=(N,), fill_value=T)
target_lengths = torch.randint(low=S_min, high=S, size=(N,))

device = torch.device("cuda")

input = input.to(device)
target = target.to(device)
input_lengths = input_lengths.to(device)
target_lengths = target_lengths.to(device)

loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

In [36]:
print(input.dtype)
print(target.dtype)
print(input_lengths.dtype)
print(target_lengths.dtype)

torch.float32
torch.int64
torch.float32
torch.int64


In [4]:
print(input.shape)

torch.Size([255, 16, 20])


In [5]:
print(target.shape)

torch.Size([16, 30])


In [6]:
input_lengths

tensor([255., 255., 255., 255., 255., 255., 255., 255., 255., 255., 255., 255.,
        255., 255., 255., 255.], device='cuda:0')

In [7]:
target_lengths

tensor([19, 19, 17, 11, 28, 11, 28, 10, 25, 10, 10, 27, 14, 21, 11, 14],
       device='cuda:0')

In [8]:
ctc_loss = nn.CTCLoss().to(device)
loss = ctc_loss(input, target, input_lengths, target_lengths)

In [46]:
import os
format(os.curdir.join("checkpoint"))


'c.h.e.c.k.p.o.i.n.t'

In [53]:
os.path.join(os.curdir,"checkpoint","new_prarams2.pth'")

".\\checkpoint\\new_prarams2.pth'"

In [54]:
index = "1234567890qwertyuiopasdfghjklzxcvbnm一二三四五六七八九十"
out=""
import random
for i in range(5000):
    out+=index[random.randint(0,len(index)-1)]

In [55]:
out

'uf三rmzwgmhk1z七四g5r五5tpf四zm0四ajhi66十1351五jp5二五h四17u53iig6v9xh二96dukg7r九a九u7u7h4z4十9o二zlf六m5r二l1qz3n6kff二四q三53ll七30wdom八7f721yvqvic四tn9lydi七ksj3cg七十aj七九g三z66pgk八v3四6y七m9yfkl五xnmt四g四ffvmg六wca85w八x2zab九36gv七4四j06四二f七w四3z2五f5vjql二0l二jvw7u6zm二j七unfs十一y2四zh9y8icj6sr八七9coh0三0五t六se0十四0九一mc91md1三3oj4vtbl5rmajag二13一1h七3cvie四u九五0g六六h七xvu8十三4九vp十k三8ry三z六qh三18五fxfk八3li6v二q一五7四da7xbgp3v03sa三h56x九ljw6heut2ui65v6dhhdeekblpan三z5三o4x93p5wue二三5y6cyy1六三tlq六vm8sls七a三06edt6zz七vsgu8z六7二5cdiykv16npytqqg1二二a85khad一az2三rkyq7七九九bs4qc九3t一5bwqj4kp3jy154二6irzlg6九1v二8u16pg6mcrui四三9qj1kixxm1五gv8h9t二jjgyd9ey7三00三f四fq9七u九ev二wpb三c6guhhn四ca四e4qmd六一mq3sw3wzb7adn044kwjj79fq一rpdn8e3y4d十k4lcrhfy十0r15十j九f23zipcm七bq7v19十fael五n七五一三o4b六1wp1八swhpepbfeo六0y9os十5i十0sxq96一十j八2八tt九九5m一一七7km七二ikn3三v六八七ak5十0cq91fk3mkpxbi八xuf三五mj7x十g8五8byhfge5c70f1ymtlxz一i五kzfqib83vp1fxusdovl6y一l5一l一okwh九h70一p二7lad八五hllrx十7gbojbneqn6九十8一z十p六24y九gc3pltxd59150四c一n54四九d九u八8fsca三四mgq3vvup三bger7ldr7八二九5wv1c一二mf00b3t四五五6四4e58r2nr30zq十2三九二九a5s七ex二0pc8vzgxswixo7