In [2]:
#%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
#%pip install torch
#%pip install torchinfo
#%pip install matplotlib
#%pip install numpy
#%pip install tqdm
#%pip install transformers

Collecting transformers
  Downloading transformers-4.46.2-py3-none-any.whl.metadata (44 kB)
Collecting huggingface-hub<1.0,>=0.23.2 (from transformers)
  Downloading huggingface_hub-0.26.2-py3-none-any.whl.metadata (13 kB)
Collecting pyyaml>=5.1 (from transformers)
  Downloading PyYAML-6.0.2-cp312-cp312-win_amd64.whl.metadata (2.1 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2024.11.6-cp312-cp312-win_amd64.whl.metadata (41 kB)
Collecting safetensors>=0.4.1 (from transformers)
  Downloading safetensors-0.4.5-cp312-none-win_amd64.whl.metadata (3.9 kB)
Collecting tokenizers<0.21,>=0.20 (from transformers)
  Downloading tokenizers-0.20.3-cp312-none-win_amd64.whl.metadata (6.9 kB)
Downloading transformers-4.46.2-py3-none-any.whl (10.0 MB)
   ---------------------------------------- 0.0/10.0 MB ? eta -:--:--
   ----------- ---------------------------- 2.9/10.0 MB 16.8 MB/s eta 0:00:01
   ---------------------- ----------------- 5.8/10.0 MB 14.7 MB/s eta 0:00:01
  

# Global Imports

In [1]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch import save
import datetime
from torch import nn, Tensor, tensor
from torchinfo import summary
import os
import matplotlib.pyplot as plt
from typing import Any
import pandas as pd
import numpy as np


# Local Imports

In [3]:
from models.trocr_apl import TrocrApl
from training.train import EpochLogs, LogPoint, train, grid_search
from dataset.dataset import HandwrittenLineOfCodeDataset
from training.train import ThresholdData


In [4]:

__filedir__: str = os.path.abspath(".")

DEVICE: str = "cuda"
EPOCHS: int = 100000
BATCH_SIZE: int = 50
IMAGE_CHANNELS: int = 1
IMAGE_WIDTH: int = 64
IMAGE_HEIGHT: int = 64
ROTATE_RANGE: tuple[int, int] = (-90, 90)

In [5]:
root_dirpath: str = os.path.join(
    __filedir__,
    os.pardir,
)


data_root_dirpath: str = os.path.join(
    root_dirpath,
    "dataset",    
)

flattened_dataset_dirpath: str = os.path.join(
    data_root_dirpath,
    "flattened_dataset"
)

dataset_info_csv_path: str = os.path.join(
    flattened_dataset_dirpath,
    "dataset_info.csv"
)


In [6]:

log_dirpath: str = os.path.join(
    root_dirpath,
    "logs"
)
os.makedirs(log_dirpath, exist_ok=True)

checkpoint_dirpath: str = os.path.join(
    root_dirpath,
    "checkpoints"
)
os.makedirs(checkpoint_dirpath, exist_ok=True)

In [None]:
import torch
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm

# Load the model and processor
model_name = "microsoft/trocr-base-handwritten"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = TrOCRProcessor.from_pretrained(model_name)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Create your dataset instance
dataset = HandwrittenLineOfCodeDataset(
    dataset_line_text=["print('Hello, World!')", "x = y + 2"],  # Replace with your lines
    unicode_character_filepath_map={
        # Example: Map of unicode characters to image paths
        "p": ["path_to_image_p.png"], "r": ["path_to_image_r.png"],  # Fill this out properly
        # Add mappings for all characters...
    },
    eol_char="<EOL>"
)

# Wrap dataset in a DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
epochs = 3
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        # Extract batch data
        text_image_tensors, text_label_tensors = batch
        text_image_tensors = text_image_tensors.to(device)
        text_label_tensors = text_label_tensors.to(device)
        
        # Generate pixel_values compatible with the model
        pixel_values = text_image_tensors  # Ensure this matches model input expectations
        labels = text_label_tensors.argmax(dim=-1)  # Convert one-hot to token indices

        # Forward pass
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Track loss
        epoch_loss += loss.item()

    print(f"Epoch {epoch + 1} completed. Loss: {epoch_loss / len(dataloader):.4f}")
