In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## 基本环境的准备

In [None]:
!pip install streamlit
!pip install hydra-core
!pip install python-multipart
!pip install editdistance

In [None]:
from pathlib import Path
import json
import tarfile
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
from urllib.request import urlretrieve
import pytorch_lightning.tuner as tuner
from pytorch_lightning.tuner.lr_finder import _LRFinder

import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

import random
from pathlib import Path
from typing import Optional

# albumentations 是一个适用于计算机视觉任务的图像增强库
import albumentations as A
import torch
from albumentations.pytorch.transforms import ToTensorV2
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from argparse import Namespace
from typing import List, Optional

import math
import torch.nn as nn
from torch import Tensor
from typing import Union
import torchvision.models
from pytorch_lightning import LightningModule
from typing import Set
import editdistance
from torchmetrics import Metric

import hydra
from omegaconf import DictConfig
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers.wandb import WandbLogger
# import os
# os.environ['CUDA_VISIBLE_DEVICES']='0'

## 运行时使用的一些工具函数，放在utils.py中

In [None]:
# utils.py

class TqdmUpTo(tqdm):
    """
    From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py.
    继承自tqdm库中的tqdm类，用于追踪下载进度，并更新tqdm进度条
    """

    def update_to(self, blocks=1, bsize=1, tsize=None) -> None:
        """Inform the progress bar how many data have been downloaded.

        Args:
            blocks: Number of blocks transferred so far.
            bsize: Size of each block (in tqdm units).
            tsize: Total size (in tqdm units).
        """
        if tsize is not None:
            self.total = tsize
        self.update(blocks * bsize - self.n)


def download_url(url: str, filename: str) -> None:
    """Download a file from url to filename, with a progress bar."""
    with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
        t.set_description(filename)
        urlretrieve(url, filename, reporthook=t.update_to, data=None)


def extract_tar_file(filename: str) -> None:
    """Extract a .tar or .tar.gz file."""
    print(f"Extracting {filename}...")
    with tarfile.open(filename, "r") as f:
        f.extractall()


class BaseDataset(Dataset):
    """
    A base Dataset class.
    继承自 torch.utils.data.Dataset 类，用于构建自定义的数据集类
    Args:
        image_filenames: (N, *) feature vector.
        targets: (N, *) target vector relative to data.
        transform: Feature transformation.
        target_transform: Target transformation.
    """

    def __init__(
        self,
        root_dir: Path,
        image_filenames: List[str],
        formulas: List[List[str]],
        transform: Optional[Callable] = None,
    ) -> None:
        super().__init__()
        assert len(image_filenames) == len(formulas)
        self.root_dir = root_dir
        self.image_filenames = image_filenames
        self.formulas = formulas
        self.transform = transform

    def __len__(self) -> int:
        """Returns the number of samples."""
        return len(self.formulas)

    def __getitem__(self, idx: int):
        """Returns a sample from the dataset at the given index."""
        image_filename, formula = self.image_filenames[idx], self.formulas[idx]
        image_filepath = self.root_dir / image_filename
        if image_filepath.is_file():
            image = pil_loader(image_filepath, mode="L")
        else:
            # Returns a blank image if cannot find the image
            image = Image.fromarray(np.full((64, 128), 255, dtype=np.uint8))
            formula = []
        if self.transform is not None:
            image = self.transform(image=np.array(image))["image"]
        return image, formula


def pil_loader(fp: Path, mode: str) -> Image.Image:
    """
    用于加载图像文件的函数，使用PIL库来打开和处理图像文件，
    将图像文件加载为 PIL 图像对象，并进行指定的模式转换后返回。
    """
    with open(fp, "rb") as f:
        img = Image.open(f)
        return img.convert(mode)


