# Скачиваем данные

In [None]:
!wget https://storage.yandexcloud.net/datasouls-competitions/ai-nto-final-2022/data.zip

In [None]:
!unzip data.zip

In [None]:
!rm -r data.zip

# Ставим библиотеки

In [None]:
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -q transformers datasets jiwer wandb albumentations==1.0.0
!pip uninstall opencv-python-headless==4.5.5.62 -y
!pip install opencv-python-headless==4.1.2.30
!pip install Augmentor
!pip install augmixations==0.1.2 bezier==2020.5.19

# Импортируем все библиотеки

In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import wandb

import numpy as np
import cv2
import os
import json
from matplotlib import pyplot as plt
import albumentations as A
import Augmentor
from augmixations import HandWrittenBlot

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#Train/Test split

In [None]:
import pandas as pd

df = pd.read_csv('/content/data/train_recognition/labels.csv')

In [None]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df, test_size=0.05, random_state=56)

In [None]:
val_df

In [None]:
import json

json.dump({p: t for p, t in zip(train_df['file_name'], train_df['text'])}, open('data/train_recognition/train_labels_splitted.json', 'w'))
json.dump({p: t for p, t in zip(val_df['file_name'], val_df['text'])}, open('data/train_recognition/val_labels_splitted.json', 'w'))

# Задаём параметры

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


config_json = {
    "alphabet": " !\"#$%&'()*+,-./0123456789:;<=>?@[\\]^_`{|}~«»ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM",
    "save_dir": "data/experiments/test",
    "num_epochs": 50,
    "image": {
        "width": 256,
        "height": 64
    },
    "train": {
        "root_path": "data/train_recognition/images/",
        "json_path": "data/train_recognition/train_labels_splitted.json",
        "batch_size": 256
    },
    "val": {
        "root_path": "data/train_recognition/images/",
        "json_path": "data/train_recognition/val_labels_splitted.json",
        "batch_size": 256
    }
}

#Теперь определяем класс датасета (torch.utils.data.Dataset) и другие вспомогательные функции

In [None]:
# функция которая помогает объединять картинки и таргет-текст в батч
def collate_fn(batch):
    images, texts, enc_texts = zip(*batch)
    images = torch.stack(images, 0)
    text_lens = torch.LongTensor([len(text) for text in texts])
    enc_pad_texts = pad_sequence(enc_texts, batch_first=True, padding_value=0)
    return images, texts, enc_pad_texts, text_lens


def get_data_loader(
    transforms, json_path, root_path, tokenizer, batch_size, drop_last, a_transforms=None
):
    
    dataset = OCRDataset(json_path, root_path, tokenizer, a_transforms, transforms)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=batch_size,
        num_workers=8,
    )
    return data_loader


class OCRDataset(Dataset):
    def __init__(self, json_path, root_path, tokenizer, a_transform=None, transform=None):
        super().__init__()
        self.a_transform = a_transform
        self.transform = transform
        with open(json_path, 'r') as f:
            data = json.load(f)
        self.data_len = len(data)

        self.img_paths = []
        self.texts = []
        for img_name, text in data.items():
            self.img_paths.append(os.path.join(root_path, img_name))
            self.texts.append(text)
        self.enc_texts = tokenizer.encode(self.texts)

    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        text = self.texts[idx]
        enc_text = torch.LongTensor(self.enc_texts[idx])
        image = cv2.imread(img_path)
        if self.a_transform is not None:
            image = self.a_transform(image=image)['image']
        if self.transform is not None:
            image = self.transform(image)
        return image, text, enc_text


class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Здесь определен Токенайзер - вспопогательный класс, который преобразует текст в числа

In [None]:
OOV_TOKEN = '<OOV>'
CTC_BLANK = '<BLANK>'
SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'


def get_char_map(alphabet):
    """Make from string alphabet character2int dict.
    Add BLANK char fro CTC loss and OOV char for out of vocabulary symbols."""
    char_map = {value: idx + 4 for (idx, value) in enumerate(alphabet)}
    char_map[CTC_BLANK] = 0
    char_map[OOV_TOKEN] = 1
    char_map[SOS_TOKEN] = 2
    char_map[EOS_TOKEN] = 3
    return char_map


