In [26]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim

In [28]:
class Actor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Actor, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )

    def forward(self, x):
        return self.model(x)


class Critic(nn.Module):
    def __init__(self, input_dim):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.model(x)


In [29]:
def flatten_input(raw_input):
    x, y, features, bbox, area = raw_input
    return [x, y] + features + list(bbox) + [area]

In [30]:
def train_actor_critic(data, num_epochs=10, lr=1e-3):
    input_dim = 9   # flattened fruit location input
    output_dim = 4  # cut array prediction

    actor = Actor(input_dim, output_dim)
    critic = Critic(input_dim)
    
    actor_optimizer = optim.Adam(actor.parameters(), lr=lr)
    critic_optimizer = optim.Adam(critic.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    for epoch in range(num_epochs):
        total_loss = 0
        for raw_input, cut, reward in data:
            state = torch.tensor(flatten_input(raw_input), dtype=torch.float32)
            target_cut = torch.tensor(cut, dtype=torch.float32)
            reward_tensor = torch.tensor([reward], dtype=torch.float32)

            predicted_cut = actor(state)
            value = critic(state)

            critic_loss = mse_loss(value, reward_tensor)
            advantage = reward_tensor - value.detach()
            actor_loss = -torch.mean(advantage * torch.sum((predicted_cut - target_cut) ** 2))

            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()

            actor_optimizer.zero_grad()
            actor_loss.backward()
            actor_optimizer.step()

            total_loss += actor_loss.item() + critic_loss.item()

        print(f"Epoch {epoch + 1}, Total Loss: {total_loss:.4f}")

    return actor


In [31]:
def train_actor_critic(data, num_epochs=10, lr=1e-3):
    input_dim = 9   # flattened fruit location input
    output_dim = 4  # cut array prediction

    actor = Actor(input_dim, output_dim)
    critic = Critic(input_dim)
    
    actor_optimizer = optim.Adam(actor.parameters(), lr=lr)
    critic_optimizer = optim.Adam(critic.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    for epoch in range(num_epochs):
        total_loss = 0
        total_actor_loss = 0
        total_critic_loss = 0
        total_advantage = 0
        total_pred_error = 0
        count = 0

        for raw_input, cut, reward in data:
            state = torch.tensor(flatten_input(raw_input), dtype=torch.float32)
            target_cut = torch.tensor(cut, dtype=torch.float32)
            reward_tensor = torch.tensor([reward], dtype=torch.float32)

            predicted_cut = actor(state)
            value = critic(state)

            critic_loss = mse_loss(value, reward_tensor)
            advantage = reward_tensor - value.detach()
            actor_loss = -torch.mean(advantage * torch.sum((predicted_cut - target_cut) ** 2))
            pred_error = torch.mean((predicted_cut - target_cut) ** 2).item()

            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()

            actor_optimizer.zero_grad()
            actor_loss.backward()
            actor_optimizer.step()

            total_loss += actor_loss.item() + critic_loss.item()
            total_actor_loss += actor_loss.item()
            total_critic_loss += critic_loss.item()
            total_advantage += advantage.item()
            total_pred_error += pred_error
            count += 1

        print(f"Epoch {epoch + 1}")
        print(f"  Total Loss:      {total_loss:.4f}")
        print(f"  Actor Loss:      {total_actor_loss / count:.4f}")
        print(f"  Critic Loss:     {total_critic_loss / count:.4f}")
        print(f"  Advantage Mean:  {total_advantage / count:.4f}")
        print(f"  Prediction Error:{total_pred_error / count:.4f}")
        print("")

    return actor


In [32]:
# Your training data format:
# [((fruit_location_array), [cut_array], reward), ...]
# fruit_location : (x, y, [R, G, B], (width, height, angle), velocity)


data = [
    ((576, 406, [155.0, 95.11111111111111, 20.11111111111111], (170, 100, 20), 19.999999999999996), [813.5, 493.5, 638.5, 318.5], 0),
    ((1080, 725, [223.0, 119.66666666666667, 3.4444444444444446], (230, 130, 0), 19.77777777777777), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((1044, 696, [176.88888888888889, 96.22222222222223, 10.666666666666666], (170, 100, 20), 19.999999999999993), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((936, 348, [232.0, 138.11111111111111, 8.88888888888889], (230, 130, 0), 18.000000000000004), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((720, 725, [175.66666666666666, 100.66666666666667, 11.777777777777779], (170, 100, 20), 14.55555555555555), [957.5, 812.5, 782.5, 637.5], 0),
    ((756, 667, [232.33333333333334, 139.0, 9.222222222222221], (230, 130, 0), 19.555555555555564), [957.5, 812.5, 782.5, 637.5], 0),
    ((828, 580, [136.22222222222223, 95.66666666666667, 23.22222222222222], (130, 100, 15), 18.77777777777778), [957.5, 812.5, 782.5, 637.5], 0),
    ((1044, 406, [179.66666666666666, 103.88888888888889, 13.555555555555555], (170, 100, 20), 19.999999999999986), [957.5, 812.5, 782.5, 637.5], 0),
    ((900, 522, [168.44444444444446, 110.77777777777777, 18.77777777777778], (170, 100, 20), 13.555555555555536), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1152, 899, [163.44444444444446, 108.0, 18.333333333333332], (170, 100, 20), 16.22222222222221), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1260, 899, [138.77777777777777, 92.11111111111111, 14.11111111111111], (130, 100, 15), 17.555555555555546), [1137.5, 609.5, 962.5, 434.5], 0),
    ((756, 464, [229.66666666666666, 136.0, 12.777777777777779], (230, 130, 0), 18.11111111111112), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1080, 493, [162.88888888888889, 92.88888888888889, 17.88888888888889], (170, 100, 20), 16.33333333333334), [1317.5, 580.5, 1142.5, 405.5], 0),
    ((468, 464, [204.77777777777777, 79.11111111111111, 58.111111111111114], (210, 70, 60), 15.222222222222229), [705.5, 551.5, 530.5, 376.5], 0),
    ((792, 464, [158.66666666666666, 94.44444444444444, 19.444444444444443], (170, 100, 20), 17.444444444444457), [705.5, 551.5, 530.5, 376.5], 0),
    ((432, 580, [160.55555555555554, 97.66666666666667, 15.777777777777779], (170, 100, 20), 16.000000000000007), [669.5, 667.5, 494.5, 492.5], 0),
    ((612, 609, [178.11111111111111, 106.77777777777777, 23.444444444444443], (170, 100, 20), 18.33333333333333), [669.5, 667.5, 494.5, 492.5], 0),
    ((396, 435, [227.11111111111111, 130.11111111111111, 7.888888888888889], (230, 130, 0), 9.88888888888889), [633.5, 522.5, 458.5, 347.5], 3),
    ((612, 580, [164.77777777777777, 105.33333333333333, 24.0], (170, 100, 20), 14.555555555555557), [633.5, 522.5, 458.5, 347.5], 3),
    ((432, 609, [161.88888888888889, 92.66666666666667, 16.22222222222222], (170, 100, 20), 19.22222222222222), [633.5, 522.5, 458.5, 347.5], 3),
    ((648, 319, [232.22222222222223, 139.22222222222223, 9.222222222222221], (230, 130, 0), 19.66666666666668), [633.5, 522.5, 458.5, 347.5], 3),
    ((504, 899, [231.88888888888889, 138.0, 9.0], (230, 130, 0), 19.111111111111118), [741.5, 986.5, 566.5, 811.5], 3),
    ((612, 812, [233.11111111111111, 139.22222222222223, 8.222222222222221], (230, 130, 0), 19.999999999999996), [741.5, 986.5, 566.5, 811.5], 3),
    ((432, 812, [233.77777777777777, 138.66666666666666, 8.88888888888889], (230, 130, 0), 17.66666666666667), [741.5, 986.5, 566.5, 811.5], 3),
    ((720, 812, [232.33333333333334, 139.88888888888889, 9.444444444444445], (230, 130, 0), 19.666666666666668), [741.5, 986.5, 566.5, 811.5], 3),
    ((792, 812, [229.44444444444446, 136.0, 11.333333333333334], (230, 130, 0), 19.444444444444443), [741.5, 986.5, 566.5, 811.5], 3),
    ((612, 870, [232.88888888888889, 140.11111111111111, 9.333333333333334], (230, 130, 0), 17.444444444444457), [741.5, 986.5, 566.5, 811.5], 3)
]
training_data = [
    ((576, 406, [155.0, 95.11111111111111, 20.11111111111111], (170, 100, 20), 19.999999999999996), [813.5, 493.5, 638.5, 318.5], 0),
    ((1080, 725, [223.0, 119.66666666666667, 3.4444444444444446], (230, 130, 0), 19.77777777777777), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((1044, 696, [176.88888888888889, 96.22222222222223, 10.666666666666666], (170, 100, 20), 19.999999999999993), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((936, 348, [232.0, 138.11111111111111, 8.88888888888889], (230, 130, 0), 18.000000000000004), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((720, 725, [175.66666666666666, 100.66666666666667, 11.777777777777779], (170, 100, 20), 14.55555555555555), [957.5, 812.5, 782.5, 637.5], 0),
    ((756, 667, [232.33333333333334, 139.0, 9.222222222222221], (230, 130, 0), 19.555555555555564), [957.5, 812.5, 782.5, 637.5], 0),
    ((828, 580, [136.22222222222223, 95.66666666666667, 23.22222222222222], (130, 100, 15), 18.77777777777778), [957.5, 812.5, 782.5, 637.5], 0),
    ((1044, 406, [179.66666666666666, 103.88888888888889, 13.555555555555555], (170, 100, 20), 19.999999999999986), [957.5, 812.5, 782.5, 637.5], 0),
    ((900, 522, [168.44444444444446, 110.77777777777777, 18.77777777777778], (170, 100, 20), 13.555555555555536), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1152, 899, [163.44444444444446, 108.0, 18.333333333333332], (170, 100, 20), 16.22222222222221), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1260, 899, [138.77777777777777, 92.11111111111111, 14.11111111111111], (130, 100, 15), 17.555555555555546), [1137.5, 609.5, 962.5, 434.5], 0),
    ((756, 464, [229.66666666666666, 136.0, 12.777777777777779], (230, 130, 0), 18.11111111111112), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1080, 493, [162.88888888888889, 92.88888888888889, 17.88888888888889], (170, 100, 20), 16.33333333333334), [1317.5, 580.5, 1142.5, 405.5], 0),
    ((468, 464, [204.77777777777777, 79.11111111111111, 58.111111111111114], (210, 70, 60), 15.222222222222229), [705.5, 551.5, 530.5, 376.5], 0),
    ((792, 464, [158.66666666666666, 94.44444444444444, 19.444444444444443], (170, 100, 20), 17.444444444444457), [705.5, 551.5, 530.5, 376.5], 0),
    ((432, 580, [160.55555555555554, 97.66666666666667, 15.777777777777779], (170, 100, 20), 16.000000000000007), [669.5, 667.5, 494.5, 492.5], 0),
    ((612, 609, [178.11111111111111, 106.77777777777777, 23.444444444444443], (170, 100, 20), 18.33333333333333), [669.5, 667.5, 494.5, 492.5], 0),
    ((396, 435, [227.11111111111111, 130.11111111111111, 7.888888888888889], (230, 130, 0), 9.88888888888889), [633.5, 522.5, 458.5, 347.5], 3),
    ((612, 580, [164.77777777777777, 105.33333333333333, 24.0], (170, 100, 20), 14.555555555555557), [633.5, 522.5, 458.5, 347.5], 3),
    ((432, 609, [161.88888888888889, 92.66666666666667, 16.22222222222222], (170, 100, 20), 19.22222222222222), [633.5, 522.5, 458.5, 347.5], 3),
    ((648, 319, [232.22222222222223, 139.22222222222223, 9.222222222222221], (230, 130, 0), 19.66666666666668), [633.5, 522.5, 458.5, 347.5], 3),
    ((504, 899, [231.88888888888889, 138.0, 9.0], (230, 130, 0), 19.111111111111118), [741.5, 986.5, 566.5, 811.5], 3),
    ((612, 812, [233.11111111111111, 139.22222222222223, 8.222222222222221], (230, 130, 0), 19.999999999999996), [741.5, 986.5, 566.5, 811.5], 3),
    ((432, 812, [233.77777777777777, 138.66666666666666, 8.88888888888889], (230, 130, 0), 17.66666666666667), [741.5, 986.5, 566.5, 811.5], 3),
    ((720, 812, [232.33333333333334, 139.88888888888889, 9.444444444444445], (230, 130, 0), 19.666666666666668), [741.5, 986.5, 566.5, 811.5], 3),
    ((792, 812, [229.44444444444446, 136.0, 11.333333333333334], (230, 130, 0), 19.444444444444443), [741.5, 986.5, 566.5, 811.5], 3),
    ((612, 870, [232.88888888888889, 140.11111111111111, 9.333333333333334], (230, 130, 0), 17.444444444444457), [741.5, 986.5, 566.5, 811.5], 3)
]


In [33]:
data = [
    ((576, 406, [155.0, 95.11111111111111, 20.11111111111111], (170, 100, 20), 19.999999999999996), [813.5, 493.5, 638.5, 318.5], 0),
    ((1080, 725, [223.0, 119.66666666666667, 3.4444444444444446], (230, 130, 0), 19.77777777777777), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((1044, 696, [176.88888888888889, 96.22222222222223, 10.666666666666666], (170, 100, 20), 19.999999999999993), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((936, 348, [232.0, 138.11111111111111, 8.88888888888889], (230, 130, 0), 18.000000000000004), [1317.5, 812.5, 1142.5, 637.5], 0),
    ((720, 725, [175.66666666666666, 100.66666666666667, 11.777777777777779], (170, 100, 20), 14.55555555555555), [957.5, 812.5, 782.5, 637.5], 0),
    ((756, 667, [232.33333333333334, 139.0, 9.222222222222221], (230, 130, 0), 19.555555555555564), [957.5, 812.5, 782.5, 637.5], 0),
    ((828, 580, [136.22222222222223, 95.66666666666667, 23.22222222222222], (130, 100, 15), 18.77777777777778), [957.5, 812.5, 782.5, 637.5], 0),
    ((1044, 406, [179.66666666666666, 103.88888888888889, 13.555555555555555], (170, 100, 20), 19.999999999999986), [957.5, 812.5, 782.5, 637.5], 0),
    ((900, 522, [168.44444444444446, 110.77777777777777, 18.77777777777778], (170, 100, 20), 13.555555555555536), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1152, 899, [163.44444444444446, 108.0, 18.333333333333332], (170, 100, 20), 16.22222222222221), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1260, 899, [138.77777777777777, 92.11111111111111, 14.11111111111111], (130, 100, 15), 17.555555555555546), [1137.5, 609.5, 962.5, 434.5], 0),
    ((756, 464, [229.66666666666666, 136.0, 12.777777777777779], (230, 130, 0), 18.11111111111112), [1137.5, 609.5, 962.5, 434.5], 0),
    ((1080, 493, [162.88888888888889, 92.88888888888889, 17.88888888888889], (170, 100, 20), 16.33333333333334), [1317.5, 580.5, 1142.5, 405.5], 0),
    ((468, 464, [204.77777777777777, 79.11111111111111, 58.111111111111114], (210, 70, 60), 15.222222222222229), [705.5, 551.5, 530.5, 376.5], 0),
    ((792, 464, [158.66666666666666, 94.44444444444444, 19.444444444444443], (170, 100, 20), 17.444444444444457), [705.5, 551.5, 530.5, 376.5], 0),
    ((432, 580, [160.55555555555554, 97.66666666666667, 15.777777777777779], (170, 100, 20), 16.000000000000007), [669.5, 667.5, 494.5, 492.5], 0),
    ((612, 609, [178.11111111111111, 106.77777777777777, 23.444444444444443], (170, 100, 20), 18.33333333333333), [669.5, 667.5, 494.5, 492.5], 0),
    ((396, 435, [227.11111111111111, 130.11111111111111, 7.888888888888889], (230, 130, 0), 9.88888888888889), [633.5, 522.5, 458.5, 347.5], 3),
    ((612, 580, [164.77777777777777, 105.33333333333333, 24.0], (170, 100, 20), 14.555555555555557), [633.5, 522.5, 458.5, 347.5], 3),
    ((432, 609, [161.88888888888889, 92.66666666666667, 16.22222222222222], (170, 100, 20), 19.22222222222222), [633.5, 522.5, 458.5, 347.5], 3),
    ((648, 319, [232.22222222222223, 139.22222222222223, 9.222222222222221], (230, 130, 0), 19.66666666666668), [633.5, 522.5, 458.5, 347.5], 3),
    ((504, 899, [231.88888888888889, 138.0, 9.0], (230, 130, 0), 19.111111111111118), [741.5, 986.5, 566.5, 811.5], 3),
    ((612, 812, [233.11111111111111, 139.22222222222223, 8.222222222222221], (230, 130, 0), 19.999999999999996), [741.5, 986.5, 566.5, 811.5], 3),
    ((432, 812, [233.77777777777777, 138.66666666666666, 8.88888888888889], (230, 130, 0), 17.66666666666667), [741.5, 986.5, 566.5, 811.5], 3),
    ((720, 812, [232.33333333333334, 139.88888888888889, 9.444444444444445], (230, 130, 0), 19.666666666666668), [741.5, 986.5, 566.5, 811.5], 3),
    ((792, 812, [229.44444444444446, 136.0, 11.333333333333334], (230, 130, 0), 19.444444444444443), [741.5, 986.5, 566.5, 811.5], 3),
    ((612, 870, [232.88888888888889, 140.11111111111111, 9.333333333333334], (230, 130, 0), 17.444444444444457), [741.5, 986.5, 566.5, 811.5], 3),
    ((756, 609, [176.22222222222223, 97.77777777777777, 9.666666666666666], (170, 100, 20), 18.777777777777793), [843.5, 696.5, 668.5, 521.5], 0),
    ((612, 667, [162.44444444444446, 106.55555555555556, 16.333333333333332], (170, 100, 20), 17.777777777777768), [699.5, 754.5, 524.5, 579.5], 4),
    ((828, 957, [160.11111111111111, 104.0, 23.22222222222222], (170, 100, 20), 17.111111111111107), [915.5, 1044.5, 740.5, 869.5], 4),
    ((1440, 580, [232.33333333333334, 138.77777777777777, 9.333333333333334], (230, 130, 0), 19.44444444444445), [1527.5, 667.5, 1352.5, 492.5], 0),
    ((1548, 580, [249.33333333333334, 200.66666666666666, 46.888888888888886], (255, 215, 50), 19.111111111111114), [1635.5, 667.5, 1460.5, 492.5], 0),
    ((1332, 725, [179.33333333333334, 99.11111111111111, 13.333333333333334], (170, 100, 20), 16.888888888888893), [1419.5, 812.5, 1244.5, 637.5], 0),
    ((1368, 696, [170.66666666666666, 100.66666666666667, 11.666666666666666], (170, 100, 20), 9.666666666666663), [1455.5, 783.5, 1280.5, 608.5], 0),
    ((1368, 696, [166.55555555555554, 113.44444444444444, 17.22222222222222], (170, 100, 20), 19.66666666666668), [1455.5, 783.5, 1280.5, 608.5], 0),
    ((504, 435, [216.44444444444446, 92.11111111111111, 21.666666666666668], (225, 100, 20), 17.111111111111097), [591.5, 522.5, 416.5, 347.5], 0),
    ((792, 1015, [232.11111111111111, 139.11111111111111, 9.222222222222221], (230, 130, 0), 19.44444444444445), [879.5, 1102.5, 704.5, 927.5], 4),
    ((1008, 609, [232.0, 138.44444444444446, 9.11111111111111], (230, 130, 0), 18.555555555555568), [1095.5, 696.5, 920.5, 521.5], 4),
    ((1152, 348, [164.55555555555554, 101.33333333333333, 19.11111111111111], (170, 100, 20), 7.666666666666675), [1239.5, 435.5, 1064.5, 260.5], 4),
    ((1152, 348, [227.77777777777777, 131.0, 10.777777777777779], (230, 130, 0), 13.000000000000007), [1239.5, 435.5, 1064.5, 260.5], 4),
    ((1368, 377, [173.77777777777777, 105.77777777777777, 17.333333333333332], (170, 100, 20), 12.22222222222221), [1455.5, 464.5, 1280.5, 289.5], 4),
    ((1368, 435, [132.88888888888889, 89.77777777777777, 21.333333333333332], (130, 100, 15), 19.444444444444446), [1455.5, 464.5, 1280.5, 289.5], 4),
    ((1512, 928, [176.44444444444446, 104.44444444444444, 11.777777777777779], (170, 100, 20), 19.11111111111112), [1599.5, 1015.5, 1424.5, 840.5], 4),
    ((1152, 899, [230.77777777777777, 136.22222222222223, 8.333333333333334], (230, 130, 0), 14.333333333333334), [1239.5, 986.5, 1064.5, 811.5], 4),
    ((1188, 870, [174.11111111111111, 97.11111111111111, 9.88888888888889], (170, 100, 20), 17.11111111111111), [1239.5, 986.5, 1064.5, 811.5], 4),
    ((1476, 696, [231.0, 136.88888888888889, 8.333333333333334], (230, 130, 0), 15.222222222222221), [1563.5, 783.5, 1388.5, 608.5], 4),
    ((1548, 580, [172.55555555555554, 99.22222222222223, 11.0], (170, 100, 20), 12.333333333333314), [1635.5, 667.5, 1460.5, 492.5], 4),
    ((576, 899, [177.33333333333334, 101.77777777777777, 10.88888888888889], (170, 100, 20), 18.222222222222225), [663.5, 986.5, 488.5, 811.5], 0),
    ((504, 928, [131.0, 93.88888888888889, 26.666666666666668], (130, 100, 15), 18.777777777777782), [663.5, 986.5, 488.5, 811.5], 0),
    ((1440, 725, [129.11111111111111, 89.55555555555556, 14.11111111111111], (130, 100, 15), 12.222222222222218), [1527.5, 812.5, 1352.5, 637.5], 0),
    ((504, 667, [210.0, 69.0, 41.55555555555556], (210, 70, 60), 18.444444444444443), [591.5, 754.5, 416.5, 579.5], 0),
    ((540, 870, [209.22222222222223, 67.22222222222223, 42.888888888888886], (210, 70, 60), 19.666666666666657), [591.5, 754.5, 416.5, 579.5], 0),
    ((1188, 841, [156.55555555555554, 16.555555555555557, 8.777777777777779], (160, 10, 0), 16.777777777777793), [1275.5, 928.5, 1100.5, 753.5], 0),
    ((1404, 696, [168.0, 108.0, 17.444444444444443], (170, 100, 20), 12.555555555555557), [1491.5, 783.5, 1316.5, 608.5], 0),
    ((432, 841, [225.11111111111111, 123.66666666666667, 4.333333333333333], (230, 130, 0), 14.555555555555546), [1491.5, 783.5, 1316.5, 608.5], 0),
    ((360, 870, [175.0, 112.22222222222223, 18.444444444444443], (170, 100, 20), 18.777777777777786), [1491.5, 783.5, 1316.5, 608.5], 0),
    ((504, 957, [164.55555555555554, 89.55555555555556, 16.666666666666668], (170, 100, 20), 19.222222222222232), [1491.5, 783.5, 1316.5, 608.5], 0),
    ((360, 377, [162.44444444444446, 95.66666666666667, 15.444444444444445], (170, 100, 20), 16.44444444444443), [447.5, 464.5, 272.5, 289.5], 0),
    ((1512, 406, [231.88888888888889, 138.0, 8.88888888888889], (230, 130, 0), 17.777777777777775), [447.5, 464.5, 272.5, 289.5], 0),
    ((396, 464, [223.44444444444446, 120.33333333333333, 3.7777777777777777], (230, 130, 0), 18.999999999999993), [447.5, 464.5, 272.5, 289.5], 0),
    ((1476, 464, [174.0, 108.55555555555556, 27.11111111111111], (170, 100, 20), 19.666666666666668), [447.5, 464.5, 272.5, 289.5], 0),
    ((1548, 319, [175.77777777777777, 102.55555555555556, 12.555555555555555], (170, 100, 20), 15.777777777777773), [1635.5, 406.5, 1460.5, 231.5], 0),
    ((360, 377, [178.44444444444446, 103.0, 12.555555555555555], (170, 100, 20), 18.8888888888889), [1635.5, 406.5, 1460.5, 231.5], 0),
    ((1548, 232, [231.33333333333334, 137.22222222222223, 8.555555555555555], (230, 130, 0), 16.11111111111113), [1635.5, 406.5, 1460.5, 231.5], 0),
    ((468, 522, [175.44444444444446, 99.77777777777777, 15.333333333333334], (170, 100, 20), 10.333333333333352), [555.5, 609.5, 380.5, 434.5], 0),
    ((360, 232, [128.55555555555554, 88.88888888888889, 22.11111111111111], (130, 100, 15), 19.666666666666682), [555.5, 609.5, 380.5, 434.5], 0),
    ((432, 754, [169.44444444444446, 115.22222222222223, 24.0], (170, 100, 20), 19.77777777777777), [519.5, 841.5, 344.5, 666.5], 0),
    ((1548, 609, [232.11111111111111, 138.55555555555554, 9.0], (230, 130, 0), 18.666666666666657), [519.5, 841.5, 344.5, 666.5], 0),
    ((1440, 957, [175.55555555555554, 100.44444444444444, 10.88888888888889], (170, 100, 20), 15.111111111111097), [1527.5, 1044.5, 1352.5, 869.5], 0),
    ((1476, 841, [172.33333333333334, 99.44444444444444, 8.11111111111111], (170, 100, 20), 14.77777777777779), [1563.5, 928.5, 1388.5, 753.5], 0),
    ((1296, 870, [179.44444444444446, 101.77777777777777, 15.777777777777779], (170, 100, 20), 15.44444444444445), [1563.5, 928.5, 1388.5, 753.5], 0),
    ((1440, 725, [158.0, 101.55555555555556, 25.0], (170, 100, 20), 18.555555555555557), [1527.5, 812.5, 1352.5, 637.5], 0),
    ((1188, 957, [144.0, 104.44444444444444, 14.11111111111111], (130, 100, 15), 19.333333333333332), [1527.5, 812.5, 1352.5, 637.5], 0),
    ((648, 667, [231.44444444444446, 137.88888888888889, 8.555555555555555], (230, 130, 0), 16.8888888888889), [735.5, 754.5, 560.5, 579.5], 0),
    ((504, 261, [224.77777777777777, 123.55555555555556, 4.444444444444445], (230, 130, 0), 15.111111111111114), [735.5, 754.5, 560.5, 579.5], 0),
    ((432, 261, [161.55555555555554, 97.0, 15.444444444444445], (170, 100, 20), 16.000000000000014), [735.5, 754.5, 560.5, 579.5], 0),
    ((576, 203, [223.44444444444446, 120.88888888888889, 3.6666666666666665], (230, 130, 0), 18.333333333333325), [663.5, 290.5, 488.5, 115.5], 0),
    ((612, 261, [175.33333333333334, 94.33333333333333, 11.222222222222221], (170, 100, 20), 19.777777777777793), [663.5, 290.5, 488.5, 115.5], 0),
    ((432, 203, [130.88888888888889, 91.0, 24.88888888888889], (130, 100, 15), 19.777777777777775), [663.5, 290.5, 488.5, 115.5], 0),
    ((648, 290, [228.55555555555554, 135.33333333333334, 13.222222222222221], (230, 130, 0), 19.00000000000002), [735.5, 377.5, 560.5, 202.5], 0),
    ((1116, 841, [127.0, 87.33333333333333, 17.444444444444443], (130, 100, 15), 18.111111111111114),[1203.5, 928.5, 1028.5, 753.5],0)
]


In [34]:
actor = train_actor_critic(data,10)

Epoch 1
  Total Loss:      -164333219.7214
  Actor Loss:      -2028822.7585
  Critic Loss:     17.5768
  Advantage Mean:  0.7044
  Prediction Error:776550.5660

Epoch 2
  Total Loss:      55149238.7145
  Actor Loss:      680851.5676
  Critic Loss:     3.2314
  Advantage Mean:  -0.0416
  Prediction Error:768080.0066

Epoch 3
  Total Loss:      33346460.4553
  Actor Loss:      411681.3372
  Critic Loss:     3.3598
  Advantage Mean:  0.0484
  Prediction Error:758052.4786

Epoch 4
  Total Loss:      -59551834.3355
  Actor Loss:      -735211.0114
  Critic Loss:     3.1801
  Advantage Mean:  0.3470
  Prediction Error:742329.6464

Epoch 5
  Total Loss:      45524765.5393
  Actor Loss:      562030.9569
  Critic Loss:     3.1855
  Advantage Mean:  -0.0329
  Prediction Error:743393.2228

Epoch 6
  Total Loss:      -60825286.8100
  Actor Loss:      -750932.7919
  Critic Loss:     3.3251
  Advantage Mean:  0.2840
  Prediction Error:743796.2070

Epoch 7
  Total Loss:      -50523736.4890
  Actor Los

In [35]:
actor = train_actor_critic(data,10)

Epoch 1
  Total Loss:      -23709749.9800
  Actor Loss:      -292726.3324
  Critic Loss:     13.3697
  Advantage Mean:  0.0309
  Prediction Error:740989.1867

Epoch 2
  Total Loss:      7790307.6086
  Actor Loss:      96173.1996
  Critic Loss:     3.4376
  Advantage Mean:  -0.0388
  Prediction Error:740705.7064

Epoch 3
  Total Loss:      8233120.8434
  Actor Loss:      101640.0599
  Critic Loss:     3.4073
  Advantage Mean:  0.0248
  Prediction Error:740581.3281

Epoch 4
  Total Loss:      7606520.3533
  Actor Loss:      93902.9259
  Critic Loss:     4.7328
  Advantage Mean:  0.0697
  Prediction Error:742397.5050

Epoch 5
  Total Loss:      -8170745.8497
  Actor Loss:      -100879.1327
  Critic Loss:     5.7272
  Advantage Mean:  0.1174
  Prediction Error:743508.0042

Epoch 6
  Total Loss:      23993440.7286
  Actor Loss:      296213.0505
  Critic Loss:     2.2671
  Advantage Mean:  -0.0137
  Prediction Error:742428.5596

Epoch 7
  Total Loss:      12532586.3463
  Actor Loss:      154

In [36]:
def predict_cut(actor, fruit_location):
    flat_input = flatten_input(fruit_location)
    input_tensor = torch.tensor(flat_input, dtype=torch.float32)
    return actor(input_tensor).detach().numpy()

In [37]:
def flatten_input(fruit_location):
    flat = []
    for item in fruit_location:
        if isinstance(item, (list, tuple)):
            flat.extend(item)
        else:
            flat.append(item)
    return flat


In [38]:
def predict_cut(actor, fruit_location):
    flat = flatten_input(fruit_location)
    x = torch.tensor(flat, dtype=torch.float32)
    return actor(x).detach().numpy()


In [39]:
fruit_location = (576, 406, [155.0, 95.11111111111111, 20.11111111111111], (170, 100, 20), 19.999999999999996)

In [40]:
predicted_cut = predict_cut(actor, fruit_location)
print(predicted_cut)

[34.438046 28.658175 19.68249  41.270603]


In [41]:
torch.save(actor.state_dict(), "actor_model.pth")