# Student

In [7]:
import numpy as np

In [13]:
class World:
    def __init__(self, subj_num, seed=None):
        np.random.seed(seed)
        self.subjs = ["%d+%d=%d" % (i+1,j+1,i+j+2) for i in range(subj_num) for j in range(subj_num)]
    
    def __repr__(self):
        return "%s(%r)" % (self.__class__, self.__dict__)
    
    def random_sim_subjs(self, group_num):
        def split(subjs, n):
            g = [[] for _ in range(n)]
            for s in subjs:
                i = np.random.randint(n)
                g[i].append(s)
            return g

        groups = split(self.subjs, group_num)
        sims = {}
        for g in groups:
            for s in g:
                sims[s] = [sim for sim in g if sim!=s]
        return sims

    def evaluate(self, student):
        v = 0
        for s in self.subjs:
            v += student.skills.get(s, 0)
        return v / len(self.subjs)

In [9]:
class Student:
    def __init__(self, sim_subjs=[], learn_rate=.1, forget_rate=.1, sim_rate=.5):
        self.skills = {}
        self.sim_subjs = sim_subjs
        self.lr = learn_rate
        self.fr = forget_rate
        self.sr = sim_rate

    def clear(self):
        for s in self.skills.keys():
            self.skills[s] = 0

    def learn(self, subj):
        # update subj's skill
        v = self.skills.get(subj, 0)
        self.skills[subj] = min(v+self.lr, 1) 
        # update similar skills
        for sim in self.sim_subjs.get(subj, []):
            v = self.skills.get(sim, 0)
            self.skills[sim] = min(v+self.lr*self.sr, 1) 

    def forget(self):
        for s,v in self.skills.items():
            self.skills[s] = v*(1-self.fr)
            
    def __repr__(self):
        return "%s(%r)" % (self.__class__, self.__dict__)

In [10]:
class BaseTeacher:
    def __init__(self, subjs):
        self.subjs = subjs

    def teach(self, student):
        if type(student) is list:
            for s in student:
                self.teach(s)
        self._do_teach(student)
    
    def _do_teach(self, student):
        raise NotImplementedError

        
class Teacher(BaseTeacher):
    def _do_teach(self, student):
        subj = np.random.choice(self.subjs)
        student.learn(subj)
        #print(s)

In [18]:
w = World(subj_num=3, seed=1969)
t = Teacher(subjs=w.subjs)
s = Student(sim_subjs=w.random_sim_subjs(group_num=6), learn_rate=1.)

print(w.evaluate(s))
print(s.skills, '\n')

for _ in range(10):
    t.teach(s)
    
print(w.evaluate(s))
print(s.skills, '\n')
    

0.0
{} 

0.6666666666666666
{'2+1=3': 1, '2+3=5': 1, '1+1=2': 1, '2+2=4': 1, '3+1=4': 1.0, '1+2=3': 0.5, '3+3=6': 0.5} 