class Tokenizer:
    """
    用于处理文本数据的标记化(tokenizer)操作，将文本数据转换为数字表示。
    可以根据提供的标记到索引字典进行索引和反索引的转换。
    定义了一组特殊标记，以及对应的索引值，并且将这些特殊标记的索引添加到了 self.ignore_indices 集合中，以便在后续处理过程中进行忽略。
    """
    def __init__(self, token_to_index: Optional[Dict[str, int]] = None) -> None:
        self.pad_token = "<PAD>"
        self.sos_token = "<SOS>"
        self.eos_token = "<EOS>"
        self.unk_token = "<UNK>"

        self.token_to_index: Dict[str, int]
        self.index_to_token: Dict[int, str]

        if token_to_index:
            self.token_to_index = token_to_index
            self.index_to_token = {index: token for token, index in self.token_to_index.items()}
            self.pad_index = self.token_to_index[self.pad_token]
            self.sos_index = self.token_to_index[self.sos_token]
            self.eos_index = self.token_to_index[self.eos_token]
            self.unk_index = self.token_to_index[self.unk_token]
        else:
            self.token_to_index = {}
            self.index_to_token = {}
            self.pad_index = self._add_token(self.pad_token)
            self.sos_index = self._add_token(self.sos_token)
            self.eos_index = self._add_token(self.eos_token)
            self.unk_index = self._add_token(self.unk_token)

        self.ignore_indices = {self.pad_index, self.sos_index, self.eos_index, self.unk_index}

    def _add_token(self, token: str) -> int:
        """Add one token to the vocabulary.

        Args:
            token: The token to be added.

        Returns:
            The index of the input token.
        """
        if token in self.token_to_index:
            return self.token_to_index[token]
        index = len(self)
        self.token_to_index[token] = index
        self.index_to_token[index] = token
        return index

    def __len__(self):
        return len(self.token_to_index)

    def train(self, formulas: List[List[str]], min_count: int = 2) -> None:
        """Create a mapping from tokens to indices and vice versa.

        Args:
            formulas: Lists of tokens.
            min_count: Tokens that appear fewer than `min_count` will not be
                included in the mapping.
        """
        # Count the frequency of each token
        counter: Dict[str, int] = {}
        for formula in formulas:
            for token in formula:
                counter[token] = counter.get(token, 0) + 1

        for token, count in counter.items():
            # Remove tokens that show up fewer than `min_count` times
            if count < min_count:
                continue
            index = len(self)
            self.index_to_token[index] = token
            self.token_to_index[token] = index

    def encode(self, formula: List[str]) -> List[int]:
        indices = [self.sos_index]
        for token in formula:
            index = self.token_to_index.get(token, self.unk_index)
            indices.append(index)
        indices.append(self.eos_index)
        return indices

    def decode(self, indices: List[int], inference: bool = True) -> List[str]:
        tokens = []
        for index in indices:
            if index not in self.index_to_token:
                raise RuntimeError(f"Found an unknown index {index}")
            if index == self.eos_index:
                break
            if inference and index in self.ignore_indices:
                continue
            token = self.index_to_token[index]
            tokens.append(token)
        return tokens

    def save(self, filename: Union[Path, str]):
        """Save token-to-index mapping to a json file."""
        with open(filename, "w") as f:
            json.dump(self.token_to_index, f)

    @classmethod
    def load(cls, filename: Union[Path, str]) -> "Tokenizer":
        """Create a `Tokenizer` from a mapping file outputted by `save`.

        Args:
            filename: Path to the file to read from.

        Returns:
            A `Tokenizer` object.
        """
        with open(filename) as f:
            token_to_index = json.load(f)
        return cls(token_to_index)


def get_all_formulas(filename: Path) -> List[List[str]]:
    """Returns all the formulas in the formula file."""
    with open(filename, encoding='utf-8') as f:
        all_formulas = [formula.strip("\n").split() for formula in f.readlines()]
    return all_formulas


def get_split(
    all_formulas: List[List[str]],
    filename: Path,
) -> Tuple[List[str], List[List[str]]]:
    image_names = []
    formulas = []
    with open(filename) as f:
        for line in f:
            img_name, formula_idx = line.strip("\n").split()
            image_names.append(img_name)
            formulas.append(all_formulas[int(formula_idx)])
    return image_names, formulas


