In [None]:
from src.q_learning_training import *
from src.q_learning_assessment import *

In [None]:
# Define the initial set of sticks
initial_sticks = list(range(1, 13))

In [None]:
# Set Q-learning parameters
num_episodes = int(5e6)  # Total number of training episodes
alpha = 0.1  # Learning rate
gamma = 0.9  # Discount factor
epsilon = 1.0  # Initial exploration rate
min_epsilon = 0.01  # Minimum exploration rate
decay_rate = 0.995  # Decay rate for exploration probability (this is adjusted dynamically)

In [None]:
# Train the agent
q_table, rewards = train_agent(
    initial_sticks, num_episodes, alpha, gamma, epsilon, min_epsilon, decay_rate
)

In [None]:
# Example of Q learning strategy
print("=====================================")
print(f"Assess the performance of the agent after training for {num_episodes:,} episodes")
win_rate, wins, losses = simulate_games(q_table, 100000)
print(f"Win rate: {win_rate*100:.2f}% - Wins: {wins} - Losses: {losses}")
print("=====================================")

In [None]:
# Example of random strategy
print("=====================================")
win_rate, wins, losses = simulate_random_games(100000)
print(f"Assess the performance of a completely random strategy for comparison")
print(
    f"Random strategy - Win rate: {win_rate*100:.2f}% - Wins: {wins} - Losses: {losses}"
)
print("=====================================")

In [16]:
import pickle

# Open the pickle file and load the data
with open("results/q_table.pkl", "rb") as f:
    q_table = pickle.load(f)

In [3]:
from multiprocessing import Pool
import numpy as np

q_table = None  # Global variable

def init_worker(q_table_data):
    global q_table
    q_table = q_table_data

def simulate_single_game(q_table):
    print("Simulating a single game...")
    state = set(range(1, 13))  # Initialize the state with all sticks
    while state:
        print(f"Current state: {state}")
        dice_roll = np.random.randint(2, 13)
        action = choose_best_action(state, q_table, dice_roll)
        
        if not action:  # No valid action, game over
            return 0  # Loss

        state -= set(action)  # Update state by removing selected sticks

        if not state:  # All sticks removed, win
            return 1  # Win

def simulate_games_parallel(q_table_data, num_games, num_processes=4):
    with Pool(num_processes, initializer=init_worker, initargs=(q_table_data,)) as p:
        results = p.map(simulate_single_game, [None] * num_games)
    
    wins = sum(results)
    losses = len(results) - wins
    win_rate = wins / num_games
    return win_rate, wins, losses

# Example usage
win_rate, wins, losses = simulate_games_parallel(q_table, 1000)

In [17]:
import random
import numpy as np
from src.q_learning_training import valid_actions


def simulate_games(q_table, num_games):
    wins = 0
    losses = 0

    for _ in range(num_games):
        state = set(range(1, 13))  # Initialize the state with all sticks
        while state:
            dice_roll = np.random.randint(2, 13)
            action = choose_best_action(state, q_table, dice_roll)
            
            if not action:  # No valid action, game over
                losses += 1
                break

            state -= set(action)  # Update state by removing selected sticks

            if not state:  # All sticks removed, win
                wins += 1
                break

    win_rate = wins / num_games
    return win_rate, wins, losses

def choose_best_action(state, q_table, dice_roll):
    """
    Choose the best action based on the Q-table for the current state and dice roll.
    """
    actions = valid_actions(list(state), dice_roll)
    if not actions:
        return None

    best_action = None
    best_q_value = float('-inf')

    state = frozenset(state)
    for action in actions:
        q_value = q_table.get(state, {}).get(action, 0)
        if q_value > best_q_value:
            best_q_value = q_value
            best_action = action

    return best_action