class Tokenizer:
    """Class for encoding and decoding string word to sequence of int
    (and vice versa) using alphabet."""

    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def encode(self, word_list):
        """Returns a list of encoded words (int)."""
        enc_words = []
        for word in word_list:
            enc_words.append(
                [self.char_map[char] if char in self.char_map
                 else self.char_map[OOV_TOKEN]
                 for char in word]
            )
        return enc_words

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        """Returns a list of words (str) after removing blanks and collapsing
        repeating characters. Also skip out of vocabulary token."""
        dec_words = []
        for word in enc_word_list:
            word_chars = ''
            for idx, char_enc in enumerate(word):
                # skip if blank symbol, oov token or repeated characters
                if (
                    char_enc not in [self.char_map[OOV_TOKEN], self.char_map[CTC_BLANK], self.char_map[SOS_TOKEN], self.char_map[EOS_TOKEN]]
                    # idx > 0 to avoid selecting [-1] item
                    and not (idx > 0 and char_enc == word[idx - 1])
                ):
                    word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words

# Ставим метрику 

In [None]:
from datasets import load_metric

cer_metric = load_metric("cer")

In [None]:
def get_accuracy(y_true, y_pred):
    return 1 - cer_metric.compute(predictions=y_pred, references=y_true)

# Делаем аугментации

In [None]:
import random
class Normalize:
    def __call__(self, img):
        img = img.astype(np.float32) / 255
        return img


class ToTensor:
    def __call__(self, arr):
        arr = torch.from_numpy(arr)
        return arr


class MoveChannels:
    """Move the channel axis to the zero position as required in pytorch."""

    def __init__(self, to_channels_first=True):
        self.to_channels_first = to_channels_first

    def __call__(self, image):
        if self.to_channels_first:
            return np.moveaxis(image, -1, 0)
        else:
            return np.moveaxis(image, 0, -1)


class ImageResize:
    def __init__(self, height, width):
        self.height = height
        self.width = width

    def __call__(self, img):
        #img = np.stack([img, img, img], axis=-1)
        if img.shape[0] > img.shape[1] * 2 and img.shape[1] > 80:
          img = cv2.rotate(img, cv2.cv2.ROTATE_90_CLOCKWISE)
        img[img.sum(axis=2) == 0] = np.array([255, 255, 255])
        w, h,_ = img.shape
        
        new_w = self.height
        new_h = int(h * (new_w / w)) 
        img = cv2.resize(img, (new_h, new_w))
        w, h,_ = img.shape
        
        img = img.astype('float32')
        
        new_h = self.width
        if h < new_h:
            add_zeros = np.full((w, new_h-h,3), 255)
            img = np.concatenate((img, add_zeros), axis=1)
        
        if h > new_h:
            img = cv2.resize(img, (new_h,new_w))
        # kernel = np.ones((2,2),np.uint8)
        # img = cv2.erode(img, kernel, iterations = 1)
        # img = cv2.medianBlur(img.astype('uint8'), 3)
        #img = cv2.fastNlMeansDenoisingColored(img.astype('uint8'),None,10,10,7,21)
        return img

class AlbuHandWrittenBlot(A.DualTransform):
    def __init__(self, hwb, always_apply=False, p=0.5):
        super(AlbuHandWrittenBlot, self).__init__(always_apply, p)
        self.hwb = hwb

    def apply(self, image, **params):
        tr_img = self.hwb(image)
        new_img = tr_img - image
        image[new_img.sum(axis=2) != 0] = np.array([69, 61, 96])
        return image

    def __call__(self, image, **params):
        tr_img = self.hwb(image)
        new_img = tr_img - image
        image[new_img.sum(axis=2) != 0] = np.array([69, 61, 96])
        return image

class UseWithProb:
    def __init__(self, transform, prob=0.5):
        self.transform = transform
        self.prob = prob

    def __call__(self, image):
        if random.random() < self.prob:
            image = self.transform(image)
        return image


class RandomGaussianBlur:
    """Apply Gaussian blur with random kernel size
    Args:
        max_ksize (int): maximal size of a kernel to apply, should be odd
        sigma_x (int): Standard deviation
    """

    def __init__(self, max_ksize=5, sigma_x=20):
        assert max_ksize % 2 == 1, "max_ksize should be odd"
        self.max_ksize = max_ksize // 2 + 1
        self.sigma_x = sigma_x

    def __call__(self, image):
        kernal_size = tuple(2 * np.random.randint(0, self.max_ksize, 2) + 1)
        blured_image = cv2.GaussianBlur(image, kernal_size, self.sigma_x)
        return blured_image


def img_crop(img, bbox):
    return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]


def random_crop(img, size):
    tw = size[0]
    th = size[1]
    h, w = img.shape[:2]
    if ((w - tw) > 0) and ((h - th) > 0):
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
    else:
        x1 = 0
        y1 = 0
    img_return = img_crop(img, (x1, y1, x1 + tw, y1 + th))
    return img_return, x1, y1