def first_and_last_nonzeros(arr):
    for i in range(len(arr)):
        if arr[i] != 0:
            break
    left = i
    for i in reversed(range(len(arr))):
        if arr[i] != 0:
            break
    right = i
    return left, right


def crop(filename: Path, padding: int = 8) -> Optional[Image.Image]:
    """
    用于对图像进行裁剪和填充操作。
    """
    image = pil_loader(filename, mode="RGBA")

    # Replace the transparency layer with a white background
    new_image = Image.new("RGBA", image.size, "WHITE")
    new_image.paste(image, (0, 0), image)
    new_image = new_image.convert("L")

    # Invert the color to have a black background and white text
    arr = 255 - np.array(new_image)

    # Area that has text should have nonzero pixel values
    row_sums = np.sum(arr, axis=1)
    col_sums = np.sum(arr, axis=0)
    y_start, y_end = first_and_last_nonzeros(row_sums)
    x_start, x_end = first_and_last_nonzeros(col_sums)

    # Some images have no text
    if y_start >= y_end or x_start >= x_end:
        print(f"{filename.name} is ignored because it does not contain any text")
        return None

    # Cropping
    cropped = arr[y_start : y_end + 1, x_start : x_end + 1]
    H, W = cropped.shape

    # Add paddings
    new_arr = np.zeros((H + padding * 2, W + padding * 2))
    new_arr[padding : H + padding, padding : W + padding] = cropped

    # Invert the color back to have a white background and black text
    new_arr = 255 - new_arr
    return Image.fromarray(new_arr).convert("L")


## 把处理IM2LaTeX数据集的py文件搬运到了im2latex.py中，在这里设置了batch_size等超参数和训练时的dataloader

In [None]:
#im2latex.py

class Im2Latex(LightningDataModule):
    """Data processing for the Im2Latex-100K dataset.

    Args:
        batch_size: The number of samples per batch.
        num_workers: The number of subprocesses to use for data loading.
        pin_memory: If True, the data loader will copy Tensors into CUDA pinned memory
            before returning them.
    """

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        pin_memory: bool = False,
    ) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        # self.data_dirname = Path(__file__).resolve().parents[2] / "data"
        self.vocab_file = Path("/kaggle/input/dataset1/lst") / "vocab.json"
        # formula_file = self.data_dirname / "im2latex_formulas.norm.new.lst"
        formula_file = Path("/kaggle/input/dataset1/lst") / "im2latex_formulas.norm.new.lst"
        # if not formula_file.is_file():
        #     raise FileNotFoundError("Did you run scripts/prepare_data.py?")
        self.all_formulas = get_all_formulas(formula_file)
        self.transform = {
            "train": A.Compose(
                [
                    A.Affine(scale=(0.6, 1.0), rotate=(-1, 1), cval=255, p=0.5),
                    A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
                    A.GaussianBlur(blur_limit=(1, 1), p=0.5),
                    ToTensorV2(),
                ]
            ),
            "val/test": ToTensorV2(),
        }

    # @property
    # def processed_images_dirname(self):
    #     return self.data_dirname / "formula_images_processed"

    def setup(self, stage: Optional[str] = None) -> None:
        """Load images and formulas, and assign them to a `torch Dataset`.

        `self.train_dataset`, `self.val_dataset` and `self.test_dataset` will
        be assigned after this method is called.
        """
        self.tokenizer = Tokenizer.load(self.vocab_file)

        if stage in ("fit", None):
            train_image_names, train_formulas = get_split(
                self.all_formulas,
                # self.data_dirname / "im2latex_train_filter.lst",
                Path("/kaggle/input/dataset1/lst") / "im2latex_train_filter.lst",
            )
            self.train_dataset = BaseDataset(
                # self.processed_images_dirname,
                Path("/kaggle/input/dataset1/images_3/train"),
                image_filenames=train_image_names,
                formulas=train_formulas,
                transform=self.transform["train"],
            )

            val_image_names, val_formulas = get_split(
                self.all_formulas,
                # self.data_dirname / "im2latex_validate_filter.lst",
                Path("/kaggle/input/dataset1/lst") / "im2latex_validate_filter.lst",
            )
            self.val_dataset = BaseDataset(
                # self.processed_images_dirname,
                Path("/kaggle/input/dataset1/images_3/val"),
                image_filenames=val_image_names,
                formulas=val_formulas,
                transform=self.transform["val/test"],
            )

        if stage in ("test", None):
            test_image_names, test_formulas = get_split(
                self.all_formulas,
                # self.data_dirname / "im2latex_test_filter.lst",
                Path("/kaggle/input/dataset1/lst") / "im2latex_test_filter.lst",
            )
            self.test_dataset = BaseDataset(
                # self.processed_images_dirname,
                Path("/kaggle/input/dataset1/images_3/test"),
                image_filenames=test_image_names,
                formulas=test_formulas,
                transform=self.transform["val/test"],
            )

    def collate_fn(self, batch):
        images, formulas = zip(*batch)
        B = len(images)
        max_H = max(image.shape[1] for image in images)
        max_W = max(image.shape[2] for image in images)
        max_length = max(len(formula) for formula in formulas)
        padded_images = torch.zeros((B, 1, max_H, max_W))
        batched_indices = torch.zeros((B, max_length + 2), dtype=torch.long)
        for i in range(B):
            H, W = images[i].shape[1], images[i].shape[2]
            y, x = random.randint(0, max_H - H), random.randint(0, max_W - W)
            padded_images[i, :, y : y + H, x : x + W] = images[i]
            indices = self.tokenizer.encode(formulas[i])
            batched_indices[i, : len(indices)] = torch.tensor(indices, dtype=torch.long)
        return padded_images, batched_indices

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            shuffle=True,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

