# Make submission predictions from final model

In [None]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import json
from tqdm import tqdm
from PIL import Image
import os

# === Manually define charset ===
charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
idx2char = {i + 1: ch for i, ch in enumerate(charset)}  # index 0 = blank
num_classes = 1 + len(charset)

# === CRNN model definition ===
class CRNN(nn.Module):
    def __init__(self, num_classes):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU()
        )
        self.proj_h = nn.AdaptiveAvgPool2d((1, None))
        self.rnn = nn.LSTM(256, 256, 2, bidirectional=True)
        self.fc = nn.Linear(512, num_classes)
        self.log_sm = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        x = self.cnn(x)
        x = self.proj_h(x)
        x = x.squeeze(2).permute(2, 0, 1)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return self.log_sm(x)

# === Preprocessing and decoding ===
def preprocess_image(path):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (128, 32))  # Match training size
    img = img.astype(np.float32) / 255.0
    return torch.tensor(img).unsqueeze(0).unsqueeze(0)

def decode_greedy(log_probs, idx2char, blank=0):
    pred = log_probs.argmax(2).permute(1, 0)
    results = []
    for seq in pred:
        string = []
        prev = blank
        for i in seq:
            i = i.item()
            if i != blank and i != prev:
                string.append(idx2char.get(i, ""))
            prev = i
        results.append("".join(string))
    return results

def predict_string(model, img_path, device, idx2char, blank=0):
    model.eval()
    img = preprocess_image(img_path).to(device)
    with torch.no_grad():
        logp = model(img)
    return decode_greedy(logp, idx2char, blank)[0]

def make_predictions_json(model, test_root, out_json_path, device, idx2char, blank=0):
    image_dir = os.path.join(test_root, "images")
    files = sorted([
        f for f in os.listdir(image_dir)
        if f.endswith(".png") and not f.startswith("._")
    ])
    results = []
    for fname in tqdm(files, desc="🔍 Predicting"):
        path = os.path.join(image_dir, fname)
        image_id = os.path.splitext(fname)[0]
        with Image.open(path) as im:
            w, h = im.size
        pred_str = predict_string(model, path, device, idx2char, blank)
        results.append({
            "height": h,
            "width": w,
            "image_id": image_id,
            "captcha_string": pred_str,
            "annotations": []
        })
    os.makedirs(os.path.dirname(out_json_path), exist_ok=True)
    with open(out_json_path, "w") as f:
        json.dump(results, f, indent=2)

In [None]:
model_path = "submissions/final_model.pth"
test_root = "data/part2/test/images" #change to part3 and 4 later
output_path = "submissions/prediction_files/part2_test_predictions.json" #change to part3 and 4 later
blank_index = 0

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = CRNN(num_classes=num_classes).to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

# Run prediction
make_predictions_json(
    model=model,
    test_root=test_root,
    out_json_path=output_path,
    device=device,
    idx2char=idx2char,
    blank=blank_index
)

print(f"\nPrediction complete. Saved to: {output_path}")
