# Using Donut Model for OCR.

In [None]:
import re
from pathlib import Path
from typing import List
from functools import partial

from transformers import (
    DonutProcessor,
    VisionEncoderDecoderConfig,
    VisionEncoderDecoderModel,
)
import torch
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from datasets import Dataset
from datasets import Image as ds_img
from tqdm.auto import tqdm

In [None]:
class CFG:
    
    test_grayscale = True
    debug_clean = False
    
    batch_size = 4
    image_path = "/kaggle/input/benetech-making-graphs-accessible/test/images"
    max_length = 512
    model_dir = "/kaggle/input/donut-model"

BOS_TOKEN = "<|BOS|>"
X_START = "<x_start>"
X_END = "<x_end>"
Y_START = "<y_start>"
Y_END = "<y_end>"

PLACEHOLDER_DATA_SERIES = "0;0"
PLACEHOLDER_CHART_TYPE = "line"

In [None]:
def clean_preds(x: List[str], y: List[str]):
    """
    This function cleans the x and y values predicted by Donut.

    Because it is a generative model, it can insert any character in the 
    model's vocabulary into the prediction string. This function primarily removes
    characters that prevent a number from being cast to a float.

    Example:

    x = ["11", "12", "1E", "14", "15"]
    y = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]

    # float("1E") will throw an error

    new_x, new_y = clean_preds(x, y)

    new_x = ["11", "12", "13", "14", "15"]
    new_y = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]

    Args:
        x (List[str]): The x values predicted by Donut.
        y (List[str]): The y values predicted by Donut.

    Returns:
        x (List[str]): The cleaned x values.
        y (List[str]): The cleaned y values.
    """
    
    def clean(str_list):
        
        new_list = []
        for temp in str_list:
            if "." not in temp:
                dtype = int
            else:
                dtype = float
            try:
                # First try removing whitespace e.g. float("10 0") will fail
                temp = dtype(re.sub("\s", "", temp))
            except ValueError:

                # remove everything that isn't a digit, period, negative sign, or the letter e
                # It could be "1e-5" or "-0.134"

                #temp = re.sub(r"[^0-9\.\-eE]", "", temp)

                if len(temp) == 0:
                    temp = 0
                else:
                    multiple_periods = len(re.findall(r"\.", temp)) > 1
                    multiple_negative_signs = len(re.findall(r"\-", temp)) > 1
                    multiple_e = len(re.findall(r"[eE]", temp)) > 1

                    # Put negative sign in from of it all
                    if multiple_negative_signs:
                        temp = "-" + re.sub(r"\-", "", temp)

                    # Keep first period if there are multiple
                    if multiple_periods:
                        chunks = temp.split(".")
                        try:
                            temp = chunks[0] + "." + "".join(chunks[1:])
                        except IndexError:
                            temp = "".join(chunks)
                    
                    # Keep last e in case it is "e1e-5"
                    if multiple_e:
                        while temp.lower().startswith("e"):
                            temp = temp[1:]
                        
                        while temp.lower().endswith("e"):
                            temp = temp[:-1]
                            
                        chunks = temp.split("e")
                        try:
                            temp = chunks[0:-1] + "e" + "".join(chunks[-1])
                        except IndexError:
                            temp = "".join(chunks)
                try:
                    temp = dtype(temp)
                except ValueError:
                    temp = 0
                    
            new_list.append(temp)

        return new_list

    all_x_chars = "".join(x)
    all_y_chars = "".join(y)

    frac_num_x = len(re.sub(r"[^\d]", "", all_x_chars)) / len(all_x_chars)
    frac_num_y = len(re.sub(r"[^\d]", "", all_y_chars)) / len(all_y_chars)
    if CFG.debug_clean:
        print(f"x before clean (len={len(x)})", x)
        print(f"y before clean (len={len(y)})", y)

    if frac_num_x >= 0.5:
        x = clean(x)
    else:
        x = [s.strip() for s in x]
    
    
    if frac_num_y >= 0.5:
        y = clean(y)
    else:
        y = [s.strip() for s in y]
        
    if CFG.debug_clean:
        print(f"x after clean (len={len(x)})", x)
        print(f"y after clean (len={len(x)})", x)

    return x, y
    