class RandomCrop:
    def __init__(self, rnd_crop_min, rnd_crop_max=1):
        self.factor_max = rnd_crop_max
        self.factor_min = rnd_crop_min

    def __call__(self, img):
        factor = random.uniform(self.factor_min, self.factor_max)
        size = (
            int(img.shape[1]*factor),
            int(img.shape[0]*factor)
        )
        img, x1, y1 = random_crop(img, size)
        return img


def largest_rotated_rect(w, h, angle):
    """
    https://stackoverflow.com/a/16770343
    Given a rectangle of size wxh that has been rotated by 'angle' (in
    radians), computes the width and height of the largest possible
    axis-aligned rectangle within the rotated rectangle.
    Original JS code by 'Andri' and Magnus Hoff from Stack Overflow
    Converted to Python by Aaron Snoswell
    """

    quadrant = int(math.floor(angle / (math.pi / 2))) & 3
    sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle
    alpha = (sign_alpha % math.pi + math.pi) % math.pi

    bb_w = w * math.cos(alpha) + h * math.sin(alpha)
    bb_h = w * math.sin(alpha) + h * math.cos(alpha)

    gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w)

    delta = math.pi - alpha - gamma

    length = h if (w < h) else w

    d = length * math.cos(alpha)
    a = d * math.sin(alpha) / math.sin(delta)

    y = a * math.cos(gamma)
    x = y * math.tan(gamma)

    return (
        bb_w - 2 * x,
        bb_h - 2 * y
    )


def crop_around_center(image, width, height):
    """
    https://stackoverflow.com/a/16770343
    Given a NumPy / OpenCV 2 image, crops it to the given width and height,
    around it's centre point
    """

    image_size = (image.shape[1], image.shape[0])
    image_center = (int(image_size[0] * 0.5), int(image_size[1] * 0.5))

    if(width > image_size[0]):
        width = image_size[0]

    if(height > image_size[1]):
        height = image_size[1]

    x1 = int(image_center[0] - width * 0.5)
    x2 = int(image_center[0] + width * 0.5)
    y1 = int(image_center[1] - height * 0.5)
    y2 = int(image_center[1] + height * 0.5)

    return image[y1:y2, x1:x2]


class RandomRotate:
    """Random image rotate around the image center
    Args:
        max_ang (float): Max angle of rotation in deg
    """

    def __init__(self, max_ang=0):
        self.max_ang = max_ang

    def __call__(self, img):
        h, w, _ = img.shape

        ang = np.random.uniform(-self.max_ang, self.max_ang)
        M = cv2.getRotationMatrix2D((w/2, h/2), ang, 1)
        img = cv2.warpAffine(img, M, (w, h))

        w_cropped, h_cropped = largest_rotated_rect(w, h, math.radians(ang))
        img = crop_around_center(img, w_cropped, h_cropped)
        return img

class Skew:
    def __init__(self, magnitude):
        self.m = magnitude
  
    def __call__(self, img):
        p = Augmentor.DataPipeline([[img]])
        p.skew_left_right(probability=1, magnitude=self.m)
        img_aug = p.sample(1)
        img_aug = img_aug[0][0]
        return img_aug.copy()

class AugHWB:
  def __init__(self, coef):
    self.coef = coef

  def __call__(self, img):
    rectangle_info = {
      'x': (None, None), ## Minimum and maximum X coordinate for blot position. Can be single int value.

      'y': (None, None), ## Minimum and maximum Y coordinate for blot position. Can be single int value.

      'h': (None, None), ## Minimum and maximum blots Height. Can be single int value.  

      'w': (None, int(img.shape[1]*self.coef)), ## Minimum and maximum blots Width. Can be single int value. 
    }
    blot_params = {
        'incline': (-10, 10), # Incline of blots. All left or right points of blot will be shifted on this value. Can be single int value.

        'intensivity': (0.4, 0.8), # Points count that will be generated for blots. Can be single float value (0, 1).

        'transparency': (0.2, 0.6), # Blots transparency. Can be single float value (0, 1).

        'count':3, # Min Max Blots count.
    }
    tr = HandWrittenBlot(
            rectangle_info ,blot_params)
    return tr(img)