def simulate_random_games(num_games):
    wins = 0
    losses = 0

    for _ in range(num_games):
        state = set(range(1, 13))  # Initialize the state with all sticks
        while state:
            dice_roll = np.random.randint(2, 13)
            action = choose_random_action(state, dice_roll)
            
            if not action:  # No valid action, game over
                losses += 1
                break

            state -= set(action)  # Update state by removing selected sticks

            if not state:  # All sticks removed, win
                wins += 1
                break

    win_rate = wins / num_games
    return win_rate, wins, losses

def choose_random_action(state, dice_roll):
    """
    Randomly choose an action from the valid actions for the current state and dice roll.
    """
    actions = valid_actions(list(state), dice_roll)
    return random.choice(actions) if actions else None

In [18]:
n = 100000
win_rate, wins, losses = simulate_games(q_table, n)
print(f"Win rate: {win_rate*100:.4f}% - Wins: {wins} - Losses: {losses}")

Win rate: 0.0380% - Wins: 38 - Losses: 99962


In [19]:
n = int(1e5)
win_rate, wins, losses = simulate_random_games(n)
print(f"Assess the performance of the random strategy over {n:,} simulated games")
print(
    f"Random strategy - Win rate: {win_rate*100:.4f}% - Wins: {wins} - Losses: {losses}"
)

Assess the performance of the random strategy over 100,000 simulated games
Random strategy - Win rate: 0.0730% - Wins: 73 - Losses: 99927


In [1]:
import csv
import torch
import easyocr
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from torchvision import transforms
import matplotlib.patches as patches
from pdf2image import convert_from_path
from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection


class PDFTableAnalyzer:
    def __init__(self, detection_model_name, structure_model_name, pdf_path, min_confidence=0.6):
        self.detection_model_name = detection_model_name
        self.structure_model_name = structure_model_name
        self.pdf_path = pdf_path
        self.min_confidence = min_confidence
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Load models
        self.detection_model = AutoModelForObjectDetection.from_pretrained(detection_model_name, revision="no_timm").to(self.device)
        self.structure_model = TableTransformerForObjectDetection.from_pretrained(structure_model_name).to(self.device)
        
        # OCR reader
        self.reader = easyocr.Reader(['en'])  # Load once

    def process_pdf(self):
        # Convert PDF to images
        pages = convert_from_path(self.pdf_path)

        for page in pages:
            image = page.convert("RGB")
            pixel_values = self.preprocess_image(image)
            objects = self.detect_tables(pixel_values, image.size)
            self.visualize_detected_tables(image, objects, self.min_confidence)

            for obj in objects:
                if obj['score'] >= self.min_confidence:
                    cropped_table = self.crop_table(image, obj)
                    cells = self.recognize_structure(cropped_table)
                    cell_coordinates = self.get_cell_coordinates_by_row(cells)
                    data = self.apply_ocr(cell_coordinates, cropped_table)
                    self.save_data_as_csv(data)

    def preprocess_image(self, image):
        detection_transform = transforms.Compose([
            MaxResize(800),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        pixel_values = detection_transform(image).unsqueeze(0)
        pixel_values = pixel_values.to(self.device)
        return pixel_values

    def detect_tables(self, pixel_values, img_size):
        with torch.no_grad():
            outputs = self.detection_model(pixel_values)
        return outputs_to_objects(outputs, img_size, self.detection_model.config.id2label)

    def visualize_detected_tables(self, img, det_tables, min_confidence):
        plt.imshow(img, interpolation="lanczos")
        ax = plt.gca()
        for det_table in det_tables:
            if det_table['score'] >= min_confidence:
                bbox = det_table['bbox']
                rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1],
                                        linewidth=2, edgecolor='red', facecolor='none')
                ax.add_patch(rect)
        plt.axis('off')
        plt.show()

    def crop_table(self, image, table_object):
        bbox = table_object['bbox']
        bbox_padded = [bbox[0]-10, bbox[1]-10, bbox[2]+10, bbox[3]+10]  # Add padding
        return image.crop(bbox_padded)

    def recognize_structure(self, cropped_table):
        structure_transform = transforms.Compose([
            MaxResize(1000),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        pixel_values = structure_transform(cropped_table).unsqueeze(0)
        pixel_values = pixel_values.to(self.device)

        with torch.no_grad():
            outputs = self.structure_model(pixel_values)

        return outputs_to_objects(outputs, cropped_table.size, self.structure_model.config.id2label)

    def get_cell_coordinates_by_row(self, cells):
        # Extract rows and columns
        rows = [cell for cell in cells if cell['label'] == 'table row']
        columns = [cell for cell in cells if cell['label'] == 'table column']

        rows.sort(key=lambda x: x['bbox'][1])
        columns.sort(key=lambda x: x['bbox'][0])

        cell_coordinates = []
        for row in rows:
            row_cells = []
            for column in columns:
                row_y1, row_y2 = row['bbox'][1], row['bbox'][3]
                col_x1, col_x2 = column['bbox'][0], column['bbox'][2]
                cell_bbox = [col_x1, row_y1, col_x2, row_y2]
                row_cells.append({'cell': cell_bbox})

            cell_coordinates.append({'row': row['bbox'], 'cells': row_cells})

        return cell_coordinates


    def apply_ocr(self, cell_coordinates, cropped_table):
        data = {}
        for idx, row in enumerate(tqdm(cell_coordinates)):
            row_text = []
            for cell in row["cells"]:
                cell_image = np.array(cropped_table.crop(cell["cell"]))
                result = self.reader.readtext(cell_image)
                text = " ".join([x[1] for x in result])
                row_text.append(text)
            data[idx] = row_text
        return data

    def save_data_as_csv(self, data, filename='output.csv'):
        with open(filename, 'w', newline='', encoding='utf-8') as result_file:
            wr = csv.writer(result_file, dialect='excel')
            for row_text in data.values():
                wr.writerow(row_text)

# Define MaxResize transform
class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        return image.resize((int(round(scale*width)), int(round(scale*height))))

# Function to convert model outputs to object detections
def outputs_to_objects(outputs, img_size, id2label):
    logits = outputs.logits.softmax(-1).max(-1)
    pred_labels = logits.indices.detach().cpu().numpy()[0]
    pred_scores = logits.values.detach().cpu().numpy()[0]
    pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]

    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = id2label.get(int(label), "unknown")  # Use 'unknown' for missing labels
        if class_label != "no object":
            bbox = rescale_bboxes(bbox, img_size).tolist()
            objects.append({'label': class_label, 'score': float(score), 'bbox': bbox})
    return objects

# Function to rescale bounding boxes to the image size
def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    return b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)