## 位置编码

In [None]:
# positional_encoding.py

class PositionalEncoding2D(nn.Module):
    """2-D positional encodings for the feature maps produced by the encoder.

    Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.

    Reference:
    https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2021-labs/blob/main/lab9/text_recognizer/models/transformer_util.py
    """

    def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000) -> None:
        super().__init__()
        self.d_model = d_model
        assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
        pe = self.make_pe(d_model, max_h, max_w)  # (d_model, max_h, max_w)
        self.register_buffer("pe", pe)

    @staticmethod
    def make_pe(d_model: int, max_h: int, max_w: int) -> Tensor:
        """Compute positional encoding."""
        pe_h = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_h)  # (max_h, 1 d_model // 2)
        pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w)  # (d_model // 2, max_h, max_w)

        pe_w = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_w)  # (max_w, 1, d_model // 2)
        pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1)  # (d_model // 2, max_h, max_w)

        pe = torch.cat([pe_h, pe_w], dim=0)  # (d_model, max_h, max_w)
        return pe

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass.

        Args:
            x: (B, d_model, H, W)

        Returns:
            (B, d_model, H, W)
        """
        assert x.shape[1] == self.pe.shape[0]  # type: ignore
        x = x + self.pe[:, : x.size(2), : x.size(3)]  # type: ignore
        return x


class PositionalEncoding1D(nn.Module):
    """Classic Attention-is-all-you-need positional encoding."""

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = self.make_pe(d_model, max_len)  # (max_len, 1, d_model)
        self.register_buffer("pe", pe)

    @staticmethod
    def make_pe(d_model: int, max_len: int) -> Tensor:
        """Compute positional encoding."""
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        return pe

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass.

        Args:
            x: (S, B, d_model)

        Returns:
            (B, d_model, H, W)
        """
        assert x.shape[2] == self.pe.shape[2]  # type: ignore
        x = x + self.pe[: x.size(0)]  # type: ignore
        return self.dropout(x)

## ResNet&Transformer模型的定义

In [None]:
# resnet_transformer.py