def string2preds(pred_string: str):
    """
    Conver the prediction string from Donut to a chart type and x and y values.

    Checks to make sure the special tokens are present and that the x and y values are not empty.
    Will truncate the list of values to the smaller length of the two lists. This is because the 
    lengths of the x and y values must be the same to earn any points.

    Args:
        pred_string (str): The prediction string from Donut.

    Returns:
        chart_type (str): The chart type predicted by Donut.
        x (List[str]): The x values predicted by Donut.
        y (List[str]): The y values predicted by Donut.
    """

    if "<dot>" in pred_string:
        chart_type = "dot"
    elif "<horizontal_bar>" in pred_string:
        chart_type = "horizontal_bar"
    elif "<vertical_bar>" in pred_string:
        chart_type = "vertical_bar"
    elif "<scatter>" in pred_string:
        chart_type = "scatter"
    elif "<line>" in pred_string:
        chart_type = "line"
    else:
        return "vertical_bar", [], []
    
    
    if not all([x in pred_string for x in [X_START, X_END, Y_START, Y_END]]):
        return chart_type, [], []
    
    pred_string = re.sub(r"<one>", "1", pred_string)

    x = pred_string.split(X_START)[1].split(X_END)[0].split(";")
    y = pred_string.split(Y_START)[1].split(Y_END)[0].split(";")

    if len(x) == 0 or len(y) == 0:
        return chart_type, [], []

    #min_length = min(len(x), len(y))

    ##x = x[:min_length]
    ##y = y[:min_length]

    #y = clear_out(y)
    x, y = clean_preds(x, y)
    
    return chart_type, x, y

In [None]:
image_dir = Path(CFG.image_path)
images = list(image_dir.glob("*.jpg"))

ds = Dataset.from_dict(
    {"image_path": [str(x) for x in images], "id": [x.stem for x in images]}
).cast_column("image_path", ds_img())

def preprocess(examples, processor):
    pixel_values = []

    for sample in examples["image_path"]:
        arr = np.array(sample)
        
        # There are some grayscale images that were making this fail
        # This prevents that.
        if len(arr.shape) == 2:
            print("Changing grayscale to 3 channel format")
            print(arr.shape)
            arr = np.stack([arr]*3, axis=-1)
        
        pixel_values.append(processor(arr, random_padding=True).pixel_values)
        
        
    return {
        "pixel_values": torch.tensor(np.vstack(pixel_values)),
    }


model = VisionEncoderDecoderModel.from_pretrained(CFG.model_dir)
model.eval()

device = torch.device("cuda:0")

model.to(device)
decoder_start_token_id = model.config.decoder_start_token_id
processor = DonutProcessor.from_pretrained(CFG.model_dir)

ids = ds["id"]
ds.set_transform(partial(preprocess, processor=processor))

data_loader = DataLoader(
    ds, batch_size=CFG.batch_size, shuffle=False
)


all_generations = []
for batch in tqdm(data_loader):
    pixel_values = batch["pixel_values"].to(device)

    batch_size = pixel_values.shape[0]

    decoder_input_ids = torch.full(
        (batch_size, 1),
        decoder_start_token_id,
        device=pixel_values.device,
    )

    try:
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=CFG.max_length,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            temperature=1,
            top_k=1,
            return_dict_in_generate=True,
        )

        all_generations.extend(processor.batch_decode(outputs.sequences))
        
    except:
        all_generations.extend([""]*batch_size)
        

chart_types, x_preds, y_preds = [], [], []
for gen in all_generations:
    try:
        chart_type, x, y = string2preds(gen)
        new_chart_type = chart_type
        x_str = ";".join(list(map(str, x)))
        y_str = ";".join(list(map(str, y)))
    except Exception as e:
        print("Failed to convert to string:", gen)
        print(e)
        new_chart_type = PLACEHOLDER_CHART_TYPE
        x_str = PLACEHOLDER_DATA_SERIES
        y_str = PLACEHOLDER_DATA_SERIES

    if len(x_str) == 0:
        x_str = PLACEHOLDER_DATA_SERIES
    if len(y_str) == 0:
        y_str = PLACEHOLDER_DATA_SERIES
        
    chart_types.append(new_chart_type)
    x_preds.append(x_str)
    y_preds.append(y_str)

In [None]:
df = pd.DataFrame(chart_types)
df.columns = ['chart_types']
x = []
for i in range(len(x_preds)):
    l = x_preds[i].split(';')
    x.append(l)