def get_train_transforms(height, width):
    a_transforms = A.Compose([
                        A.OneOf([
                             A.RGBShift(p=1),
                             A.HueSaturationValue(p=1),
                             A.CLAHE(p=1),
                        ], p=0.1),
                        A.OneOf([
                             A.Blur(blur_limit=7,p=1),
                             A.GaussianBlur(p=1),
                             A.MedianBlur (blur_limit=7,p=1),
                        ], p=0.1),
                        A.OneOf([
                             A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=1),
                             A.RandomGamma (gamma_limit=(80, 150), p=1),
                             A.RandomToneCurve(p=1),
                        ], p=0.1),
                        A.OneOf([
                             A.ColorJitter(p=1),
                             A.JpegCompression(p=1),
                             #A.RingingOvershoot(p=1)
                        ], p=0.1),
                        A.OneOf([
                                 A.ISONoise(p=1),
                                 A.GaussNoise(p=1),
                                 A.MultiplicativeNoise(p=1)
                        ], p=0.1),
                        A.Sharpen(p=0.1),
                        A.RandomShadow (p=0.1),
                        #A.ToGray(p=0.2)
    ])
    
    transforms = torchvision.transforms.Compose([
        UseWithProb(AugHWB(0.3), 0.2),  
        UseWithProb(RandomRotate(5), 0.2),
        UseWithProb(RandomCrop(rnd_crop_min=0.85), 0.2),
        UseWithProb(Skew(0.2), 0.2),
        ImageResize(height, width),
        MoveChannels(to_channels_first=True),
        Normalize(),
        ToTensor()
    ])
    return a_transforms, transforms


def get_val_transforms(height, width):
    transforms = torchvision.transforms.Compose([
        ImageResize(height, width),
        MoveChannels(to_channels_first=True),
        Normalize(),
        ToTensor()
    ])
    return transforms

# Здесь определяем саму модель - CRNN

Подробнее об архитектуре можно почитать в статье https://arxiv.org/abs/1507.05717