class ResNetTransformer(nn.Module):
    def __init__(
        self,
        d_model: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_decoder_layers: int,
        max_output_len: int,
        sos_index: int,
        eos_index: int,
        pad_index: int,
        num_classes: int,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.max_output_len = max_output_len + 2
        self.sos_index = sos_index
        self.eos_index = eos_index
        self.pad_index = pad_index

        # Encoder
        resnet = torchvision.models.resnet50(pretrained=False)
        self.backbone = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
        )

        self.bottleneck = nn.Conv2d(2048, self.d_model, 1)
        self.image_positional_encoder = PositionalEncoding2D(self.d_model)

        # Decoder
        self.embedding = nn.Embedding(num_classes, self.d_model)
        self.y_mask = generate_square_subsequent_mask(self.max_output_len)
        self.word_positional_encoder = PositionalEncoding1D(self.d_model, max_len=self.max_output_len)
        transformer_decoder_layer = nn.TransformerDecoderLayer(self.d_model, nhead, dim_feedforward, dropout)
        self.transformer_decoder = nn.TransformerDecoder(transformer_decoder_layer, num_decoder_layers)
        self.fc = nn.Linear(self.d_model, num_classes)

        # It is empirically important to initialize weights properly
        if self.training:
            self._init_weights()

    def _init_weights(self) -> None:
        """Initialize weights."""
        init_range = 0.1
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-init_range, init_range)

        nn.init.kaiming_normal_(
            self.bottleneck.weight.data,
            a=0,
            mode="fan_out",
            nonlinearity="relu",
        )
        if self.bottleneck.bias is not None:
            _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.bottleneck.weight.data)
            bound = 1 / math.sqrt(fan_out)
            nn.init.normal_(self.bottleneck.bias, -bound, bound)

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        """Forward pass.

        Args:
            x: (B, _E, _H, _W)
            y: (B, Sy) with elements in (0, num_classes - 1)

        Returns:
            (B, num_classes, Sy) logits
        """
        encoded_x = self.encode(x)  # (Sx, B, E)
        output = self.decode(y, encoded_x)  # (Sy, B, num_classes)
        output = output.permute(1, 2, 0)  # (B, num_classes, Sy)
        return output

    def encode(self, x: Tensor) -> Tensor:
        """Encode inputs.

        Args:
            x: (B, C, _H, _W)

        Returns:
            (Sx, B, E)
        """
        # Resnet expects 3 channels but training images are in gray scale
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        x = self.backbone(x)  # (B, RESNET_DIM, H, W); H = _H // 32, W = _W // 32
        x = self.bottleneck(x)  # (B, E, H, W)
        x = self.image_positional_encoder(x)  # (B, E, H, W)
        x = x.flatten(start_dim=2)  # (B, E, H * W)
        x = x.permute(2, 0, 1)  # (Sx, B, E); Sx = H * W
        return x

    def decode(self, y: Tensor, encoded_x: Tensor) -> Tensor:
        """Decode encoded inputs with teacher-forcing.

        Args:
            encoded_x: (Sx, B, E)
            y: (B, Sy) with elements in (0, num_classes - 1)

        Returns:
            (Sy, B, num_classes) logits
        """
        y = y.permute(1, 0)  # (Sy, B)
        y = self.embedding(y) * math.sqrt(self.d_model)  # (Sy, B, E)
        y = self.word_positional_encoder(y)  # (Sy, B, E)
        Sy = y.shape[0]
        y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x)  # (Sy, Sy)
        output = self.transformer_decoder(y, encoded_x, y_mask)  # (Sy, B, E)
        output = self.fc(output)  # (Sy, B, num_classes)
        return output

    def predict(self, x: Tensor) -> Tensor:
        """Make predctions at inference time.

        Args:
            x: (B, C, H, W). Input images.

        Returns:
            (B, max_output_len) with elements in (0, num_classes - 1).
        """
        B = x.shape[0]
        S = self.max_output_len

        encoded_x = self.encode(x)  # (Sx, B, E)

        output_indices = torch.full((B, S), self.pad_index).type_as(x).long()
        output_indices[:, 0] = self.sos_index
        has_ended = torch.full((B,), False)

        for Sy in range(1, S):
            y = output_indices[:, :Sy]  # (B, Sy)
            logits = self.decode(y, encoded_x)  # (Sy, B, num_classes)
            # Select the token with the highest conditional probability
            output = torch.argmax(logits, dim=-1)  # (Sy, B)
            output_indices[:, Sy] = output[-1:]  # Set the last output token

            # Early stopping of prediction loop to speed up prediction
            has_ended |= (output_indices[:, Sy] == self.eos_index).type_as(has_ended)
            if torch.all(has_ended):
                break

        # Set all tokens after end token to be padding
        eos_positions = find_first(output_indices, self.eos_index)
        for i in range(B):
            j = int(eos_positions[i].item()) + 1
            output_indices[i, j:] = self.pad_index

        return output_indices


