Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

We’re showing branches in this repository, but you can also compare across forks.

base fork: Hexren/TT2P
base: d8924a704d
...
head fork: Hexren/TT2P
compare: cc8f31c85f
  • 2 commits
  • 11 files changed
  • 0 commit comments
  • 1 contributor
Commits on Jun 12, 2012
Muench_s QLearning Agent implementiert. Variablentypen auf List statt ArrayLis…
…t geändert
487bdef
Commits on Jun 13, 2012
Muench_s QLearningAgent neuste Version cc8f31c
BIN  Aufgabe2/RLBlackjack/bin/praktikum/BlackJackEnvironment.class
View
Binary file not shown
BIN  Aufgabe2/RLBlackjack/bin/praktikum/Card.class
View
Binary file not shown
BIN  Aufgabe2/RLBlackjack/bin/praktikum/Experiment.class
View
Binary file not shown
BIN  Aufgabe2/RLBlackjack/bin/praktikum/SarsaAgent.class
View
Binary file not shown
BIN  Aufgabe2/RLBlackjack/bin/praktikum/StateActionPair.class
View
Binary file not shown
BIN  Aufgabe2/RLBlackjack/bin/praktikum/StopThread.class
View
Binary file not shown
BIN  Aufgabe2/RLBlackjack/bin/praktikum/Symbol.class
View
Binary file not shown
BIN  Aufgabe2/RLBlackjack/bin/praktikum/TableDescription.class
View
Binary file not shown
235 Aufgabe2/RLBlackjack/src/praktikum/QLearningAgent.java
View
@@ -2,6 +2,7 @@
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Random;
import org.rlcommunity.rlglue.codec.AgentInterface;
@@ -17,36 +18,23 @@
private boolean showOutput;
private Random randGenerator = new Random();
- private ArrayList<Integer> lastAction;
- private ArrayList<Integer> lastObservation;
+ private List<Integer> s;
+ private List<Integer> a;
- private ArrayList<Integer[]> numActions;
+ private List<Integer[]> numActions;
private double gamma = 0.0;
private double epsilon = 0.1;
private double alpha = 0.1;
- private HashMap<StateActionPair, Double> actionValueFunction;
+ private HashMap<StateActionPair, Double> qFunction;
public QLearningAgent() {
super();
showOutput = false;
// Action-Value-Function initialisieren n-dimensionales Array Q(s,a)
- actionValueFunction = new HashMap<StateActionPair, Double>();
- }
-
- @Override
- public void agent_cleanup() {
- // alles löschen
- lastAction = null;
- lastObservation = null;
- actionValueFunction = null;
- }
-
- @Override
- public void agent_end(double arg0) {
- // TODO Auto-generated method stub
+ qFunction = new HashMap<StateActionPair, Double>();
}
@Override
@@ -79,34 +67,6 @@ public void agent_init(String taskSpecification) {
}
@Override
- public String agent_message(String msg) {
- // Diese Methode reagiert auf Messages die eintreffen
-
- if(msg.equals("team name")) {
- return TEAM_NAME;
- } else if(msg.equals("team member")) {
- return TEAM_MEMBERS;
- } else if(msg.equals("training start")) {
- showOutput = false;
- return "Start training: Output " + showOutput;
- } else if(msg.equals("training end")) {
- showOutput = true;
- return "Training ended: Output " + showOutput;
- } else if(msg.equals("get stats")) {
- StringBuilder sb = new StringBuilder();
- for (StateActionPair saPair : actionValueFunction.keySet()) {
- sb.append(saPair.toString());
- sb.append(" Value: " + actionValueFunction.get(saPair));
- sb.append("\n");
- }
-
- return sb.toString();
- } else {
- return "Message not understood! You sent: " + msg;
- }
- }
-
- @Override
public Action agent_start(Observation observation) {
// Diese Methode startet den Agenten
// Zustandsinformationen dürfen nur durch Integer-Werte beschrieben werden
@@ -121,7 +81,7 @@ public Action agent_start(Observation observation) {
actualObservation.add(i, observation.getInt(i));
}
- ArrayList<Integer> actualAction = nextAction(actualObservation); //neue Aktion auf Basis der aktuellen Beobachtung wählen
+ List<Integer> actualAction = eGreedy(actualObservation); //neue Aktion auf Basis der aktuellen Beobachtung wählen
Action returnAction = new Action(actualAction.size(), 0, 0);
int[] actions = new int[actualAction.size()]; //Arraylist in Integer umwandeln
@@ -130,69 +90,184 @@ public Action agent_start(Observation observation) {
}
returnAction.intArray= actions;
- lastAction = actualAction; //aktuelle Aktion sichern
- saveObservation(actualObservation); //aktuelle Beobachtung sichern
+ a = actualAction;
+ s = actualObservation;
return returnAction;
}
@Override
- public Action agent_step(double reward, Observation observation) {
+ public Action agent_step(double r, Observation obs) {
//Diese Methode führt einen Step im Agenten durch
// Zustandsinformationen dürfen nur durch Integer-Werte beschrieben werden
// die Dimensionen sind dabei frei wählbar
- assert (observation.getNumInts()>=1);
- assert (observation.getNumDoubles()==0);
- assert (observation.getNumChars()==0);
+ assert (obs.getNumInts()>=1);
+ assert (obs.getNumDoubles()==0);
+ assert (obs.getNumChars()==0);
// Zustand für aktuelle Beobachtung erstellen
- ArrayList<Integer> actualObservation = new ArrayList<Integer>();
- for (int i = 0; i < observation.getNumInts(); i++) {
- actualObservation.add(i, observation.getInt(i));
+ List<Integer> s_new = new ArrayList<Integer>();
+ for (int i = 0; i < obs.getNumInts(); i++) {
+ s_new.add(i, obs.getInt(i));
}
- ArrayList<Integer> actualAction = nextAction(actualObservation); //neue Aktion auf Basis der aktuellen Beobachtung wählen
-
- Double Q_sa = actionValueFunction.get(new StateActionPair(lastObservation, lastAction)); // Q(s,a) bestimmen
+ List<Integer> a_new = optimalAction(s_new); //neue Aktion auf Basis der aktuellen Beobachtung wählen
+ Double Q_sa = qFunction.get(new StateActionPair(s, a)); // Q(s,a) bestimmen
Q_sa = (Q_sa == null) ? 0 : Q_sa; // Q_sa könnte null sein wenn es noch nicht berechnet wurde
- Double Q_sprime_aprime = actionValueFunction.get(new StateActionPair(actualObservation, actualAction)); // Q(s',a') bestimmen
- Q_sprime_aprime = (Q_sprime_aprime == null) ? 0 : Q_sa; // Q_sprime_aprime könnte null sein wenn es noch nicht berechnet wurde
+ Double Q_sa_next = qFunction.get(new StateActionPair(s_new, a_new));
+ Q_sa_next = (Q_sa_next == null) ? 0 : Q_sa_next;
// neuen Wert für Q(s,a) berechnen
- double new_Q_sa = Q_sa + alpha * (reward + gamma * Q_sprime_aprime - Q_sa);
+ double new_Q_sa = Q_sa + alpha * (r + gamma * Q_sa_next - Q_sa);
// mit berechneten Wert ActionValueFunction aktualisieren
- actionValueFunction.put(new StateActionPair(lastObservation, lastAction), new_Q_sa);
+ qFunction.put(new StateActionPair(s, a), new_Q_sa);
- Action returnAction = new Action(actualAction.size(), 0, 0);
- int[] actions = new int[actualAction.size()]; //Arraylist in Integer umwandeln
+ // Aktuelle Beobachtung speichern
+ s = s_new;
+
+ // Aktion für nächsten Schritt auswählen
+ List<Integer> a_next = eGreedy(s);
+
+ // Gewähte Aktion merken
+ a = a_next;
+
+ Action returnAction = new Action(a_next.size(), 0, 0);
+ int[] actions = new int[a_next.size()];
for (int i = 0; i < actions.length; i++) {
- actions[i] = actualAction.get(i);
+ actions[i] = a_next.get(i);
}
-
returnAction.intArray= actions;
+
+ return returnAction;
+ }
+
+ @Override
+ public void agent_end(double reward) {
+ // Die Episode wurde beendet, vom letzten Reward muss noch einmal gerlernt werden
+
+ Double Q_sa = qFunction.get(new StateActionPair(s, a)); // Q(s,a) bestimmen
+ Q_sa = (Q_sa == null) ? 0 : Q_sa; // Q_sa könnte null sein wenn es noch nicht berechnet wurde
+ Double new_Q_sa = Q_sa + alpha * (reward - Q_sa);
+
+ // mit berechneten Wert ActionValueFunction aktualisieren
+ qFunction.put(new StateActionPair(s, a), new_Q_sa);
- lastAction = actualAction; //aktuelle Aktion sichern
- saveObservation(actualObservation); //aktuelle Beobachtung sichern
+ s = null;
+ a = null;
+ }
+
+ @Override
+ public void agent_cleanup() {
+ // alles löschen
+ s = null;
+ a = null;
+ }
+
+ private List<Integer> optimalAction(List<Integer> actualObservation) {
+ ArrayList<Integer> chosenActions = new ArrayList<Integer>();
- return returnAction;
+ ArrayList<Integer> minValues = new ArrayList<Integer>();
+ ArrayList<Integer> maxValues = new ArrayList<Integer>();
+ for (int i = 0; i < numActions.size(); i++) {
+ minValues.add(i, numActions.get(i)[0]); //für jede Dimension kleinste Aktion auswählen
+ maxValues.add(i, numActions.get(i)[1]); //für jede Dimension größte Aktion auswählen
+ }
+
+ chosenActions = minValues;
+ ArrayList< ArrayList<Integer>> permutations = getPermutations(minValues, maxValues);
+
+ for (ArrayList<Integer> actions : permutations) {
+ Double val1 = qFunction.get(new StateActionPair(actualObservation, actions));
+ Double val2 = qFunction.get(new StateActionPair(actualObservation, chosenActions));
+ if (val1 != null && val2 != null && val1 > val2 ){
+ chosenActions = actions;
+ }
+ }
+
+ return chosenActions;
}
- private ArrayList<Integer> nextAction(ArrayList<Integer> actualObservation) {
- // TODO Auto-generated method stub
- return null;
+ private List<Integer> eGreedy(List<Integer> actualObservation){
+ //aus jeder Dimension eine zufällige Aktionen bestimmen und diese in der ArrayList speichern
+ ArrayList<Integer> chosenActions = new ArrayList<Integer>();
+
+ int minValue = 0;
+ int maxValue = 0;
+ int rndAction = 0;
+
+ // Zufallszahl generieren und prüfen ob zufällig eine andere
+ // Aktion ausgewählt wird als die optimale
+ if(randGenerator.nextDouble() <= epsilon){
+ for (int i = 0; i < numActions.size(); i++) {
+ minValue = numActions.get(i)[0]; //untere Grenze der Dimension bestimmen
+ maxValue = numActions.get(i)[1]; //obere Grenze der Dimension bestimmen
+ rndAction = minValue + (int)(Math.random() * ((maxValue - minValue) + 1)); //zufällige Aktion zwischen Min und Max der Dimension
+ chosenActions.add(i, rndAction);
+ }
+ return chosenActions;
+ } else {
+ return optimalAction(actualObservation);
+ }
+ }
+
+ /**
+ * Diese Methode dient zum Herstellen alle möglichen Permutationen von Aktionen die geprüft werden müssen, wenn es mehrdimensionale
+ * Aktionen gibt. Liegt zum Beispiel eine 2-dimensionale Aktion vor mit den Grenzen [0,1] und [1,3] müssen alle möglichen Kombinationen
+ * ([0,1][0,2][0,3][1,1][1,2][1,3]) geprüft werden.
+ * @param minValues - kleinste Werte je Dimension [0,1]
+ * @param maxValues - größte Werte je Dimension [1,3]
+ * @return Permutationen
+ */
+ private ArrayList<ArrayList<Integer>> getPermutations(ArrayList<Integer> minValues, ArrayList<Integer> maxValues){
+ ArrayList<ArrayList<Integer>> permutations = new ArrayList<ArrayList<Integer>>();
+ ArrayList<Integer> tempPerm = null;
+
+ for (int dim = 0; dim < minValues.size(); dim++) { //über alle Dimensionen laufen
+ for (int i = minValues.get(dim); i <= maxValues.get(dim); i++) { //über den Wertebereich der aktuellen Dimension laufen
+ tempPerm = new ArrayList<Integer>();
+ for (int j = 0; j < minValues.size(); j++) {
+ tempPerm.add(minValues.get(j));
+ }
+ tempPerm.set(dim,i);
+ permutations.add(tempPerm);
+ }
+ }
+
+ return permutations;
+ }
+
+ @Override
+ public String agent_message(String msg) {
+ // Diese Methode reagiert auf Messages die eintreffen
+
+ if(msg.equals("team name")) {
+ return TEAM_NAME;
+ } else if(msg.equals("team member")) {
+ return TEAM_MEMBERS;
+ } else if(msg.equals("training start")) {
+ showOutput = false;
+ return "QLearningAgent Start training: Output " + showOutput;
+ } else if(msg.equals("training end")) {
+ showOutput = true;
+ return "QLearningAgent Training ended: Output " + showOutput;
+ } else if(msg.equals("get stats")) {
+ StringBuilder sb = new StringBuilder();
+ for (StateActionPair saPair : qFunction.keySet()) {
+ sb.append(saPair.toString());
+ sb.append(" Value: " + qFunction.get(saPair));
+ sb.append("\n");
+ }
+
+ return sb.toString();
+ } else {
+ return "Message not understood! You sent: " + msg;
+ }
}
public static void main(String[] args) {
AgentLoader theLoader = new AgentLoader(new QLearningAgent());
theLoader.run();
}
-
- private void saveObservation(ArrayList<Integer> observation){
- lastObservation = new ArrayList<Integer>();
- for (Integer integer : observation) {
- lastObservation.add(integer);
- }
- }
}
4 Aufgabe2/RLBlackjack/src/praktikum/SarsaAgent.java
View
@@ -95,10 +95,10 @@ public String agent_message(String msg) {
return TEAM_MEMBERS;
} else if(msg.equals("training start")){
showOutput = false;
- return "Start training: Output " + showOutput;
+ return "SarsaAgent Start training: Output " + showOutput;
} else if(msg.equals("training end")){
showOutput = true;
- return "Training ended: Output " + showOutput;
+ return "SarsaAgent Training ended: Output " + showOutput;
} else if(msg.equals("get stats")){
StringBuilder sb = new StringBuilder();
for (StateActionPair saPair : actionValueFunction.keySet()) {
12 Aufgabe2/RLBlackjack/src/praktikum/StateActionPair.java
View
@@ -1,21 +1,21 @@
package praktikum;
-import java.util.ArrayList;
+import java.util.List;
public class StateActionPair {
- private ArrayList<Integer> state;
- private ArrayList<Integer> action;
+ private List<Integer> state;
+ private List<Integer> action;
- public StateActionPair(ArrayList<Integer> state, ArrayList<Integer> action){
+ public StateActionPair(List<Integer> state, List<Integer> action){
this.state = state;
this.action = action;
}
- public ArrayList<Integer> getState(){
+ public List<Integer> getState(){
return state;
}
- public ArrayList<Integer> getAction(){
+ public List<Integer> getAction(){
return action;
}

No commit comments for this range

Something went wrong with that request. Please try again.