In [None]:
import math
def get_backbone(pretrained=True):
    m = torchvision.models.resnet34(pretrained=True)
    input_conv = nn.Conv2d(3, 64, 7, 1, 3)
    output_conv = nn.Conv2d(512, 256 // 4, 1)
    blocks = [input_conv, m.bn1, m.relu,
              m.maxpool, m.layer1, m.layer2, m.layer3]
    return nn.Sequential(*blocks)


class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out

class LSTMBlock(nn.Module):
    def __init__(self, number_class_symbols, width=64, time_feature_count=256, lstm_hidden=256, lstm_len=2, gelu=False):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(
            (time_feature_count, width))
        self.bilstm = BiLSTM(time_feature_count, lstm_hidden, lstm_len)
        if gelu:
            self.classifier = nn.Sequential(
                nn.Linear(lstm_hidden * 2, lstm_hidden),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(lstm_hidden, number_class_symbols)
            )
        else:
            self.classifier = nn.Linear(lstm_hidden * 2, number_class_symbols)
    
    def forward(self, x):
        x = self.avg_pool(x)
        x = x.transpose(1, 2)
        x = self.bilstm(x)
        ctc_output = self.classifier(x)
        ctc_output = nn.functional.log_softmax(ctc_output, dim=2).permute(1, 0, 2)
        return ctc_output

class CRNN(nn.Module):
    def __init__(
        self, number_class_symbols, time_feature_count=256, lstm_hidden=256,
        lstm_len=2,
    ):
        super().__init__()
        self.feature_extractor = get_backbone(pretrained=True)
        self.lstm1 = LSTMBlock(number_class_symbols, 64, 256, 256, 2)
        self.lstm2 = LSTMBlock(number_class_symbols, 64, 256, 512, 2, True)
        self.lstm3 = LSTMBlock(number_class_symbols, 64, 256, 256, 3)
        self.lstm4 = LSTMBlock(number_class_symbols, 64, 512, 256, 2)


    def forward(self, x):
        x = self.feature_extractor(x)
        b, c, h, w = x.size()
        x = x.view(b, c * h, w)
        ctc_output1 = self.lstm1(x)
        ctc_output2 = self.lstm2(x)
        #ctc_output3 = self.lstm3(x)
        #ctc_output4 = self.lstm4(x)
        return ctc_output1, ctc_output2


# Переходим к самому скрипту обучения - циклы трейна и валидации

In [None]:
class CtcAttentionLoss():
  def __init__(self):
    self.ctc_criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
    self.att_criterion = nn.CrossEntropyLoss(ignore_index=0)

  def __call__(self, ctc_outputs, enc_pad_texts, output_lenghts, text_lens, att_outputs):
    att_loss = self.att_criterion(att_outputs.view(-1, att_outputs.size(2)), enc_pad_texts[:, 1:].permute(1, 0).reshape(-1))
    #enc_pad_texts[enc_pad_texts == 3] = 0
    
    ctc_loss = self.ctc_criterion(ctc_outputs, enc_pad_texts, output_lenghts, text_lens)
    return ctc_loss, att_loss, ctc_loss

In [None]:
from tqdm.notebook import tqdm
import transformers
def val_loop(data_loader, model, tokenizer, device):
    acc_avg = AverageMeter()
    for images, texts, _, _ in tqdm(data_loader):
        batch_size = len(texts)
        text_preds = predict(images, model, tokenizer, device)
        acc_avg.update(get_accuracy(texts, text_preds), batch_size)
    print(texts)
    print(text_preds)
    print(f'Validation, acc: {acc_avg.avg:.4f}')
    return acc_avg.avg


def train_loop(data_loader, model, criterion, optimizer, epoch, scheduler):
    loss_avg = AverageMeter()
    model.train()
    for images, texts, enc_pad_texts, text_lens in tqdm(data_loader):
        model.zero_grad()
        images = images.to(device=DEVICE)
        enc_pad_texts = enc_pad_texts.to(device=DEVICE)
        batch_size = len(texts)
        ctc_output1, ctc_output2 = model(images)

        output_lenghts1 = torch.full(
            size=(ctc_output1.size(1),),
            fill_value=ctc_output1.size(0),
            dtype=torch.long
        )
        loss1 = criterion(ctc_output1, enc_pad_texts, output_lenghts1, text_lens)

        output_lenghts2 = torch.full(
            size=(ctc_output2.size(1),),
            fill_value=ctc_output2.size(0),
            dtype=torch.long
        )
        loss2 = criterion(ctc_output2, enc_pad_texts, output_lenghts2, text_lens)

        loss_avg.update(loss1.item() + loss2.item(), batch_size)
        loss = loss1  + loss2 
        loss.backward()
        optimizer.step()
        scheduler.step()
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    print(f'\nEpoch {epoch}, Loss: {loss_avg.avg:.5f}, LR: {lr:.7f}')
    return loss_avg.avg


def predict(images, model, tokenizer, device):
    model.eval()
    images = images.to(device=device)
    with torch.no_grad():
        ctc_output1, ctc_output2 = model(images)
        ctc_output = 1.0 * ctc_output2
    pred = torch.argmax(ctc_output.detach().cpu(), -1).permute(1, 0).numpy()
    text_preds = tokenizer.decode(pred)
    return text_preds


def get_loaders(tokenizer, config):
    a_transforms, train_transforms = get_train_transforms(
        height=config['image']['height'],
        width=config['image']['width']
    )
    train_loader = get_data_loader(
        json_path=config['train']['json_path'],
        root_path=config['train']['root_path'],
        a_transforms=a_transforms,
        transforms=train_transforms,
        tokenizer=tokenizer,
        batch_size=config['train']['batch_size'],
        drop_last=True
    )
    val_transforms = get_val_transforms(
        height=config['image']['height'],
        width=config['image']['width']
    )
    val_loader = get_data_loader(
        transforms=val_transforms,
        json_path=config['val']['json_path'],
        root_path=config['val']['root_path'],
        tokenizer=tokenizer,
        batch_size=config['val']['batch_size'],
        drop_last=False
    )
    return train_loader, val_loader


def train(config):
    tokenizer = Tokenizer(config['alphabet'])
    os.makedirs(config['save_dir'], exist_ok=True)
    train_loader, val_loader = get_loaders(tokenizer, config)
    model = CRNN(number_class_symbols=tokenizer.get_num_chars())
    model.to(device=DEVICE)


    criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001,
                                  weight_decay=0.01)
    scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=5*len(train_loader), num_training_steps=100*len(train_loader))
    best_acc = -np.inf

    for epoch in tqdm(range(config['num_epochs'])):
        loss_avg = train_loop(train_loader, model, criterion, optimizer, epoch, scheduler)
        acc_avg = val_loop(val_loader, model, tokenizer, DEVICE)
        if acc_avg > best_acc:
            best_acc = acc_avg
            model_save_path = os.path.join(
                config['save_dir'], f'model-{epoch}-{acc_avg:.4f}.ckpt')
            
            torch.save(model.state_dict(), model_save_path)
            print('Model weights saved')
        all_model_save_path = os.path.join(
                config['save_dir'], f'all_model-{epoch}-{acc_avg:.4f}.ckpt')
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict()
         }, all_model_save_path)

# Учимся!

In [None]:
train(config_json)