y = []
for i in range(len(y_preds)):
    l = y_preds[i].split(';')
    y.append(l)
df['image_id'] = ids
df['x_val'] = x
df['y_val'] = y

# Axis points

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torchvision.models as models
import numpy as np
import pandas as pd
import cv2
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
from torchvision import models
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence

class CNN(pl.LightningModule):
    def __init__(self, max_num_points):
        super(CNN, self).__init__()
        resnet = models.resnet18(weights=None)
        self.resnet_layers = nn.Sequential(*list(resnet.children())[:-1])
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, max_num_points * 2)
        self.fc3 = nn.Linear(128, max_num_points * 4)

    def forward(self, x):
        x = self.resnet_layers(x)
        x = x.view(x.size(0), -1)
        x = nn.functional.relu(self.fc1(x))
        points = self.fc2(x)
        labels = self.fc3(x)
        points = points.view(points.size(0), -1, 2)
        labels = labels.view(labels.size(0), -1, 4)
        return {'points': points, 'labels': labels}

    def training_step(self, batch, batch_idx):
        images, targets = batch
        outputs = self(images)
        loss_points = self.compute_loss_points(outputs['points'], targets['points'], targets['mask'])
        loss_labels = self.compute_loss_labels(outputs['labels'], targets['labels'])
        loss = loss_points + loss_labels
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        outputs = self(images)
        loss_points = self.compute_loss_points(outputs['points'], targets['points'], targets['mask'])
        loss_labels = self.compute_loss_labels(outputs['labels'], targets['labels'])
        loss = loss_points + loss_labels
        self.log('val_loss', loss)
        return loss

    def compute_loss_points(self, outputs, targets, mask):
        criterion = nn.MSELoss(reduction='none')
        loss = criterion(outputs, targets.float())
        masked_loss = loss * mask.unsqueeze(-1).float()
        mean_loss = torch.sum(masked_loss) / torch.sum(mask)
        return mean_loss
    
    def compute_loss_labels(self, outputs, targets):
        criterion = nn.CrossEntropyLoss(reduction='none')
        batch_size, seq_length, num_classes = outputs.size()
        reshaped_outputs = outputs.view(batch_size * seq_length, num_classes)
        reshaped_targets = targets.view(batch_size * seq_length)
        loss = criterion(reshaped_outputs, reshaped_targets)
        mean_loss = torch.mean(loss)
        return mean_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
        return optimizer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model_x = CNN(max_num_points=25)
model_x.load_state_dict(torch.load('/kaggle/input/x-axis-model-10/model (1).pth'))
model_x.eval()
model_x.to(device)
model_y = CNN(max_num_points=25)
model_y.load_state_dict(torch.load('/kaggle/input/y-axis-model-10/Y_Point_generation_weights_1.0.pth'))
model_y.eval()
model_y.to(device)

In [None]:
ids = []
import os
test_dataset_dir = CFG.image_path
for filename in os.listdir(test_dataset_dir):
    if filename.endswith(".jpg"):
        ids.append(filename[:-4])
test_df = pd.DataFrame(ids)
test_df.columns = ['image_id']

In [None]:
from PIL import Image
from torchvision import transforms

class CustDataset(Dataset):
    def __init__(self, dataframe, image_dir, image_transform=None):
        super().__init__()
        self.image_ids = dataframe['image_id'].unique()
        self.df = dataframe
        self.image_dir = image_dir
        self.image_transform = image_transform

    def convert_to_rgb(self, image):
        if image.mode != 'RGB':
            image = image.convert('RGB')
        return image

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        records = self.df[self.df['image_id'] == image_id]

        image = Image.open(f'{self.image_dir}/{image_id}.jpg')
        original_width, original_height = image.size

        image = self.convert_to_rgb(image)

        if self.image_transform is not None:
            image = self.image_transform(image)

        return image, image_id

    def __len__(self) -> int:
        return self.image_ids.shape[0]

In [None]:
from torchvision import transforms

image_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [None]:
DIR_INPUT = CFG.image_path
DIR_TRAIN = f'{DIR_INPUT}'
test_dataset = CustDataset(test_df,DIR_TRAIN,image_transform=image_transform)

In [None]:
test_data_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
)