# Convert model's bounding box format to [x_min, y_min, x_max, y_max]
def box_cxcywh_to_xyxy(x):
    if x.dim() == 1:
        # Single bounding box
        x_c, y_c, w, h = x
    elif x.dim() == 2 and x.size(0) == 1:
        # Single bounding box in a batch
        x_c, y_c, w, h = x.squeeze(0)
    else:
        raise ValueError("Unexpected bounding box shape: {}".format(x.shape))

    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)



  from .autonotebook import tqdm as notebook_tqdm


In [11]:
analyzer = PDFTableAnalyzer(
    "microsoft/table-transformer-detection", 
    "microsoft/table-structure-recognition-v1.1-all", 
    "7279_test_short.pdf"
)

config.json: 100%|██████████| 76.5k/76.5k [00:00<00:00, 25.5MB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
model.safetensors: 100%|██████████| 115M/115M [00:01<00:00, 66.1MB/s] 
config.json: 100%|██████████| 76.8k/76.8k [00:00<00:00, 1.18MB/s]
model.safetensors: 100%|██████████| 115M/115M [00:01<00:00, 66.6MB/s] 
Downloading detection model, please wait. This may take several minutes depending upon your network connection.


Progress: |██████████████████████████████████████████████████| 100.0% Complete

Downloading recognition model, please wait. This may take several minutes depending upon your network connection.


Progress: |██████████████████████████████████████████████████| 100.0% Complete