def generate_square_subsequent_mask(size: int) -> Tensor:
    """Generate a triangular (size, size) mask."""
    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
    return mask


def find_first(x: Tensor, element: Union[int, float], dim: int = 1) -> Tensor:
    """Find the first occurence of element in x along a given dimension.

    Args:
        x: The input tensor to be searched.
        element: The number to look for.
        dim: The dimension to reduce.

    Returns:
        Indices of the first occurence of the element in x. If not found, return the
        length of x along dim.

    Usage:
        >>> first_element(Tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1]]), 3)
        tensor([2, 1, 3])

    Reference:
        https://discuss.pytorch.org/t/first-nonzero-index/24769/9

        I fixed an edge case where the element we are looking for is at index 0. The
        original algorithm will return the length of x instead of 0.
    """
    mask = x == element
    found, indices = ((mask.cumsum(dim) == 1) & mask).max(dim)
    indices[(~found) & (indices == 0)] = x.shape[dim]
    return indices

## 计算字符错误率

In [None]:
# metrics.py

class CharacterErrorRate(Metric):
    """
    继承自torchmetrics中的Metric类 用于计算字符错误率
    """
    def __init__(self, ignore_indices: Set[int], *args):
        super().__init__(*args)
        self.ignore_indices = ignore_indices
        self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.error: Tensor
        self.total: Tensor

    def update(self, preds, targets):
        N = preds.shape[0]
        for i in range(N):
            pred = [token for token in preds[i].tolist() if token not in self.ignore_indices]
            target = [token for token in targets[i].tolist() if token not in self.ignore_indices]
            distance = editdistance.distance(pred, target)
            if max(len(pred), len(target)) > 0:
                self.error += distance / max(len(pred), len(target))
        self.total += N

    def compute(self) -> Tensor:
        return self.error / self.total

### 这是对模型、损失函数、优化器的合体的封装。里面放了train_step和val_step两个函数分别用来计算训练损失和验证损失，用test_step完成对测试集的训练。
### on_test_epoch_end函数用于训练结束后写入文件。

In [None]:
# lit_resnet_transformer.py