In [None]:
model_x.eval()
model_y.eval()
results = []
for images,image_ids in test_data_loader:
    images = images.to(device)
    outputs_x = model_x(images)
    outputs_y = model_y(images)
    for i, image in enumerate(images):
        points_x = outputs_x['points'][i].data.cpu().numpy()
        labels_x = torch.argmax(outputs_x['labels'][i], dim=1).data.cpu().numpy()
        points_x = points_x[labels_x == 1]
        points_y = outputs_y['points'][i].data.cpu().numpy()
        labels_y = torch.argmax(outputs_y['labels'][i], dim=1).data.cpu().numpy()
        points_y = points_y[labels_y == 1]
        image_id = image_ids[i]
        result = {
            'image_id': image_id,
            'x_points': points_x,
            'y_points': points_y,
        }
        results.append(result)
    del images, outputs_x, outputs_y 

In [None]:
df2 = pd.DataFrame(results)

In [None]:
df2.head()

In [None]:
import matplotlib.pyplot as plt
i = 2
print(df['x_val'][i])
print(df['y_val'][i])
points_x = df2['x_points'][i]
points_y = df2['y_points'][i]
img_path = CFG.image_path+'/'+df2['image_id'][i]+'.jpg'
image = cv2.imread(img_path)
img = cv2.resize(image, (256, 256))
fig, ax = plt.subplots()
ax.imshow(img)
for point in points_x:
    x = point[0]
    y = point[1]
    ax.plot(x, y, 'ro')

for point in points_y:
    x = point[0]
    y = point[1]
    ax.plot(x, y, 'ro')

ax.axis('off')
plt.show()

In [None]:
def extend_y_axis(y_points,y_labels):
    diff_points = y_points[1]-y_points[0]
    diff_labels = y_labels[0]-y_labels[1]
    data = [[y_points[0],y_labels[0]]]
    for i in range(25):
        point = data[i][0]+diff_points
        labels = data[i][1]-diff_labels
        data.append([point,labels])
    return data

In [None]:
extended_y = []
for i in range(len(df)):
    points_y = df2['y_points'][i]
    y = sorted([sub_array[1] for sub_array in points_y])
    labels = [float(label) if label.isdigit() else 0.0 for label in df['y_val'][i]]
    if len(labels)==0 or len(labels)==1:
        labels = [0.0,0.0]
    extended_y.append(extend_y_axis(y,labels))

# Marker generation

In [None]:
import pandas as pd
import numpy as np
import cv2
import os
import re

from PIL import Image

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torchvision

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SequentialSampler

from matplotlib import pyplot as plt

In [None]:
ids = []
import os
test_dataset_dir = CFG.image_path
for filename in os.listdir(test_dataset_dir):
    if filename.endswith(".jpg"):
        ids.append(filename[:-4])
test_df = pd.DataFrame(ids)
test_df.columns = ['image_id']

In [None]:
class CustDataset(Dataset):
    def __init__(self, dataframe, image_dir, transforms=None):
        super().__init__()

        self.image_ids = dataframe['image_id'].unique()
        self.df = dataframe
        self.image_dir = image_dir
        self.transforms = transforms

    def convert_to_rgb(self, image):
        if len(image.shape) == 2:  # Grayscale image
            image = np.expand_dims(image, axis=2)
            image = np.repeat(image, 3, axis=2)
        return image

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        image_path = f'{self.image_dir}/{image_id}.jpg'

        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)  # Read image as grayscale
        image = self.convert_to_rgb(image)  # Convert to RGB if grayscale
        image = image.astype(np.float32) / 255.0

        target = {'image_id': torch.tensor([index])}

        if self.transforms:
            sample = {'image': image}
            sample = self.transforms(**sample)
            image = sample['image']

        return image, image_id

    def __len__(self) -> int:
        return self.image_ids.shape[0]

In [None]:
def get_test_transform():
    return A.Compose([
        A.Resize(256,256),
        ToTensorV2(p=1.0)
    ])

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

DIR_INPUT = CFG.image_path
DIR_TRAIN = f'{DIR_INPUT}'
test_dataset = CustDataset(test_df, DIR_TRAIN, get_test_transform())

In [None]:
test_data_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

In [None]:
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)

In [None]:
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 4)

