In [23]:
# ================================
# 1. CHECK GPU
# ================================
import torch
print("GPU Available:", torch.cuda.is_available())
print("GPU Type:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

device = "cuda" if torch.cuda.is_available() else "cpu"


GPU Available: True
GPU Type: Tesla T4


In [24]:
import zipfile
import os

zip_path = "reduced_dataset.zip"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(".")

print("Dataset extracted!")
print(os.listdir("reduced_dataset"))


Dataset extracted!
['labels', 'images']


In [25]:
import os

print("Images:", len(os.listdir("reduced_dataset/images")))
print("Labels:", len(os.listdir("reduced_dataset/labels")))


Images: 500
Labels: 500


In [26]:
# --- 1. INSTALL AND IMPORT DEPENDENCIES ---
!pip install transformers torch pillow torchvision

import torch
import torch.nn as nn
from transformers import ViTModel
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torch.nn import CTCLoss
import os
from PIL import Image
import random
import string
import numpy as np
import sys



In [35]:
# --- 2. DATASET PATHS ---
# !!! IMPORTANT !!!
# Please ensure you have uploaded your dataset to these directories
# (e.g., recognition_dataset/images/001.png and recognition_dataset/labels/001.txt)

IMAGE_DIR = "reduced_dataset/images"
LABEL_DIR = "reduced_dataset/labels"
NUM_EPOCHS = 2 # Reduced number of epochs for quick demonstration
BATCH_SIZE = 4

# Check if data directories exist
if not os.path.exists(IMAGE_DIR) or not os.path.exists(LABEL_DIR):
    print("--- ⚠️ DATA WARNING ⚠️ ---")
    print(f"Image directory not found: {IMAGE_DIR}")
    print("Please upload your 'recognition_dataset' folder containing 'images' and 'labels' to the Colab environment.")
    sys.exit(1)

if not os.listdir(IMAGE_DIR):
    print("--- ⚠️ DATA WARNING ⚠️ ---")
    print(f"Image directory {IMAGE_DIR} is empty.")
    print("Please ensure your dataset files (.png and .txt) are uploaded.")
    sys.exit(1)
print(f"Found {len(os.listdir(IMAGE_DIR))} image files. Proceeding with training...")

Found 500 image files. Proceeding with training...


In [28]:
# --- 3. MODEL DEFINITION ---

class ViT_CTC_Model(nn.Module):
    """
    ViT-based model for sequence recognition using CTC loss.
    """
    def __init__(self, num_classes):
        super().__init__()
        # Load pre-trained Vision Transformer
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")

        # Linear layer for CTC output. ViT-base output hidden size is 768.
        self.fc = nn.Linear(768, num_classes)

    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        # last_hidden_state: (B, sequence_length=197, hidden_size=768)
        x = outputs.last_hidden_state
        x = self.fc(x) # Output shape: (B, seq, num_classes)
        return x

In [29]:
# --- 4. DATASET DEFINITION ---

class RecognitionDataset(Dataset):
    """Custom Dataset for loading images and text labels."""
    def __init__(self, image_dir, label_dir, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.image_files = sorted(os.listdir(image_dir))
        # Default transforms for ViT input
        self.transforms = transforms if transforms else T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)

        # Determine label file name
        txt_name = img_name.split('.')[0] + ".txt"
        label_path = os.path.join(self.label_dir, txt_name)

        # load image
        img = Image.open(img_path).convert("RGB")
        img = self.transforms(img)

        # load label (text)
        with open(label_path, "r", encoding="utf-8") as f:
            text = f.read().strip()

        return img, text

In [30]:
# --- 5. TRAINING SCRIPT ---

# Character set definition (must include all possible characters in your labels)
charset = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
# Mapping characters to indices (1-based, 0 is reserved for the CTC blank token)
char2idx = {c:i+1 for i,c in enumerate(charset)}
char2idx["<blank>"] = 0
num_classes = len(char2idx) # Final output dimension

# Function to convert a text string into a tensor of indices
def encode_text(text):
    return torch.tensor([char2idx[c] for c in text if c in char2idx], dtype=torch.long)

In [31]:
# Create dataset and dataloader
train_dataset = RecognitionDataset(IMAGE_DIR, LABEL_DIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [32]:
# Initialize Model, Optimizer, and Loss
model = ViT_CTC_Model(num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# blank=0 is crucial
criterion = CTCLoss(blank=0, reduction='mean')

# Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViT_CTC_Model(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_feat

In [36]:
print(f"Training on device: {device}")
print("Training started...")

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    model.train()
    for imgs, texts in train_loader:
        imgs = imgs.to(device)
        labels_list = [encode_text(t) for t in texts]

        # 1. Target labels (concatenated)
        labels = torch.cat(labels_list).to(device)
        # 2. Lengths of target labels before concatenation
        label_lens = torch.tensor([len(l) for l in labels_list], dtype=torch.long)

        # Forward pass
        outputs = model(imgs)  # (B, seq, num_classes)

        # 3. Permute output: CTC requires (sequence_length, Batch_size, num_classes)
        outputs = outputs.permute(1, 0, 2)

        # 4. Lengths of input sequence: full length of the ViT's patch sequence (197)
        input_lens = torch.full(
            size=(outputs.size(1),),
            fill_value=outputs.size(0),
            dtype=torch.long
        )

        # Calculate CTC Loss
        loss = criterion(outputs, labels, input_lens, label_lens)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: Loss = {total_loss / len(train_loader):.4f}")

Training on device: cuda
Training started...
Epoch 1/2: Loss = 2.8369
Epoch 2/2: Loss = 2.9975


In [34]:
# Save the trained model
MODEL_PATH = "vit_ctc.pth"
torch.save(model.state_dict(), MODEL_PATH)
print(f"Training complete! Model saved as {MODEL_PATH}")

Training complete! Model saved as vit_ctc.pth
