In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import os
import random
import csv

class dataset(Dataset):
    def __init__(self, data_dir, transform):
        self.data_dir = data_dir
        self.image_files = []
        for file in os.listdir(self.data_dir):
            if file.endswith('.png'):
                img_path = os.path.join(self.data_dir, file)
                img = Image.open(img_path).convert('RGB')
                pixel = img.getpixel((0,0))
                is_red = pixel[0] > 128 and pixel[1] < 128 and pixel[2] < 128
                self.image_files.append((file, is_red))
        random.shuffle(self.image_files)
        
        self.transform = transform
        self.char_to_index = {char : (i+1) for i, char in enumerate("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")}
        self.index_to_char = {self.char_to_index[char]:char for char in self.char_to_index}
        self.index_to_char[0] = '<blank>'
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name, is_red_background = self.image_files[idx]
        image_path = os.path.join(self.data_dir, img_name)
        image = Image.open(image_path).convert('L')
        image = self.transform(image)
        label = img_name[:-4]
        label = label[::-1] if is_red_background else label
        label_indices = [self.char_to_index[i] for i in label]
        return image, torch.tensor(label_indices), is_red_background

transform = transforms.Compose([
    transforms.Resize((32,128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

class CRNN(nn.Module):
    def __init__(self, num_chars):
        super().__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(1,64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 64 --> 32
            nn.Dropout(0.1),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),  # 32 --> 16 only height
            nn.Dropout(0.1),
                
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),  # 16 --> 8 only height
            nn.Dropout(0.1),   
        )
        
        self.lstm1 = nn.LSTM(512 * 2, 256, bidirectional=True, batch_first=True)
        self.linear1 = nn.Linear(512, 256)
        self.lstm2 = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
        self.linear2 = nn.Linear(512, num_chars + 1)
        
    def forward(self,x):
        conv = self.cnn(x)  # [batch, 512, 2, 32]
        batch, channel, height, width = conv.size()
        conv = conv.permute(0, 3, 1, 2)  # [batch, width, channel, height]
        conv = conv.reshape(batch, width, channel * height)  # [batch, width, feature]
        
        lstm_out, _ = self.lstm1(conv)  # [batch, width, 512]
        lstm_out = self.linear1(lstm_out)  # [batch, width, 256]
        lstm_out, _ = self.lstm2(lstm_out)  # [batch, width, 512]
        output = self.linear2(lstm_out)  # [batch, width, num_chars + 1]
        return output

def decode_predictions(output, index_to_char):
    if len(output.shape) == 2:
        output = output.unsqueeze(0)
    
    output = nn.functional.log_softmax(output, dim=-1)
    predictions = torch.argmax(output, dim=-1) 
    
    batch_results = []
    for batch_pred in predictions:
        chars = []
        prev_char = None
        
        for p in batch_pred:
            char_idx = p.item()
            if char_idx != 0 and char_idx != prev_char:  
                if char_idx in index_to_char:
                    chars.append(index_to_char[char_idx])
            prev_char = char_idx
        
        batch_results.append(''.join(chars))
    
    return batch_results[0] if len(batch_results) == 1 else batch_results



def train_model(model, train_data, test_data, optimizer, loss_func, device, index_to_char, num_epochs=100):
    model.train()
    best_accuracy = 0
    
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, (images, labels, is_red_background) in enumerate(train_data):
            images = images.to(device)
            labels = labels.to(device)
            
            batch_size = images.size(0)
            outputs = model(images)  # [batch, width, num_chars + 1]
            
            input_lengths = torch.full(size=(batch_size,), fill_value=outputs.size(1), dtype=torch.long)
            target_lengths = torch.sum(labels != 0, dim=1)
            
            outputs = outputs.log_softmax(2) 
            outputs = outputs.permute(1, 0, 2)
            
            optimizer.zero_grad()
            loss = loss_func(outputs, labels, input_lengths, target_lengths)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()
            
            total_loss += loss.item()
            
        model.eval()
        correct = 0
        total_samples = len(test_data.dataset)
        
        test_samples = list(test_data)
        random_sample = random.choice(test_samples)
        all_results = []
        
        with torch.no_grad():
            for i, (test_img, test_label, is_red_background) in enumerate(test_data):
                test_img = test_img.to(device)
                predicted_text = decode_predictions(model(test_img), index_to_char)
                true_text = ''.join([index_to_char[idx.item()] for idx in test_label[0] if idx.item() != 0])
                if is_red_background:
                    predicted_text = predicted_text[::-1]
                    true_text = true_text[::-1] 
                all_results.append({
                    'original': true_text,
                    'predicted': predicted_text,
                    'correct': predicted_text == true_text,
                })
                
                if predicted_text == true_text:
                    correct += 1
            
            test_img, test_label, is_red_background = random_sample
            test_img = test_img.to(device)
            predicted_text = decode_predictions(model(test_img), index_to_char)
            true_text = ''.join([index_to_char[idx.item()] for idx in test_label[0] if idx.item() != 0])
            
            print(f'Random Sample:')
            print('Red' if is_red_background[0] else 'Green')
            print(f'Original text: {true_text}')
            print(f'Predicted: {predicted_text}\n')
        
        accuracy = correct / total_samples
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, 'best_model.pth')
            
            with open('bonus4488.csv', 'w', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=['original', 'predicted', 'correct'])
                writer.writeheader()
                writer.writerows(all_results)
            
            print(f'New best model saved with accuracy: {accuracy:.4f}')
            
        model.train()
        print(f'Epoch {epoch+1} complete, Average Loss: {total_loss/len(train_data):.4f}')
        print(f'Accuracy: {accuracy:.4f}')


def collate_fn(batch):
    batch.sort(key=lambda x: len(x[1]), reverse=True)
    images, labels, is_red_background = zip(*batch)
    images = torch.stack(images, 0)
    lengths = [len(label) for label in labels]
    max_length = max(lengths)
    padded_labels = torch.zeros(len(labels), max_length).long()
    for i, label in enumerate(labels):
        padded_labels[i, :len(label)] = label
        
    return images, padded_labels, is_red_background

data = dataset('/kaggle/input/bonus-5000/bonus_5000', transform)
num_chars = len(data.index_to_char) 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CRNN(num_chars).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_size = int(0.8 * len(data))
test_size = len(data) - train_size
train_dataset, test_dataset = random_split(data, [train_size, test_size])
train_data = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_data = DataLoader(test_dataset, shuffle=True, collate_fn=collate_fn)
loss_func = nn.CTCLoss(blank=0, zero_infinity=True, reduction='mean')

train_model(model, train_data, test_data, optimizer, loss_func, device, data.index_to_char, num_epochs=100)

Random Sample:
Green
Original text: CWBkYG
Predicted: 

Epoch 1 complete, Average Loss: 4.4889
Accuracy: 0.0000
Random Sample:
Green
Original text: CVMg
Predicted: 

Epoch 2 complete, Average Loss: 4.2463
Accuracy: 0.0000
Random Sample:
Red
Original text: qtmDOEo
Predicted: 

Epoch 3 complete, Average Loss: 4.1947
Accuracy: 0.0000
Random Sample:
Red
Original text: fDVtk
Predicted: L

Epoch 4 complete, Average Loss: 4.1228
Accuracy: 0.0000
Random Sample:
Red
Original text: ITQYy
Predicted: BBBf

Epoch 5 complete, Average Loss: 3.8182
Accuracy: 0.0000
Random Sample:
Green
Original text: uueh
Predicted: dKeh

Epoch 6 complete, Average Loss: 3.5316
Accuracy: 0.0000
Random Sample:
Green
Original text: DcKA
Predicted: DcKh

Epoch 7 complete, Average Loss: 3.0053
Accuracy: 0.0000
Random Sample:
Green
Original text: KlJOJz
Predicted: EIJOIf

New best model saved with accuracy: 0.0020
Epoch 8 complete, Average Loss: 2.3866
Accuracy: 0.0020
Random Sample:
Red
Original text: ckjM
Predicted: cMi