In [None]:
model.load_state_dict(torch.load('/kaggle/input/marker-model/Marker_weights.pth'))

In [None]:
model.to(device)
model.eval()
results = []
for images, image_ids in test_data_loader:
    images = list(image.to(device) for image in images)
    outputs = model(images)
    for i, image in enumerate(images):
        boxes = outputs[i]['boxes'].data.cpu().numpy()
        scores = outputs[i]['scores'].data.cpu().numpy()
        labels = outputs[i]['labels'].data.cpu().numpy()
        image_id = image_ids[i]
        result = {
            'image_id': image_id,
            'boxes': boxes,
            'labels': labels,
            'scores': scores
        }
        results.append(result)
    
    del images, outputs 

In [None]:
df3 = pd.DataFrame(results)

In [None]:
import matplotlib.pyplot as plt
i = 1
print(df['x_val'][i])
print(df['y_val'][i])
points_x = df2['x_points'][i]
points_y = df2['y_points'][i]
boxes = df3['boxes'][i]
boxes = boxes[np.logical_and(df3['labels'][i] == 3, df3['scores'][i] >= 0.5)].astype(np.int32)
img_path = CFG.image_path+'/'+df2['image_id'][i]+'.jpg'
image = cv2.imread(img_path)
img = cv2.resize(image, (256, 256))
fig, ax = plt.subplots(1, 1, figsize=(16, 8))
for point in points_x:
    x = point[0]
    y = point[1]
    ax.plot(x, y, 'ro')

for point in points_y:
    x = point[0]
    y = point[1]
    ax.plot(x, y, 'ro')
for box in boxes:
    cv2.rectangle(img,
                  (box[0], box[1]),
                  (box[2], box[3]),
                  (220,0, 0), 1)
ax.axis('off')
ax.imshow(img)

In [None]:
class PredBarPlot():
    def __init__(self,marker,scores,labels,y_points,y_labels):
        self.marker = marker
        self.scores = scores
        self.labels = labels
        self.y_points = y_points
        self.detection_threshold = 0.5
        self.y_labels = y_labels
    
    def calculate_center(self,box):
        center_x = (box[0] + box[2]) / 2
        center_y = (box[1] + box[3]) / 2
        return center_x, center_y
    
    def calculate_iou(self,box1, box2):
        x1 = max(box1[0], box2[0])
        y1 = max(box1[1], box2[1])
        x2 = min(box1[2], box2[2])
        y2 = min(box1[3], box2[3])
        intersection_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
        box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
        box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
        iou = intersection_area / float(box1_area + box2_area - intersection_area)
        return iou
    
    def remove_high_iou_boxes(self,boxes, threshold):
        redundant_indices = []
        for i in range(len(boxes)):
            for j in range(i + 1, len(boxes)):
                iou = self.calculate_iou(boxes[i], boxes[j])
                if iou > threshold:
                    redundant_indices.append(j)
        filtered_boxes = [box for i, box in enumerate(boxes) if i not in redundant_indices]
        return filtered_boxes
    
    def find_element_above(self,number,lst):
        lst.sort()
        for element in lst:
            if element > number:
                return element
        return None  
    
    def generate_output(self):
        marker = self.marker[np.logical_and(self.labels == 3, scores >= self.detection_threshold)]
        marker = self.remove_high_iou_boxes(marker,0.0)
        marker_p = []
        for i in range(len(marker)):
            x,y = self.calculate_center(marker[i])
            marker_p.append([x,y])
        marker_p = sorted(marker_p, key=lambda x: x[0])
        result = []
        for i in range(len(marker_p)):
            el = self.find_element_above(marker_p[i][1],self.y_points)
            if el==None:
                el = self.y_points[0]
            index = self.y_points.index(el)
            diff = self.y_points[index] - marker_p[i][1] 
            least_count = (self.y_labels[index]-self.y_labels[index-1])/(self.y_points[index-1]-self.y_points[index])
            res = self.y_labels[index]+diff*least_count
            result.append(res)
        return result