class LitResNetTransformer(LightningModule):
    def __init__(
        self,
        d_model: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_decoder_layers: int,
        max_output_len: int,
        lr: float = 0.001,
        weight_decay: float = 0.0001,
        milestones: List[int] = [5],
        gamma: float = 0.1,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.weight_decay = weight_decay
        self.milestones = milestones
        self.gamma = gamma

        vocab_file = "/kaggle/input/dataset1/lst/vocab.json"
        self.tokenizer = Tokenizer.load(vocab_file)
        self.model = ResNetTransformer(
            d_model=d_model,
            dim_feedforward=dim_feedforward,
            nhead=nhead,
            dropout=dropout,
            num_decoder_layers=num_decoder_layers,
            max_output_len=max_output_len,
            sos_index=self.tokenizer.sos_index,
            eos_index=self.tokenizer.eos_index,
            pad_index=self.tokenizer.pad_index,
            num_classes=len(self.tokenizer),
        )
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_index)
        self.val_cer = CharacterErrorRate(self.tokenizer.ignore_indices)
        self.test_cer = CharacterErrorRate(self.tokenizer.ignore_indices)

        self.test_step_outputs = []
        self.test_target = []

    def training_step(self, batch, batch_idx):
        imgs, targets = batch
        logits = self.model(imgs, targets[:, :-1])
        loss = self.loss_fn(logits, targets[:, 1:])
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, targets = batch
        logits = self.model(imgs, targets[:, :-1])
        loss = self.loss_fn(logits, targets[:, 1:])
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)

        preds = self.model.predict(imgs)
        val_cer = self.val_cer(preds, targets)
        self.log("val/cer", val_cer)

    def test_step(self, batch, batch_idx):
        imgs, targets = batch
        self.test_target.append(targets)
        preds = self.model.predict(imgs)
        test_cer = self.test_cer(preds, targets)
        self.log("test/cer", test_cer)
        self.test_step_outputs.append(preds)
        return preds

    # 在2.0.0中已经被替换
    # def test_epoch_end(self, test_outputs):
    #     with open("test_predictions.txt", "w") as f:
    #         for preds in test_outputs:
    #             for pred in preds:
    #                 decoded = self.tokenizer.decode(pred.tolist())
    #                 decoded.append("\n")
    #                 decoded_str = " ".join(decoded)
    #                 f.write(decoded_str)

    def on_test_epoch_end(self):
        with open("test_predictions.txt", "w") as f:
            for preds in self.test_step_outputs:
                for pred in preds:
                    decoded = self.tokenizer.decode(pred.tolist())
                    decoded.append("\n")
                    decoded_str = " ".join(decoded)
                    f.write(decoded_str)
        with open("test_targets.txt", "w") as f:
            for targets in self.test_target:
                for target in targets:
                    decoded = self.tokenizer.decode(target.tolist())
                    decoded.append("\n")
                    decoded_str = " ".join(decoded)
                    f.write(decoded_str)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.milestones, gamma=self.gamma)
        return [optimizer], [scheduler]

## 载入数据集完成训练

In [None]:
# 指定配置文件所在的目录
hydra.initialize_config_dir(config_dir="/kaggle/input/config/")

# 加载并合并配置文件
cfg = hydra.compose(config_name="config")

# print(cfg)
datamodule = Im2Latex(**cfg.data)
datamodule.setup()
# print(cfg.data)
lit_model = LitResNetTransformer(**cfg.lit_model)

callbacks: List[Callback] = []
if cfg.callbacks.model_checkpoint:
    callbacks.append(ModelCheckpoint(**cfg.callbacks.model_checkpoint, dirpath=Path("/kaggle/working/")))
    print("cfg.callbacks.model_checkpoint")
if cfg.callbacks.early_stopping:
    callbacks.append(EarlyStopping(**cfg.callbacks.early_stopping))
    print("cfg.callbacks.early_stopping")

logger: Optional[WandbLogger] = None
if cfg.logger:
    logger = WandbLogger(**cfg.logger)
    print("cfg.logger")

trainer = Trainer(**cfg.trainer, callbacks=callbacks, logger=logger) # , 
#                   accelerator='gpu', devices=2, auto_select_gpus=True)

if trainer.logger:
    trainer.logger.log_hyperparams(Namespace(**cfg))

# trainer.tune(lit_model, datamodule=datamodule)
# temp = tuner.Tuner(trainer)
# lr_best = temp.lr_find(lit_model, datamodule=datamodule)
# print("lr is: "+ lr_best.suggestion())
# lit_model.lr = lr_best.suggestion()
# optimal_batch_size = temp.scale_batch_size(lit_model, datamodule=datamodule)
# datamodule.batch_size = int(optimal_batch_size/2)
# datamodule.batch_size = 8


trainer.fit(lit_model, datamodule=datamodule)

## 生成测试数据

In [None]:
trainer.test(lit_model, datamodule=datamodule, ckpt_path="/los=0.04/cer=0.02/checkpoint.ckpt")