In [None]:
import socket
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense
import cv2
import threading
import time
import struct
import pickle
import os
from functools import reduce
import copy

class A2C_Agent:
    def __init__(self, state_size, action_size):
        global advantages
        self.state_size = state_size
        self.action_size = action_size
        self.value_size = 1

        self.discount_factor = 0.9
        self.actor_lr = 0.001
        self.critic_lr = 0.01

        self.main_actor = self.build_actor()
        self.main_critic = self.build_critic()

    def build_actor(self):
        actor = tf.keras.models.Sequential()
        actor.add(Dense(self.state_size, input_dim=self.state_size, activation='relu',kernel_initializer='he_uniform'))
        actor.add(Dense(self.action_size, activation='softmax',kernel_initializer='he_uniform'))
        return actor

    def build_critic(self):
        critic = tf.keras.models.Sequential()
        critic.add(Dense(self.state_size, input_dim=self.state_size, activation='relu',kernel_initializer='he_uniform'))
        critic.add(Dense(self.value_size, activation='linear',kernel_initializer='he_uniform'))
        return critic
    
class cache_env:
    def __init__(self, VN, V_list, cs, s_len, l_len, K, a=1):
        self.VN = VN
        self.V_list = V_list
        self.cs = cs
        self.K = K
        self.s_len = s_len
        self.l_len = l_len
        self.s_buffer = []
        self.l_buffer = []
        self.s_cnt = np.zeros(VN)
        self.l_cnt = np.zeros(VN)
        self.a = a
        self.P = np.array([1/(i**self.a) for i in range(1, self.VN+1)])
        self.P /= sum(self.P)
        self.state = os.listdir('Cache1')
        self.rq = zipf(self.VN, self.P, 1)
        self.count()

    def step(self, a):
        rq = list(self.rq)
        states = None
        if a == 'pass':
            pass
        elif a == 'append':
            self.state = os.listdir('Cache1')
        else:
            states = [copy.deepcopy(self.state) for _ in range(self.K)]
            states = [index(self.V_list, states[i]) for i in range(self.K)]
            for i in range(self.K):
                if a[i] == 0:
                    continue
                states[i].remove(states[i][a[i]-1])
                states[i].append(self.rq[0])
        self.rq = zipf(self.VN, self.P, 1)
        self.count()
        return states
        
    def count(self):
        if sum(self.s_cnt) == self.s_len:
            self.s_cnt[self.s_buffer[0]] -= 1
            self.s_buffer = self.s_buffer[1:]
        self.s_cnt[self.rq] += 1
        self.s_buffer.append(self.rq[0])
        if sum(self.l_cnt) == self.l_len:
            self.l_cnt[self.l_buffer[0]] -= 1
            self.l_buffer = self.l_buffer[1:]
        self.l_cnt[self.rq] += 1
        self.l_buffer.append(self.rq[0])

def zipf(VN, P, n):
    return np.random.choice(VN, n, False, P)

def index(array, elements):
    index_list = []
    for element in elements:
        index_list = index_list + list(np.where(np.array(array) == element)[0])
    return index_list

def Request(r_v, c_v):
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    ip = '203.255.56.63'
    port = 4007
    print('Connecting to server...')
    sock.connect((ip,port))
    num = sock.recv(16).decode()  # 클라이언트 번호 받기
    do = 1
    if num:
        print('Connection successful. You are Client[{}]'.format(num))
    else:
        print('Connection fail.')
        do = 0

    sock.send(r_v.encode())
    print('Sent r_v({}) to the server!'.format(r_v))
    time.sleep(0.001)
    sock.send(pickle.dumps(c_v))
    print('Sent a caching list{} to the server!'.format(c_v))
    print(' ')
    return sock
    
def indexing_mp4(xor_data, v_data, count):
    n = len(v_data)
    if n == 0:
        return xor_data
    
    datas = []
    datas.append(np.frombuffer(xor_data, dtype='uint8'))
    xor_len = len(xor_data)
    for i in range(n):
        datas.append(np.frombuffer(v_data[i][count:count+xor_len], dtype='uint8'))
    
    def XOR(data1, data2):
        l1 = len(data1)
        l2 = len(data2)
        
        if l1 > l2:
            xor_len = l2
            add_data = data1[l2:l1]
        else:
            xor_len = l1
            add_data = data2[l1:l2]

        data = np.bitwise_xor(data1[:xor_len], data2[:xor_len])
        data = np.append(data, add_data)

        return data
    
    data = reduce(XOR, datas)
    data = data.tobytes()
    return data