In [None]:
class PredLinePlot():
    def __init__(self,marker,scores,labels,x_points,y_points,y_labels):
        self.marker = marker
        self.scores = scores
        self.labels = labels
        self.x_points = x_points
        self.y_points = y_points
        self.detection_threshold = 0.5
        self.y_labels = y_labels
        
    def calculate_center(self,box):
        center_x = (box[0] + box[2]) / 2
        center_y = (box[1] + box[3]) / 2
        return center_x, center_y
    
    def closest_point(self,number, point_list):
        closest = None
        min_difference = float('inf')
        for point in point_list:
            difference = abs(number - point)
            if difference < min_difference:
                min_difference = difference
                closest = point
        return closest
    
    def find_element_above(self,number,lst):
        lst.sort()
        for element in lst:
            if element > number:
                return element
        return None  
    
    def generate_output(self):
        marker = self.marker[np.logical_and(self.labels == 3, self.scores >= 0.0)]
        marker_points_x = []
        marker_points_y = []
        for i in range(len(marker)):
            x,y = self.calculate_center(marker[i])
            marker_points_x.append(x)
            marker_points_y.append(y)
        data = []
        for i in range(len(self.x_points)):
            el = self.closest_point(self.x_points[i],marker_points_x)
            index = marker_points_x.index(el)
            data.append([marker_points_x[index],marker_points_y[index]])
        data = sorted(data, key=lambda x: x[0])
        result = []
        for i in range(len(data)):
            el = self.find_element_above(data[i][1],self.y_points)
            if el == None:
                el = y_points[0]
            index = self.y_points.index(el)
            diff = self.y_points[index] - data[i][1] 
            least_count = (self.y_labels[index]-self.y_labels[index-1])/(self.y_points[index-1]-self.y_points[index])
            res = self.y_labels[index]+diff*least_count
            result.append(res)
            
        return result

In [None]:
class PredScatterPlot():
    def __init__(self,marker,scores,labels,y_points,y_labels):
        self.marker = marker
        self.scores = scores
        self.labels = labels
        self.y_points = y_points
        self.detection_threshold = 0.5
        self.y_labels = y_labels
    
    def calculate_center(self,box):
        center_x = (box[0] + box[2]) / 2
        center_y = (box[1] + box[3]) / 2
        return center_x, center_y
    
    def find_element_above(self,number,lst):
        lst.sort()
        for element in lst:
            if element > number:
                return element
        return None  
    
    def generate_output(self):
        marker = self.marker[np.logical_and(self.labels == 3, scores >= self.detection_threshold)]
        marker_p = []
        for i in range(len(marker)):
            x,y = self.calculate_center(marker[i])
            marker_p.append([x,y])
        marker_p = sorted(marker_p, key=lambda x: x[0])
        result = []
        for i in range(len(marker_p)):
            el = self.find_element_above(marker_p[i][1],self.y_points)
            if(el==None):
                el = self.y_points[0]
            index = self.y_points.index(el)
            diff = self.y_points[index] - marker_p[i][1] 
            least_count = (self.y_labels[index]-self.y_labels[index-1])/(self.y_points[index-1]-self.y_points[index])
            res = self.y_labels[index]+diff*least_count
            result.append(res)
        return result

In [None]:
class PredDotPlot():
    def __init__(self,marker,scores,labels,x_points,y_points,y_labels):
        self.marker = marker
        self.scores = scores
        self.labels = labels
        self.x_points = x_points
        self.y_points = y_points
        self.detection_threshold = 0.5
        self.y_labels = y_labels
    
    def calculate_center(self,box):
        center_x = (box[0] + box[2]) / 2
        center_y = (box[1] + box[3]) / 2
        return center_x, center_y
    
    def count_points_with_same_x(self,target_point, points, tolerance):
        count = 0
        for point in points:
            if abs(point[0] - target_point[0]) <= tolerance:
                count += 1
        return count
    
    def calculate_iou(self,box1, box2):
        x1 = max(box1[0], box2[0])
        y1 = max(box1[1], box2[1])
        x2 = min(box1[2], box2[2])
        y2 = min(box1[3], box2[3])
        intersection_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
        box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
        box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
        iou = intersection_area / float(box1_area + box2_area - intersection_area)
        return iou
    
    def remove_high_iou_boxes(self,boxes, threshold):
        redundant_indices = []
        for i in range(len(boxes)):
            for j in range(i + 1, len(boxes)):
                iou = self.calculate_iou(boxes[i], boxes[j])
                if iou > threshold:
                    redundant_indices.append(j)
        filtered_boxes = [box for i, box in enumerate(boxes) if i not in redundant_indices]
        return filtered_boxes
    
    def generate_output(self):
        marker = self.marker[np.logical_and(self.labels == 3, self.scores >= self.detection_threshold)]
        marker = self.remove_high_iou_boxes(marker,0.0)
        marker_points = []
        for i in range(len(marker)):
            x,y = self.calculate_center(marker[i])
            marker_points.append([x,y])
        scale = abs(self.y_labels[0]-self.y_labels[1])
        y_min = 0
        res = []
        for i in range(len(self.x_points)):
            count = self.count_points_with_same_x(self.x_points[i],marker_points,10)
            res.append(y_min+scale*count)
        res = [int(element) for element in res]
        return res

