In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import sys
sys.path.append('/content/gdrive/MyDrive/east-pytorch-main/east-pytorch-main')

In [None]:
import os
import tqdm
import argparse

import torch
import torch.nn as nn
from torch.utils import data
from torch.optim import lr_scheduler

from nets.nn import EAST
from util.loss import Loss
from util.dataset import EASTDataset

In [None]:
def train(opt):
    file_num = len(os.listdir(opt.train_images))
    dataset = EASTDataset(opt.train_images, opt.train_labels)
    train_loader = data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers,
                                   drop_last=True)

    criterion = Loss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EAST()
    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True
    model.to(device)
    #-- (ajout pour le fine-tuning) --#
    model.load_state_dict(torch.load(opt.source_model))
    #---------------------------------#
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[opt.epoch_iter // 2], gamma=0.1)

    for epoch in range(opt.epoch_iter):
        model.train()
        epoch_loss = 0
        print(('\n' + '%10s' * 3) % ('epoch', 'loss', 'gpu'))
        progress_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader))
        for i, (img, gt_score, gt_geo, ignored_map) in progress_bar:
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(
                device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)

            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)
            s = ('%10s' + '%10.4g' + '%10s') % ('%g/%g' % (epoch + 1, opt.epoch_iter), epoch_loss / (i + 1), mem)
            progress_bar.set_description(s)

        scheduler.step()

        if (epoch + 1) % opt.interval == 0:
            state_dict = model.module.state_dict() if data_parallel else model.state_dict()
            torch.save(state_dict, os.path.join(opt.model_save, 'model_classic_epoch_{}.pth'.format(epoch + 1)))

# Nouvelle section

In [None]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='EAST: An Efficient and Accurate Scene Text Detector')
    parser.add_argument('--train_images', type=str, default='/content/gdrive/MyDrive/east-pytorch-main/data_train_test_classic/train/images', help='path to train images')
    parser.add_argument('--train_labels', type=str, default='/content/gdrive/MyDrive/east-pytorch-main/data_train_test_classic/train/labels', help='path to train labels')
    #-- (ajout pour le fine-tuning) --#
    parser.add_argument('--source_model', type=str, default='/content/gdrive/MyDrive/east-pytorch-main/east-pytorch-main/weights/east.pth', help='path to source model')
    #---------------------------------#
    parser.add_argument('--model_save', type=str, default='/content/gdrive/MyDrive/east-pytorch-main/east-pytorch-main/weights', help='path to save checkpoints')
    parser.add_argument('--batch_size', type=int, default=12, help='batch size')
    parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers in dataloader')
    parser.add_argument('--epoch_iter', type=int, default=300, help='number of iterations')
    parser.add_argument('--interval', type=int, default=50, help='saving interval of checkpoints')
    parser.add_argument('-f')

    opt = parser.parse_args()
    train(opt)
    
  



     epoch      loss       gpu


     1/300     5.547       14G: 100%|██████████| 4/4 [02:42<00:00, 40.59s/it]


     epoch      loss       gpu



     2/300     5.362       14G: 100%|██████████| 4/4 [02:39<00:00, 39.85s/it]


     epoch      loss       gpu



     3/300     5.193       14G: 100%|██████████| 4/4 [02:42<00:00, 40.71s/it]


     epoch      loss       gpu



     4/300     5.042       14G: 100%|██████████| 4/4 [02:34<00:00, 38.66s/it]


     epoch      loss       gpu



     5/300      4.76       14G: 100%|██████████| 4/4 [02:33<00:00, 38.27s/it]


     epoch      loss       gpu



     6/300     4.608       14G: 100%|██████████| 4/4 [02:30<00:00, 37.60s/it]


     epoch      loss       gpu



     7/300     4.623       14G: 100%|██████████| 4/4 [02:44<00:00, 41.14s/it]


     epoch      loss       gpu



     8/300     4.426       14G: 100%|██████████| 4/4 [02:39<00:00, 39.79s/it]


     epoch      loss       gpu



     9/300     4.343       14G: 100%|██████████| 4/4 [02:39<00:00, 39.90s/it]


     epoch      loss       gpu



    10/300     4.301       14G: 100%|██████████| 4/4 [02:31<00:00, 37.90s/it]


     epoch      loss       gpu



    11/300     4.182       14G: 100%|██████████| 4/4 [02:29<00:00, 37.33s/it]


     epoch      loss       gpu



    12/300     4.114       14G: 100%|██████████| 4/4 [02:25<00:00, 36.45s/it]


     epoch      loss       gpu



    13/300      4.07       14G: 100%|██████████| 4/4 [02:36<00:00, 39.07s/it]


     epoch      loss       gpu



    14/300      4.03       14G: 100%|██████████| 4/4 [02:30<00:00, 37.65s/it]


     epoch      loss       gpu



    15/300     4.129       14G: 100%|██████████| 4/4 [02:29<00:00, 37.47s/it]


     epoch      loss       gpu



    16/300     4.055       14G: 100%|██████████| 4/4 [02:32<00:00, 38.04s/it]


     epoch      loss       gpu



    17/300      4.07       14G: 100%|██████████| 4/4 [02:31<00:00, 37.96s/it]


     epoch      loss       gpu



    18/300     3.993       14G: 100%|██████████| 4/4 [02:33<00:00, 38.47s/it]


     epoch      loss       gpu



    19/300     3.957       14G: 100%|██████████| 4/4 [02:31<00:00, 37.82s/it]


     epoch      loss       gpu



    20/300      3.97       14G: 100%|██████████| 4/4 [02:35<00:00, 38.86s/it]


     epoch      loss       gpu



    21/300     3.969       14G: 100%|██████████| 4/4 [02:30<00:00, 37.70s/it]


     epoch      loss       gpu



    22/300     3.749       14G: 100%|██████████| 4/4 [02:26<00:00, 36.70s/it]


     epoch      loss       gpu



    23/300     3.984       14G: 100%|██████████| 4/4 [02:38<00:00, 39.66s/it]


     epoch      loss       gpu



  0%|          | 0/4 [00:00<?, ?it/s]

# Nouvelle section