def Stream(cache, hit):
    global count
    r_v = cache.V_list[cache.rq[0]]
    c_v = cache.state
    
    stream_path = 'Cache1/{}'.format(r_v)
    if hit:
        print('Local Play ({})'.format(r_v))
    else:
        sock = Request(r_v, c_v)
        if sock.recv(16).decode() == 'MC':
            MCAST_GRP = sock.recv(16).decode()
            MCAST_PORT = int(sock.recv(16).decode())

            client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
            client.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            client.bind(('', MCAST_PORT))

            mreq = struct.pack("4sl", socket.inet_aton(MCAST_GRP), socket.INADDR_ANY)
            client.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)

            xor_list = pickle.loads(client.recv(1024))
            if len(xor_list) > 1:
                print('Receive ({}) by XMC'.format(r_v))
                print('XOR list :', xor_list)
            else:
                print('Receive ({}) by MC'.format(r_v))
            xor_list.remove(r_v)
            n = len(xor_list)

            v_path = []
            for i in xor_list:
                v_path.append('Cache1/{}'.format(i))

            v_data = []
            for i in range(n):
                with open(v_path[i], 'rb') as v:
                    v_data.append(v.read())

            n_h = int(client.recv(16).decode())  # 헤더 분열 갯수
            n_d = int(client.recv(16).decode())  # 데이터 분열 갯수

            header_data = b''
            count = 0

            for i in range(n_h):  # 헤더 받아서 인덱싱
                xor_data = client.recv(65536)
                header_data += indexing_mp4(xor_data, v_data, count)
                count += len(xor_data)

            open(stream_path,'wb').write(header_data)

            def recv():  # 데이터 받아서 인덱싱
                global count
                for i in range(n_d):
                    xor_data = client.recv(65536)
                    data = indexing_mp4(xor_data, v_data, count)
                    if data[0:len(data)] == b'\x00' * len(data):
                        break
                    open(stream_path, 'ab').write(data)
                    count += len(data)

            t = threading.Thread(target = recv)
            t.start()
        else:
            print('Receive ({}) by UC'.format(r_v))
            port = int(sock.recv(16).decode())
            client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            client.bind(('', port))
            n_h = int(client.recv(16).decode())  # 헤더 분열 갯수 
            n_d = int(client.recv(16).decode())  # 데이터 분열 갯수
            
            header_data = b''
            
            for i in range(n_h):
                header_data += client.recv(65536)

            open(stream_path,'wb').write(header_data)

            def recv():
                for i in range(n_d):
                    data = client.recv(65536)
                    open(stream_path, 'ab').write(data)

            t = threading.Thread(target = recv)
            t.start()

    index = 0
    cv2.namedWindow(r_v, cv2.WINDOW_NORMAL)

    while 1:
        video = cv2.VideoCapture(stream_path)

        if video.get(5) == 0:  # 받은 데이터가 없다면 잠시 쉬었다가 다시 실행
            time.sleep(1)
            continue

        f = round(1/video.get(5)*1000) - 3  # 초당 프레임으로 waitkey값 계산

        video.set(1, index)  # 프레임 위치 설정
        while 1:
            ret, frame = video.read()
            if ret:
                cv2.imshow(r_v, frame)
                cv2.waitKey(f)
            else:
                time.sleep(1)
                index = video.get(1)  # 프레임 위치 저장
                break

        if index < video.get(7):  # 총 프레임수 만큼 재생했다면 종료
            time.sleep(1)
            pass
        else:
            break

        video.release()

    cv2.destroyWindow(r_v)
    
    print('----------------------------------------')

In [None]:
N = 20

V_list = ['Beach{}.mp4'.format(i) for i in range(1,11)] + ['Effect{}.mp4'.format(i) for i in range(1,11)]
V_list += ['Light{}.mp4'.format(i) for i in range(1,11)] + ['Stars{}.mp4'.format(i) for i in range(1,11)]
num_of_video = len(V_list)
cache_size = 10
zipf_param = 1.2
s_len = 10
l_len = 100
K = 5
ch_p = 0.001

state_size = 2 * (cache_size + 1)
action_size = cache_size + 1

with tf.Graph().as_default():
    Agent = A2C_Agent(state_size, action_size)
    Agent.main_actor.load_weights('actor_weights')
    Agent.main_critic.load_weights('critic_weights')
    cache = cache_env(num_of_video, V_list, cache_size, s_len, l_len, K, zipf_param)
    for i in range(N):
        if np.random.rand() < ch_p:
            new_P = np.array([1/(i**cache.a) for i in range(1, cache.VN+1)])
            new_P /= sum(new_P)
            np.random.shuffle(new_P)
            cache.P = rho * cache.P + (1-rho) * new_P
            
        if cache.V_list[cache.rq[0]] in cache.state:
            Stream(cache, 1)
            cache.step('pass')
        elif len(cache.state) < cache.cs:
            Stream(cache, 0)
            cache.step('append')
        else:
            Stream(cache, 0)
            rq = list(cache.rq)
            print(cache.state)
            cache_index = index(cache.V_list, cache.state) 
            state = np.hstack((cache.s_cnt[rq + cache_index], cache.l_cnt[rq + cache_index]))

            pred = Agent.main_actor.predict(np.array([state]))[0]
            a_list = np.random.choice(action_size, K, False, p = pred)

            states = cache.step(a_list)
            state_list = [np.hstack((cache.s_cnt[rq + states[i]], cache.l_cnt[rq + states[i]])) for i in range(K)]
            critics = Agent.main_critic.predict(np.vstack(state_list))
            idx = np.where(critics == max(critics))[0][0]

            remove_video = list(set(cache.state) - set(states[idx]))[0]
            os.remove('Cache1/{}'.format(remove_video))
            cache.state = os.listdir('Cache1')