In [None]:
chart_types, x_preds, y_preds = [], [], []
for i in range(len(df)):
    result = []
    if df['chart_types'][i] == 'line':
        '''
        result = [0.0,0.0]
        boxes = df3['boxes'][i]
        scores = df3['scores'][i]
        labels = df3['labels'][i]
        y_points = [sub_array[0] for sub_array in extended_y[i]]
        y_labels = [sub_array[1] for sub_array in extended_y[i]]
        x_points = df2['x_points'][i]
        x_points = [sub_array[0] for sub_array in x_points]
        x_points = sorted(x_points)
        pred_line_plot = PredLinePlot(boxes,scores,labels,x_points,y_points,y_labels)
        result = pred_line_plot.generate_output()'''
        result = [0.0,0.0]
    
    elif df['chart_types'][i] == 'vertical_bar':
        boxes = df3['boxes'][i]
        scores = df3['scores'][i]
        labels = df3['labels'][i]
        y_points = [sub_array[0] for sub_array in extended_y[i]]
        y_labels = [sub_array[1] for sub_array in extended_y[i]]
        pred_bar_plot = PredBarPlot(boxes,scores,labels,y_points,y_labels)
        result = pred_bar_plot.generate_output()
    
    elif df['chart_types'][i] == 'dot':
        boxes = df3['boxes'][i]
        scores = df3['scores'][i]
        labels = df3['labels'][i]
        y_points = [sub_array[0] for sub_array in extended_y[i]]
        y_labels = [sub_array[1] for sub_array in extended_y[i]]
        x_points = df2['x_points'][i]
        pred_dot_plot = PredDotPlot(boxes,scores,labels,x_points,y_points,y_labels)
        result = pred_dot_plot.generate_output()
    
    elif df['chart_types'][i] == 'scatter':
        result = [0.0,0.0]
        boxes = df3['boxes'][i]
        scores = df3['scores'][i]
        labels = df3['labels'][i]
        y_points = [sub_array[0] for sub_array in extended_y[i]]
        y_labels = [sub_array[1] for sub_array in extended_y[i]]
        pred_scatter_plot = PredScatterPlot(boxes,scores,labels,y_points,y_labels)
        result = pred_scatter_plot.generate_output()
        
    chart_types.append(df['chart_types'][i])
    x_labels = df['x_val'][i]
    '''
    boxes = df2[df2['image_id']==df['image_id'][i]]['boxes'][i]
    scores = df2[df2['image_id']==df['image_id'][i]]['scores'][i]
    labels = df2[df2['image_id']==df['image_id'][i]]['labels'][i]
    x_boxes = boxes[np.logical_and(labels == 1,scores >= 0.5)]
    l = len(x_boxes)
    x_labels = x_labels[:l]'''
    x_str = ";".join(list(map(str, x_labels)))
    y_str = ";".join(list(map(str, result)))
    x_preds.append(x_str)
    y_preds.append(y_str)

In [None]:
sub_df = pd.DataFrame(
    data={
        "id": [f"{id_}_x" for id_ in ids] + [f"{id_}_y" for id_ in ids],
        "data_series": x_preds + y_preds,
        "chart_type": chart_types * 2,
    }
)

In [None]:
for i in range(len(sub_df)):
    sub_df['data_series'][i] = [0.0 if val == 'nan' else val for val in sub_df['data_series'][i].split(';')]
    sub_df['data_series'][i] = ";".join(list(map(str, sub_df['data_series'][i])))

In [None]:
sub_df['data_series'] = sub_df['data_series'].replace('', 0.0)

In [None]:
sub_df.to_csv('submission.csv', index=False)